View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.db.protocol.opengauss.codec;
19  
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.CompositeByteBuf;
22  import io.netty.channel.ChannelHandlerContext;
23  import org.apache.shardingsphere.db.protocol.codec.DatabasePacketCodecEngine;
24  import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
25  import org.apache.shardingsphere.db.protocol.opengauss.packet.command.OpenGaussCommandPacketType;
26  import org.apache.shardingsphere.db.protocol.opengauss.packet.command.generic.OpenGaussErrorResponsePacket;
27  import org.apache.shardingsphere.db.protocol.packet.command.CommandPacketType;
28  import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
29  import org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLMessageSeverityLevel;
30  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacketType;
31  import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
32  import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
33  import org.apache.shardingsphere.infra.exception.postgresql.vendor.PostgreSQLVendorError;
34  
35  import java.nio.charset.Charset;
36  import java.util.LinkedList;
37  import java.util.List;
38  
39  /**
40   * Database packet codec for openGauss.
41   */
42  public final class OpenGaussPacketCodecEngine implements DatabasePacketCodecEngine {
43      
44      private static final int SSL_REQUEST_PAYLOAD_LENGTH = 8;
45      
46      private static final int SSL_REQUEST_CODE = (1234 << 16) + 5679;
47      
48      private static final int MESSAGE_TYPE_LENGTH = 1;
49      
50      private static final int PAYLOAD_LENGTH = 4;
51      
52      private boolean startupPhase = true;
53      
54      private final List<ByteBuf> pendingMessages = new LinkedList<>();
55      
56      @Override
57      public boolean isValidHeader(final int readableBytes) {
58          return readableBytes >= (startupPhase ? 0 : MESSAGE_TYPE_LENGTH) + PAYLOAD_LENGTH;
59      }
60      
61      @Override
62      public void decode(final ChannelHandlerContext context, final ByteBuf in, final List<Object> out) {
63          while (isValidHeader(in.readableBytes())) {
64              if (startupPhase) {
65                  handleStartupPhase(in, out);
66                  return;
67              }
68              int payloadLength = in.getInt(in.readerIndex() + 1);
69              if (in.readableBytes() < MESSAGE_TYPE_LENGTH + payloadLength) {
70                  return;
71              }
72              byte type = in.getByte(in.readerIndex());
73              CommandPacketType commandPacketType = OpenGaussCommandPacketType.valueOf(type);
74              if (requireAggregation(commandPacketType)) {
75                  pendingMessages.add(in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
76              } else if (pendingMessages.isEmpty()) {
77                  out.add(in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
78              } else {
79                  handlePendingMessages(context, in, out, payloadLength);
80              }
81          }
82      }
83      
84      private void handleStartupPhase(final ByteBuf in, final List<Object> out) {
85          int readerIndex = in.readerIndex();
86          if (in.readableBytes() == SSL_REQUEST_PAYLOAD_LENGTH && SSL_REQUEST_PAYLOAD_LENGTH == in.getInt(readerIndex) && SSL_REQUEST_CODE == in.getInt(readerIndex + 4)) {
87              out.add(in.readRetainedSlice(SSL_REQUEST_PAYLOAD_LENGTH));
88              return;
89          }
90          if (in.readableBytes() == in.getInt(readerIndex)) {
91              out.add(in.readRetainedSlice(in.readableBytes()));
92              startupPhase = false;
93          }
94      }
95      
96      private boolean requireAggregation(final CommandPacketType commandPacketType) {
97          return OpenGaussCommandPacketType.isExtendedProtocolPacketType(commandPacketType)
98                  && PostgreSQLCommandPacketType.SYNC_COMMAND != commandPacketType && PostgreSQLCommandPacketType.FLUSH_COMMAND != commandPacketType;
99      }
100     
101     private void handlePendingMessages(final ChannelHandlerContext context, final ByteBuf in, final List<Object> out, final int payloadLength) {
102         CompositeByteBuf result = context.alloc().compositeBuffer(pendingMessages.size() + 1);
103         result.addComponents(true, pendingMessages).addComponent(true, in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
104         out.add(result);
105         pendingMessages.clear();
106     }
107     
108     @Override
109     public void encode(final ChannelHandlerContext context, final DatabasePacket message, final ByteBuf out) {
110         boolean isIdentifierPacket = message instanceof PostgreSQLIdentifierPacket;
111         if (isIdentifierPacket) {
112             prepareMessageHeader(out, ((PostgreSQLIdentifierPacket) message).getIdentifier().getValue());
113         }
114         PostgreSQLPacketPayload payload = new PostgreSQLPacketPayload(out, context.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).get());
115         try {
116             message.write(payload);
117             // CHECKSTYLE:OFF
118         } catch (final RuntimeException ex) {
119             // CHECKSTYLE:ON
120             payload.getByteBuf().resetWriterIndex();
121             // TODO consider what severity to use
122             OpenGaussErrorResponsePacket errorResponsePacket = new OpenGaussErrorResponsePacket(
123                     PostgreSQLMessageSeverityLevel.ERROR, PostgreSQLVendorError.SYSTEM_ERROR.getSqlState().getValue(), ex.getMessage());
124             isIdentifierPacket = true;
125             prepareMessageHeader(out, errorResponsePacket.getIdentifier().getValue());
126             errorResponsePacket.write(payload);
127         } finally {
128             if (isIdentifierPacket) {
129                 updateMessageLength(out);
130             }
131         }
132     }
133     
134     private void prepareMessageHeader(final ByteBuf out, final char type) {
135         out.writeByte(type);
136         out.writeInt(0);
137     }
138     
139     private void updateMessageLength(final ByteBuf out) {
140         out.setInt(1, out.readableBytes() - MESSAGE_TYPE_LENGTH);
141     }
142     
143     @Override
144     public PostgreSQLPacketPayload createPacketPayload(final ByteBuf message, final Charset charset) {
145         return new PostgreSQLPacketPayload(message, charset);
146     }
147 }