1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.proxy.frontend.mysql.authentication;
19
20 import com.google.common.base.Strings;
21 import io.netty.buffer.ByteBufUtil;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.epoll.EpollDomainSocketChannel;
24 import lombok.extern.slf4j.Slf4j;
25 import org.apache.shardingsphere.authentication.Authenticator;
26 import org.apache.shardingsphere.authentication.AuthenticatorFactory;
27 import org.apache.shardingsphere.authentication.result.AuthenticationResult;
28 import org.apache.shardingsphere.authentication.result.AuthenticationResultBuilder;
29 import org.apache.shardingsphere.authority.checker.AuthorityChecker;
30 import org.apache.shardingsphere.authority.rule.AuthorityRule;
31 import org.apache.shardingsphere.database.exception.core.exception.connection.AccessDeniedException;
32 import org.apache.shardingsphere.database.exception.core.exception.syntax.database.UnknownDatabaseException;
33 import org.apache.shardingsphere.database.exception.mysql.exception.DatabaseAccessDeniedException;
34 import org.apache.shardingsphere.database.exception.mysql.exception.HandshakeException;
35 import org.apache.shardingsphere.database.protocol.constant.CommonConstants;
36 import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLCapabilityFlag;
37 import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLCharacterSets;
38 import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLConnectionPhase;
39 import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLConstants;
40 import org.apache.shardingsphere.database.protocol.mysql.constant.MySQLStatusFlag;
41 import org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
42 import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthSwitchRequestPacket;
43 import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthSwitchResponsePacket;
44 import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthenticationPluginData;
45 import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakePacket;
46 import org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
47 import org.apache.shardingsphere.database.protocol.mysql.payload.MySQLPacketPayload;
48 import org.apache.shardingsphere.database.protocol.payload.PacketPayload;
49 import org.apache.shardingsphere.infra.metadata.user.Grantee;
50 import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
51 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
52 import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationEngine;
53 import org.apache.shardingsphere.proxy.frontend.connection.ConnectionIdGenerator;
54 import org.apache.shardingsphere.proxy.frontend.mysql.authentication.authenticator.MySQLAuthenticatorType;
55 import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLStatementIdGenerator;
56 import org.apache.shardingsphere.proxy.frontend.mysql.ssl.MySQLSSLRequestHandler;
57 import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
58
59 import java.net.InetSocketAddress;
60 import java.net.SocketAddress;
61 import java.util.Optional;
62
63
64
65
66 @Slf4j
67 public final class MySQLAuthenticationEngine implements AuthenticationEngine {
68
69 private final MySQLAuthenticationPluginData authPluginData = new MySQLAuthenticationPluginData();
70
71 private MySQLConnectionPhase connectionPhase = MySQLConnectionPhase.INITIAL_HANDSHAKE;
72
73 private byte[] authResponse;
74
75 private AuthenticationResult currentAuthResult;
76
77 @Override
78 public int handshake(final ChannelHandlerContext context) {
79 int result = ConnectionIdGenerator.getInstance().nextId();
80 connectionPhase = MySQLConnectionPhase.AUTH_PHASE_FAST_PATH;
81 boolean sslEnabled = ProxySSLContext.getInstance().isSSLEnabled();
82 if (sslEnabled) {
83 context.pipeline().addFirst(MySQLSSLRequestHandler.class.getSimpleName(), new MySQLSSLRequestHandler());
84 }
85 context.writeAndFlush(new MySQLHandshakePacket(result, sslEnabled, authPluginData));
86 MySQLStatementIdGenerator.getInstance().registerConnection(result);
87 return result;
88 }
89
90 @Override
91 public AuthenticationResult authenticate(final ChannelHandlerContext context, final PacketPayload payload) {
92 AuthorityRule rule = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(AuthorityRule.class);
93 if (MySQLConnectionPhase.AUTH_PHASE_FAST_PATH == connectionPhase) {
94 currentAuthResult = authenticatePhaseFastPath(context, payload, rule);
95 if (!currentAuthResult.isFinished()) {
96 return currentAuthResult;
97 }
98 } else if (MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH == connectionPhase) {
99 authenticateMismatchedMethod((MySQLPacketPayload) payload);
100 }
101 Grantee grantee = new Grantee(currentAuthResult.getUsername(), getHostAddress(context));
102 if (!login(rule, grantee, authResponse)) {
103 throw new AccessDeniedException(currentAuthResult.getUsername(), grantee.getHostname(), 0 != authResponse.length);
104 }
105 if (!authorizeDatabase(rule, grantee, currentAuthResult.getDatabase())) {
106 throw new DatabaseAccessDeniedException(currentAuthResult.getUsername(), grantee.getHostname(), currentAuthResult.getDatabase());
107 }
108 writeOKPacket(context);
109 return AuthenticationResultBuilder.finished(grantee.getUsername(), grantee.getHostname(), currentAuthResult.getDatabase());
110 }
111
112 private AuthenticationResult authenticatePhaseFastPath(final ChannelHandlerContext context, final PacketPayload payload, final AuthorityRule rule) {
113 MySQLHandshakeResponse41Packet handshakeResponsePacket;
114 try {
115 handshakeResponsePacket = new MySQLHandshakeResponse41Packet((MySQLPacketPayload) payload);
116 } catch (final IndexOutOfBoundsException ex) {
117 if (log.isWarnEnabled()) {
118 log.warn("Received bad handshake from client {}: \n{}", context.channel(), ByteBufUtil.prettyHexDump(payload.getByteBuf().resetReaderIndex()));
119 }
120 throw new HandshakeException();
121 }
122 authResponse = handshakeResponsePacket.getAuthResponse();
123 setMultiStatementsOption(context, handshakeResponsePacket);
124 setCharacterSet(context, handshakeResponsePacket);
125 String database = handshakeResponsePacket.getDatabase();
126 if (!Strings.isNullOrEmpty(database) && !ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().containsDatabase(database)) {
127 throw new UnknownDatabaseException(database);
128 }
129 String username = handshakeResponsePacket.getUsername();
130 String hostname = getHostAddress(context);
131 ShardingSphereUser user = rule.findUser(new Grantee(username, hostname)).orElseGet(() -> new ShardingSphereUser(username, "", hostname));
132 Authenticator authenticator = new AuthenticatorFactory<>(MySQLAuthenticatorType.class, rule).newInstance(user);
133 if (0 == authResponse.length || isClientPluginAuthenticate(handshakeResponsePacket) && !authenticator.getAuthenticationMethodName().equals(handshakeResponsePacket.getAuthPluginName())) {
134 connectionPhase = MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH;
135 context.writeAndFlush(new MySQLAuthSwitchRequestPacket(authenticator.getAuthenticationMethodName(), authPluginData));
136 return AuthenticationResultBuilder.continued(username, hostname, database);
137 }
138 return AuthenticationResultBuilder.finished(username, hostname, database);
139 }
140
141 private void setMultiStatementsOption(final ChannelHandlerContext context, final MySQLHandshakeResponse41Packet handshakeResponsePacket) {
142 context.channel().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).set(handshakeResponsePacket.getMultiStatementsOption());
143 }
144
145 private void setCharacterSet(final ChannelHandlerContext context, final MySQLHandshakeResponse41Packet handshakeResponsePacket) {
146 MySQLCharacterSets characterSet = MySQLCharacterSets.findById(handshakeResponsePacket.getCharacterSet());
147 context.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).set(characterSet.getCharset());
148 context.channel().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).set(characterSet);
149 }
150
151 private boolean isClientPluginAuthenticate(final MySQLHandshakeResponse41Packet packet) {
152 return 0 != (packet.getCapabilityFlags() & MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.getValue());
153 }
154
155 private void authenticateMismatchedMethod(final MySQLPacketPayload payload) {
156 authResponse = new MySQLAuthSwitchResponsePacket(payload).getAuthPluginResponse();
157 }
158
159 private boolean login(final AuthorityRule rule, final Grantee grantee, final byte[] authenticationResponse) {
160 Optional<ShardingSphereUser> user = rule.findUser(grantee);
161 return user.isPresent()
162 && new AuthenticatorFactory<>(MySQLAuthenticatorType.class, rule).newInstance(user.get()).authenticate(user.get(), new Object[]{authenticationResponse, authPluginData});
163 }
164
165 private boolean authorizeDatabase(final AuthorityRule rule, final Grantee grantee, final String databaseName) {
166 return null == databaseName || new AuthorityChecker(rule, grantee).isAuthorized(databaseName);
167 }
168
169 private String getHostAddress(final ChannelHandlerContext context) {
170 if (context.channel() instanceof EpollDomainSocketChannel) {
171 return context.channel().parent().localAddress().toString();
172 }
173 SocketAddress socketAddress = context.channel().remoteAddress();
174 return socketAddress instanceof InetSocketAddress ? ((InetSocketAddress) socketAddress).getAddress().getHostAddress() : socketAddress.toString();
175 }
176
177 private void writeOKPacket(final ChannelHandlerContext context) {
178 context.writeAndFlush(new MySQLOKPacket(MySQLStatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue()));
179 }
180 }