View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.driver.jdbc.core.connection;
19  
20  import com.google.common.base.Preconditions;
21  import com.google.common.collect.LinkedHashMultimap;
22  import com.google.common.collect.Multimap;
23  import com.google.common.collect.Sets;
24  import lombok.Getter;
25  import org.apache.shardingsphere.authority.rule.AuthorityRule;
26  import org.apache.shardingsphere.driver.jdbc.adapter.executor.ForceExecuteTemplate;
27  import org.apache.shardingsphere.driver.jdbc.adapter.invocation.MethodInvocationRecorder;
28  import org.apache.shardingsphere.driver.jdbc.core.ShardingSphereSavepoint;
29  import org.apache.shardingsphere.infra.datasource.pool.creator.DataSourcePoolCreator;
30  import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
31  import org.apache.shardingsphere.infra.exception.kernel.connection.OverallConnectionNotEnoughException;
32  import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
33  import org.apache.shardingsphere.infra.executor.sql.prepare.driver.OnlineDatabaseConnectionManager;
34  import org.apache.shardingsphere.infra.instance.metadata.InstanceMetaData;
35  import org.apache.shardingsphere.infra.instance.metadata.InstanceType;
36  import org.apache.shardingsphere.infra.instance.metadata.proxy.ProxyInstanceMetaData;
37  import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
38  import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
39  import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
40  import org.apache.shardingsphere.infra.session.connection.transaction.TransactionConnectionContext;
41  import org.apache.shardingsphere.metadata.persist.MetaDataBasedPersistService;
42  import org.apache.shardingsphere.mode.manager.ContextManager;
43  import org.apache.shardingsphere.traffic.rule.TrafficRule;
44  import org.apache.shardingsphere.transaction.ConnectionSavepointManager;
45  import org.apache.shardingsphere.transaction.ConnectionTransaction;
46  import org.apache.shardingsphere.transaction.rule.TransactionRule;
47  
48  import javax.sql.DataSource;
49  import java.security.SecureRandom;
50  import java.sql.Connection;
51  import java.sql.SQLException;
52  import java.sql.Savepoint;
53  import java.util.ArrayList;
54  import java.util.Collection;
55  import java.util.Collections;
56  import java.util.LinkedHashMap;
57  import java.util.List;
58  import java.util.Map;
59  import java.util.Map.Entry;
60  import java.util.Optional;
61  import java.util.Random;
62  
63  /**
64   * Database connection manager of ShardingSphere-JDBC.
65   */
66  public final class DriverDatabaseConnectionManager implements OnlineDatabaseConnectionManager<Connection>, AutoCloseable {
67      
68      private final Map<String, DataSource> dataSourceMap = new LinkedHashMap<>();
69      
70      private final Map<String, DataSource> physicalDataSourceMap = new LinkedHashMap<>();
71      
72      private final Map<String, DataSource> trafficDataSourceMap = new LinkedHashMap<>();
73      
74      private final Multimap<String, Connection> cachedConnections = LinkedHashMultimap.create();
75      
76      private final MethodInvocationRecorder<Connection> methodInvocationRecorder = new MethodInvocationRecorder<>();
77      
78      private final ForceExecuteTemplate<Connection> forceExecuteTemplate = new ForceExecuteTemplate<>();
79      
80      private final Random random = new SecureRandom();
81      
82      @Getter
83      private final ConnectionContext connectionContext;
84      
85      private final ContextManager contextManager;
86      
87      private final String databaseName;
88      
89      public DriverDatabaseConnectionManager(final String databaseName, final ContextManager contextManager) {
90          for (Entry<String, StorageUnit> entry : contextManager.getStorageUnits(databaseName).entrySet()) {
91              DataSource dataSource = entry.getValue().getDataSource();
92              String cacheKey = getKey(databaseName, entry.getKey());
93              dataSourceMap.put(cacheKey, dataSource);
94              physicalDataSourceMap.put(cacheKey, dataSource);
95          }
96          for (Entry<String, DataSource> entry : getTrafficDataSourceMap(databaseName, contextManager).entrySet()) {
97              String cacheKey = getKey(databaseName, entry.getKey());
98              dataSourceMap.put(cacheKey, entry.getValue());
99              trafficDataSourceMap.put(cacheKey, entry.getValue());
100         }
101         connectionContext = new ConnectionContext(cachedConnections::keySet);
102         connectionContext.setCurrentDatabase(databaseName);
103         this.contextManager = contextManager;
104         this.databaseName = databaseName;
105     }
106     
107     private Map<String, DataSource> getTrafficDataSourceMap(final String databaseName, final ContextManager contextManager) {
108         TrafficRule rule = contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TrafficRule.class);
109         if (rule.getStrategyRules().isEmpty()) {
110             return Collections.emptyMap();
111         }
112         MetaDataBasedPersistService persistService = contextManager.getMetaDataContexts().getPersistService();
113         String actualDatabaseName = contextManager.getMetaDataContexts().getMetaData().getDatabase(databaseName).getName();
114         Map<String, DataSourcePoolProperties> propsMap = persistService.getDataSourceUnitService().load(actualDatabaseName);
115         Preconditions.checkState(!propsMap.isEmpty(), "Can not get data source properties from meta data.");
116         DataSourcePoolProperties propsSample = propsMap.values().iterator().next();
117         Collection<ShardingSphereUser> users = contextManager.getMetaDataContexts().getMetaData()
118                 .getGlobalRuleMetaData().getSingleRule(AuthorityRule.class).getConfiguration().getUsers();
119         Collection<InstanceMetaData> instances = contextManager.getInstanceContext().getAllClusterInstances(InstanceType.PROXY, rule.getLabels()).values();
120         return DataSourcePoolCreator.create(createDataSourcePoolPropertiesMap(instances, users, propsSample, actualDatabaseName), true);
121     }
122     
123     private Map<String, DataSourcePoolProperties> createDataSourcePoolPropertiesMap(final Collection<InstanceMetaData> instances, final Collection<ShardingSphereUser> users,
124                                                                                     final DataSourcePoolProperties propsSample, final String schema) {
125         Map<String, DataSourcePoolProperties> result = new LinkedHashMap<>(instances.size(), 1F);
126         for (InstanceMetaData each : instances) {
127             result.put(each.getId(), createDataSourcePoolProperties((ProxyInstanceMetaData) each, users, propsSample, schema));
128         }
129         return result;
130     }
131     
132     private DataSourcePoolProperties createDataSourcePoolProperties(final ProxyInstanceMetaData instanceMetaData, final Collection<ShardingSphereUser> users,
133                                                                     final DataSourcePoolProperties propsSample, final String schema) {
134         Map<String, Object> props = propsSample.getAllLocalProperties();
135         props.put("jdbcUrl", createJdbcUrl(instanceMetaData, schema, props));
136         ShardingSphereUser user = users.iterator().next();
137         props.put("username", user.getGrantee().getUsername());
138         props.put("password", user.getPassword());
139         return new DataSourcePoolProperties("com.zaxxer.hikari.HikariDataSource", props);
140     }
141     
142     private String createJdbcUrl(final ProxyInstanceMetaData instanceMetaData, final String schema, final Map<String, Object> props) {
143         String jdbcUrl = String.valueOf(props.get("jdbcUrl"));
144         String jdbcUrlPrefix = jdbcUrl.substring(0, jdbcUrl.indexOf("//"));
145         String jdbcUrlSuffix = jdbcUrl.contains("?") ? jdbcUrl.substring(jdbcUrl.indexOf('?')) : "";
146         return String.format("%s//%s:%s/%s%s", jdbcUrlPrefix, instanceMetaData.getIp(), instanceMetaData.getPort(), schema, jdbcUrlSuffix);
147     }
148     
149     /**
150      * Get connection transaction.
151      *
152      * @return connection transaction
153      */
154     public ConnectionTransaction getConnectionTransaction() {
155         TransactionRule rule = contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class);
156         return new ConnectionTransaction(rule, connectionContext.getTransactionContext());
157     }
158     
159     /**
160      * Set auto commit.
161      *
162      * @param autoCommit auto commit
163      * @throws SQLException SQL exception
164      */
165     public void setAutoCommit(final boolean autoCommit) throws SQLException {
166         methodInvocationRecorder.record("setAutoCommit", target -> target.setAutoCommit(autoCommit));
167         forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setAutoCommit(autoCommit));
168     }
169     
170     /**
171      * Commit.
172      *
173      * @throws SQLException SQL exception
174      */
175     public void commit() throws SQLException {
176         ConnectionTransaction connectionTransaction = getConnectionTransaction();
177         try {
178             if (connectionTransaction.isLocalTransaction() && connectionContext.getTransactionContext().isExceptionOccur()) {
179                 forceExecuteTemplate.execute(cachedConnections.values(), Connection::rollback);
180             } else if (connectionTransaction.isLocalTransaction()) {
181                 forceExecuteTemplate.execute(cachedConnections.values(), Connection::commit);
182             } else {
183                 connectionTransaction.commit();
184             }
185         } finally {
186             for (Connection each : cachedConnections.values()) {
187                 ConnectionSavepointManager.getInstance().transactionFinished(each);
188             }
189         }
190     }
191     
192     /**
193      * Rollback.
194      *
195      * @throws SQLException SQL exception
196      */
197     public void rollback() throws SQLException {
198         ConnectionTransaction connectionTransaction = getConnectionTransaction();
199         try {
200             if (connectionTransaction.isLocalTransaction()) {
201                 forceExecuteTemplate.execute(cachedConnections.values(), Connection::rollback);
202             } else {
203                 connectionTransaction.rollback();
204             }
205         } finally {
206             for (Connection each : cachedConnections.values()) {
207                 ConnectionSavepointManager.getInstance().transactionFinished(each);
208             }
209         }
210     }
211     
212     /**
213      * Rollback to savepoint.
214      *
215      * @param savepoint savepoint
216      * @throws SQLException SQL exception
217      */
218     public void rollback(final Savepoint savepoint) throws SQLException {
219         for (Connection each : cachedConnections.values()) {
220             ConnectionSavepointManager.getInstance().rollbackToSavepoint(each, savepoint.getSavepointName());
221         }
222     }
223     
224     /**
225      * Set savepoint.
226      *
227      * @param savepointName savepoint name
228      * @return savepoint savepoint
229      * @throws SQLException SQL exception
230      */
231     public Savepoint setSavepoint(final String savepointName) throws SQLException {
232         ShardingSphereSavepoint result = new ShardingSphereSavepoint(savepointName);
233         for (Connection each : cachedConnections.values()) {
234             ConnectionSavepointManager.getInstance().setSavepoint(each, savepointName);
235         }
236         methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, savepointName));
237         return result;
238     }
239     
240     /**
241      * Set savepoint.
242      *
243      * @return savepoint savepoint
244      * @throws SQLException SQL exception
245      */
246     public Savepoint setSavepoint() throws SQLException {
247         ShardingSphereSavepoint result = new ShardingSphereSavepoint();
248         for (Connection each : cachedConnections.values()) {
249             ConnectionSavepointManager.getInstance().setSavepoint(each, result.getSavepointName());
250         }
251         methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, result.getSavepointName()));
252         return result;
253     }
254     
255     /**
256      * Release savepoint.
257      *
258      * @param savepoint savepoint
259      * @throws SQLException SQL exception
260      */
261     public void releaseSavepoint(final Savepoint savepoint) throws SQLException {
262         for (Connection each : cachedConnections.values()) {
263             ConnectionSavepointManager.getInstance().releaseSavepoint(each, savepoint.getSavepointName());
264         }
265     }
266     
267     /**
268      * Get transaction isolation.
269      *
270      * @return transaction isolation level
271      * @throws SQLException SQL exception
272      */
273     public Optional<Integer> getTransactionIsolation() throws SQLException {
274         return cachedConnections.values().isEmpty() ? Optional.empty() : Optional.of(cachedConnections.values().iterator().next().getTransactionIsolation());
275     }
276     
277     /**
278      * Set transaction isolation.
279      *
280      * @param level transaction isolation level
281      * @throws SQLException SQL exception
282      */
283     public void setTransactionIsolation(final int level) throws SQLException {
284         methodInvocationRecorder.record("setTransactionIsolation", connection -> connection.setTransactionIsolation(level));
285         forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setTransactionIsolation(level));
286     }
287     
288     /**
289      * Set read only.
290      *
291      * @param readOnly read only
292      * @throws SQLException SQL exception
293      */
294     public void setReadOnly(final boolean readOnly) throws SQLException {
295         methodInvocationRecorder.record("setReadOnly", connection -> connection.setReadOnly(readOnly));
296         forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setReadOnly(readOnly));
297     }
298     
299     /**
300      * Whether connection valid.
301      *
302      * @param timeout timeout
303      * @return connection valid or not
304      * @throws SQLException SQL exception
305      */
306     public boolean isValid(final int timeout) throws SQLException {
307         for (Connection each : cachedConnections.values()) {
308             if (!each.isValid(timeout)) {
309                 return false;
310             }
311         }
312         return true;
313     }
314     
315     /**
316      * Get random physical data source name.
317      *
318      * @return random physical data source name
319      */
320     public String getRandomPhysicalDataSourceName() {
321         return getRandomPhysicalDatabaseAndDataSourceName()[1];
322     }
323     
324     private String[] getRandomPhysicalDatabaseAndDataSourceName() {
325         Collection<String> cachedPhysicalDataSourceNames = Sets.intersection(physicalDataSourceMap.keySet(), cachedConnections.keySet());
326         Collection<String> databaseAndDatasourceNames = cachedPhysicalDataSourceNames.isEmpty() ? physicalDataSourceMap.keySet() : cachedPhysicalDataSourceNames;
327         return new ArrayList<>(databaseAndDatasourceNames).get(random.nextInt(databaseAndDatasourceNames.size())).split("\\.");
328     }
329     
330     /**
331      * Get random connection.
332      *
333      * @return random connection
334      * @throws SQLException SQL exception
335      */
336     public Connection getRandomConnection() throws SQLException {
337         String[] databaseAndDataSourceName = getRandomPhysicalDatabaseAndDataSourceName();
338         return getConnections(databaseAndDataSourceName[0], databaseAndDataSourceName[1], 0, 1, ConnectionMode.MEMORY_STRICTLY).get(0);
339     }
340     
341     @Override
342     public List<Connection> getConnections(final String dataSourceName, final int connectionOffset, final int connectionSize, final ConnectionMode connectionMode) throws SQLException {
343         return getConnections(connectionContext.getDatabaseName().orElse(databaseName), dataSourceName, connectionOffset, connectionSize, connectionMode);
344     }
345     
346     private List<Connection> getConnections(final String currentDatabaseName, final String dataSourceName, final int connectionOffset, final int connectionSize,
347                                             final ConnectionMode connectionMode) throws SQLException {
348         String cacheKey = getKey(currentDatabaseName, dataSourceName);
349         DataSource dataSource = databaseName.equals(currentDatabaseName)
350                 ? dataSourceMap.get(cacheKey)
351                 : contextManager.getStorageUnits(currentDatabaseName).get(dataSourceName).getDataSource();
352         Preconditions.checkNotNull(dataSource, "Missing the data source name: '%s'", dataSourceName);
353         Collection<Connection> connections;
354         synchronized (cachedConnections) {
355             connections = cachedConnections.get(cacheKey);
356         }
357         List<Connection> result;
358         int maxConnectionSize = connectionOffset + connectionSize;
359         if (connections.size() >= maxConnectionSize) {
360             result = new ArrayList<>(connections).subList(connectionOffset, maxConnectionSize);
361         } else if (connections.isEmpty()) {
362             Collection<Connection> newConnections = createConnections(currentDatabaseName, dataSourceName, dataSource, maxConnectionSize, connectionMode);
363             result = new ArrayList<>(newConnections).subList(connectionOffset, maxConnectionSize);
364             synchronized (cachedConnections) {
365                 cachedConnections.putAll(cacheKey, newConnections);
366             }
367         } else {
368             List<Connection> allConnections = new ArrayList<>(maxConnectionSize);
369             allConnections.addAll(connections);
370             Collection<Connection> newConnections = createConnections(currentDatabaseName, dataSourceName, dataSource, maxConnectionSize - connections.size(), connectionMode);
371             allConnections.addAll(newConnections);
372             result = allConnections.subList(connectionOffset, maxConnectionSize);
373             synchronized (cachedConnections) {
374                 cachedConnections.putAll(cacheKey, newConnections);
375             }
376         }
377         return result;
378     }
379     
380     private String getKey(final String databaseName, final String dataSourceName) {
381         return databaseName.toLowerCase() + "." + dataSourceName;
382     }
383     
384     @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
385     private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
386                                                final ConnectionMode connectionMode) throws SQLException {
387         if (1 == connectionSize) {
388             Connection connection = createConnection(databaseName, dataSourceName, dataSource, connectionContext.getTransactionContext());
389             methodInvocationRecorder.replay(connection);
390             return Collections.singletonList(connection);
391         }
392         if (ConnectionMode.CONNECTION_STRICTLY == connectionMode) {
393             return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
394         }
395         synchronized (dataSource) {
396             return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
397         }
398     }
399     
400     private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
401                                                final TransactionConnectionContext transactionConnectionContext) throws SQLException {
402         List<Connection> result = new ArrayList<>(connectionSize);
403         for (int i = 0; i < connectionSize; i++) {
404             try {
405                 Connection connection = createConnection(databaseName, dataSourceName, dataSource, transactionConnectionContext);
406                 methodInvocationRecorder.replay(connection);
407                 result.add(connection);
408             } catch (final SQLException ex) {
409                 for (Connection each : result) {
410                     each.close();
411                 }
412                 throw new OverallConnectionNotEnoughException(connectionSize, result.size(), ex).toSQLException();
413             }
414         }
415         return result;
416     }
417     
418     private Connection createConnection(final String databaseName, final String dataSourceName, final DataSource dataSource,
419                                         final TransactionConnectionContext transactionConnectionContext) throws SQLException {
420         Optional<Connection> connectionInTransaction =
421                 isRawJdbcDataSource(databaseName, dataSourceName) ? getConnectionTransaction().getConnection(databaseName, dataSourceName, transactionConnectionContext) : Optional.empty();
422         return connectionInTransaction.isPresent() ? connectionInTransaction.get() : dataSource.getConnection();
423     }
424     
425     private boolean isRawJdbcDataSource(final String databaseName, final String dataSourceName) {
426         return !trafficDataSourceMap.containsKey(getKey(databaseName, dataSourceName));
427     }
428     
429     @Override
430     public void close() throws SQLException {
431         try {
432             forceExecuteTemplate.execute(cachedConnections.values(), Connection::close);
433         } finally {
434             cachedConnections.clear();
435         }
436     }
437 }