View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.driver.jdbc.core.connection;
19  
20  import com.google.common.base.Preconditions;
21  import com.google.common.collect.LinkedHashMultimap;
22  import com.google.common.collect.Multimap;
23  import com.google.common.collect.Sets;
24  import lombok.Getter;
25  import org.apache.shardingsphere.driver.jdbc.adapter.executor.ForceExecuteTemplate;
26  import org.apache.shardingsphere.driver.jdbc.adapter.invocation.MethodInvocationRecorder;
27  import org.apache.shardingsphere.driver.jdbc.core.savepoint.ShardingSphereSavepoint;
28  import org.apache.shardingsphere.infra.exception.kernel.connection.OverallConnectionNotEnoughException;
29  import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
30  import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DatabaseConnectionManager;
31  import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
32  import org.apache.shardingsphere.infra.session.connection.transaction.TransactionConnectionContext;
33  import org.apache.shardingsphere.mode.manager.ContextManager;
34  import org.apache.shardingsphere.transaction.savepoint.ConnectionSavepointManager;
35  import org.apache.shardingsphere.transaction.ConnectionTransaction;
36  import org.apache.shardingsphere.transaction.rule.TransactionRule;
37  
38  import javax.sql.DataSource;
39  import java.sql.Connection;
40  import java.sql.SQLException;
41  import java.sql.Savepoint;
42  import java.util.ArrayList;
43  import java.util.Collection;
44  import java.util.Collections;
45  import java.util.List;
46  import java.util.Map;
47  import java.util.Optional;
48  import java.util.concurrent.ThreadLocalRandom;
49  import java.util.stream.Collectors;
50  
51  /**
52   * Database connection manager of ShardingSphere-JDBC.
53   */
54  public final class DriverDatabaseConnectionManager implements DatabaseConnectionManager<Connection>, AutoCloseable {
55      
56      private final String currentDatabaseName;
57      
58      private final ContextManager contextManager;
59      
60      private final Map<String, DataSource> dataSourceMap;
61      
62      @Getter
63      private final ConnectionContext connectionContext;
64      
65      private final Multimap<String, Connection> cachedConnections = LinkedHashMultimap.create();
66      
67      private final MethodInvocationRecorder<Connection> methodInvocationRecorder = new MethodInvocationRecorder<>();
68      
69      private final ForceExecuteTemplate<Connection> forceExecuteTemplate = new ForceExecuteTemplate<>();
70      
71      public DriverDatabaseConnectionManager(final String currentDatabaseName, final ContextManager contextManager) {
72          this.currentDatabaseName = currentDatabaseName;
73          this.contextManager = contextManager;
74          dataSourceMap = contextManager.getStorageUnits(currentDatabaseName).entrySet()
75                  .stream().collect(Collectors.toMap(entry -> getKey(currentDatabaseName, entry.getKey()), entry -> entry.getValue().getDataSource()));
76          connectionContext = new ConnectionContext(cachedConnections::keySet);
77          connectionContext.setCurrentDatabaseName(currentDatabaseName);
78      }
79      
80      private String getKey(final String databaseName, final String dataSourceName) {
81          return databaseName.toLowerCase() + "." + dataSourceName;
82      }
83      
84      /**
85       * Get connection transaction.
86       *
87       * @return connection transaction
88       */
89      public ConnectionTransaction getConnectionTransaction() {
90          TransactionRule rule = contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class);
91          return new ConnectionTransaction(rule, connectionContext.getTransactionContext());
92      }
93      
94      /**
95       * Set auto commit.
96       *
97       * @param autoCommit auto commit
98       * @throws SQLException SQL exception
99       */
100     public void setAutoCommit(final boolean autoCommit) throws SQLException {
101         methodInvocationRecorder.record("setAutoCommit", connection -> connection.setAutoCommit(autoCommit));
102         forceExecuteTemplate.execute(getCachedConnections(), connection -> connection.setAutoCommit(autoCommit));
103         if (autoCommit) {
104             clearCachedConnections();
105         }
106     }
107     
108     private Collection<Connection> getCachedConnections() {
109         return cachedConnections.values();
110     }
111     
112     /**
113      * Clear cached connections.
114      *
115      * @throws SQLException SQL exception
116      */
117     public void clearCachedConnections() throws SQLException {
118         try {
119             forceExecuteTemplate.execute(cachedConnections.values(), Connection::close);
120         } finally {
121             cachedConnections.clear();
122         }
123     }
124     
125     /**
126      * Begin transaction.
127      *
128      * @throws SQLException SQL exception
129      */
130     public void begin() throws SQLException {
131         ConnectionTransaction connectionTransaction = getConnectionTransaction();
132         connectionContext.getTransactionContext().beginTransaction(connectionTransaction.getTransactionType().name(), connectionTransaction.getDistributedTransactionManager());
133         if (!connectionTransaction.isLocalTransaction()) {
134             close();
135         }
136         doBegin(connectionTransaction);
137     }
138     
139     private void doBegin(final ConnectionTransaction connectionTransaction) throws SQLException {
140         if (connectionTransaction.isLocalTransaction()) {
141             setAutoCommit(false);
142         } else {
143             connectionTransaction.begin();
144         }
145     }
146     
147     /**
148      * Commit.
149      *
150      * @throws SQLException SQL exception
151      */
152     public void commit() throws SQLException {
153         ConnectionTransaction connectionTransaction = getConnectionTransaction();
154         if (!connectionContext.getTransactionContext().isInTransaction() && !connectionTransaction.isInDistributedTransaction()) {
155             return;
156         }
157         try {
158             if (connectionTransaction.isLocalTransaction() && connectionContext.getTransactionContext().isExceptionOccur()) {
159                 forceExecuteTemplate.execute(getCachedConnections(), Connection::rollback);
160             } else if (connectionTransaction.isLocalTransaction()) {
161                 forceExecuteTemplate.execute(getCachedConnections(), Connection::commit);
162             } else {
163                 connectionTransaction.commit();
164             }
165         } finally {
166             clear();
167         }
168     }
169     
170     /**
171      * Rollback.
172      *
173      * @throws SQLException SQL exception
174      */
175     public void rollback() throws SQLException {
176         ConnectionTransaction connectionTransaction = getConnectionTransaction();
177         if (!connectionContext.getTransactionContext().isInTransaction() && !connectionTransaction.isInDistributedTransaction()) {
178             return;
179         }
180         try {
181             if (connectionTransaction.isLocalTransaction()) {
182                 forceExecuteTemplate.execute(getCachedConnections(), Connection::rollback);
183             } else {
184                 connectionTransaction.rollback();
185             }
186         } finally {
187             clear();
188         }
189     }
190     
191     /**
192      * Rollback to savepoint.
193      *
194      * @param savepoint savepoint
195      * @throws SQLException SQL exception
196      */
197     public void rollback(final Savepoint savepoint) throws SQLException {
198         for (Connection each : getCachedConnections()) {
199             ConnectionSavepointManager.getInstance().rollbackToSavepoint(each, savepoint.getSavepointName());
200         }
201     }
202     
203     private void clear() {
204         methodInvocationRecorder.remove("setSavepoint");
205         for (Connection each : getCachedConnections()) {
206             ConnectionSavepointManager.getInstance().transactionFinished(each);
207         }
208         connectionContext.close();
209     }
210     
211     /**
212      * Set savepoint.
213      *
214      * @param savepointName savepoint name
215      * @return savepoint savepoint
216      * @throws SQLException SQL exception
217      */
218     public Savepoint setSavepoint(final String savepointName) throws SQLException {
219         ShardingSphereSavepoint result = new ShardingSphereSavepoint(savepointName);
220         for (Connection each : getCachedConnections()) {
221             ConnectionSavepointManager.getInstance().setSavepoint(each, savepointName);
222         }
223         methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, savepointName));
224         return result;
225     }
226     
227     /**
228      * Set savepoint.
229      *
230      * @return savepoint savepoint
231      * @throws SQLException SQL exception
232      */
233     public Savepoint setSavepoint() throws SQLException {
234         ShardingSphereSavepoint result = new ShardingSphereSavepoint();
235         for (Connection each : getCachedConnections()) {
236             ConnectionSavepointManager.getInstance().setSavepoint(each, result.getSavepointName());
237         }
238         methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, result.getSavepointName()));
239         return result;
240     }
241     
242     /**
243      * Release savepoint.
244      *
245      * @param savepoint savepoint
246      * @throws SQLException SQL exception
247      */
248     public void releaseSavepoint(final Savepoint savepoint) throws SQLException {
249         methodInvocationRecorder.remove("setSavepoint");
250         for (Connection each : getCachedConnections()) {
251             ConnectionSavepointManager.getInstance().releaseSavepoint(each, savepoint.getSavepointName());
252         }
253     }
254     
255     /**
256      * Get transaction isolation.
257      *
258      * @return transaction isolation level
259      * @throws SQLException SQL exception
260      */
261     public Optional<Integer> getTransactionIsolation() throws SQLException {
262         return cachedConnections.values().isEmpty() ? Optional.empty() : Optional.of(cachedConnections.values().iterator().next().getTransactionIsolation());
263     }
264     
265     /**
266      * Set transaction isolation.
267      *
268      * @param level transaction isolation level
269      * @throws SQLException SQL exception
270      */
271     public void setTransactionIsolation(final int level) throws SQLException {
272         methodInvocationRecorder.record("setTransactionIsolation", connection -> connection.setTransactionIsolation(level));
273         forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setTransactionIsolation(level));
274     }
275     
276     /**
277      * Set read only.
278      *
279      * @param readOnly read only
280      * @throws SQLException SQL exception
281      */
282     public void setReadOnly(final boolean readOnly) throws SQLException {
283         methodInvocationRecorder.record("setReadOnly", connection -> connection.setReadOnly(readOnly));
284         forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setReadOnly(readOnly));
285     }
286     
287     /**
288      * Whether connection valid.
289      *
290      * @param timeout timeout
291      * @return connection valid or not
292      * @throws SQLException SQL exception
293      */
294     public boolean isValid(final int timeout) throws SQLException {
295         for (Connection each : cachedConnections.values()) {
296             if (!each.isValid(timeout)) {
297                 return false;
298             }
299         }
300         return true;
301     }
302     
303     /**
304      * Get random physical data source name.
305      *
306      * @return random physical data source name
307      */
308     public String getRandomPhysicalDataSourceName() {
309         return getRandomPhysicalDatabaseAndDataSourceName()[1];
310     }
311     
312     private String[] getRandomPhysicalDatabaseAndDataSourceName() {
313         Collection<String> cachedPhysicalDataSourceNames = Sets.intersection(dataSourceMap.keySet(), cachedConnections.keySet());
314         Collection<String> databaseAndDatasourceNames = cachedPhysicalDataSourceNames.isEmpty() ? dataSourceMap.keySet() : cachedPhysicalDataSourceNames;
315         return new ArrayList<>(databaseAndDatasourceNames).get(ThreadLocalRandom.current().nextInt(databaseAndDatasourceNames.size())).split("\\.");
316     }
317     
318     /**
319      * Get random connection.
320      *
321      * @return random connection
322      * @throws SQLException SQL exception
323      */
324     public Connection getRandomConnection() throws SQLException {
325         String[] databaseAndDataSourceName = getRandomPhysicalDatabaseAndDataSourceName();
326         return getConnections0(databaseAndDataSourceName[0], databaseAndDataSourceName[1], 0, 1, ConnectionMode.MEMORY_STRICTLY).get(0);
327     }
328     
329     @Override
330     public List<Connection> getConnections(final String databaseName, final String dataSourceName, final int connectionOffset, final int connectionSize,
331                                            final ConnectionMode connectionMode) throws SQLException {
332         return getConnections0(databaseName, dataSourceName, connectionOffset, connectionSize, connectionMode);
333     }
334     
335     private List<Connection> getConnections0(final String databaseName, final String dataSourceName, final int connectionOffset, final int connectionSize,
336                                              final ConnectionMode connectionMode) throws SQLException {
337         String cacheKey = getKey(databaseName, dataSourceName);
338         DataSource dataSource = currentDatabaseName.equals(databaseName) ? dataSourceMap.get(cacheKey) : contextManager.getStorageUnits(databaseName).get(dataSourceName).getDataSource();
339         Preconditions.checkNotNull(dataSource, "Missing the data source name: '%s'", dataSourceName);
340         Collection<Connection> connections;
341         synchronized (cachedConnections) {
342             connections = cachedConnections.get(cacheKey);
343         }
344         List<Connection> result;
345         int maxConnectionSize = connectionOffset + connectionSize;
346         if (connections.size() >= maxConnectionSize) {
347             result = new ArrayList<>(connections).subList(connectionOffset, maxConnectionSize);
348         } else if (connections.isEmpty()) {
349             Collection<Connection> newConnections = createConnections(databaseName, dataSourceName, dataSource, maxConnectionSize, connectionMode);
350             result = new ArrayList<>(newConnections).subList(connectionOffset, maxConnectionSize);
351             synchronized (cachedConnections) {
352                 cachedConnections.putAll(cacheKey, newConnections);
353             }
354         } else {
355             List<Connection> allConnections = new ArrayList<>(maxConnectionSize);
356             allConnections.addAll(connections);
357             Collection<Connection> newConnections = createConnections(databaseName, dataSourceName, dataSource, maxConnectionSize - connections.size(), connectionMode);
358             allConnections.addAll(newConnections);
359             result = allConnections.subList(connectionOffset, maxConnectionSize);
360             synchronized (cachedConnections) {
361                 cachedConnections.putAll(cacheKey, newConnections);
362             }
363         }
364         return result;
365     }
366     
367     @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
368     private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
369                                                final ConnectionMode connectionMode) throws SQLException {
370         if (1 == connectionSize) {
371             Connection connection = createConnection(databaseName, dataSourceName, dataSource, connectionContext.getTransactionContext());
372             try {
373                 methodInvocationRecorder.replay(connection);
374             } catch (final SQLException ex) {
375                 connection.close();
376                 throw ex;
377             }
378             return Collections.singletonList(connection);
379         }
380         if (ConnectionMode.CONNECTION_STRICTLY == connectionMode) {
381             return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
382         }
383         synchronized (dataSource) {
384             return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
385         }
386     }
387     
388     private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
389                                                final TransactionConnectionContext transactionConnectionContext) throws SQLException {
390         List<Connection> result = new ArrayList<>(connectionSize);
391         for (int i = 0; i < connectionSize; i++) {
392             try {
393                 Connection connection = createConnection(databaseName, dataSourceName, dataSource, transactionConnectionContext);
394                 methodInvocationRecorder.replay(connection);
395                 result.add(connection);
396             } catch (final SQLException ex) {
397                 for (Connection each : result) {
398                     each.close();
399                 }
400                 throw new OverallConnectionNotEnoughException(connectionSize, result.size(), ex).toSQLException();
401             }
402         }
403         return result;
404     }
405     
406     private Connection createConnection(final String databaseName, final String dataSourceName, final DataSource dataSource,
407                                         final TransactionConnectionContext transactionConnectionContext) throws SQLException {
408         Optional<Connection> connectionInTransaction = getConnectionTransaction().getConnection(databaseName, dataSourceName, transactionConnectionContext);
409         return connectionInTransaction.isPresent() ? connectionInTransaction.get() : dataSource.getConnection();
410     }
411     
412     @Override
413     public void close() throws SQLException {
414         clearCachedConnections();
415     }
416 }