1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.driver.jdbc.core.connection;
19
20 import com.google.common.base.Preconditions;
21 import com.google.common.collect.LinkedHashMultimap;
22 import com.google.common.collect.Multimap;
23 import com.google.common.collect.Sets;
24 import lombok.Getter;
25 import org.apache.shardingsphere.authority.rule.AuthorityRule;
26 import org.apache.shardingsphere.driver.jdbc.adapter.executor.ForceExecuteTemplate;
27 import org.apache.shardingsphere.driver.jdbc.adapter.invocation.MethodInvocationRecorder;
28 import org.apache.shardingsphere.driver.jdbc.core.ShardingSphereSavepoint;
29 import org.apache.shardingsphere.infra.datasource.pool.creator.DataSourcePoolCreator;
30 import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
31 import org.apache.shardingsphere.infra.exception.kernel.connection.OverallConnectionNotEnoughException;
32 import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
33 import org.apache.shardingsphere.infra.executor.sql.prepare.driver.OnlineDatabaseConnectionManager;
34 import org.apache.shardingsphere.infra.instance.metadata.InstanceMetaData;
35 import org.apache.shardingsphere.infra.instance.metadata.InstanceType;
36 import org.apache.shardingsphere.infra.instance.metadata.proxy.ProxyInstanceMetaData;
37 import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
38 import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
39 import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
40 import org.apache.shardingsphere.infra.session.connection.transaction.TransactionConnectionContext;
41 import org.apache.shardingsphere.metadata.persist.MetaDataBasedPersistService;
42 import org.apache.shardingsphere.mode.manager.ContextManager;
43 import org.apache.shardingsphere.traffic.rule.TrafficRule;
44 import org.apache.shardingsphere.transaction.ConnectionSavepointManager;
45 import org.apache.shardingsphere.transaction.ConnectionTransaction;
46 import org.apache.shardingsphere.transaction.rule.TransactionRule;
47
48 import javax.sql.DataSource;
49 import java.security.SecureRandom;
50 import java.sql.Connection;
51 import java.sql.SQLException;
52 import java.sql.Savepoint;
53 import java.util.ArrayList;
54 import java.util.Collection;
55 import java.util.Collections;
56 import java.util.LinkedHashMap;
57 import java.util.List;
58 import java.util.Map;
59 import java.util.Map.Entry;
60 import java.util.Optional;
61 import java.util.Random;
62
63
64
65
66 public final class DriverDatabaseConnectionManager implements OnlineDatabaseConnectionManager<Connection>, AutoCloseable {
67
68 private final Map<String, DataSource> dataSourceMap = new LinkedHashMap<>();
69
70 private final Map<String, DataSource> physicalDataSourceMap = new LinkedHashMap<>();
71
72 private final Map<String, DataSource> trafficDataSourceMap = new LinkedHashMap<>();
73
74 private final Multimap<String, Connection> cachedConnections = LinkedHashMultimap.create();
75
76 private final MethodInvocationRecorder<Connection> methodInvocationRecorder = new MethodInvocationRecorder<>();
77
78 private final ForceExecuteTemplate<Connection> forceExecuteTemplate = new ForceExecuteTemplate<>();
79
80 private final Random random = new SecureRandom();
81
82 @Getter
83 private final ConnectionContext connectionContext;
84
85 private final ContextManager contextManager;
86
87 private final String databaseName;
88
89 public DriverDatabaseConnectionManager(final String databaseName, final ContextManager contextManager) {
90 for (Entry<String, StorageUnit> entry : contextManager.getStorageUnits(databaseName).entrySet()) {
91 DataSource dataSource = entry.getValue().getDataSource();
92 String cacheKey = getKey(databaseName, entry.getKey());
93 dataSourceMap.put(cacheKey, dataSource);
94 physicalDataSourceMap.put(cacheKey, dataSource);
95 }
96 for (Entry<String, DataSource> entry : getTrafficDataSourceMap(databaseName, contextManager).entrySet()) {
97 String cacheKey = getKey(databaseName, entry.getKey());
98 dataSourceMap.put(cacheKey, entry.getValue());
99 trafficDataSourceMap.put(cacheKey, entry.getValue());
100 }
101 connectionContext = new ConnectionContext(cachedConnections::keySet);
102 connectionContext.setCurrentDatabase(databaseName);
103 this.contextManager = contextManager;
104 this.databaseName = databaseName;
105 }
106
107 private Map<String, DataSource> getTrafficDataSourceMap(final String databaseName, final ContextManager contextManager) {
108 TrafficRule rule = contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TrafficRule.class);
109 if (rule.getStrategyRules().isEmpty()) {
110 return Collections.emptyMap();
111 }
112 MetaDataBasedPersistService persistService = contextManager.getMetaDataContexts().getPersistService();
113 String actualDatabaseName = contextManager.getMetaDataContexts().getMetaData().getDatabase(databaseName).getName();
114 Map<String, DataSourcePoolProperties> propsMap = persistService.getDataSourceUnitService().load(actualDatabaseName);
115 Preconditions.checkState(!propsMap.isEmpty(), "Can not get data source properties from meta data.");
116 DataSourcePoolProperties propsSample = propsMap.values().iterator().next();
117 Collection<ShardingSphereUser> users = contextManager.getMetaDataContexts().getMetaData()
118 .getGlobalRuleMetaData().getSingleRule(AuthorityRule.class).getConfiguration().getUsers();
119 Collection<InstanceMetaData> instances = contextManager.getInstanceContext().getAllClusterInstances(InstanceType.PROXY, rule.getLabels()).values();
120 return DataSourcePoolCreator.create(createDataSourcePoolPropertiesMap(instances, users, propsSample, actualDatabaseName), true);
121 }
122
123 private Map<String, DataSourcePoolProperties> createDataSourcePoolPropertiesMap(final Collection<InstanceMetaData> instances, final Collection<ShardingSphereUser> users,
124 final DataSourcePoolProperties propsSample, final String schema) {
125 Map<String, DataSourcePoolProperties> result = new LinkedHashMap<>(instances.size(), 1F);
126 for (InstanceMetaData each : instances) {
127 result.put(each.getId(), createDataSourcePoolProperties((ProxyInstanceMetaData) each, users, propsSample, schema));
128 }
129 return result;
130 }
131
132 private DataSourcePoolProperties createDataSourcePoolProperties(final ProxyInstanceMetaData instanceMetaData, final Collection<ShardingSphereUser> users,
133 final DataSourcePoolProperties propsSample, final String schema) {
134 Map<String, Object> props = propsSample.getAllLocalProperties();
135 props.put("jdbcUrl", createJdbcUrl(instanceMetaData, schema, props));
136 ShardingSphereUser user = users.iterator().next();
137 props.put("username", user.getGrantee().getUsername());
138 props.put("password", user.getPassword());
139 return new DataSourcePoolProperties("com.zaxxer.hikari.HikariDataSource", props);
140 }
141
142 private String createJdbcUrl(final ProxyInstanceMetaData instanceMetaData, final String schema, final Map<String, Object> props) {
143 String jdbcUrl = String.valueOf(props.get("jdbcUrl"));
144 String jdbcUrlPrefix = jdbcUrl.substring(0, jdbcUrl.indexOf("//"));
145 String jdbcUrlSuffix = jdbcUrl.contains("?") ? jdbcUrl.substring(jdbcUrl.indexOf('?')) : "";
146 return String.format("%s//%s:%s/%s%s", jdbcUrlPrefix, instanceMetaData.getIp(), instanceMetaData.getPort(), schema, jdbcUrlSuffix);
147 }
148
149
150
151
152
153
154 public ConnectionTransaction getConnectionTransaction() {
155 TransactionRule rule = contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class);
156 return new ConnectionTransaction(rule, connectionContext.getTransactionContext());
157 }
158
159
160
161
162
163
164
165 public void setAutoCommit(final boolean autoCommit) throws SQLException {
166 methodInvocationRecorder.record("setAutoCommit", target -> target.setAutoCommit(autoCommit));
167 forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setAutoCommit(autoCommit));
168 }
169
170
171
172
173
174
175 public void commit() throws SQLException {
176 ConnectionTransaction connectionTransaction = getConnectionTransaction();
177 try {
178 if (connectionTransaction.isLocalTransaction() && connectionContext.getTransactionContext().isExceptionOccur()) {
179 forceExecuteTemplate.execute(cachedConnections.values(), Connection::rollback);
180 } else if (connectionTransaction.isLocalTransaction()) {
181 forceExecuteTemplate.execute(cachedConnections.values(), Connection::commit);
182 } else {
183 connectionTransaction.commit();
184 }
185 } finally {
186 for (Connection each : cachedConnections.values()) {
187 ConnectionSavepointManager.getInstance().transactionFinished(each);
188 }
189 }
190 }
191
192
193
194
195
196
197 public void rollback() throws SQLException {
198 ConnectionTransaction connectionTransaction = getConnectionTransaction();
199 try {
200 if (connectionTransaction.isLocalTransaction()) {
201 forceExecuteTemplate.execute(cachedConnections.values(), Connection::rollback);
202 } else {
203 connectionTransaction.rollback();
204 }
205 } finally {
206 for (Connection each : cachedConnections.values()) {
207 ConnectionSavepointManager.getInstance().transactionFinished(each);
208 }
209 }
210 }
211
212
213
214
215
216
217
218 public void rollback(final Savepoint savepoint) throws SQLException {
219 for (Connection each : cachedConnections.values()) {
220 ConnectionSavepointManager.getInstance().rollbackToSavepoint(each, savepoint.getSavepointName());
221 }
222 }
223
224
225
226
227
228
229
230
231 public Savepoint setSavepoint(final String savepointName) throws SQLException {
232 ShardingSphereSavepoint result = new ShardingSphereSavepoint(savepointName);
233 for (Connection each : cachedConnections.values()) {
234 ConnectionSavepointManager.getInstance().setSavepoint(each, savepointName);
235 }
236 methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, savepointName));
237 return result;
238 }
239
240
241
242
243
244
245
246 public Savepoint setSavepoint() throws SQLException {
247 ShardingSphereSavepoint result = new ShardingSphereSavepoint();
248 for (Connection each : cachedConnections.values()) {
249 ConnectionSavepointManager.getInstance().setSavepoint(each, result.getSavepointName());
250 }
251 methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, result.getSavepointName()));
252 return result;
253 }
254
255
256
257
258
259
260
261 public void releaseSavepoint(final Savepoint savepoint) throws SQLException {
262 for (Connection each : cachedConnections.values()) {
263 ConnectionSavepointManager.getInstance().releaseSavepoint(each, savepoint.getSavepointName());
264 }
265 }
266
267
268
269
270
271
272
273 public Optional<Integer> getTransactionIsolation() throws SQLException {
274 return cachedConnections.values().isEmpty() ? Optional.empty() : Optional.of(cachedConnections.values().iterator().next().getTransactionIsolation());
275 }
276
277
278
279
280
281
282
283 public void setTransactionIsolation(final int level) throws SQLException {
284 methodInvocationRecorder.record("setTransactionIsolation", connection -> connection.setTransactionIsolation(level));
285 forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setTransactionIsolation(level));
286 }
287
288
289
290
291
292
293
294 public void setReadOnly(final boolean readOnly) throws SQLException {
295 methodInvocationRecorder.record("setReadOnly", connection -> connection.setReadOnly(readOnly));
296 forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setReadOnly(readOnly));
297 }
298
299
300
301
302
303
304
305
306 public boolean isValid(final int timeout) throws SQLException {
307 for (Connection each : cachedConnections.values()) {
308 if (!each.isValid(timeout)) {
309 return false;
310 }
311 }
312 return true;
313 }
314
315
316
317
318
319
320 public String getRandomPhysicalDataSourceName() {
321 return getRandomPhysicalDatabaseAndDataSourceName()[1];
322 }
323
324 private String[] getRandomPhysicalDatabaseAndDataSourceName() {
325 Collection<String> cachedPhysicalDataSourceNames = Sets.intersection(physicalDataSourceMap.keySet(), cachedConnections.keySet());
326 Collection<String> databaseAndDatasourceNames = cachedPhysicalDataSourceNames.isEmpty() ? physicalDataSourceMap.keySet() : cachedPhysicalDataSourceNames;
327 return new ArrayList<>(databaseAndDatasourceNames).get(random.nextInt(databaseAndDatasourceNames.size())).split("\\.");
328 }
329
330
331
332
333
334
335
336 public Connection getRandomConnection() throws SQLException {
337 String[] databaseAndDataSourceName = getRandomPhysicalDatabaseAndDataSourceName();
338 return getConnections(databaseAndDataSourceName[0], databaseAndDataSourceName[1], 0, 1, ConnectionMode.MEMORY_STRICTLY).get(0);
339 }
340
341 @Override
342 public List<Connection> getConnections(final String dataSourceName, final int connectionOffset, final int connectionSize, final ConnectionMode connectionMode) throws SQLException {
343 return getConnections(connectionContext.getDatabaseName().orElse(databaseName), dataSourceName, connectionOffset, connectionSize, connectionMode);
344 }
345
346 private List<Connection> getConnections(final String currentDatabaseName, final String dataSourceName, final int connectionOffset, final int connectionSize,
347 final ConnectionMode connectionMode) throws SQLException {
348 String cacheKey = getKey(currentDatabaseName, dataSourceName);
349 DataSource dataSource = databaseName.equals(currentDatabaseName)
350 ? dataSourceMap.get(cacheKey)
351 : contextManager.getStorageUnits(currentDatabaseName).get(dataSourceName).getDataSource();
352 Preconditions.checkNotNull(dataSource, "Missing the data source name: '%s'", dataSourceName);
353 Collection<Connection> connections;
354 synchronized (cachedConnections) {
355 connections = cachedConnections.get(cacheKey);
356 }
357 List<Connection> result;
358 int maxConnectionSize = connectionOffset + connectionSize;
359 if (connections.size() >= maxConnectionSize) {
360 result = new ArrayList<>(connections).subList(connectionOffset, maxConnectionSize);
361 } else if (connections.isEmpty()) {
362 Collection<Connection> newConnections = createConnections(currentDatabaseName, dataSourceName, dataSource, maxConnectionSize, connectionMode);
363 result = new ArrayList<>(newConnections).subList(connectionOffset, maxConnectionSize);
364 synchronized (cachedConnections) {
365 cachedConnections.putAll(cacheKey, newConnections);
366 }
367 } else {
368 List<Connection> allConnections = new ArrayList<>(maxConnectionSize);
369 allConnections.addAll(connections);
370 Collection<Connection> newConnections = createConnections(currentDatabaseName, dataSourceName, dataSource, maxConnectionSize - connections.size(), connectionMode);
371 allConnections.addAll(newConnections);
372 result = allConnections.subList(connectionOffset, maxConnectionSize);
373 synchronized (cachedConnections) {
374 cachedConnections.putAll(cacheKey, newConnections);
375 }
376 }
377 return result;
378 }
379
380 private String getKey(final String databaseName, final String dataSourceName) {
381 return databaseName.toLowerCase() + "." + dataSourceName;
382 }
383
384 @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
385 private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
386 final ConnectionMode connectionMode) throws SQLException {
387 if (1 == connectionSize) {
388 Connection connection = createConnection(databaseName, dataSourceName, dataSource, connectionContext.getTransactionContext());
389 methodInvocationRecorder.replay(connection);
390 return Collections.singletonList(connection);
391 }
392 if (ConnectionMode.CONNECTION_STRICTLY == connectionMode) {
393 return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
394 }
395 synchronized (dataSource) {
396 return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
397 }
398 }
399
400 private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
401 final TransactionConnectionContext transactionConnectionContext) throws SQLException {
402 List<Connection> result = new ArrayList<>(connectionSize);
403 for (int i = 0; i < connectionSize; i++) {
404 try {
405 Connection connection = createConnection(databaseName, dataSourceName, dataSource, transactionConnectionContext);
406 methodInvocationRecorder.replay(connection);
407 result.add(connection);
408 } catch (final SQLException ex) {
409 for (Connection each : result) {
410 each.close();
411 }
412 throw new OverallConnectionNotEnoughException(connectionSize, result.size(), ex).toSQLException();
413 }
414 }
415 return result;
416 }
417
418 private Connection createConnection(final String databaseName, final String dataSourceName, final DataSource dataSource,
419 final TransactionConnectionContext transactionConnectionContext) throws SQLException {
420 Optional<Connection> connectionInTransaction =
421 isRawJdbcDataSource(databaseName, dataSourceName) ? getConnectionTransaction().getConnection(databaseName, dataSourceName, transactionConnectionContext) : Optional.empty();
422 return connectionInTransaction.isPresent() ? connectionInTransaction.get() : dataSource.getConnection();
423 }
424
425 private boolean isRawJdbcDataSource(final String databaseName, final String dataSourceName) {
426 return !trafficDataSourceMap.containsKey(getKey(databaseName, dataSourceName));
427 }
428
429 @Override
430 public void close() throws SQLException {
431 try {
432 forceExecuteTemplate.execute(cachedConnections.values(), Connection::close);
433 } finally {
434 cachedConnections.clear();
435 }
436 }
437 }