1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended;
19
20 import lombok.RequiredArgsConstructor;
21 import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
22 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacket;
23 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLNoDataPacket;
24 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
25 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLBindCompletePacket;
26 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLComBindPacket;
27 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.describe.PostgreSQLComDescribePacket;
28 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.execute.PostgreSQLComExecutePacket;
29 import org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLCommandCompletePacket;
30 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
31 import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
32 import org.apache.shardingsphere.proxy.frontend.postgresql.command.query.PostgreSQLCommand;
33
34 import java.sql.SQLException;
35 import java.util.ArrayList;
36 import java.util.Collection;
37 import java.util.LinkedList;
38 import java.util.List;
39
40
41
42
43 @RequiredArgsConstructor
44 public final class PostgreSQLAggregatedBatchedStatementsCommandExecutor implements CommandExecutor {
45
46 private final ConnectionSession connectionSession;
47
48 private final List<PostgreSQLCommandPacket> packets;
49
50 @Override
51 public Collection<DatabasePacket> execute() throws SQLException {
52 PostgreSQLServerPreparedStatement preparedStatement = getPreparedStatement();
53 PostgreSQLBatchedStatementsExecutor executor = new PostgreSQLBatchedStatementsExecutor(connectionSession, preparedStatement, readParameterSets(preparedStatement.getParameterTypes()));
54 Collection<DatabasePacket> result = new ArrayList<>(packets.size());
55 int totalInserted = executor.executeBatch();
56 int executePacketCount = executePacketCount();
57 for (PostgreSQLCommandPacket each : packets) {
58 if (each instanceof PostgreSQLComBindPacket) {
59 result.add(PostgreSQLBindCompletePacket.getInstance());
60 }
61 if (each instanceof PostgreSQLComDescribePacket) {
62 result.add(preparedStatement.describeRows().orElseGet(PostgreSQLNoDataPacket::getInstance));
63 }
64 if (each instanceof PostgreSQLComExecutePacket) {
65 String tag = PostgreSQLCommand.valueOf(preparedStatement.getSqlStatementContext().getSqlStatement().getClass()).orElse(PostgreSQLCommand.INSERT).getTag();
66 result.add(new PostgreSQLCommandCompletePacket(tag, 0 == executePacketCount ? 1 : totalInserted / executePacketCount));
67 }
68 }
69 return result;
70 }
71
72 private PostgreSQLServerPreparedStatement getPreparedStatement() {
73 PostgreSQLComBindPacket bindPacket = (PostgreSQLComBindPacket) packets.get(0);
74 return connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(bindPacket.getStatementId());
75 }
76
77 private List<List<Object>> readParameterSets(final List<PostgreSQLColumnType> parameterTypes) {
78 List<List<Object>> result = new LinkedList<>();
79 for (PostgreSQLCommandPacket each : packets) {
80 if (each instanceof PostgreSQLComBindPacket) {
81 result.add(((PostgreSQLComBindPacket) each).readParameters(parameterTypes));
82 }
83 }
84 return result;
85 }
86
87 private int executePacketCount() {
88 int result = 0;
89 for (PostgreSQLCommandPacket each : packets) {
90 if (each instanceof PostgreSQLComExecutePacket) {
91 result++;
92 }
93 }
94 return result;
95 }
96 }