View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.driver.executor.batch;
19  
20  import lombok.Getter;
21  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
22  import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
23  import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroup;
24  import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext;
25  import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupReportContext;
26  import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit;
27  import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
28  import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler;
29  import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
30  import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
31  import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
32  import org.apache.shardingsphere.infra.metadata.user.Grantee;
33  import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute;
34  import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
35  import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
36  
37  import java.sql.SQLException;
38  import java.sql.Statement;
39  import java.util.ArrayList;
40  import java.util.Collection;
41  import java.util.Collections;
42  import java.util.LinkedList;
43  import java.util.List;
44  import java.util.Map;
45  import java.util.Map.Entry;
46  import java.util.Optional;
47  
48  /**
49   * Prepared statement executor to process add batch.
50   */
51  public final class BatchPreparedStatementExecutor {
52      
53      private final MetaDataContexts metaDataContexts;
54      
55      private final JDBCExecutor jdbcExecutor;
56      
57      private ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext;
58      
59      @Getter
60      private final Collection<BatchExecutionUnit> batchExecutionUnits;
61      
62      private int batchCount;
63      
64      private final String databaseName;
65      
66      public BatchPreparedStatementExecutor(final MetaDataContexts metaDataContexts, final JDBCExecutor jdbcExecutor, final String databaseName, final String processId) {
67          this.databaseName = databaseName;
68          this.metaDataContexts = metaDataContexts;
69          this.jdbcExecutor = jdbcExecutor;
70          executionGroupContext = new ExecutionGroupContext<>(new LinkedList<>(), new ExecutionGroupReportContext(processId, databaseName, new Grantee("", "")));
71          batchExecutionUnits = new LinkedList<>();
72      }
73      
74      /**
75       * Initialize executor.
76       *
77       * @param executionGroupContext execution group context
78       */
79      public void init(final ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext) {
80          this.executionGroupContext = executionGroupContext;
81      }
82      
83      /**
84       * Add batch for execution units.
85       *
86       * @param executionUnits execution units
87       */
88      public void addBatchForExecutionUnits(final Collection<ExecutionUnit> executionUnits) {
89          Collection<BatchExecutionUnit> batchExecutionUnits = createBatchExecutionUnits(executionUnits);
90          handleOldBatchExecutionUnits(batchExecutionUnits);
91          handleNewBatchExecutionUnits(batchExecutionUnits);
92          batchCount++;
93      }
94      
95      private Collection<BatchExecutionUnit> createBatchExecutionUnits(final Collection<ExecutionUnit> executionUnits) {
96          List<BatchExecutionUnit> result = new ArrayList<>(executionUnits.size());
97          for (ExecutionUnit each : executionUnits) {
98              BatchExecutionUnit batchExecutionUnit = new BatchExecutionUnit(each);
99              result.add(batchExecutionUnit);
100         }
101         return result;
102     }
103     
104     private void handleOldBatchExecutionUnits(final Collection<BatchExecutionUnit> newExecutionUnits) {
105         newExecutionUnits.forEach(this::reviseBatchExecutionUnits);
106     }
107     
108     private void reviseBatchExecutionUnits(final BatchExecutionUnit batchExecutionUnit) {
109         for (BatchExecutionUnit each : batchExecutionUnits) {
110             if (each.equals(batchExecutionUnit)) {
111                 reviseBatchExecutionUnit(each, batchExecutionUnit);
112             }
113         }
114     }
115     
116     private void reviseBatchExecutionUnit(final BatchExecutionUnit oldBatchExecutionUnit, final BatchExecutionUnit newBatchExecutionUnit) {
117         oldBatchExecutionUnit.getExecutionUnit().getSqlUnit().getParameters().addAll(newBatchExecutionUnit.getExecutionUnit().getSqlUnit().getParameters());
118         oldBatchExecutionUnit.mapAddBatchCount(batchCount);
119     }
120     
121     private void handleNewBatchExecutionUnits(final Collection<BatchExecutionUnit> newExecutionUnits) {
122         newExecutionUnits.removeAll(batchExecutionUnits);
123         for (BatchExecutionUnit each : newExecutionUnits) {
124             each.mapAddBatchCount(batchCount);
125         }
126         batchExecutionUnits.addAll(newExecutionUnits);
127     }
128     
129     /**
130      * Execute batch.
131      *
132      * @param sqlStatementContext SQL statement context
133      * @return execute results
134      * @throws SQLException SQL exception
135      */
136     public int[] executeBatch(final SQLStatementContext sqlStatementContext) throws SQLException {
137         boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
138         JDBCExecutorCallback<int[]> callback = new JDBCExecutorCallback<int[]>(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(),
139                 metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData(), sqlStatementContext.getSqlStatement(), isExceptionThrown) {
140             
141             @Override
142             protected int[] executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
143                 return statement.executeBatch();
144             }
145             
146             @SuppressWarnings("OptionalContainsCollection")
147             @Override
148             protected Optional<int[]> getSaneResult(final SQLStatement sqlStatement, final SQLException ex) {
149                 return Optional.empty();
150             }
151         };
152         List<int[]> results = jdbcExecutor.execute(executionGroupContext, callback);
153         if (results.isEmpty()) {
154             return new int[0];
155         }
156         return isNeedAccumulate(sqlStatementContext) ? accumulate(results) : results.get(0);
157     }
158     
159     private boolean isNeedAccumulate(final SQLStatementContext sqlStatementContext) {
160         for (DataNodeRuleAttribute each : metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) {
161             if (each.isNeedAccumulate(sqlStatementContext.getTablesContext().getTableNames())) {
162                 return true;
163             }
164         }
165         return false;
166     }
167     
168     private int[] accumulate(final List<int[]> executeResults) {
169         int[] result = new int[batchCount];
170         int count = 0;
171         for (ExecutionGroup<JDBCExecutionUnit> each : executionGroupContext.getInputGroups()) {
172             for (JDBCExecutionUnit eachUnit : each.getInputs()) {
173                 accumulate(executeResults.get(count), result, eachUnit);
174                 count++;
175             }
176         }
177         return result;
178     }
179     
180     private void accumulate(final int[] executeResult, final int[] addBatchCounts, final JDBCExecutionUnit executionUnit) {
181         for (Entry<Integer, Integer> entry : getJDBCAndActualAddBatchCallTimesMap(executionUnit).entrySet()) {
182             int value = null == executeResult ? 0 : executeResult[entry.getValue()];
183             addBatchCounts[entry.getKey()] += value;
184         }
185     }
186     
187     private Map<Integer, Integer> getJDBCAndActualAddBatchCallTimesMap(final JDBCExecutionUnit executionUnit) {
188         for (BatchExecutionUnit each : batchExecutionUnits) {
189             if (isSameDataSourceAndSQL(each, executionUnit)) {
190                 return each.getJdbcAndActualAddBatchCallTimesMap();
191             }
192         }
193         return Collections.emptyMap();
194     }
195     
196     private boolean isSameDataSourceAndSQL(final BatchExecutionUnit batchExecutionUnit, final JDBCExecutionUnit jdbcExecutionUnit) {
197         return batchExecutionUnit.getExecutionUnit().getDataSourceName().equals(jdbcExecutionUnit.getExecutionUnit().getDataSourceName())
198                 && batchExecutionUnit.getExecutionUnit().getSqlUnit().getSql().equals(jdbcExecutionUnit.getExecutionUnit().getSqlUnit().getSql());
199     }
200     
201     /**
202      * Get statements.
203      *
204      * @return statements
205      */
206     public List<Statement> getStatements() {
207         List<Statement> result = new LinkedList<>();
208         for (ExecutionGroup<JDBCExecutionUnit> eachGroup : executionGroupContext.getInputGroups()) {
209             for (JDBCExecutionUnit eachUnit : eachGroup.getInputs()) {
210                 Statement storageResource = eachUnit.getStorageResource();
211                 result.add(storageResource);
212             }
213         }
214         return result;
215     }
216     
217     /**
218      * Get parameter sets.
219      *
220      * @param statement statement
221      * @return parameter sets
222      */
223     public List<List<Object>> getParameterSet(final Statement statement) {
224         for (ExecutionGroup<JDBCExecutionUnit> each : executionGroupContext.getInputGroups()) {
225             Optional<JDBCExecutionUnit> result = findJDBCExecutionUnit(statement, each);
226             if (result.isPresent()) {
227                 return getParameterSets(result.get());
228             }
229         }
230         return Collections.emptyList();
231     }
232     
233     private Optional<JDBCExecutionUnit> findJDBCExecutionUnit(final Statement statement, final ExecutionGroup<JDBCExecutionUnit> executionGroup) {
234         for (JDBCExecutionUnit each : executionGroup.getInputs()) {
235             if (each.getStorageResource().equals(statement)) {
236                 return Optional.of(each);
237             }
238         }
239         return Optional.empty();
240     }
241     
242     private List<List<Object>> getParameterSets(final JDBCExecutionUnit executionUnit) {
243         for (BatchExecutionUnit each : batchExecutionUnits) {
244             if (isSameDataSourceAndSQL(each, executionUnit)) {
245                 return each.getParameterSets();
246             }
247         }
248         throw new IllegalStateException("Can not get value from parameter sets.");
249     }
250     
251     /**
252      * Clear.
253      */
254     public void clear() {
255         getStatements().clear();
256         executionGroupContext.getInputGroups().clear();
257         batchCount = 0;
258         batchExecutionUnits.clear();
259     }
260 }