1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.prepare;
19
20 import lombok.RequiredArgsConstructor;
21 import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLBinaryColumnType;
22 import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
23 import org.apache.shardingsphere.db.protocol.mysql.packet.MySQLPacket;
24 import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.MySQLComSetOptionPacket;
25 import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.MySQLColumnDefinition41Packet;
26 import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.MySQLColumnDefinitionFlag;
27 import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.prepare.MySQLComStmtPrepareOKPacket;
28 import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.prepare.MySQLComStmtPreparePacket;
29 import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLEofPacket;
30 import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
31 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
32 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
33 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
34 import org.apache.shardingsphere.infra.binder.context.statement.type.dml.SelectStatementContext;
35 import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
36 import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
37 import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
38 import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
39 import org.apache.shardingsphere.infra.exception.mysql.exception.TooManyPlaceholdersException;
40 import org.apache.shardingsphere.infra.exception.mysql.exception.UnsupportedPreparedStatementException;
41 import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
42 import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
43 import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
44 import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
45 import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
46 import org.apache.shardingsphere.parser.rule.SQLParserRule;
47 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
48 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
49 import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
50 import org.apache.shardingsphere.proxy.frontend.mysql.command.ServerStatusFlagCalculator;
51 import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLServerPreparedStatement;
52 import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLStatementIdGenerator;
53 import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.ParameterMarkerSegment;
54 import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
55
56 import java.util.ArrayList;
57 import java.util.Collection;
58 import java.util.Collections;
59 import java.util.LinkedList;
60 import java.util.List;
61 import java.util.Optional;
62 import java.util.concurrent.CopyOnWriteArrayList;
63
64
65
66
67 @RequiredArgsConstructor
68 public final class MySQLComStmtPrepareExecutor implements CommandExecutor {
69
70 private static final int MAX_PARAMETER_COUNT = 65535;
71
72 private final DatabaseType databaseType = TypedSPILoader.getService(DatabaseType.class, "MySQL");
73
74 private final MySQLComStmtPreparePacket packet;
75
76 private final ConnectionSession connectionSession;
77
78 @Override
79 public Collection<DatabasePacket> execute() {
80 failedIfContainsMultiStatements();
81 MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
82 SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
83 SQLStatement sqlStatement = sqlParserRule.getSQLParserEngine(databaseType).parse(packet.getSQL(), true);
84 ShardingSpherePreconditions.checkState(MySQLComStmtPrepareChecker.isAllowedStatement(sqlStatement), UnsupportedPreparedStatementException::new);
85 SQLStatementContext sqlStatementContext = new SQLBindEngine(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData(),
86 connectionSession.getCurrentDatabaseName(), packet.getHintValueContext()).bind(databaseType, sqlStatement, Collections.emptyList());
87 int statementId = MySQLStatementIdGenerator.getInstance().nextStatementId(connectionSession.getConnectionId());
88 MySQLServerPreparedStatement serverPreparedStatement = new MySQLServerPreparedStatement(packet.getSQL(), sqlStatementContext, packet.getHintValueContext(), new CopyOnWriteArrayList<>());
89 connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId, serverPreparedStatement);
90 return createPackets(sqlStatementContext, statementId, serverPreparedStatement);
91 }
92
93 private void failedIfContainsMultiStatements() {
94
95 if (connectionSession.getAttributeMap().hasAttr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY)
96 && MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.OPTION_MULTI_STATEMENTS_ATTRIBUTE_KEY).get()
97 && packet.getSQL().contains(";")) {
98 throw new UnsupportedPreparedStatementException();
99 }
100 }
101
102 private Collection<DatabasePacket> createPackets(final SQLStatementContext sqlStatementContext, final int statementId, final MySQLServerPreparedStatement serverPreparedStatement) {
103 Collection<DatabasePacket> result = new LinkedList<>();
104 Collection<Projection> projections = getProjections(sqlStatementContext);
105 int parameterCount = sqlStatementContext.getSqlStatement().getParameterCount();
106 ShardingSpherePreconditions.checkState(parameterCount <= MAX_PARAMETER_COUNT, TooManyPlaceholdersException::new);
107 result.add(new MySQLComStmtPrepareOKPacket(statementId, projections.size(), parameterCount, 0));
108 int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
109 int statusFlags = ServerStatusFlagCalculator.calculateFor(connectionSession, true);
110 if (parameterCount > 0) {
111 result.addAll(createParameterColumnDefinition41Packets(sqlStatementContext, characterSet, serverPreparedStatement));
112 result.add(new MySQLEofPacket(statusFlags));
113 }
114 if (!projections.isEmpty() && sqlStatementContext instanceof SelectStatementContext) {
115 result.addAll(createProjectionColumnDefinition41Packets((SelectStatementContext) sqlStatementContext, characterSet));
116 result.add(new MySQLEofPacket(statusFlags));
117 }
118 return result;
119 }
120
121 private Collection<Projection> getProjections(final SQLStatementContext sqlStatementContext) {
122 return sqlStatementContext instanceof SelectStatementContext ? ((SelectStatementContext) sqlStatementContext).getProjectionsContext().getExpandProjections() : Collections.emptyList();
123 }
124
125 private Collection<MySQLPacket> createParameterColumnDefinition41Packets(final SQLStatementContext sqlStatementContext, final int characterSet,
126 final MySQLServerPreparedStatement serverPreparedStatement) {
127 List<ShardingSphereColumn> columnsOfParameterMarkers =
128 MySQLComStmtPrepareParameterMarkerExtractor.findColumnsOfParameterMarkers(sqlStatementContext.getSqlStatement(), getSchema(sqlStatementContext));
129 Collection<ParameterMarkerSegment> parameterMarkerSegments = sqlStatementContext.getSqlStatement().getParameterMarkers();
130 Collection<MySQLPacket> result = new ArrayList<>(parameterMarkerSegments.size());
131 Collection<Integer> paramColumnDefinitionFlags = new ArrayList<>(parameterMarkerSegments.size());
132 for (int index = 0; index < parameterMarkerSegments.size(); index++) {
133 ShardingSphereColumn column = null;
134 if (!columnsOfParameterMarkers.isEmpty() && index < columnsOfParameterMarkers.size()) {
135 column = columnsOfParameterMarkers.get(index);
136 }
137 if (null != column) {
138 int columnDefinitionFlag = calculateColumnDefinitionFlag(column);
139 result.add(createMySQLColumnDefinition41Packet(characterSet, columnDefinitionFlag, MySQLBinaryColumnType.valueOfJDBCType(column.getDataType())));
140 paramColumnDefinitionFlags.add(columnDefinitionFlag);
141 } else {
142 result.add(createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING));
143 paramColumnDefinitionFlags.add(0);
144 }
145 }
146 serverPreparedStatement.getParameterColumnDefinitionFlags().addAll(paramColumnDefinitionFlags);
147 return result;
148 }
149
150 private Collection<MySQLPacket> createProjectionColumnDefinition41Packets(final SelectStatementContext selectStatementContext, final int characterSet) {
151 Collection<Projection> projections = selectStatementContext.getProjectionsContext().getExpandProjections();
152 ShardingSphereSchema schema = getSchema(selectStatementContext);
153 Collection<MySQLPacket> result = new ArrayList<>(projections.size());
154 for (Projection each : projections) {
155
156 if (each instanceof ColumnProjection) {
157 result.add(Optional.ofNullable(schema.getTable(((ColumnProjection) each).getOriginalTable().getValue()))
158 .map(table -> table.getColumn(((ColumnProjection) each).getOriginalColumn().getValue()))
159 .map(column -> createMySQLColumnDefinition41Packet(characterSet, calculateColumnDefinitionFlag(column), MySQLBinaryColumnType.valueOfJDBCType(column.getDataType())))
160 .orElseGet(() -> createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING)));
161 } else {
162 result.add(createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING));
163 }
164 }
165 return result;
166 }
167
168 private ShardingSphereSchema getSchema(final SQLStatementContext sqlStatementContext) {
169 String databaseName = sqlStatementContext.getTablesContext().getDatabaseName().orElseGet(connectionSession::getCurrentDatabaseName);
170 ShardingSphereDatabase database = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabase(databaseName);
171 return sqlStatementContext.getTablesContext().getSchemaName().map(database::getSchema)
172 .orElseGet(() -> database.getSchema(new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName())));
173 }
174
175 private int calculateColumnDefinitionFlag(final ShardingSphereColumn column) {
176 int result = 0;
177 result |= column.isPrimaryKey() ? MySQLColumnDefinitionFlag.PRIMARY_KEY.getValue() : 0;
178 result |= column.isUnsigned() ? MySQLColumnDefinitionFlag.UNSIGNED.getValue() : 0;
179 return result;
180 }
181
182 private MySQLColumnDefinition41Packet createMySQLColumnDefinition41Packet(final int characterSet, final int columnDefinitionFlag, final MySQLBinaryColumnType columnType) {
183 return new MySQLColumnDefinition41Packet(characterSet, columnDefinitionFlag, "", "", "", "", "", 0, columnType, 0, false);
184 }
185 }