View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.db.protocol.postgresql.codec;
19  
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.CompositeByteBuf;
22  import io.netty.channel.ChannelHandlerContext;
23  import org.apache.shardingsphere.db.protocol.codec.DatabasePacketCodecEngine;
24  import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
25  import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
26  import org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLMessageSeverityLevel;
27  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacketType;
28  import org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLErrorResponsePacket;
29  import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
30  import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
31  import org.apache.shardingsphere.infra.exception.postgresql.vendor.PostgreSQLVendorError;
32  
33  import java.nio.charset.Charset;
34  import java.util.LinkedList;
35  import java.util.List;
36  
37  /**
38   * Database packet codec for PostgreSQL.
39   */
40  public final class PostgreSQLPacketCodecEngine implements DatabasePacketCodecEngine {
41      
42      private static final int SSL_REQUEST_PAYLOAD_LENGTH = 8;
43      
44      private static final int SSL_REQUEST_CODE = (1234 << 16) + 5679;
45      
46      private static final int MESSAGE_TYPE_LENGTH = 1;
47      
48      private static final int PAYLOAD_LENGTH = 4;
49      
50      private boolean startupPhase = true;
51      
52      private final List<ByteBuf> pendingMessages = new LinkedList<>();
53      
54      @Override
55      public boolean isValidHeader(final int readableBytes) {
56          return readableBytes >= (startupPhase ? 0 : MESSAGE_TYPE_LENGTH) + PAYLOAD_LENGTH;
57      }
58      
59      @Override
60      public void decode(final ChannelHandlerContext context, final ByteBuf in, final List<Object> out) {
61          while (isValidHeader(in.readableBytes())) {
62              if (startupPhase) {
63                  handleStartupPhase(in, out);
64                  return;
65              }
66              int payloadLength = in.getInt(in.readerIndex() + 1);
67              if (in.readableBytes() < MESSAGE_TYPE_LENGTH + payloadLength) {
68                  return;
69              }
70              byte type = in.getByte(in.readerIndex());
71              PostgreSQLCommandPacketType commandPacketType = PostgreSQLCommandPacketType.valueOf(type);
72              if (requireAggregation(commandPacketType)) {
73                  pendingMessages.add(in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
74              } else if (pendingMessages.isEmpty()) {
75                  out.add(in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
76              } else {
77                  handlePendingMessages(context, in, out, payloadLength);
78              }
79          }
80      }
81      
82      private void handleStartupPhase(final ByteBuf in, final List<Object> out) {
83          int readerIndex = in.readerIndex();
84          if (in.readableBytes() == SSL_REQUEST_PAYLOAD_LENGTH && SSL_REQUEST_PAYLOAD_LENGTH == in.getInt(readerIndex) && SSL_REQUEST_CODE == in.getInt(readerIndex + 4)) {
85              out.add(in.readRetainedSlice(SSL_REQUEST_PAYLOAD_LENGTH));
86              return;
87          }
88          if (in.readableBytes() == in.getInt(readerIndex)) {
89              out.add(in.readRetainedSlice(in.readableBytes()));
90              startupPhase = false;
91          }
92      }
93      
94      private boolean requireAggregation(final PostgreSQLCommandPacketType commandPacketType) {
95          return PostgreSQLCommandPacketType.isExtendedProtocolPacketType(commandPacketType)
96                  && PostgreSQLCommandPacketType.SYNC_COMMAND != commandPacketType && PostgreSQLCommandPacketType.FLUSH_COMMAND != commandPacketType;
97      }
98      
99      private void handlePendingMessages(final ChannelHandlerContext context, final ByteBuf in, final List<Object> out, final int payloadLength) {
100         CompositeByteBuf result = context.alloc().compositeBuffer(pendingMessages.size() + 1);
101         result.addComponents(true, pendingMessages).addComponent(true, in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
102         out.add(result);
103         pendingMessages.clear();
104     }
105     
106     @Override
107     public void encode(final ChannelHandlerContext context, final DatabasePacket message, final ByteBuf out) {
108         boolean isIdentifierPacket = message instanceof PostgreSQLIdentifierPacket;
109         if (isIdentifierPacket) {
110             prepareMessageHeader(out, ((PostgreSQLIdentifierPacket) message).getIdentifier().getValue());
111         }
112         PostgreSQLPacketPayload payload = new PostgreSQLPacketPayload(out, context.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).get());
113         try {
114             message.write(payload);
115             // CHECKSTYLE:OFF
116         } catch (final RuntimeException ex) {
117             // CHECKSTYLE:ON
118             payload.getByteBuf().resetWriterIndex();
119             // TODO consider what severity to use
120             PostgreSQLErrorResponsePacket errorResponsePacket = PostgreSQLErrorResponsePacket.newBuilder(
121                     PostgreSQLMessageSeverityLevel.ERROR, PostgreSQLVendorError.SYSTEM_ERROR, ex.getMessage()).build();
122             isIdentifierPacket = true;
123             prepareMessageHeader(out, errorResponsePacket.getIdentifier().getValue());
124             errorResponsePacket.write(payload);
125         } finally {
126             if (isIdentifierPacket) {
127                 updateMessageLength(out);
128             }
129         }
130     }
131     
132     private void prepareMessageHeader(final ByteBuf out, final char type) {
133         out.writeByte(type);
134         out.writeInt(0);
135     }
136     
137     private void updateMessageLength(final ByteBuf out) {
138         out.setInt(1, out.readableBytes() - MESSAGE_TYPE_LENGTH);
139     }
140     
141     @Override
142     public PostgreSQLPacketPayload createPacketPayload(final ByteBuf message, final Charset charset) {
143         return new PostgreSQLPacketPayload(message, charset);
144     }
145 }