View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.driver.jdbc.core.connection;
19  
20  import com.google.common.base.Preconditions;
21  import com.google.common.collect.LinkedHashMultimap;
22  import com.google.common.collect.Multimap;
23  import com.google.common.collect.Sets;
24  import lombok.Getter;
25  import org.apache.shardingsphere.driver.jdbc.adapter.executor.ForceExecuteTemplate;
26  import org.apache.shardingsphere.driver.jdbc.adapter.invocation.MethodInvocationRecorder;
27  import org.apache.shardingsphere.driver.jdbc.core.savepoint.ShardingSphereSavepoint;
28  import org.apache.shardingsphere.infra.exception.kernel.connection.OverallConnectionNotEnoughException;
29  import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
30  import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DatabaseConnectionManager;
31  import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
32  import org.apache.shardingsphere.infra.session.connection.transaction.TransactionConnectionContext;
33  import org.apache.shardingsphere.mode.manager.ContextManager;
34  import org.apache.shardingsphere.transaction.savepoint.ConnectionSavepointManager;
35  import org.apache.shardingsphere.transaction.ConnectionTransaction;
36  import org.apache.shardingsphere.transaction.api.TransactionType;
37  import org.apache.shardingsphere.transaction.rule.TransactionRule;
38  
39  import javax.sql.DataSource;
40  import java.sql.Connection;
41  import java.sql.SQLException;
42  import java.sql.Savepoint;
43  import java.util.ArrayList;
44  import java.util.Collection;
45  import java.util.Collections;
46  import java.util.List;
47  import java.util.Map;
48  import java.util.Optional;
49  import java.util.concurrent.ThreadLocalRandom;
50  import java.util.stream.Collectors;
51  
52  /**
53   * Database connection manager of ShardingSphere-JDBC.
54   */
55  public final class DriverDatabaseConnectionManager implements DatabaseConnectionManager<Connection>, AutoCloseable {
56      
57      private final String currentDatabaseName;
58      
59      private final ContextManager contextManager;
60      
61      private final Map<String, DataSource> dataSourceMap;
62      
63      @Getter
64      private final ConnectionContext connectionContext;
65      
66      private final Multimap<String, Connection> cachedConnections = LinkedHashMultimap.create();
67      
68      private final MethodInvocationRecorder<Connection> methodInvocationRecorder = new MethodInvocationRecorder<>();
69      
70      private final ForceExecuteTemplate<Connection> forceExecuteTemplate = new ForceExecuteTemplate<>();
71      
72      public DriverDatabaseConnectionManager(final String currentDatabaseName, final ContextManager contextManager) {
73          this.currentDatabaseName = currentDatabaseName;
74          this.contextManager = contextManager;
75          dataSourceMap = contextManager.getStorageUnits(currentDatabaseName).entrySet()
76                  .stream().collect(Collectors.toMap(entry -> getKey(currentDatabaseName, entry.getKey()), entry -> entry.getValue().getDataSource()));
77          connectionContext = new ConnectionContext(cachedConnections::keySet);
78          connectionContext.setCurrentDatabaseName(currentDatabaseName);
79      }
80      
81      private String getKey(final String databaseName, final String dataSourceName) {
82          return databaseName.toLowerCase() + "." + dataSourceName;
83      }
84      
85      /**
86       * Get connection transaction.
87       *
88       * @return connection transaction
89       */
90      public ConnectionTransaction getConnectionTransaction() {
91          TransactionRule rule = contextManager.getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class);
92          return new ConnectionTransaction(rule, connectionContext.getTransactionContext());
93      }
94      
95      /**
96       * Set auto commit.
97       *
98       * @param autoCommit auto commit
99       * @throws SQLException SQL exception
100      */
101     public void setAutoCommit(final boolean autoCommit) throws SQLException {
102         methodInvocationRecorder.record("setAutoCommit", connection -> connection.setAutoCommit(autoCommit));
103         forceExecuteTemplate.execute(getCachedConnections(), connection -> connection.setAutoCommit(autoCommit));
104         if (autoCommit) {
105             clearCachedConnections();
106         }
107     }
108     
109     private Collection<Connection> getCachedConnections() {
110         return cachedConnections.values();
111     }
112     
113     /**
114      * Clear cached connections.
115      *
116      * @throws SQLException SQL exception
117      */
118     public void clearCachedConnections() throws SQLException {
119         try {
120             forceExecuteTemplate.execute(cachedConnections.values(), Connection::close);
121         } finally {
122             cachedConnections.clear();
123         }
124     }
125     
126     /**
127      * Begin transaction.
128      *
129      * @throws SQLException SQL exception
130      */
131     public void begin() throws SQLException {
132         ConnectionTransaction connectionTransaction = getConnectionTransaction();
133         if (TransactionType.isDistributedTransaction(connectionTransaction.getTransactionType())) {
134             close();
135             connectionTransaction.begin();
136         }
137         connectionContext.getTransactionContext().beginTransaction(String.valueOf(connectionTransaction.getTransactionType()), connectionTransaction.getDistributedTransactionManager());
138     }
139     
140     /**
141      * Commit.
142      *
143      * @throws SQLException SQL exception
144      */
145     public void commit() throws SQLException {
146         ConnectionTransaction connectionTransaction = getConnectionTransaction();
147         try {
148             if (connectionTransaction.isLocalTransaction() && connectionContext.getTransactionContext().isExceptionOccur()) {
149                 forceExecuteTemplate.execute(getCachedConnections(), Connection::rollback);
150             } else if (connectionTransaction.isLocalTransaction()) {
151                 forceExecuteTemplate.execute(getCachedConnections(), Connection::commit);
152             } else {
153                 connectionTransaction.commit();
154             }
155         } finally {
156             methodInvocationRecorder.remove("setSavepoint");
157             for (Connection each : getCachedConnections()) {
158                 ConnectionSavepointManager.getInstance().transactionFinished(each);
159             }
160             connectionContext.close();
161             clearCachedConnections();
162         }
163     }
164     
165     /**
166      * Rollback.
167      *
168      * @throws SQLException SQL exception
169      */
170     public void rollback() throws SQLException {
171         ConnectionTransaction connectionTransaction = getConnectionTransaction();
172         try {
173             if (connectionTransaction.isLocalTransaction()) {
174                 forceExecuteTemplate.execute(getCachedConnections(), Connection::rollback);
175             } else {
176                 connectionTransaction.rollback();
177             }
178         } finally {
179             methodInvocationRecorder.remove("setSavepoint");
180             for (Connection each : getCachedConnections()) {
181                 ConnectionSavepointManager.getInstance().transactionFinished(each);
182             }
183             connectionContext.close();
184             clearCachedConnections();
185         }
186     }
187     
188     /**
189      * Rollback to savepoint.
190      *
191      * @param savepoint savepoint
192      * @throws SQLException SQL exception
193      */
194     public void rollback(final Savepoint savepoint) throws SQLException {
195         for (Connection each : getCachedConnections()) {
196             ConnectionSavepointManager.getInstance().rollbackToSavepoint(each, savepoint.getSavepointName());
197         }
198     }
199     
200     /**
201      * Set savepoint.
202      *
203      * @param savepointName savepoint name
204      * @return savepoint savepoint
205      * @throws SQLException SQL exception
206      */
207     public Savepoint setSavepoint(final String savepointName) throws SQLException {
208         ShardingSphereSavepoint result = new ShardingSphereSavepoint(savepointName);
209         for (Connection each : getCachedConnections()) {
210             ConnectionSavepointManager.getInstance().setSavepoint(each, savepointName);
211         }
212         methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, savepointName));
213         return result;
214     }
215     
216     /**
217      * Set savepoint.
218      *
219      * @return savepoint savepoint
220      * @throws SQLException SQL exception
221      */
222     public Savepoint setSavepoint() throws SQLException {
223         ShardingSphereSavepoint result = new ShardingSphereSavepoint();
224         for (Connection each : getCachedConnections()) {
225             ConnectionSavepointManager.getInstance().setSavepoint(each, result.getSavepointName());
226         }
227         methodInvocationRecorder.record("setSavepoint", target -> ConnectionSavepointManager.getInstance().setSavepoint(target, result.getSavepointName()));
228         return result;
229     }
230     
231     /**
232      * Release savepoint.
233      *
234      * @param savepoint savepoint
235      * @throws SQLException SQL exception
236      */
237     public void releaseSavepoint(final Savepoint savepoint) throws SQLException {
238         methodInvocationRecorder.remove("setSavepoint");
239         for (Connection each : getCachedConnections()) {
240             ConnectionSavepointManager.getInstance().releaseSavepoint(each, savepoint.getSavepointName());
241         }
242     }
243     
244     /**
245      * Get transaction isolation.
246      *
247      * @return transaction isolation level
248      * @throws SQLException SQL exception
249      */
250     public Optional<Integer> getTransactionIsolation() throws SQLException {
251         return cachedConnections.values().isEmpty() ? Optional.empty() : Optional.of(cachedConnections.values().iterator().next().getTransactionIsolation());
252     }
253     
254     /**
255      * Set transaction isolation.
256      *
257      * @param level transaction isolation level
258      * @throws SQLException SQL exception
259      */
260     public void setTransactionIsolation(final int level) throws SQLException {
261         methodInvocationRecorder.record("setTransactionIsolation", connection -> connection.setTransactionIsolation(level));
262         forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setTransactionIsolation(level));
263     }
264     
265     /**
266      * Set read only.
267      *
268      * @param readOnly read only
269      * @throws SQLException SQL exception
270      */
271     public void setReadOnly(final boolean readOnly) throws SQLException {
272         methodInvocationRecorder.record("setReadOnly", connection -> connection.setReadOnly(readOnly));
273         forceExecuteTemplate.execute(cachedConnections.values(), connection -> connection.setReadOnly(readOnly));
274     }
275     
276     /**
277      * Whether connection valid.
278      *
279      * @param timeout timeout
280      * @return connection valid or not
281      * @throws SQLException SQL exception
282      */
283     public boolean isValid(final int timeout) throws SQLException {
284         for (Connection each : cachedConnections.values()) {
285             if (!each.isValid(timeout)) {
286                 return false;
287             }
288         }
289         return true;
290     }
291     
292     /**
293      * Get random physical data source name.
294      *
295      * @return random physical data source name
296      */
297     public String getRandomPhysicalDataSourceName() {
298         return getRandomPhysicalDatabaseAndDataSourceName()[1];
299     }
300     
301     private String[] getRandomPhysicalDatabaseAndDataSourceName() {
302         Collection<String> cachedPhysicalDataSourceNames = Sets.intersection(dataSourceMap.keySet(), cachedConnections.keySet());
303         Collection<String> databaseAndDatasourceNames = cachedPhysicalDataSourceNames.isEmpty() ? dataSourceMap.keySet() : cachedPhysicalDataSourceNames;
304         return new ArrayList<>(databaseAndDatasourceNames).get(ThreadLocalRandom.current().nextInt(databaseAndDatasourceNames.size())).split("\\.");
305     }
306     
307     /**
308      * Get random connection.
309      *
310      * @return random connection
311      * @throws SQLException SQL exception
312      */
313     public Connection getRandomConnection() throws SQLException {
314         String[] databaseAndDataSourceName = getRandomPhysicalDatabaseAndDataSourceName();
315         return getConnections0(databaseAndDataSourceName[0], databaseAndDataSourceName[1], 0, 1, ConnectionMode.MEMORY_STRICTLY).get(0);
316     }
317     
318     @Override
319     public List<Connection> getConnections(final String databaseName, final String dataSourceName, final int connectionOffset, final int connectionSize,
320                                            final ConnectionMode connectionMode) throws SQLException {
321         return getConnections0(databaseName, dataSourceName, connectionOffset, connectionSize, connectionMode);
322     }
323     
324     private List<Connection> getConnections0(final String databaseName, final String dataSourceName, final int connectionOffset, final int connectionSize,
325                                              final ConnectionMode connectionMode) throws SQLException {
326         String cacheKey = getKey(databaseName, dataSourceName);
327         DataSource dataSource = currentDatabaseName.equals(databaseName) ? dataSourceMap.get(cacheKey) : contextManager.getStorageUnits(databaseName).get(dataSourceName).getDataSource();
328         Preconditions.checkNotNull(dataSource, "Missing the data source name: '%s'", dataSourceName);
329         Collection<Connection> connections;
330         synchronized (cachedConnections) {
331             connections = cachedConnections.get(cacheKey);
332         }
333         List<Connection> result;
334         int maxConnectionSize = connectionOffset + connectionSize;
335         if (connections.size() >= maxConnectionSize) {
336             result = new ArrayList<>(connections).subList(connectionOffset, maxConnectionSize);
337         } else if (connections.isEmpty()) {
338             Collection<Connection> newConnections = createConnections(databaseName, dataSourceName, dataSource, maxConnectionSize, connectionMode);
339             result = new ArrayList<>(newConnections).subList(connectionOffset, maxConnectionSize);
340             synchronized (cachedConnections) {
341                 cachedConnections.putAll(cacheKey, newConnections);
342             }
343         } else {
344             List<Connection> allConnections = new ArrayList<>(maxConnectionSize);
345             allConnections.addAll(connections);
346             Collection<Connection> newConnections = createConnections(databaseName, dataSourceName, dataSource, maxConnectionSize - connections.size(), connectionMode);
347             allConnections.addAll(newConnections);
348             result = allConnections.subList(connectionOffset, maxConnectionSize);
349             synchronized (cachedConnections) {
350                 cachedConnections.putAll(cacheKey, newConnections);
351             }
352         }
353         return result;
354     }
355     
356     @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
357     private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
358                                                final ConnectionMode connectionMode) throws SQLException {
359         if (1 == connectionSize) {
360             Connection connection = createConnection(databaseName, dataSourceName, dataSource, connectionContext.getTransactionContext());
361             try {
362                 methodInvocationRecorder.replay(connection);
363             } catch (final SQLException ex) {
364                 connection.close();
365                 throw ex;
366             }
367             return Collections.singletonList(connection);
368         }
369         if (ConnectionMode.CONNECTION_STRICTLY == connectionMode) {
370             return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
371         }
372         synchronized (dataSource) {
373             return createConnections(databaseName, dataSourceName, dataSource, connectionSize, connectionContext.getTransactionContext());
374         }
375     }
376     
377     private List<Connection> createConnections(final String databaseName, final String dataSourceName, final DataSource dataSource, final int connectionSize,
378                                                final TransactionConnectionContext transactionConnectionContext) throws SQLException {
379         List<Connection> result = new ArrayList<>(connectionSize);
380         for (int i = 0; i < connectionSize; i++) {
381             try {
382                 Connection connection = createConnection(databaseName, dataSourceName, dataSource, transactionConnectionContext);
383                 methodInvocationRecorder.replay(connection);
384                 result.add(connection);
385             } catch (final SQLException ex) {
386                 for (Connection each : result) {
387                     each.close();
388                 }
389                 throw new OverallConnectionNotEnoughException(connectionSize, result.size(), ex).toSQLException();
390             }
391         }
392         return result;
393     }
394     
395     private Connection createConnection(final String databaseName, final String dataSourceName, final DataSource dataSource,
396                                         final TransactionConnectionContext transactionConnectionContext) throws SQLException {
397         Optional<Connection> connectionInTransaction = getConnectionTransaction().getConnection(databaseName, dataSourceName, transactionConnectionContext);
398         return connectionInTransaction.isPresent() ? connectionInTransaction.get() : dataSource.getConnection();
399     }
400     
401     @Override
402     public void close() throws SQLException {
403         clearCachedConnections();
404     }
405 }