1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.driver.executor.batch;
19
20 import lombok.Getter;
21 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
22 import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
23 import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroup;
24 import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext;
25 import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupReportContext;
26 import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit;
27 import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
28 import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler;
29 import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
30 import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
31 import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
32 import org.apache.shardingsphere.infra.metadata.user.Grantee;
33 import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute;
34 import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
35 import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
36
37 import java.sql.SQLException;
38 import java.sql.Statement;
39 import java.util.ArrayList;
40 import java.util.Collection;
41 import java.util.Collections;
42 import java.util.LinkedList;
43 import java.util.List;
44 import java.util.Map;
45 import java.util.Map.Entry;
46 import java.util.Optional;
47
48
49
50
51 public final class BatchPreparedStatementExecutor {
52
53 private final MetaDataContexts metaDataContexts;
54
55 private final JDBCExecutor jdbcExecutor;
56
57 private ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext;
58
59 @Getter
60 private final Collection<BatchExecutionUnit> batchExecutionUnits;
61
62 private int batchCount;
63
64 private final String databaseName;
65
66 public BatchPreparedStatementExecutor(final MetaDataContexts metaDataContexts, final JDBCExecutor jdbcExecutor, final String databaseName, final String processId) {
67 this.databaseName = databaseName;
68 this.metaDataContexts = metaDataContexts;
69 this.jdbcExecutor = jdbcExecutor;
70 executionGroupContext = new ExecutionGroupContext<>(new LinkedList<>(), new ExecutionGroupReportContext(processId, databaseName, new Grantee("", "")));
71 batchExecutionUnits = new LinkedList<>();
72 }
73
74
75
76
77
78
79 public void init(final ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext) {
80 this.executionGroupContext = executionGroupContext;
81 }
82
83
84
85
86
87
88 public void addBatchForExecutionUnits(final Collection<ExecutionUnit> executionUnits) {
89 Collection<BatchExecutionUnit> batchExecutionUnits = createBatchExecutionUnits(executionUnits);
90 handleOldBatchExecutionUnits(batchExecutionUnits);
91 handleNewBatchExecutionUnits(batchExecutionUnits);
92 batchCount++;
93 }
94
95 private Collection<BatchExecutionUnit> createBatchExecutionUnits(final Collection<ExecutionUnit> executionUnits) {
96 List<BatchExecutionUnit> result = new ArrayList<>(executionUnits.size());
97 for (ExecutionUnit each : executionUnits) {
98 BatchExecutionUnit batchExecutionUnit = new BatchExecutionUnit(each);
99 result.add(batchExecutionUnit);
100 }
101 return result;
102 }
103
104 private void handleOldBatchExecutionUnits(final Collection<BatchExecutionUnit> newExecutionUnits) {
105 newExecutionUnits.forEach(this::reviseBatchExecutionUnits);
106 }
107
108 private void reviseBatchExecutionUnits(final BatchExecutionUnit batchExecutionUnit) {
109 for (BatchExecutionUnit each : batchExecutionUnits) {
110 if (each.equals(batchExecutionUnit)) {
111 reviseBatchExecutionUnit(each, batchExecutionUnit);
112 }
113 }
114 }
115
116 private void reviseBatchExecutionUnit(final BatchExecutionUnit oldBatchExecutionUnit, final BatchExecutionUnit newBatchExecutionUnit) {
117 oldBatchExecutionUnit.getExecutionUnit().getSqlUnit().getParameters().addAll(newBatchExecutionUnit.getExecutionUnit().getSqlUnit().getParameters());
118 oldBatchExecutionUnit.mapAddBatchCount(batchCount);
119 }
120
121 private void handleNewBatchExecutionUnits(final Collection<BatchExecutionUnit> newExecutionUnits) {
122 newExecutionUnits.removeAll(batchExecutionUnits);
123 for (BatchExecutionUnit each : newExecutionUnits) {
124 each.mapAddBatchCount(batchCount);
125 }
126 batchExecutionUnits.addAll(newExecutionUnits);
127 }
128
129
130
131
132
133
134
135
136 public int[] executeBatch(final SQLStatementContext sqlStatementContext) throws SQLException {
137 boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
138 JDBCExecutorCallback<int[]> callback = new JDBCExecutorCallback<int[]>(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(),
139 metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData(), sqlStatementContext.getSqlStatement(), isExceptionThrown) {
140
141 @Override
142 protected int[] executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
143 return statement.executeBatch();
144 }
145
146 @SuppressWarnings("OptionalContainsCollection")
147 @Override
148 protected Optional<int[]> getSaneResult(final SQLStatement sqlStatement, final SQLException ex) {
149 return Optional.empty();
150 }
151 };
152 List<int[]> results = jdbcExecutor.execute(executionGroupContext, callback);
153 if (results.isEmpty()) {
154 return new int[0];
155 }
156 return isNeedAccumulate(sqlStatementContext) ? accumulate(results) : results.get(0);
157 }
158
159 private boolean isNeedAccumulate(final SQLStatementContext sqlStatementContext) {
160 for (DataNodeRuleAttribute each : metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) {
161 if (each.isNeedAccumulate(sqlStatementContext.getTablesContext().getTableNames())) {
162 return true;
163 }
164 }
165 return false;
166 }
167
168 private int[] accumulate(final List<int[]> executeResults) {
169 int[] result = new int[batchCount];
170 int count = 0;
171 for (ExecutionGroup<JDBCExecutionUnit> each : executionGroupContext.getInputGroups()) {
172 for (JDBCExecutionUnit eachUnit : each.getInputs()) {
173 accumulate(executeResults.get(count), result, eachUnit);
174 count++;
175 }
176 }
177 return result;
178 }
179
180 private void accumulate(final int[] executeResult, final int[] addBatchCounts, final JDBCExecutionUnit executionUnit) {
181 for (Entry<Integer, Integer> entry : getJDBCAndActualAddBatchCallTimesMap(executionUnit).entrySet()) {
182 int value = null == executeResult ? 0 : executeResult[entry.getValue()];
183 addBatchCounts[entry.getKey()] += value;
184 }
185 }
186
187 private Map<Integer, Integer> getJDBCAndActualAddBatchCallTimesMap(final JDBCExecutionUnit executionUnit) {
188 for (BatchExecutionUnit each : batchExecutionUnits) {
189 if (isSameDataSourceAndSQL(each, executionUnit)) {
190 return each.getJdbcAndActualAddBatchCallTimesMap();
191 }
192 }
193 return Collections.emptyMap();
194 }
195
196 private boolean isSameDataSourceAndSQL(final BatchExecutionUnit batchExecutionUnit, final JDBCExecutionUnit jdbcExecutionUnit) {
197 return batchExecutionUnit.getExecutionUnit().getDataSourceName().equals(jdbcExecutionUnit.getExecutionUnit().getDataSourceName())
198 && batchExecutionUnit.getExecutionUnit().getSqlUnit().getSql().equals(jdbcExecutionUnit.getExecutionUnit().getSqlUnit().getSql());
199 }
200
201
202
203
204
205
206 public List<Statement> getStatements() {
207 List<Statement> result = new LinkedList<>();
208 for (ExecutionGroup<JDBCExecutionUnit> eachGroup : executionGroupContext.getInputGroups()) {
209 for (JDBCExecutionUnit eachUnit : eachGroup.getInputs()) {
210 Statement storageResource = eachUnit.getStorageResource();
211 result.add(storageResource);
212 }
213 }
214 return result;
215 }
216
217
218
219
220
221
222
223 public List<List<Object>> getParameterSet(final Statement statement) {
224 for (ExecutionGroup<JDBCExecutionUnit> each : executionGroupContext.getInputGroups()) {
225 Optional<JDBCExecutionUnit> result = findJDBCExecutionUnit(statement, each);
226 if (result.isPresent()) {
227 return getParameterSets(result.get());
228 }
229 }
230 return Collections.emptyList();
231 }
232
233 private Optional<JDBCExecutionUnit> findJDBCExecutionUnit(final Statement statement, final ExecutionGroup<JDBCExecutionUnit> executionGroup) {
234 for (JDBCExecutionUnit each : executionGroup.getInputs()) {
235 if (each.getStorageResource().equals(statement)) {
236 return Optional.of(each);
237 }
238 }
239 return Optional.empty();
240 }
241
242 private List<List<Object>> getParameterSets(final JDBCExecutionUnit executionUnit) {
243 for (BatchExecutionUnit each : batchExecutionUnits) {
244 if (isSameDataSourceAndSQL(each, executionUnit)) {
245 return each.getParameterSets();
246 }
247 }
248 throw new IllegalStateException("Can not get value from parameter sets.");
249 }
250
251
252
253
254 public void clear() {
255 getStatements().clear();
256 executionGroupContext.getInputGroups().clear();
257 batchCount = 0;
258 batchExecutionUnits.clear();
259 }
260 }