1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.infra.database.mysql.checker;
19
20 import org.apache.shardingsphere.infra.database.core.checker.DialectDatabasePrivilegeChecker;
21 import org.apache.shardingsphere.infra.database.core.checker.PrivilegeCheckType;
22 import org.apache.shardingsphere.infra.database.core.exception.CheckDatabaseEnvironmentFailedException;
23 import org.apache.shardingsphere.infra.database.core.exception.MissingRequiredPrivilegeException;
24
25 import javax.sql.DataSource;
26 import java.sql.Connection;
27 import java.sql.PreparedStatement;
28 import java.sql.ResultSet;
29 import java.sql.SQLException;
30 import java.util.Arrays;
31 import java.util.Collection;
32 import java.util.Collections;
33 import java.util.EnumMap;
34 import java.util.Map;
35
36
37
38
39 public final class MySQLDatabasePrivilegeChecker implements DialectDatabasePrivilegeChecker {
40
41 private static final String SHOW_GRANTS_SQL = "SHOW GRANTS";
42
43 private static final int MYSQL_MAJOR_VERSION_8 = 8;
44
45
46 private static final String[][] PIPELINE_REQUIRED_PRIVILEGES =
47 {{"ALL PRIVILEGES", "ON *.*"}, {"REPLICATION SLAVE", "REPLICATION CLIENT", "ON *.*"}, {"REPLICATION SLAVE", "BINLOG MONITOR", "ON *.*"}};
48
49 private static final String[][] XA_REQUIRED_PRIVILEGES = {{"ALL PRIVILEGES", "ON *.*"}, {"XA_RECOVER_ADMIN", "ON *.*"}};
50
51 private static final Map<PrivilegeCheckType, Collection<String>> REQUIRED_PRIVILEGES_FOR_MESSAGE = new EnumMap<>(PrivilegeCheckType.class);
52
53 static {
54 REQUIRED_PRIVILEGES_FOR_MESSAGE.put(PrivilegeCheckType.PIPELINE, Arrays.asList("REPLICATION SLAVE", "REPLICATION CLIENT"));
55 REQUIRED_PRIVILEGES_FOR_MESSAGE.put(PrivilegeCheckType.SELECT, Collections.singleton("SELECT ON DATABASE"));
56 REQUIRED_PRIVILEGES_FOR_MESSAGE.put(PrivilegeCheckType.XA, Collections.singleton("XA_RECOVER_ADMIN"));
57 }
58
59 @Override
60 public void check(final DataSource dataSource, final PrivilegeCheckType privilegeCheckType) {
61 try (Connection connection = dataSource.getConnection()) {
62 if (PrivilegeCheckType.XA == privilegeCheckType && MYSQL_MAJOR_VERSION_8 != connection.getMetaData().getDatabaseMajorVersion()) {
63 return;
64 }
65 checkPrivilege(connection, privilegeCheckType);
66 } catch (final SQLException ex) {
67 throw new CheckDatabaseEnvironmentFailedException(ex);
68 }
69 }
70
71 private void checkPrivilege(final Connection connection, final PrivilegeCheckType privilegeCheckType) {
72 try (
73 PreparedStatement preparedStatement = connection.prepareStatement(SHOW_GRANTS_SQL);
74 ResultSet resultSet = preparedStatement.executeQuery()) {
75 while (resultSet.next()) {
76 String privilege = resultSet.getString(1).toUpperCase();
77 if (matchPrivileges(privilege, getRequiredPrivileges(connection, privilegeCheckType))) {
78 return;
79 }
80 }
81 } catch (final SQLException ex) {
82 throw new CheckDatabaseEnvironmentFailedException(ex);
83 }
84 throw new MissingRequiredPrivilegeException(REQUIRED_PRIVILEGES_FOR_MESSAGE.get(privilegeCheckType));
85 }
86
87 private String[][] getRequiredPrivileges(final Connection connection, final PrivilegeCheckType privilegeCheckType) throws SQLException {
88 switch (privilegeCheckType) {
89 case PIPELINE:
90 return PIPELINE_REQUIRED_PRIVILEGES;
91 case SELECT:
92 return getSelectRequiredPrivilege(connection);
93 case XA:
94 return XA_REQUIRED_PRIVILEGES;
95 default:
96 return new String[0][0];
97 }
98 }
99
100 private String[][] getSelectRequiredPrivilege(final Connection connection) throws SQLException {
101 String onCatalog = String.format("ON `%s`.*", connection.getCatalog().toUpperCase());
102 return new String[][]{{"ALL PRIVILEGES", "ON *.*"}, {"SELECT", "ON *.*"}, {"ALL PRIVILEGES", onCatalog}, {"SELECT", onCatalog}};
103 }
104
105 private boolean matchPrivileges(final String grantedPrivileges, final String[][] requiredPrivileges) {
106 return Arrays.stream(requiredPrivileges).anyMatch(each -> Arrays.stream(each).allMatch(grantedPrivileges::contains));
107 }
108
109 @Override
110 public String getDatabaseType() {
111 return "MySQL";
112 }
113 }