1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.db.protocol.opengauss.codec;
19
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.CompositeByteBuf;
22 import io.netty.channel.ChannelHandlerContext;
23 import org.apache.shardingsphere.db.protocol.codec.DatabasePacketCodecEngine;
24 import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
25 import org.apache.shardingsphere.db.protocol.opengauss.packet.command.OpenGaussCommandPacketType;
26 import org.apache.shardingsphere.db.protocol.opengauss.packet.command.generic.OpenGaussErrorResponsePacket;
27 import org.apache.shardingsphere.db.protocol.packet.command.CommandPacketType;
28 import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
29 import org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLMessageSeverityLevel;
30 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacketType;
31 import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
32 import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
33 import org.apache.shardingsphere.infra.exception.postgresql.vendor.PostgreSQLVendorError;
34
35 import java.nio.charset.Charset;
36 import java.util.LinkedList;
37 import java.util.List;
38
39
40
41
42 public final class OpenGaussPacketCodecEngine implements DatabasePacketCodecEngine {
43
44 private static final int SSL_REQUEST_PAYLOAD_LENGTH = 8;
45
46 private static final int SSL_REQUEST_CODE = (1234 << 16) + 5679;
47
48 private static final int MESSAGE_TYPE_LENGTH = 1;
49
50 private static final int PAYLOAD_LENGTH = 4;
51
52 private boolean startupPhase = true;
53
54 private final List<ByteBuf> pendingMessages = new LinkedList<>();
55
56 @Override
57 public boolean isValidHeader(final int readableBytes) {
58 return readableBytes >= (startupPhase ? 0 : MESSAGE_TYPE_LENGTH) + PAYLOAD_LENGTH;
59 }
60
61 @Override
62 public void decode(final ChannelHandlerContext context, final ByteBuf in, final List<Object> out) {
63 while (isValidHeader(in.readableBytes())) {
64 if (startupPhase) {
65 handleStartupPhase(in, out);
66 return;
67 }
68 int payloadLength = in.getInt(in.readerIndex() + 1);
69 if (in.readableBytes() < MESSAGE_TYPE_LENGTH + payloadLength) {
70 return;
71 }
72 byte type = in.getByte(in.readerIndex());
73 CommandPacketType commandPacketType = OpenGaussCommandPacketType.valueOf(type);
74 if (requireAggregation(commandPacketType)) {
75 pendingMessages.add(in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
76 } else if (pendingMessages.isEmpty()) {
77 out.add(in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
78 } else {
79 handlePendingMessages(context, in, out, payloadLength);
80 }
81 }
82 }
83
84 private void handleStartupPhase(final ByteBuf in, final List<Object> out) {
85 int readerIndex = in.readerIndex();
86 if (in.readableBytes() == SSL_REQUEST_PAYLOAD_LENGTH && SSL_REQUEST_PAYLOAD_LENGTH == in.getInt(readerIndex) && SSL_REQUEST_CODE == in.getInt(readerIndex + 4)) {
87 out.add(in.readRetainedSlice(SSL_REQUEST_PAYLOAD_LENGTH));
88 return;
89 }
90 if (in.readableBytes() == in.getInt(readerIndex)) {
91 out.add(in.readRetainedSlice(in.readableBytes()));
92 startupPhase = false;
93 }
94 }
95
96 private boolean requireAggregation(final CommandPacketType commandPacketType) {
97 return OpenGaussCommandPacketType.isExtendedProtocolPacketType(commandPacketType)
98 && PostgreSQLCommandPacketType.SYNC_COMMAND != commandPacketType && PostgreSQLCommandPacketType.FLUSH_COMMAND != commandPacketType;
99 }
100
101 private void handlePendingMessages(final ChannelHandlerContext context, final ByteBuf in, final List<Object> out, final int payloadLength) {
102 CompositeByteBuf result = context.alloc().compositeBuffer(pendingMessages.size() + 1);
103 result.addComponents(true, pendingMessages).addComponent(true, in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
104 out.add(result);
105 pendingMessages.clear();
106 }
107
108 @Override
109 public void encode(final ChannelHandlerContext context, final DatabasePacket message, final ByteBuf out) {
110 boolean isIdentifierPacket = message instanceof PostgreSQLIdentifierPacket;
111 if (isIdentifierPacket) {
112 prepareMessageHeader(out, ((PostgreSQLIdentifierPacket) message).getIdentifier().getValue());
113 }
114 PostgreSQLPacketPayload payload = new PostgreSQLPacketPayload(out, context.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).get());
115 try {
116 message.write(payload);
117
118 } catch (final RuntimeException ex) {
119
120 payload.getByteBuf().resetWriterIndex();
121
122 OpenGaussErrorResponsePacket errorResponsePacket = new OpenGaussErrorResponsePacket(
123 PostgreSQLMessageSeverityLevel.ERROR, PostgreSQLVendorError.SYSTEM_ERROR.getSqlState().getValue(), ex.getMessage());
124 isIdentifierPacket = true;
125 prepareMessageHeader(out, errorResponsePacket.getIdentifier().getValue());
126 errorResponsePacket.write(payload);
127 } finally {
128 if (isIdentifierPacket) {
129 updateMessageLength(out);
130 }
131 }
132 }
133
134 private void prepareMessageHeader(final ByteBuf out, final char type) {
135 out.writeByte(type);
136 out.writeInt(0);
137 }
138
139 private void updateMessageLength(final ByteBuf out) {
140 out.setInt(1, out.readableBytes() - MESSAGE_TYPE_LENGTH);
141 }
142
143 @Override
144 public PostgreSQLPacketPayload createPacketPayload(final ByteBuf message, final Charset charset) {
145 return new PostgreSQLPacketPayload(message, charset);
146 }
147 }