View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.proxy.frontend.opengauss.authentication;
19  
20  import com.google.common.base.Strings;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.handler.ssl.SslHandler;
23  import org.apache.shardingsphere.authority.checker.AuthorityChecker;
24  import org.apache.shardingsphere.authority.rule.AuthorityRule;
25  import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
26  import org.apache.shardingsphere.db.protocol.constant.DatabaseProtocolServerInfo;
27  import org.apache.shardingsphere.db.protocol.opengauss.constant.OpenGaussProtocolVersion;
28  import org.apache.shardingsphere.db.protocol.opengauss.packet.authentication.OpenGaussAuthenticationHexData;
29  import org.apache.shardingsphere.db.protocol.opengauss.packet.authentication.OpenGaussAuthenticationSCRAMSha256Packet;
30  import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
31  import org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLAuthenticationMethod;
32  import org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLReadyForQueryPacket;
33  import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLAuthenticationOKPacket;
34  import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLComStartupPacket;
35  import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLParameterStatusPacket;
36  import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLPasswordMessagePacket;
37  import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLRandomGenerator;
38  import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLUnwillingPacket;
39  import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLSSLWillingPacket;
40  import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.authentication.PostgreSQLMD5PasswordAuthenticationPacket;
41  import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
42  import org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLMessagePacketType;
43  import org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
44  import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
45  import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
46  import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.UnknownDatabaseException;
47  import org.apache.shardingsphere.infra.exception.postgresql.exception.authority.EmptyUsernameException;
48  import org.apache.shardingsphere.infra.exception.postgresql.exception.authority.InvalidPasswordException;
49  import org.apache.shardingsphere.infra.exception.postgresql.exception.authority.PrivilegeNotGrantedException;
50  import org.apache.shardingsphere.infra.exception.postgresql.exception.authority.UnknownUsernameException;
51  import org.apache.shardingsphere.infra.exception.postgresql.exception.protocol.ProtocolViolationException;
52  import org.apache.shardingsphere.infra.metadata.user.Grantee;
53  import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
54  import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
55  import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
56  import org.apache.shardingsphere.proxy.backend.postgresql.handler.admin.executor.variable.charset.PostgreSQLCharacterSets;
57  import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationEngine;
58  import org.apache.shardingsphere.authentication.result.AuthenticationResult;
59  import org.apache.shardingsphere.authentication.result.AuthenticationResultBuilder;
60  import org.apache.shardingsphere.authentication.Authenticator;
61  import org.apache.shardingsphere.authentication.AuthenticatorFactory;
62  import org.apache.shardingsphere.proxy.frontend.connection.ConnectionIdGenerator;
63  import org.apache.shardingsphere.proxy.frontend.opengauss.authentication.authenticator.OpenGaussAuthenticatorType;
64  import org.apache.shardingsphere.proxy.frontend.ssl.ProxySSLContext;
65  
66  import java.util.Optional;
67  
68  /**
69   * Authentication engine for openGauss.
70   * 
71   * @see <a href="https://opengauss.org/zh/blogs/blogs.html?post/douxin/sm3_for_opengauss/">SM3 for openGauss</a>
72   */
73  public final class OpenGaussAuthenticationEngine implements AuthenticationEngine {
74      
75      private static final int SSL_REQUEST_PAYLOAD_LENGTH = 8;
76      
77      private static final int SSL_REQUEST_CODE = 80877103;
78      
79      private static final int PROTOCOL_350_SERVER_ITERATOR = 2048;
80      
81      private static final int PROTOCOL_351_SERVER_ITERATOR = 10000;
82      
83      private final OpenGaussAuthenticationHexData authHexData = new OpenGaussAuthenticationHexData();
84      
85      private boolean startupMessageReceived;
86      
87      private String clientEncoding;
88      
89      private int serverIteration;
90      
91      private byte[] md5Salt;
92      
93      private AuthenticationResult currentAuthResult;
94      
95      @Override
96      public int handshake(final ChannelHandlerContext context) {
97          return ConnectionIdGenerator.getInstance().nextId();
98      }
99      
100     @Override
101     public AuthenticationResult authenticate(final ChannelHandlerContext context, final PacketPayload payload) {
102         if (SSL_REQUEST_PAYLOAD_LENGTH == payload.getByteBuf().markReaderIndex().readInt() && SSL_REQUEST_CODE == payload.getByteBuf().readInt()) {
103             if (ProxySSLContext.getInstance().isSSLEnabled()) {
104                 SslHandler sslHandler = new SslHandler(ProxySSLContext.getInstance().newSSLEngine(context.alloc()), true);
105                 context.pipeline().addFirst(SslHandler.class.getSimpleName(), sslHandler);
106                 context.writeAndFlush(new PostgreSQLSSLWillingPacket());
107             } else {
108                 context.writeAndFlush(new PostgreSQLSSLUnwillingPacket());
109             }
110             return AuthenticationResultBuilder.continued();
111         }
112         payload.getByteBuf().resetReaderIndex();
113         AuthorityRule rule = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(AuthorityRule.class);
114         return startupMessageReceived ? processPasswordMessage(context, (PostgreSQLPacketPayload) payload, rule) : processStartupMessage(context, (PostgreSQLPacketPayload) payload, rule);
115     }
116     
117     private AuthenticationResult processPasswordMessage(final ChannelHandlerContext context, final PostgreSQLPacketPayload payload, final AuthorityRule rule) {
118         char messageType = (char) payload.readInt1();
119         ShardingSpherePreconditions.checkState(PostgreSQLMessagePacketType.PASSWORD_MESSAGE.getValue() == messageType,
120                 () -> new ProtocolViolationException("password", Character.toString(messageType)));
121         PostgreSQLPasswordMessagePacket passwordMessagePacket = new PostgreSQLPasswordMessagePacket(payload);
122         login(rule, passwordMessagePacket.getDigest());
123         context.write(new PostgreSQLAuthenticationOKPacket());
124         context.write(new PostgreSQLParameterStatusPacket("server_version",
125                 DatabaseProtocolServerInfo.getProtocolVersion(currentAuthResult.getDatabase(), TypedSPILoader.getService(DatabaseType.class, "openGauss"))));
126         context.write(new PostgreSQLParameterStatusPacket("client_encoding", clientEncoding));
127         context.write(new PostgreSQLParameterStatusPacket("server_encoding", "UTF8"));
128         context.write(new PostgreSQLParameterStatusPacket("integer_datetimes", "on"));
129         context.writeAndFlush(PostgreSQLReadyForQueryPacket.NOT_IN_TRANSACTION);
130         return AuthenticationResultBuilder.finished(currentAuthResult.getUsername(), "", currentAuthResult.getDatabase());
131     }
132     
133     private void login(final AuthorityRule rule, final String digest) {
134         String username = currentAuthResult.getUsername();
135         String databaseName = currentAuthResult.getDatabase();
136         ShardingSpherePreconditions.checkState(Strings.isNullOrEmpty(databaseName) || ProxyContext.getInstance().databaseExists(databaseName), () -> new UnknownDatabaseException(databaseName));
137         Grantee grantee = new Grantee(username, "%");
138         Optional<ShardingSphereUser> user = rule.findUser(grantee);
139         ShardingSpherePreconditions.checkState(user.isPresent(), () -> new UnknownUsernameException(username));
140         Authenticator authenticator = new AuthenticatorFactory<>(OpenGaussAuthenticatorType.class, rule).newInstance(user.get());
141         ShardingSpherePreconditions.checkState(login(authenticator, user.get(), digest), () -> new InvalidPasswordException(username));
142         ShardingSpherePreconditions.checkState(null == databaseName || new AuthorityChecker(rule, grantee).isAuthorized(databaseName), () -> new PrivilegeNotGrantedException(username, databaseName));
143     }
144     
145     private boolean login(final Authenticator authenticator, final ShardingSphereUser user, final String digest) {
146         if (PostgreSQLAuthenticationMethod.MD5.getMethodName().equals(authenticator.getAuthenticationMethodName())) {
147             return authenticator.authenticate(user, new Object[]{digest, md5Salt});
148         }
149         return authenticator.authenticate(user, new Object[]{digest, authHexData.getSalt(), authHexData.getNonce(), serverIteration});
150     }
151     
152     private AuthenticationResult processStartupMessage(final ChannelHandlerContext context, final PostgreSQLPacketPayload payload, final AuthorityRule rule) {
153         startupMessageReceived = true;
154         PostgreSQLComStartupPacket startupPacket = new PostgreSQLComStartupPacket(payload);
155         clientEncoding = startupPacket.getClientEncoding();
156         context.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).set(PostgreSQLCharacterSets.findCharacterSet(clientEncoding));
157         String username = startupPacket.getUsername();
158         ShardingSpherePreconditions.checkNotEmpty(username, EmptyUsernameException::new);
159         context.writeAndFlush(getIdentifierPacket(username, rule, startupPacket.getVersion()));
160         currentAuthResult = AuthenticationResultBuilder.continued(username, "", startupPacket.getDatabase());
161         return currentAuthResult;
162     }
163     
164     private PostgreSQLIdentifierPacket getIdentifierPacket(final String username, final AuthorityRule rule, final int version) {
165         Optional<Authenticator> authenticator = rule.findUser(new Grantee(username, "")).map(optional -> new AuthenticatorFactory<>(OpenGaussAuthenticatorType.class, rule).newInstance(optional));
166         if (authenticator.isPresent() && PostgreSQLAuthenticationMethod.MD5.getMethodName().equals(authenticator.get().getAuthenticationMethodName())) {
167             md5Salt = PostgreSQLRandomGenerator.getInstance().generateRandomBytes(4);
168             return new PostgreSQLMD5PasswordAuthenticationPacket(md5Salt);
169         }
170         serverIteration = version == OpenGaussProtocolVersion.PROTOCOL_350.getVersion() ? PROTOCOL_350_SERVER_ITERATOR : PROTOCOL_351_SERVER_ITERATOR;
171         String password = rule.findUser(new Grantee(username, "%")).map(ShardingSphereUser::getPassword).orElse("");
172         return new OpenGaussAuthenticationSCRAMSha256Packet(version, serverIteration, authHexData, password);
173     }
174 }