1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.proxy.frontend.mysql.command.query.text.query;
19
20 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
21 import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
22 import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
23 import org.apache.shardingsphere.infra.connection.kernel.KernelProcessor;
24 import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
25 import org.apache.shardingsphere.infra.executor.audit.SQLAuditEngine;
26 import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroup;
27 import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext;
28 import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupReportContext;
29 import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
30 import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit;
31 import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
32 import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler;
33 import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
34 import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
35 import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
36 import org.apache.shardingsphere.infra.executor.sql.execute.result.update.UpdateResult;
37 import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DriverExecutionPrepareEngine;
38 import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.JDBCDriverType;
39 import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.StatementOption;
40 import org.apache.shardingsphere.infra.hint.HintValueContext;
41 import org.apache.shardingsphere.infra.hint.SQLHintUtils;
42 import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
43 import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
44 import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
45 import org.apache.shardingsphere.infra.parser.SQLParserEngine;
46 import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
47 import org.apache.shardingsphere.infra.session.query.QueryContext;
48 import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
49 import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
50 import org.apache.shardingsphere.parser.rule.SQLParserRule;
51 import org.apache.shardingsphere.proxy.backend.connector.jdbc.statement.JDBCBackendStatement;
52 import org.apache.shardingsphere.proxy.backend.context.BackendExecutorContext;
53 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
54 import org.apache.shardingsphere.proxy.backend.handler.ProxyBackendHandler;
55 import org.apache.shardingsphere.proxy.backend.response.header.ResponseHeader;
56 import org.apache.shardingsphere.proxy.backend.response.header.update.UpdateResponseHeader;
57 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
58 import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
59 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
60
61 import java.sql.Connection;
62 import java.sql.SQLException;
63 import java.sql.Statement;
64 import java.util.Arrays;
65 import java.util.Collection;
66 import java.util.Collections;
67 import java.util.HashMap;
68 import java.util.LinkedList;
69 import java.util.List;
70 import java.util.Map;
71 import java.util.Optional;
72 import java.util.regex.Pattern;
73
74
75
76
77 public final class MySQLMultiStatementsHandler implements ProxyBackendHandler {
78
79 private static final Pattern MULTI_UPDATE_STATEMENTS = Pattern.compile(";(?=\\s*update)", Pattern.CASE_INSENSITIVE);
80
81 private static final Pattern MULTI_DELETE_STATEMENTS = Pattern.compile(";(?=\\s*delete)", Pattern.CASE_INSENSITIVE);
82
83 private final KernelProcessor kernelProcessor = new KernelProcessor();
84
85 private final JDBCExecutor jdbcExecutor;
86
87 private final ConnectionSession connectionSession;
88
89 private final SQLStatement sqlStatementSample;
90
91 private final MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
92
93 private final Collection<QueryContext> multiSQLQueryContexts = new LinkedList<>();
94
95 public MySQLMultiStatementsHandler(final ConnectionSession connectionSession, final SQLStatement sqlStatementSample, final String sql) {
96 jdbcExecutor = new JDBCExecutor(BackendExecutorContext.getInstance().getExecutorEngine(), connectionSession.getConnectionContext());
97 connectionSession.getDatabaseConnectionManager().handleAutoCommit();
98 this.connectionSession = connectionSession;
99 this.sqlStatementSample = sqlStatementSample;
100 Pattern pattern = sqlStatementSample instanceof UpdateStatement ? MULTI_UPDATE_STATEMENTS : MULTI_DELETE_STATEMENTS;
101 SQLParserEngine sqlParserEngine = getSQLParserEngine();
102 for (String each : extractMultiStatements(pattern, sql)) {
103 SQLStatement eachSQLStatement = sqlParserEngine.parse(each, false);
104 multiSQLQueryContexts.add(createQueryContext(each, eachSQLStatement));
105 }
106 }
107
108 private SQLParserEngine getSQLParserEngine() {
109 MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
110 SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
111 return sqlParserRule.getSQLParserEngine(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
112 }
113
114 private List<String> extractMultiStatements(final Pattern pattern, final String sql) {
115
116 return Arrays.asList(pattern.split(sql));
117 }
118
119 private QueryContext createQueryContext(final String sql, final SQLStatement sqlStatement) {
120 HintValueContext hintValueContext = SQLHintUtils.extractHint(sql);
121 SQLStatementContext sqlStatementContext = new SQLBindEngine(metaDataContexts.getMetaData(), connectionSession.getDatabaseName(), hintValueContext).bind(sqlStatement, Collections.emptyList());
122 return new QueryContext(sqlStatementContext, sql, Collections.emptyList(), hintValueContext);
123 }
124
125 @Override
126 public ResponseHeader execute() throws SQLException {
127 Collection<ShardingSphereRule> rules = metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName()).getRuleMetaData().getRules();
128 DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = new DriverExecutionPrepareEngine<>(JDBCDriverType.STATEMENT, metaDataContexts.getMetaData().getProps()
129 .<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY), connectionSession.getDatabaseConnectionManager(),
130 (JDBCBackendStatement) connectionSession.getStatementManager(), new StatementOption(false), rules,
131 metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName()).getResourceMetaData().getStorageUnits());
132 return executeMultiStatements(prepareEngine);
133 }
134
135 private UpdateResponseHeader executeMultiStatements(final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine) throws SQLException {
136 Collection<ExecutionContext> executionContexts = createExecutionContexts();
137 Map<String, List<ExecutionUnit>> dataSourcesToExecutionUnits = buildDataSourcesToExecutionUnits(executionContexts);
138 ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext =
139 prepareEngine.prepare(executionContexts.iterator().next().getRouteContext(), samplingExecutionUnit(dataSourcesToExecutionUnits),
140 new ExecutionGroupReportContext(connectionSession.getProcessId(), connectionSession.getDatabaseName(), connectionSession.getGrantee()));
141 for (ExecutionGroup<JDBCExecutionUnit> each : executionGroupContext.getInputGroups()) {
142 for (JDBCExecutionUnit unit : each.getInputs()) {
143 prepareBatchedStatement(unit, dataSourcesToExecutionUnits);
144 }
145 }
146 return executeBatchedStatements(executionGroupContext);
147 }
148
149 private Collection<ExecutionContext> createExecutionContexts() {
150 Collection<ExecutionContext> result = new LinkedList<>();
151 for (QueryContext each : multiSQLQueryContexts) {
152 result.add(createExecutionContext(each));
153 }
154 return result;
155 }
156
157 private Map<String, List<ExecutionUnit>> buildDataSourcesToExecutionUnits(final Collection<ExecutionContext> executionContexts) {
158 Map<String, List<ExecutionUnit>> result = new HashMap<>();
159 for (ExecutionContext each : executionContexts) {
160 for (ExecutionUnit executionUnit : each.getExecutionUnits()) {
161 result.computeIfAbsent(executionUnit.getDataSourceName(), unused -> new LinkedList<>()).add(executionUnit);
162 }
163 }
164 return result;
165 }
166
167 private ExecutionContext createExecutionContext(final QueryContext queryContext) {
168 RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
169 ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName());
170 SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
171 return kernelProcessor.generateExecutionContext(queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connectionSession.getConnectionContext());
172 }
173
174 private Collection<ExecutionUnit> samplingExecutionUnit(final Map<String, List<ExecutionUnit>> dataSourcesToExecutionUnits) {
175 Collection<ExecutionUnit> result = new LinkedList<>();
176 for (List<ExecutionUnit> each : dataSourcesToExecutionUnits.values()) {
177 result.add(each.get(0));
178 }
179 return result;
180 }
181
182 private void prepareBatchedStatement(final JDBCExecutionUnit executionUnit, final Map<String, List<ExecutionUnit>> dataSourcesToExecutionUnits) throws SQLException {
183 Statement statement = executionUnit.getStorageResource();
184 for (ExecutionUnit each : dataSourcesToExecutionUnits.get(executionUnit.getExecutionUnit().getDataSourceName())) {
185 statement.addBatch(each.getSqlUnit().getSql());
186 }
187 }
188
189 private UpdateResponseHeader executeBatchedStatements(final ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext) throws SQLException {
190 boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
191 ResourceMetaData resourceMetaData = metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName()).getResourceMetaData();
192 JDBCExecutorCallback<int[]> callback = new BatchedJDBCExecutorCallback(resourceMetaData, sqlStatementSample, isExceptionThrown);
193 List<int[]> executeResults = jdbcExecutor.execute(executionGroupContext, callback);
194 int updated = 0;
195 for (int[] eachResult : executeResults) {
196 for (int each : eachResult) {
197 updated += each;
198 }
199 }
200
201 return new UpdateResponseHeader(sqlStatementSample, Collections.singletonList(new UpdateResult(updated, 0L)));
202 }
203
204 private static final class BatchedJDBCExecutorCallback extends JDBCExecutorCallback<int[]> {
205
206 private BatchedJDBCExecutorCallback(final ResourceMetaData resourceMetaData, final SQLStatement sqlStatement, final boolean isExceptionThrown) {
207 super(TypedSPILoader.getService(DatabaseType.class, "MySQL"), resourceMetaData, sqlStatement, isExceptionThrown);
208 }
209
210 @Override
211 protected int[] executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
212 try {
213 return statement.executeBatch();
214 } finally {
215 statement.close();
216 }
217 }
218
219 @SuppressWarnings("OptionalContainsCollection")
220 @Override
221 protected Optional<int[]> getSaneResult(final SQLStatement sqlStatement, final SQLException ex) {
222 return Optional.empty();
223 }
224 }
225 }