1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.driver.jdbc.core.connection;
19
20 import com.google.common.base.Preconditions;
21 import com.google.common.collect.LinkedHashMultimap;
22 import com.google.common.collect.Multimap;
23 import com.google.common.collect.Sets;
24 import lombok.Getter;
25 import org.apache.shardingsphere.driver.jdbc.adapter.executor.ForceExecuteTemplate;
26 import org.apache.shardingsphere.driver.jdbc.adapter.invocation.MethodInvocationRecorder;
27 import org.apache.shardingsphere.driver.jdbc.core.savepoint.ShardingSphereSavepoint;
28 import org.apache.shardingsphere.infra.exception.kernel.connection.OverallConnectionNotEnoughException;
29 import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
30 import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DatabaseConnectionManager;
31 import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
32 import org.apache.shardingsphere.infra.session.connection.transaction.TransactionConnectionContext;
33 import org.apache.shardingsphere.mode.manager.ContextManager;
34 import org.apache.shardingsphere.transaction.savepoint.ConnectionSavepointManager;
35 import org.apache.shardingsphere.transaction.ConnectionTransaction;
36 import org.apache.shardingsphere.transaction.rule.TransactionRule;
37
38 import javax.sql.DataSource;
39 import java.sql.Connection;
40 import java.sql.SQLException;
41 import java.sql.Savepoint;
42 import java.util.ArrayList;
43 import java.util.Collection;
44 import java.util.Collections;
45 import java.util.List;
46 import java.util.Map;
47 import java.util.Optional;
48 import java.util.concurrent.ThreadLocalRandom;
49 import java.util.stream.Collectors;
50
51
52
53
54 public final class DriverDatabaseConnectionManager implements DatabaseConnectionManager<Connection>, AutoCloseable {
55
56 private final String currentDatabaseName;
57
58 private final ContextManager contextManager;
59
60 private final Map<String, DataSource> dataSourceMap;
61
62 @Getter
63 private final ConnectionContext connectionContext;
64
65 private final Multimap<String, Connection> cachedConnections = LinkedHashMultimap.create();
66
67 private final MethodInvocationRecorder<Connection> methodInvocationRecorder = new MethodInvocationRecorder<>();
68
69 private final ForceExecuteTemplate<Connection> forceExecuteTemplate = new ForceExecuteTemplate<>();
70
71 public DriverDatabaseConnectionManager(final String currentDatabaseName, final ContextManager contextManager) {
72 this.currentDatabaseName = currentDatabaseName;
73 this.contextManager = contextManager;
74 dataSourceMap = contextManager.getStorageUnits(currentDatabaseName).entrySet()
75 .stream().collect(Collectors.toMap(entry -> getKey(currentDatabaseName, entry.getKey()), entry -> entry.getValue().getDataSource()));
76 connectionContext = new ConnectionContext(cachedConnections::keySet);
77 connectionContext.setCurrentDatabaseName(currentDatabaseName);
78 }
79
80 private String getKey(final String databaseName, final String dataSourceName) {
81 return databaseName.toLowerCase() + "." + dataSourceName;
82 }
83
84
85
86
87
88
89 public ConnectionTransaction getConnectionTransaction() {
90 TransactionRule rule = contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class);
91 return new ConnectionTransaction(rule, connectionContext.getTransactionContext());
92 }
93
94
95
96
97
98
99
100 public void setAutoCommit(final boolean autoCommit) throws SQLException {
101 methodInvocationRecorder.record("setAutoCommit", connection -> connection.setAutoCommit(autoCommit));
102 forceExecuteTemplate.execute(getCachedConnections(), connection -> connection.setAutoCommit(autoCommit));
103 if (autoCommit) {
104 clearCachedConnections();
105 }
106 }
107
108 private Collection<Connection> getCachedConnections() {
109 return cachedConnections.values();
110 }
111
112
113
114
115
116
117 public void clearCachedConnections() throws SQLException {
118 try {
119 forceExecuteTemplate.execute(cachedConnections.values(), Connection::close);
120 } finally {
121 cachedConnections.clear();
122 }
123 }
124
125
126
127
128
129
130 public void begin() throws SQLException {
131 ConnectionTransaction connectionTransaction = getConnectionTransaction();
132 connectionContext.getTransactionContext().beginTransaction(connectionTransaction.getTransactionType().name(), connectionTransaction.getDistributedTransactionManager());
133 if (!connectionTransaction.isLocalTransaction()) {
134 close();
135 }
136 doBegin(connectionTransaction);
137 }
138
139 private void doBegin(final ConnectionTransaction connectionTransaction) throws SQLException {
140 if (connectionTransaction.isLocalTransaction()) {
141 setAutoCommit(false);
142 } else {
143 connectionTransaction.begin();
144 }
145 }
146
147
148
149
150
151
152 public void commit() throws SQLException {
153 ConnectionTransaction connectionTransaction = getConnectionTransaction();
154 if (!connectionContext.getTransactionContext().isInTransaction() && !connectionTransaction.isInDistributedTransaction()) {
155 return;
156 }
157 try {
158 if (connectionTransaction.isLocalTransaction() && connectionContext.getTransactionContext().isExceptionOccur()) {
159 forceExecuteTemplate.execute(getCachedConnections(), Connection::rollback);
160 } else if (connectionTransaction.isLocalTransaction()) {
161 forceExecuteTemplate.execute(getCachedConnections(), Connection::commit);
162 } else {
163 connectionTransaction.commit();
164 }
165 } finally {
166 clear();
167 }
168 }
169
170
171
172
173
174
175 public void rollback() throws SQLException {
176 ConnectionTransaction connectionTransaction = getConnectionTransaction();
177 if (!connectionContext.getTransactionContext().isInTransaction() && !connectionTransaction.isInDistributedTransaction()) {
178 return;
179 }
180 try {
181 if (connectionTransaction.isLocalTransaction()) {
182 forceExecuteTemplate.execute(getCachedConnections(), Connection::rollback);
183 } else {
184 connectionTransaction.rollback();
185 }
186 } finally {
187 clear();
188 }
189 }
190
191
192
193
194
195
196
197 public void rollback(final Savepoint savepoint) throws SQLException {
198 for (Connection each : getCachedConnections()) {
199 ConnectionSavepointManager.getInstance().rollbackToSavepoint(each, savepoint.getSavepointName());
200 }
201 }
202
203 private void clear() {
204 methodInvocationRecorder.remove("setSavepoint");
205 for (Connection each : getCachedConnections()) {
206 ConnectionSavepointManager.getInstance().transactionFinished(each);
207 }
208 connectionContext.close();
209 }
210
211
212
213
214
215
216
217
218 public Savepoint setSavepoint(final String savepointName) throws SQLException {
219 ShardingSphereSavepoint result = new ShardingSphereSavepoint(savepointName);
220 for (Connection each : getCachedConnections()) {
221 ConnectionSavepointManager.getInstance().setSavepoint(each, savepointName);
222 }
223 methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, savepointName));
224 return result;
225 }
226
227
228
229
230
231
232
233 public Savepoint setSavepoint() throws SQLException {
234 ShardingSphereSavepoint result = new ShardingSphereSavepoint();
235 for (Connection each : getCachedConnections()) {
236 ConnectionSavepointManager.getInstance().setSavepoint(each, result.getSavepointName());
237 }
238 methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, result.getSavepointName()));
239 return result;
240 }
241
242
243
244
245
246
247
248 public void releaseSavepoint(final Savepoint savepoint) throws SQLException {
249 methodInvocationRecorder.remove("setSavepoint");
250 for (Connection each : getCachedConnections()) {
251 ConnectionSavepointManager.getInstance().releaseSavepoint(each, savepoint.getSavepointName());
252 }
253 }
254
255
256
257
258
259
260
261 public Optional<Integer> getTransactionIsolation() throws SQLException {
262 return cachedConnections.values().isEmpty() ? Optional.empty() : Optional.of(cachedConnections.values().iterator().next().getTransactionIsolation());
263 }
264
265
266
267
268
269
270
271 public void setTransactionIsolation(final int level) throws SQLException {
272 methodInvocationRecorder.record("setTransactionIsolation", connection -> connection.setTransactionIsolation(level));
273 forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setTransactionIsolation(level));
274 }
275
276
277
278
279
280
281
282 public void setReadOnly(final boolean readOnly) throws SQLException {
283 methodInvocationRecorder.record("setReadOnly", connection -> connection.setReadOnly(readOnly));
284 forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setReadOnly(readOnly));
285 }
286
287
288
289
290
291
292
293
294 public boolean isValid(final int timeout) throws SQLException {
295 for (Connection each : cachedConnections.values()) {
296 if (!each.isValid(timeout)) {
297 return false;
298 }
299 }
300 return true;
301 }
302
303
304
305
306
307
308 public String getRandomPhysicalDataSourceName() {
309 return getRandomPhysicalDatabaseAndDataSourceName()[1];
310 }
311
312 private String[] getRandomPhysicalDatabaseAndDataSourceName() {
313 Collection<String> cachedPhysicalDataSourceNames = Sets.intersection(dataSourceMap.keySet(), cachedConnections.keySet());
314 Collection<String> databaseAndDatasourceNames = cachedPhysicalDataSourceNames.isEmpty() ? dataSourceMap.keySet() : cachedPhysicalDataSourceNames;
315 return new ArrayList<>(databaseAndDatasourceNames).get(ThreadLocalRandom.current().nextInt(databaseAndDatasourceNames.size())).split("\\.");
316 }
317
318
319
320
321
322
323
324 public Connection getRandomConnection() throws SQLException {
325 String[] databaseAndDataSourceName = getRandomPhysicalDatabaseAndDataSourceName();
326 return getConnections0(databaseAndDataSourceName[0], databaseAndDataSourceName[1], 0, 1, ConnectionMode.MEMORY_STRICTLY).get(0);
327 }
328
329 @Override
330 public List<Connection> getConnections(final String databaseName, final String dataSourceName, final int connectionOffset, final int connectionSize,
331 final ConnectionMode connectionMode) throws SQLException {
332 return getConnections0(databaseName, dataSourceName, connectionOffset, connectionSize, connectionMode);
333 }
334
335 private List<Connection> getConnections0(final String databaseName, final String dataSourceName, final int connectionOffset, final int connectionSize,
336 final ConnectionMode connectionMode) throws SQLException {
337 String cacheKey = getKey(databaseName, dataSourceName);
338 DataSource dataSource = currentDatabaseName.equals(databaseName) ? dataSourceMap.get(cacheKey) : contextManager.getStorageUnits(databaseName).get(dataSourceName).getDataSource();
339 Preconditions.checkNotNull(dataSource, "Missing the data source name: '%s'", dataSourceName);
340 Collection<Connection> connections;
341 synchronized (cachedConnections) {
342 connections = cachedConnections.get(cacheKey);
343 }
344 List<Connection> result;
345 int maxConnectionSize = connectionOffset + connectionSize;
346 if (connections.size() >= maxConnectionSize) {
347 result = new ArrayList<>(connections).subList(connectionOffset, maxConnectionSize);
348 } else if (connections.isEmpty()) {
349 Collection<Connection> newConnections = createConnections(databaseName, dataSourceName, dataSource, maxConnectionSize, connectionMode);
350 result = new ArrayList<>(newConnections).subList(connectionOffset, maxConnectionSize);
351 synchronized (cachedConnections) {
352 cachedConnections.putAll(cacheKey, newConnections);
353 }
354 } else {
355 List<Connection> allConnections = new ArrayList<>(maxConnectionSize);
356 allConnections.addAll(connections);
357 Collection<Connection> newConnections = createConnections(databaseName, dataSourceName, dataSource, maxConnectionSize - connections.size(), connectionMode);
358 allConnections.addAll(newConnections);
359 result = allConnections.subList(connectionOffset, maxConnectionSize);
360 synchronized (cachedConnections) {
361 cachedConnections.putAll(cacheKey, newConnections);
362 }
363 }
364 return result;
365 }
366
367 @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
368 private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
369 final ConnectionMode connectionMode) throws SQLException {
370 if (1 == connectionSize) {
371 Connection connection = createConnection(databaseName, dataSourceName, dataSource, connectionContext.getTransactionContext());
372 try {
373 methodInvocationRecorder.replay(connection);
374 } catch (final SQLException ex) {
375 connection.close();
376 throw ex;
377 }
378 return Collections.singletonList(connection);
379 }
380 if (ConnectionMode.CONNECTION_STRICTLY == connectionMode) {
381 return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
382 }
383 synchronized (dataSource) {
384 return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
385 }
386 }
387
388 private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
389 final TransactionConnectionContext transactionConnectionContext) throws SQLException {
390 List<Connection> result = new ArrayList<>(connectionSize);
391 for (int i = 0; i < connectionSize; i++) {
392 try {
393 Connection connection = createConnection(databaseName, dataSourceName, dataSource, transactionConnectionContext);
394 methodInvocationRecorder.replay(connection);
395 result.add(connection);
396 } catch (final SQLException ex) {
397 for (Connection each : result) {
398 each.close();
399 }
400 throw new OverallConnectionNotEnoughException(connectionSize, result.size(), ex).toSQLException();
401 }
402 }
403 return result;
404 }
405
406 private Connection createConnection(final String databaseName, final String dataSourceName, final DataSource dataSource,
407 final TransactionConnectionContext transactionConnectionContext) throws SQLException {
408 Optional<Connection> connectionInTransaction = getConnectionTransaction().getConnection(databaseName, dataSourceName, transactionConnectionContext);
409 return connectionInTransaction.isPresent() ? connectionInTransaction.get() : dataSource.getConnection();
410 }
411
412 @Override
413 public void close() throws SQLException {
414 clearCachedConnections();
415 }
416 }