View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.prepare;
19  
20  import lombok.RequiredArgsConstructor;
21  import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLBinaryColumnType;
22  import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
23  import org.apache.shardingsphere.db.protocol.mysql.packet.MySQLPacket;
24  import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.MySQLComSetOptionPacket;
25  import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.MySQLColumnDefinition41Packet;
26  import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.MySQLColumnDefinitionFlag;
27  import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.prepare.MySQLComStmtPrepareOKPacket;
28  import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.prepare.MySQLComStmtPreparePacket;
29  import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLEofPacket;
30  import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
31  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
32  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
33  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
34  import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
35  import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
36  import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
37  import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
38  import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
39  import org.apache.shardingsphere.infra.exception.mysql.exception.TooManyPlaceholdersException;
40  import org.apache.shardingsphere.infra.exception.mysql.exception.UnsupportedPreparedStatementException;
41  import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
42  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
43  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
44  import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
45  import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
46  import org.apache.shardingsphere.parser.rule.SQLParserRule;
47  import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
48  import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
49  import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
50  import org.apache.shardingsphere.proxy.frontend.mysql.command.ServerStatusFlagCalculator;
51  import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLServerPreparedStatement;
52  import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLStatementIdGenerator;
53  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
54  import org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
55  import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
56  
57  import java.util.ArrayList;
58  import java.util.Collection;
59  import java.util.Collections;
60  import java.util.LinkedList;
61  import java.util.List;
62  import java.util.Optional;
63  import java.util.concurrent.CopyOnWriteArrayList;
64  
65  /**
66   * COM_STMT_PREPARE command executor for MySQL.
67   */
68  @RequiredArgsConstructor
69  public final class MySQLComStmtPrepareExecutor implements CommandExecutor {
70      
71      private static final int MAX_PARAMETER_COUNT = 65535;
72      
73      private final MySQLComStmtPreparePacket packet;
74      
75      private final ConnectionSession connectionSession;
76      
77      @Override
78      public Collection<DatabasePacket> execute() {
79          failedIfContainsMultiStatements();
80          MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
81          SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
82          DatabaseType databaseType = TypedSPILoader.getService(DatabaseType.class, "MySQL");
83          SQLStatement sqlStatement = sqlParserRule.getSQLParserEngine(databaseType).parse(packet.getSQL(), true);
84          if (!MySQLComStmtPrepareChecker.isAllowedStatement(sqlStatement)) {
85              throw new UnsupportedPreparedStatementException();
86          }
87          SQLStatementContext sqlStatementContext = new SQLBindEngine(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData(),
88                  connectionSession.getDefaultDatabaseName(), packet.getHintValueContext()).bind(sqlStatement, Collections.emptyList());
89          int statementId = MySQLStatementIdGenerator.getInstance().nextStatementId(connectionSession.getConnectionId());
90          MySQLServerPreparedStatement serverPreparedStatement = new MySQLServerPreparedStatement(packet.getSQL(), sqlStatementContext, packet.getHintValueContext(), new CopyOnWriteArrayList<>());
91          connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId, serverPreparedStatement);
92          return createPackets(sqlStatementContext, statementId, serverPreparedStatement);
93      }
94      
95      private void failedIfContainsMultiStatements() {
96          // TODO Multi statements should be identified by SQL Parser instead of checking if sql contains ";".
97          if (connectionSession.getAttributeMap().hasAttr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)
98                  && MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).get()
99                  && packet.getSQL().contains(";")) {
100             throw new UnsupportedPreparedStatementException();
101         }
102     }
103     
104     private Collection<DatabasePacket> createPackets(final SQLStatementContext sqlStatementContext, final int statementId, final MySQLServerPreparedStatement serverPreparedStatement) {
105         Collection<DatabasePacket> result = new LinkedList<>();
106         Collection<Projection> projections = getProjections(sqlStatementContext);
107         int parameterCount = sqlStatementContext.getSqlStatement().getParameterCount();
108         ShardingSpherePreconditions.checkState(parameterCount <= MAX_PARAMETER_COUNT, TooManyPlaceholdersException::new);
109         result.add(new MySQLComStmtPrepareOKPacket(statementId, projections.size(), parameterCount, 0));
110         int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
111         int statusFlags = ServerStatusFlagCalculator.calculateFor(connectionSession);
112         if (parameterCount > 0) {
113             result.addAll(createParameterColumnDefinition41Packets(sqlStatementContext, characterSet, serverPreparedStatement));
114             result.add(new MySQLEofPacket(statusFlags));
115         }
116         if (!projections.isEmpty() && sqlStatementContext instanceof SelectStatementContext) {
117             result.addAll(createProjectionColumnDefinition41Packets((SelectStatementContext) sqlStatementContext, characterSet));
118             result.add(new MySQLEofPacket(statusFlags));
119         }
120         return result;
121     }
122     
123     private Collection<Projection> getProjections(final SQLStatementContext sqlStatementContext) {
124         return sqlStatementContext instanceof SelectStatementContext ? ((SelectStatementContext) sqlStatementContext).getProjectionsContext().getExpandProjections() : Collections.emptyList();
125     }
126     
127     private Collection<MySQLPacket> createParameterColumnDefinition41Packets(final SQLStatementContext sqlStatementContext, final int characterSet,
128                                                                              final MySQLServerPreparedStatement serverPreparedStatement) {
129         List<ShardingSphereColumn> columnsOfParameterMarkers =
130                 MySQLComStmtPrepareParameterMarkerExtractor.findColumnsOfParameterMarkers(sqlStatementContext.getSqlStatement(), getSchema(sqlStatementContext));
131         Collection<ParameterMarkerSegment> parameterMarkerSegments = ((AbstractSQLStatement) sqlStatementContext.getSqlStatement()).getParameterMarkerSegments();
132         Collection<MySQLPacket> result = new ArrayList<>(parameterMarkerSegments.size());
133         Collection<Integer> paramColumnDefinitionFlags = new ArrayList<>(parameterMarkerSegments.size());
134         for (int index = 0; index < parameterMarkerSegments.size(); index++) {
135             ShardingSphereColumn column = columnsOfParameterMarkers.isEmpty() ? null : columnsOfParameterMarkers.get(index);
136             if (null != column) {
137                 int columnDefinitionFlag = calculateColumnDefinitionFlag(column);
138                 result.add(createMySQLColumnDefinition41Packet(characterSet, columnDefinitionFlag, MySQLBinaryColumnType.valueOfJDBCType(column.getDataType())));
139                 paramColumnDefinitionFlags.add(columnDefinitionFlag);
140             } else {
141                 result.add(createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING));
142                 paramColumnDefinitionFlags.add(0);
143             }
144         }
145         serverPreparedStatement.getParameterColumnDefinitionFlags().addAll(paramColumnDefinitionFlags);
146         return result;
147     }
148     
149     private Collection<MySQLPacket> createProjectionColumnDefinition41Packets(final SelectStatementContext selectStatementContext, final int characterSet) {
150         Collection<Projection> projections = selectStatementContext.getProjectionsContext().getExpandProjections();
151         ShardingSphereSchema schema = getSchema(selectStatementContext);
152         Collection<MySQLPacket> result = new ArrayList<>(projections.size());
153         for (Projection each : projections) {
154             // TODO Calculate column definition flag for other projection types
155             if (each instanceof ColumnProjection) {
156                 result.add(Optional.ofNullable(schema.getTable(((ColumnProjection) each).getOriginalTable().getValue()))
157                         .map(table -> table.getColumns().get(((ColumnProjection) each).getOriginalColumn().getValue()))
158                         .map(column -> createMySQLColumnDefinition41Packet(characterSet, calculateColumnDefinitionFlag(column), MySQLBinaryColumnType.valueOfJDBCType(column.getDataType())))
159                         .orElseGet(() -> createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING)));
160             } else {
161                 result.add(createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING));
162             }
163         }
164         return result;
165     }
166     
167     private ShardingSphereSchema getSchema(final SQLStatementContext sqlStatementContext) {
168         String databaseName = sqlStatementContext.getTablesContext().getDatabaseName().orElseGet(connectionSession::getDefaultDatabaseName);
169         ShardingSphereDatabase database = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabase(databaseName);
170         return sqlStatementContext.getTablesContext().getSchemaName().map(database::getSchema)
171                 .orElseGet(() -> database.getSchema(new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName())));
172     }
173     
174     private int calculateColumnDefinitionFlag(final ShardingSphereColumn column) {
175         int result = 0;
176         result |= column.isPrimaryKey() ? MySQLColumnDefinitionFlag.PRIMARY_KEY.getValue() : 0;
177         result |= column.isUnsigned() ? MySQLColumnDefinitionFlag.UNSIGNED.getValue() : 0;
178         return result;
179     }
180     
181     private MySQLColumnDefinition41Packet createMySQLColumnDefinition41Packet(final int characterSet, final int columnDefinitionFlag, final MySQLBinaryColumnType columnType) {
182         return new MySQLColumnDefinition41Packet(characterSet, columnDefinitionFlag, "", "", "", "", "", 0, columnType, 0, false);
183     }
184 }