1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.db.protocol.postgresql.codec;
19
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.CompositeByteBuf;
22 import io.netty.channel.ChannelHandlerContext;
23 import org.apache.shardingsphere.db.protocol.codec.DatabasePacketCodecEngine;
24 import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
25 import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
26 import org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLMessageSeverityLevel;
27 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacketType;
28 import org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLErrorResponsePacket;
29 import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
30 import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
31 import org.apache.shardingsphere.infra.exception.postgresql.vendor.PostgreSQLVendorError;
32
33 import java.nio.charset.Charset;
34 import java.util.LinkedList;
35 import java.util.List;
36
37
38
39
40 public final class PostgreSQLPacketCodecEngine implements DatabasePacketCodecEngine {
41
42 private static final int SSL_REQUEST_PAYLOAD_LENGTH = 8;
43
44 private static final int SSL_REQUEST_CODE = (1234 << 16) + 5679;
45
46 private static final int MESSAGE_TYPE_LENGTH = 1;
47
48 private static final int PAYLOAD_LENGTH = 4;
49
50 private boolean startupPhase = true;
51
52 private final List<ByteBuf> pendingMessages = new LinkedList<>();
53
54 @Override
55 public boolean isValidHeader(final int readableBytes) {
56 return readableBytes >= (startupPhase ? 0 : MESSAGE_TYPE_LENGTH) + PAYLOAD_LENGTH;
57 }
58
59 @Override
60 public void decode(final ChannelHandlerContext context, final ByteBuf in, final List<Object> out) {
61 while (isValidHeader(in.readableBytes())) {
62 if (startupPhase) {
63 handleStartupPhase(in, out);
64 return;
65 }
66 int payloadLength = in.getInt(in.readerIndex() + 1);
67 if (in.readableBytes() < MESSAGE_TYPE_LENGTH + payloadLength) {
68 return;
69 }
70 byte type = in.getByte(in.readerIndex());
71 PostgreSQLCommandPacketType commandPacketType = PostgreSQLCommandPacketType.valueOf(type);
72 if (requireAggregation(commandPacketType)) {
73 pendingMessages.add(in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
74 } else if (pendingMessages.isEmpty()) {
75 out.add(in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
76 } else {
77 handlePendingMessages(context, in, out, payloadLength);
78 }
79 }
80 }
81
82 private void handleStartupPhase(final ByteBuf in, final List<Object> out) {
83 int readerIndex = in.readerIndex();
84 if (in.readableBytes() == SSL_REQUEST_PAYLOAD_LENGTH && SSL_REQUEST_PAYLOAD_LENGTH == in.getInt(readerIndex) && SSL_REQUEST_CODE == in.getInt(readerIndex + 4)) {
85 out.add(in.readRetainedSlice(SSL_REQUEST_PAYLOAD_LENGTH));
86 return;
87 }
88 if (in.readableBytes() == in.getInt(readerIndex)) {
89 out.add(in.readRetainedSlice(in.readableBytes()));
90 startupPhase = false;
91 }
92 }
93
94 private boolean requireAggregation(final PostgreSQLCommandPacketType commandPacketType) {
95 return PostgreSQLCommandPacketType.isExtendedProtocolPacketType(commandPacketType)
96 && PostgreSQLCommandPacketType.SYNC_COMMAND != commandPacketType && PostgreSQLCommandPacketType.FLUSH_COMMAND != commandPacketType;
97 }
98
99 private void handlePendingMessages(final ChannelHandlerContext context, final ByteBuf in, final List<Object> out, final int payloadLength) {
100 CompositeByteBuf result = context.alloc().compositeBuffer(pendingMessages.size() + 1);
101 result.addComponents(true, pendingMessages).addComponent(true, in.readRetainedSlice(MESSAGE_TYPE_LENGTH + payloadLength));
102 out.add(result);
103 pendingMessages.clear();
104 }
105
106 @Override
107 public void encode(final ChannelHandlerContext context, final DatabasePacket message, final ByteBuf out) {
108 boolean isIdentifierPacket = message instanceof PostgreSQLIdentifierPacket;
109 if (isIdentifierPacket) {
110 prepareMessageHeader(out, ((PostgreSQLIdentifierPacket) message).getIdentifier().getValue());
111 }
112 PostgreSQLPacketPayload payload = new PostgreSQLPacketPayload(out, context.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).get());
113 try {
114 message.write(payload);
115
116 } catch (final RuntimeException ex) {
117
118 payload.getByteBuf().resetWriterIndex();
119
120 PostgreSQLErrorResponsePacket errorResponsePacket = PostgreSQLErrorResponsePacket.newBuilder(
121 PostgreSQLMessageSeverityLevel.ERROR, PostgreSQLVendorError.SYSTEM_ERROR, ex.getMessage()).build();
122 isIdentifierPacket = true;
123 prepareMessageHeader(out, errorResponsePacket.getIdentifier().getValue());
124 errorResponsePacket.write(payload);
125 } finally {
126 if (isIdentifierPacket) {
127 updateMessageLength(out);
128 }
129 }
130 }
131
132 private void prepareMessageHeader(final ByteBuf out, final char type) {
133 out.writeByte(type);
134 out.writeInt(0);
135 }
136
137 private void updateMessageLength(final ByteBuf out) {
138 out.setInt(1, out.readableBytes() - MESSAGE_TYPE_LENGTH);
139 }
140
141 @Override
142 public PostgreSQLPacketPayload createPacketPayload(final ByteBuf message, final Charset charset) {
143 return new PostgreSQLPacketPayload(message, charset);
144 }
145 }