1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended;
19
20 import lombok.Getter;
21 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacket;
22 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLComBindPacket;
23 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.execute.PostgreSQLComExecutePacket;
24 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLComParsePacket;
25 import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierTag;
26 import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
27
28 import java.util.List;
29
30 @Getter
31 public final class PostgreSQLAggregatedCommandPacket extends PostgreSQLCommandPacket {
32
33 private final List<PostgreSQLCommandPacket> packets;
34
35 private final boolean containsBatchedStatements;
36
37 private final int batchPacketBeginIndex;
38
39 private final int batchPacketEndIndex;
40
41 public PostgreSQLAggregatedCommandPacket(final List<PostgreSQLCommandPacket> packets) {
42 this.packets = packets;
43 String firstStatementId = null;
44 String firstPortal = null;
45 int parsePacketCount = 0;
46 int bindPacketCountForFirstStatement = 0;
47 int executePacketCountForFirstStatement = 0;
48 int batchPacketBeginIndex = -1;
49 int batchPacketEndIndex = -1;
50 int index = 0;
51 for (PostgreSQLCommandPacket each : packets) {
52 if (each instanceof PostgreSQLComParsePacket) {
53 if (++parsePacketCount > 1) {
54 break;
55 }
56 if (null == firstStatementId) {
57 firstStatementId = ((PostgreSQLComParsePacket) each).getStatementId();
58 } else if (!firstStatementId.equals(((PostgreSQLComParsePacket) each).getStatementId())) {
59 break;
60 }
61 }
62 if (each instanceof PostgreSQLComBindPacket) {
63 if (-1 == batchPacketBeginIndex) {
64 batchPacketBeginIndex = index;
65 }
66 if (null == firstStatementId) {
67 firstStatementId = ((PostgreSQLComBindPacket) each).getStatementId();
68 } else if (!firstStatementId.equals(((PostgreSQLComBindPacket) each).getStatementId())) {
69 break;
70 }
71 if (null == firstPortal) {
72 firstPortal = ((PostgreSQLComBindPacket) each).getPortal();
73 } else if (!firstPortal.equals(((PostgreSQLComBindPacket) each).getPortal())) {
74 break;
75 }
76 bindPacketCountForFirstStatement++;
77 }
78 if (each instanceof PostgreSQLComExecutePacket) {
79 if (index > batchPacketEndIndex) {
80 batchPacketEndIndex = index;
81 }
82 if (null == firstPortal) {
83 firstPortal = ((PostgreSQLComExecutePacket) each).getPortal();
84 } else if (!firstPortal.equals(((PostgreSQLComExecutePacket) each).getPortal())) {
85 break;
86 }
87 executePacketCountForFirstStatement++;
88 }
89 index++;
90 }
91 this.batchPacketBeginIndex = batchPacketBeginIndex;
92 this.batchPacketEndIndex = batchPacketEndIndex;
93 containsBatchedStatements = bindPacketCountForFirstStatement == executePacketCountForFirstStatement && bindPacketCountForFirstStatement >= 3;
94 }
95
96 @Override
97 protected void write(final PostgreSQLPacketPayload payload) {
98 }
99
100 @Override
101 public PostgreSQLIdentifierTag getIdentifier() {
102 return () -> '?';
103 }
104 }