View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended;
19  
20  import lombok.RequiredArgsConstructor;
21  import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
22  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacket;
23  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLNoDataPacket;
24  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
25  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLBindCompletePacket;
26  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLComBindPacket;
27  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.describe.PostgreSQLComDescribePacket;
28  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.execute.PostgreSQLComExecutePacket;
29  import org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLCommandCompletePacket;
30  import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
31  import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
32  import org.apache.shardingsphere.proxy.frontend.postgresql.command.query.PostgreSQLCommand;
33  
34  import java.sql.SQLException;
35  import java.util.ArrayList;
36  import java.util.Collection;
37  import java.util.LinkedList;
38  import java.util.List;
39  
40  /**
41   * Aggregated batched statements command executor for PostgreSQL.
42   */
43  @RequiredArgsConstructor
44  public final class PostgreSQLAggregatedBatchedStatementsCommandExecutor implements CommandExecutor {
45      
46      private final ConnectionSession connectionSession;
47      
48      private final List<PostgreSQLCommandPacket> packets;
49      
50      @Override
51      public Collection<DatabasePacket> execute() throws SQLException {
52          PostgreSQLServerPreparedStatement preparedStatement = getPreparedStatement();
53          PostgreSQLBatchedStatementsExecutor executor = new PostgreSQLBatchedStatementsExecutor(connectionSession, preparedStatement, readParameterSets(preparedStatement.getParameterTypes()));
54          Collection<DatabasePacket> result = new ArrayList<>(packets.size());
55          int totalInserted = executor.executeBatch();
56          int executePacketCount = executePacketCount();
57          for (PostgreSQLCommandPacket each : packets) {
58              if (each instanceof PostgreSQLComBindPacket) {
59                  result.add(PostgreSQLBindCompletePacket.getInstance());
60              }
61              if (each instanceof PostgreSQLComDescribePacket) {
62                  result.add(preparedStatement.describeRows().orElseGet(PostgreSQLNoDataPacket::getInstance));
63              }
64              if (each instanceof PostgreSQLComExecutePacket) {
65                  String tag = PostgreSQLCommand.valueOf(preparedStatement.getSqlStatementContext().getSqlStatement().getClass()).orElse(PostgreSQLCommand.INSERT).getTag();
66                  result.add(new PostgreSQLCommandCompletePacket(tag, 0 == executePacketCount ? 1 : totalInserted / executePacketCount));
67              }
68          }
69          return result;
70      }
71      
72      private PostgreSQLServerPreparedStatement getPreparedStatement() {
73          PostgreSQLComBindPacket bindPacket = (PostgreSQLComBindPacket) packets.get(0);
74          return connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(bindPacket.getStatementId());
75      }
76      
77      private List<List<Object>> readParameterSets(final List<PostgreSQLColumnType> parameterTypes) {
78          List<List<Object>> result = new LinkedList<>();
79          for (PostgreSQLCommandPacket each : packets) {
80              if (each instanceof PostgreSQLComBindPacket) {
81                  result.add(((PostgreSQLComBindPacket) each).readParameters(parameterTypes));
82              }
83          }
84          return result;
85      }
86      
87      private int executePacketCount() {
88          int result = 0;
89          for (PostgreSQLCommandPacket each : packets) {
90              if (each instanceof PostgreSQLComExecutePacket) {
91                  result++;
92              }
93          }
94          return result;
95      }
96  }