1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.driver.jdbc.core.connection;
19
20 import com.google.common.base.Preconditions;
21 import com.google.common.collect.LinkedHashMultimap;
22 import com.google.common.collect.Multimap;
23 import com.google.common.collect.Sets;
24 import lombok.Getter;
25 import org.apache.shardingsphere.driver.jdbc.adapter.executor.ForceExecuteTemplate;
26 import org.apache.shardingsphere.driver.jdbc.adapter.invocation.MethodInvocationRecorder;
27 import org.apache.shardingsphere.driver.jdbc.core.savepoint.ShardingSphereSavepoint;
28 import org.apache.shardingsphere.infra.exception.kernel.connection.OverallConnectionNotEnoughException;
29 import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
30 import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DatabaseConnectionManager;
31 import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
32 import org.apache.shardingsphere.infra.session.connection.transaction.TransactionConnectionContext;
33 import org.apache.shardingsphere.mode.manager.ContextManager;
34 import org.apache.shardingsphere.transaction.savepoint.ConnectionSavepointManager;
35 import org.apache.shardingsphere.transaction.ConnectionTransaction;
36 import org.apache.shardingsphere.transaction.api.TransactionType;
37 import org.apache.shardingsphere.transaction.rule.TransactionRule;
38
39 import javax.sql.DataSource;
40 import java.sql.Connection;
41 import java.sql.SQLException;
42 import java.sql.Savepoint;
43 import java.util.ArrayList;
44 import java.util.Collection;
45 import java.util.Collections;
46 import java.util.List;
47 import java.util.Map;
48 import java.util.Optional;
49 import java.util.concurrent.ThreadLocalRandom;
50 import java.util.stream.Collectors;
51
52
53
54
55 public final class DriverDatabaseConnectionManager implements DatabaseConnectionManager<Connection>, AutoCloseable {
56
57 private final String currentDatabaseName;
58
59 private final ContextManager contextManager;
60
61 private final Map<String, DataSource> dataSourceMap;
62
63 @Getter
64 private final ConnectionContext connectionContext;
65
66 private final Multimap<String, Connection> cachedConnections = LinkedHashMultimap.create();
67
68 private final MethodInvocationRecorder<Connection> methodInvocationRecorder = new MethodInvocationRecorder<>();
69
70 private final ForceExecuteTemplate<Connection> forceExecuteTemplate = new ForceExecuteTemplate<>();
71
72 public DriverDatabaseConnectionManager(final String currentDatabaseName, final ContextManager contextManager) {
73 this.currentDatabaseName = currentDatabaseName;
74 this.contextManager = contextManager;
75 dataSourceMap = contextManager.getStorageUnits(currentDatabaseName).entrySet()
76 .stream().collect(Collectors.toMap(entry -> getKey(currentDatabaseName, entry.getKey()), entry -> entry.getValue().getDataSource()));
77 connectionContext = new ConnectionContext(cachedConnections::keySet);
78 connectionContext.setCurrentDatabaseName(currentDatabaseName);
79 }
80
81 private String getKey(final String databaseName, final String dataSourceName) {
82 return databaseName.toLowerCase() + "." + dataSourceName;
83 }
84
85
86
87
88
89
90 public ConnectionTransaction getConnectionTransaction() {
91 TransactionRule rule = contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class);
92 return new ConnectionTransaction(rule, connectionContext.getTransactionContext());
93 }
94
95
96
97
98
99
100
101 public void setAutoCommit(final boolean autoCommit) throws SQLException {
102 methodInvocationRecorder.record("setAutoCommit", connection -> connection.setAutoCommit(autoCommit));
103 forceExecuteTemplate.execute(getCachedConnections(), connection -> connection.setAutoCommit(autoCommit));
104 if (autoCommit) {
105 clearCachedConnections();
106 }
107 }
108
109 private Collection<Connection> getCachedConnections() {
110 return cachedConnections.values();
111 }
112
113
114
115
116
117
118 public void clearCachedConnections() throws SQLException {
119 try {
120 forceExecuteTemplate.execute(cachedConnections.values(), Connection::close);
121 } finally {
122 cachedConnections.clear();
123 }
124 }
125
126
127
128
129
130
131 public void begin() throws SQLException {
132 ConnectionTransaction connectionTransaction = getConnectionTransaction();
133 if (TransactionType.isDistributedTransaction(connectionTransaction.getTransactionType())) {
134 close();
135 connectionTransaction.begin();
136 }
137 connectionContext.getTransactionContext().beginTransaction(String.valueOf(connectionTransaction.getTransactionType()), connectionTransaction.getDistributedTransactionManager());
138 }
139
140
141
142
143
144
145 public void commit() throws SQLException {
146 ConnectionTransaction connectionTransaction = getConnectionTransaction();
147 try {
148 if (connectionTransaction.isLocalTransaction() && connectionContext.getTransactionContext().isExceptionOccur()) {
149 forceExecuteTemplate.execute(getCachedConnections(), Connection::rollback);
150 } else if (connectionTransaction.isLocalTransaction()) {
151 forceExecuteTemplate.execute(getCachedConnections(), Connection::commit);
152 } else {
153 connectionTransaction.commit();
154 }
155 } finally {
156 methodInvocationRecorder.remove("setSavepoint");
157 for (Connection each : getCachedConnections()) {
158 ConnectionSavepointManager.getInstance().transactionFinished(each);
159 }
160 connectionContext.close();
161 clearCachedConnections();
162 }
163 }
164
165
166
167
168
169
170 public void rollback() throws SQLException {
171 ConnectionTransaction connectionTransaction = getConnectionTransaction();
172 try {
173 if (connectionTransaction.isLocalTransaction()) {
174 forceExecuteTemplate.execute(getCachedConnections(), Connection::rollback);
175 } else {
176 connectionTransaction.rollback();
177 }
178 } finally {
179 methodInvocationRecorder.remove("setSavepoint");
180 for (Connection each : getCachedConnections()) {
181 ConnectionSavepointManager.getInstance().transactionFinished(each);
182 }
183 connectionContext.close();
184 clearCachedConnections();
185 }
186 }
187
188
189
190
191
192
193
194 public void rollback(final Savepoint savepoint) throws SQLException {
195 for (Connection each : getCachedConnections()) {
196 ConnectionSavepointManager.getInstance().rollbackToSavepoint(each, savepoint.getSavepointName());
197 }
198 }
199
200
201
202
203
204
205
206
207 public Savepoint setSavepoint(final String savepointName) throws SQLException {
208 ShardingSphereSavepoint result = new ShardingSphereSavepoint(savepointName);
209 for (Connection each : getCachedConnections()) {
210 ConnectionSavepointManager.getInstance().setSavepoint(each, savepointName);
211 }
212 methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, savepointName));
213 return result;
214 }
215
216
217
218
219
220
221
222 public Savepoint setSavepoint() throws SQLException {
223 ShardingSphereSavepoint result = new ShardingSphereSavepoint();
224 for (Connection each : getCachedConnections()) {
225 ConnectionSavepointManager.getInstance().setSavepoint(each, result.getSavepointName());
226 }
227 methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, result.getSavepointName()));
228 return result;
229 }
230
231
232
233
234
235
236
237 public void releaseSavepoint(final Savepoint savepoint) throws SQLException {
238 methodInvocationRecorder.remove("setSavepoint");
239 for (Connection each : getCachedConnections()) {
240 ConnectionSavepointManager.getInstance().releaseSavepoint(each, savepoint.getSavepointName());
241 }
242 }
243
244
245
246
247
248
249
250 public Optional<Integer> getTransactionIsolation() throws SQLException {
251 return cachedConnections.values().isEmpty() ? Optional.empty() : Optional.of(cachedConnections.values().iterator().next().getTransactionIsolation());
252 }
253
254
255
256
257
258
259
260 public void setTransactionIsolation(final int level) throws SQLException {
261 methodInvocationRecorder.record("setTransactionIsolation", connection -> connection.setTransactionIsolation(level));
262 forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setTransactionIsolation(level));
263 }
264
265
266
267
268
269
270
271 public void setReadOnly(final boolean readOnly) throws SQLException {
272 methodInvocationRecorder.record("setReadOnly", connection -> connection.setReadOnly(readOnly));
273 forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setReadOnly(readOnly));
274 }
275
276
277
278
279
280
281
282
283 public boolean isValid(final int timeout) throws SQLException {
284 for (Connection each : cachedConnections.values()) {
285 if (!each.isValid(timeout)) {
286 return false;
287 }
288 }
289 return true;
290 }
291
292
293
294
295
296
297 public String getRandomPhysicalDataSourceName() {
298 return getRandomPhysicalDatabaseAndDataSourceName()[1];
299 }
300
301 private String[] getRandomPhysicalDatabaseAndDataSourceName() {
302 Collection<String> cachedPhysicalDataSourceNames = Sets.intersection(dataSourceMap.keySet(), cachedConnections.keySet());
303 Collection<String> databaseAndDatasourceNames = cachedPhysicalDataSourceNames.isEmpty() ? dataSourceMap.keySet() : cachedPhysicalDataSourceNames;
304 return new ArrayList<>(databaseAndDatasourceNames).get(ThreadLocalRandom.current().nextInt(databaseAndDatasourceNames.size())).split("\\.");
305 }
306
307
308
309
310
311
312
313 public Connection getRandomConnection() throws SQLException {
314 String[] databaseAndDataSourceName = getRandomPhysicalDatabaseAndDataSourceName();
315 return getConnections0(databaseAndDataSourceName[0], databaseAndDataSourceName[1], 0, 1, ConnectionMode.MEMORY_STRICTLY).get(0);
316 }
317
318 @Override
319 public List<Connection> getConnections(final String databaseName, final String dataSourceName, final int connectionOffset, final int connectionSize,
320 final ConnectionMode connectionMode) throws SQLException {
321 return getConnections0(databaseName, dataSourceName, connectionOffset, connectionSize, connectionMode);
322 }
323
324 private List<Connection> getConnections0(final String databaseName, final String dataSourceName, final int connectionOffset, final int connectionSize,
325 final ConnectionMode connectionMode) throws SQLException {
326 String cacheKey = getKey(databaseName, dataSourceName);
327 DataSource dataSource = currentDatabaseName.equals(databaseName) ? dataSourceMap.get(cacheKey) : contextManager.getStorageUnits(databaseName).get(dataSourceName).getDataSource();
328 Preconditions.checkNotNull(dataSource, "Missing the data source name: '%s'", dataSourceName);
329 Collection<Connection> connections;
330 synchronized (cachedConnections) {
331 connections = cachedConnections.get(cacheKey);
332 }
333 List<Connection> result;
334 int maxConnectionSize = connectionOffset + connectionSize;
335 if (connections.size() >= maxConnectionSize) {
336 result = new ArrayList<>(connections).subList(connectionOffset, maxConnectionSize);
337 } else if (connections.isEmpty()) {
338 Collection<Connection> newConnections = createConnections(databaseName, dataSourceName, dataSource, maxConnectionSize, connectionMode);
339 result = new ArrayList<>(newConnections).subList(connectionOffset, maxConnectionSize);
340 synchronized (cachedConnections) {
341 cachedConnections.putAll(cacheKey, newConnections);
342 }
343 } else {
344 List<Connection> allConnections = new ArrayList<>(maxConnectionSize);
345 allConnections.addAll(connections);
346 Collection<Connection> newConnections = createConnections(databaseName, dataSourceName, dataSource, maxConnectionSize - connections.size(), connectionMode);
347 allConnections.addAll(newConnections);
348 result = allConnections.subList(connectionOffset, maxConnectionSize);
349 synchronized (cachedConnections) {
350 cachedConnections.putAll(cacheKey, newConnections);
351 }
352 }
353 return result;
354 }
355
356 @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
357 private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
358 final ConnectionMode connectionMode) throws SQLException {
359 if (1 == connectionSize) {
360 Connection connection = createConnection(databaseName, dataSourceName, dataSource, connectionContext.getTransactionContext());
361 try {
362 methodInvocationRecorder.replay(connection);
363 } catch (final SQLException ex) {
364 connection.close();
365 throw ex;
366 }
367 return Collections.singletonList(connection);
368 }
369 if (ConnectionMode.CONNECTION_STRICTLY == connectionMode) {
370 return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
371 }
372 synchronized (dataSource) {
373 return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
374 }
375 }
376
377 private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
378 final TransactionConnectionContext transactionConnectionContext) throws SQLException {
379 List<Connection> result = new ArrayList<>(connectionSize);
380 for (int i = 0; i < connectionSize; i++) {
381 try {
382 Connection connection = createConnection(databaseName, dataSourceName, dataSource, transactionConnectionContext);
383 methodInvocationRecorder.replay(connection);
384 result.add(connection);
385 } catch (final SQLException ex) {
386 for (Connection each : result) {
387 each.close();
388 }
389 throw new OverallConnectionNotEnoughException(connectionSize, result.size(), ex).toSQLException();
390 }
391 }
392 return result;
393 }
394
395 private Connection createConnection(final String databaseName, final String dataSourceName, final DataSource dataSource,
396 final TransactionConnectionContext transactionConnectionContext) throws SQLException {
397 Optional<Connection> connectionInTransaction = getConnectionTransaction().getConnection(databaseName, dataSourceName, transactionConnectionContext);
398 return connectionInTransaction.isPresent() ? connectionInTransaction.get() : dataSource.getConnection();
399 }
400
401 @Override
402 public void close() throws SQLException {
403 clearCachedConnections();
404 }
405 }