View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.prepare;
19  
20  import lombok.RequiredArgsConstructor;
21  import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLBinaryColumnType;
22  import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
23  import org.apache.shardingsphere.db.protocol.mysql.packet.MySQLPacket;
24  import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.MySQLComSetOptionPacket;
25  import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.MySQLColumnDefinition41Packet;
26  import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.MySQLColumnDefinitionFlag;
27  import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.prepare.MySQLComStmtPrepareOKPacket;
28  import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.prepare.MySQLComStmtPreparePacket;
29  import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLEofPacket;
30  import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
31  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
32  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
33  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
34  import org.apache.shardingsphere.infra.binder.context.statement.type.dml.SelectStatementContext;
35  import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
36  import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
37  import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
38  import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
39  import org.apache.shardingsphere.infra.exception.mysql.exception.TooManyPlaceholdersException;
40  import org.apache.shardingsphere.infra.exception.mysql.exception.UnsupportedPreparedStatementException;
41  import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
42  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
43  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
44  import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
45  import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
46  import org.apache.shardingsphere.parser.rule.SQLParserRule;
47  import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
48  import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
49  import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
50  import org.apache.shardingsphere.proxy.frontend.mysql.command.ServerStatusFlagCalculator;
51  import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLServerPreparedStatement;
52  import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLStatementIdGenerator;
53  import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.ParameterMarkerSegment;
54  import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
55  
56  import java.util.ArrayList;
57  import java.util.Collection;
58  import java.util.Collections;
59  import java.util.LinkedList;
60  import java.util.List;
61  import java.util.Optional;
62  import java.util.concurrent.CopyOnWriteArrayList;
63  
64  /**
65   * COM_STMT_PREPARE command executor for MySQL.
66   */
67  @RequiredArgsConstructor
68  public final class MySQLComStmtPrepareExecutor implements CommandExecutor {
69      
70      private static final int MAX_PARAMETER_COUNT = 65535;
71      
72      private final DatabaseType databaseType = TypedSPILoader.getService(DatabaseType.class, "MySQL");
73      
74      private final MySQLComStmtPreparePacket packet;
75      
76      private final ConnectionSession connectionSession;
77      
78      @Override
79      public Collection<DatabasePacket> execute() {
80          failedIfContainsMultiStatements();
81          MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
82          SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
83          SQLStatement sqlStatement = sqlParserRule.getSQLParserEngine(databaseType).parse(packet.getSQL(), true);
84          ShardingSpherePreconditions.checkState(MySQLComStmtPrepareChecker.isAllowedStatement(sqlStatement), UnsupportedPreparedStatementException::new);
85          SQLStatementContext sqlStatementContext = new SQLBindEngine(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData(),
86                  connectionSession.getCurrentDatabaseName(), packet.getHintValueContext()).bind(databaseType, sqlStatement, Collections.emptyList());
87          int statementId = MySQLStatementIdGenerator.getInstance().nextStatementId(connectionSession.getConnectionId());
88          MySQLServerPreparedStatement serverPreparedStatement = new MySQLServerPreparedStatement(packet.getSQL(), sqlStatementContext, packet.getHintValueContext(), new CopyOnWriteArrayList<>());
89          connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId, serverPreparedStatement);
90          return createPackets(sqlStatementContext, statementId, serverPreparedStatement);
91      }
92      
93      private void failedIfContainsMultiStatements() {
94          // TODO Multi statements should be identified by SQL Parser instead of checking if sql contains ";".
95          if (connectionSession.getAttributeMap().hasAttr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)
96                  && MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).get()
97                  && packet.getSQL().contains(";")) {
98              throw new UnsupportedPreparedStatementException();
99          }
100     }
101     
102     private Collection<DatabasePacket> createPackets(final SQLStatementContext sqlStatementContext, final int statementId, final MySQLServerPreparedStatement serverPreparedStatement) {
103         Collection<DatabasePacket> result = new LinkedList<>();
104         Collection<Projection> projections = getProjections(sqlStatementContext);
105         int parameterCount = sqlStatementContext.getSqlStatement().getParameterCount();
106         ShardingSpherePreconditions.checkState(parameterCount <= MAX_PARAMETER_COUNT, TooManyPlaceholdersException::new);
107         result.add(new MySQLComStmtPrepareOKPacket(statementId, projections.size(), parameterCount, 0));
108         int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
109         int statusFlags = ServerStatusFlagCalculator.calculateFor(connectionSession, true);
110         if (parameterCount > 0) {
111             result.addAll(createParameterColumnDefinition41Packets(sqlStatementContext, characterSet, serverPreparedStatement));
112             result.add(new MySQLEofPacket(statusFlags));
113         }
114         if (!projections.isEmpty() && sqlStatementContext instanceof SelectStatementContext) {
115             result.addAll(createProjectionColumnDefinition41Packets((SelectStatementContext) sqlStatementContext, characterSet));
116             result.add(new MySQLEofPacket(statusFlags));
117         }
118         return result;
119     }
120     
121     private Collection<Projection> getProjections(final SQLStatementContext sqlStatementContext) {
122         return sqlStatementContext instanceof SelectStatementContext ? ((SelectStatementContext) sqlStatementContext).getProjectionsContext().getExpandProjections() : Collections.emptyList();
123     }
124     
125     private Collection<MySQLPacket> createParameterColumnDefinition41Packets(final SQLStatementContext sqlStatementContext, final int characterSet,
126                                                                              final MySQLServerPreparedStatement serverPreparedStatement) {
127         List<ShardingSphereColumn> columnsOfParameterMarkers =
128                 MySQLComStmtPrepareParameterMarkerExtractor.findColumnsOfParameterMarkers(sqlStatementContext.getSqlStatement(), getSchema(sqlStatementContext));
129         Collection<ParameterMarkerSegment> parameterMarkerSegments = sqlStatementContext.getSqlStatement().getParameterMarkers();
130         Collection<MySQLPacket> result = new ArrayList<>(parameterMarkerSegments.size());
131         Collection<Integer> paramColumnDefinitionFlags = new ArrayList<>(parameterMarkerSegments.size());
132         for (int index = 0; index < parameterMarkerSegments.size(); index++) {
133             ShardingSphereColumn column = null;
134             if (!columnsOfParameterMarkers.isEmpty() && index < columnsOfParameterMarkers.size()) {
135                 column = columnsOfParameterMarkers.get(index);
136             }
137             if (null != column) {
138                 int columnDefinitionFlag = calculateColumnDefinitionFlag(column);
139                 result.add(createMySQLColumnDefinition41Packet(characterSet, columnDefinitionFlag, MySQLBinaryColumnType.valueOfJDBCType(column.getDataType())));
140                 paramColumnDefinitionFlags.add(columnDefinitionFlag);
141             } else {
142                 result.add(createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING));
143                 paramColumnDefinitionFlags.add(0);
144             }
145         }
146         serverPreparedStatement.getParameterColumnDefinitionFlags().addAll(paramColumnDefinitionFlags);
147         return result;
148     }
149     
150     private Collection<MySQLPacket> createProjectionColumnDefinition41Packets(final SelectStatementContext selectStatementContext, final int characterSet) {
151         Collection<Projection> projections = selectStatementContext.getProjectionsContext().getExpandProjections();
152         ShardingSphereSchema schema = getSchema(selectStatementContext);
153         Collection<MySQLPacket> result = new ArrayList<>(projections.size());
154         for (Projection each : projections) {
155             // TODO Calculate column definition flag for other projection types
156             if (each instanceof ColumnProjection) {
157                 result.add(Optional.ofNullable(schema.getTable(((ColumnProjection) each).getOriginalTable().getValue()))
158                         .map(table -> table.getColumn(((ColumnProjection) each).getOriginalColumn().getValue()))
159                         .map(column -> createMySQLColumnDefinition41Packet(characterSet, calculateColumnDefinitionFlag(column), MySQLBinaryColumnType.valueOfJDBCType(column.getDataType())))
160                         .orElseGet(() -> createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING)));
161             } else {
162                 result.add(createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING));
163             }
164         }
165         return result;
166     }
167     
168     private ShardingSphereSchema getSchema(final SQLStatementContext sqlStatementContext) {
169         String databaseName = sqlStatementContext.getTablesContext().getDatabaseName().orElseGet(connectionSession::getCurrentDatabaseName);
170         ShardingSphereDatabase database = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabase(databaseName);
171         return sqlStatementContext.getTablesContext().getSchemaName().map(database::getSchema)
172                 .orElseGet(() -> database.getSchema(new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName())));
173     }
174     
175     private int calculateColumnDefinitionFlag(final ShardingSphereColumn column) {
176         int result = 0;
177         result |= column.isPrimaryKey() ? MySQLColumnDefinitionFlag.PRIMARY_KEY.getValue() : 0;
178         result |= column.isUnsigned() ? MySQLColumnDefinitionFlag.UNSIGNED.getValue() : 0;
179         return result;
180     }
181     
182     private MySQLColumnDefinition41Packet createMySQLColumnDefinition41Packet(final int characterSet, final int columnDefinitionFlag, final MySQLBinaryColumnType columnType) {
183         return new MySQLColumnDefinition41Packet(characterSet, columnDefinitionFlag, "", "", "", "", "", 0, columnType, 0, false);
184     }
185 }