1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.prepare;
19
20 import lombok.RequiredArgsConstructor;
21 import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLBinaryColumnType;
22 import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
23 import org.apache.shardingsphere.db.protocol.mysql.packet.MySQLPacket;
24 import org.apache.shardingsphere.db.protocol.mysql.packet.command.admin.MySQLComSetOptionPacket;
25 import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.MySQLColumnDefinition41Packet;
26 import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.MySQLColumnDefinitionFlag;
27 import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.prepare.MySQLComStmtPrepareOKPacket;
28 import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.prepare.MySQLComStmtPreparePacket;
29 import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLEofPacket;
30 import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
31 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
32 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
33 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
34 import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
35 import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
36 import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
37 import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
38 import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
39 import org.apache.shardingsphere.infra.exception.mysql.exception.TooManyPlaceholdersException;
40 import org.apache.shardingsphere.infra.exception.mysql.exception.UnsupportedPreparedStatementException;
41 import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
42 import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
43 import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
44 import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
45 import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
46 import org.apache.shardingsphere.parser.rule.SQLParserRule;
47 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
48 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
49 import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
50 import org.apache.shardingsphere.proxy.frontend.mysql.command.ServerStatusFlagCalculator;
51 import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLServerPreparedStatement;
52 import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLStatementIdGenerator;
53 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
54 import org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
55 import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
56
57 import java.util.ArrayList;
58 import java.util.Collection;
59 import java.util.Collections;
60 import java.util.LinkedList;
61 import java.util.List;
62 import java.util.Optional;
63 import java.util.concurrent.CopyOnWriteArrayList;
64
65
66
67
68 @RequiredArgsConstructor
69 public final class MySQLComStmtPrepareExecutor implements CommandExecutor {
70
71 private static final int MAX_PARAMETER_COUNT = 65535;
72
73 private final MySQLComStmtPreparePacket packet;
74
75 private final ConnectionSession connectionSession;
76
77 @Override
78 public Collection<DatabasePacket> execute() {
79 failedIfContainsMultiStatements();
80 MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
81 SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
82 DatabaseType databaseType = TypedSPILoader.getService(DatabaseType.class, "MySQL");
83 SQLStatement sqlStatement = sqlParserRule.getSQLParserEngine(databaseType).parse(packet.getSQL(), true);
84 if (!MySQLComStmtPrepareChecker.isAllowedStatement(sqlStatement)) {
85 throw new UnsupportedPreparedStatementException();
86 }
87 SQLStatementContext sqlStatementContext = new SQLBindEngine(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData(),
88 connectionSession.getDefaultDatabaseName(), packet.getHintValueContext()).bind(sqlStatement, Collections.emptyList());
89 int statementId = MySQLStatementIdGenerator.getInstance().nextStatementId(connectionSession.getConnectionId());
90 MySQLServerPreparedStatement serverPreparedStatement = new MySQLServerPreparedStatement(packet.getSQL(), sqlStatementContext, packet.getHintValueContext(), new CopyOnWriteArrayList<>());
91 connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId, serverPreparedStatement);
92 return createPackets(sqlStatementContext, statementId, serverPreparedStatement);
93 }
94
95 private void failedIfContainsMultiStatements() {
96
97 if (connectionSession.getAttributeMap().hasAttr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS)
98 && MySQLComSetOptionPacket.MYSQL_OPTION_MULTI_STATEMENTS_ON == connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_OPTION_MULTI_STATEMENTS).get()
99 && packet.getSQL().contains(";")) {
100 throw new UnsupportedPreparedStatementException();
101 }
102 }
103
104 private Collection<DatabasePacket> createPackets(final SQLStatementContext sqlStatementContext, final int statementId, final MySQLServerPreparedStatement serverPreparedStatement) {
105 Collection<DatabasePacket> result = new LinkedList<>();
106 Collection<Projection> projections = getProjections(sqlStatementContext);
107 int parameterCount = sqlStatementContext.getSqlStatement().getParameterCount();
108 ShardingSpherePreconditions.checkState(parameterCount <= MAX_PARAMETER_COUNT, TooManyPlaceholdersException::new);
109 result.add(new MySQLComStmtPrepareOKPacket(statementId, projections.size(), parameterCount, 0));
110 int characterSet = connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get().getId();
111 int statusFlags = ServerStatusFlagCalculator.calculateFor(connectionSession);
112 if (parameterCount > 0) {
113 result.addAll(createParameterColumnDefinition41Packets(sqlStatementContext, characterSet, serverPreparedStatement));
114 result.add(new MySQLEofPacket(statusFlags));
115 }
116 if (!projections.isEmpty() && sqlStatementContext instanceof SelectStatementContext) {
117 result.addAll(createProjectionColumnDefinition41Packets((SelectStatementContext) sqlStatementContext, characterSet));
118 result.add(new MySQLEofPacket(statusFlags));
119 }
120 return result;
121 }
122
123 private Collection<Projection> getProjections(final SQLStatementContext sqlStatementContext) {
124 return sqlStatementContext instanceof SelectStatementContext ? ((SelectStatementContext) sqlStatementContext).getProjectionsContext().getExpandProjections() : Collections.emptyList();
125 }
126
127 private Collection<MySQLPacket> createParameterColumnDefinition41Packets(final SQLStatementContext sqlStatementContext, final int characterSet,
128 final MySQLServerPreparedStatement serverPreparedStatement) {
129 List<ShardingSphereColumn> columnsOfParameterMarkers =
130 MySQLComStmtPrepareParameterMarkerExtractor.findColumnsOfParameterMarkers(sqlStatementContext.getSqlStatement(), getSchema(sqlStatementContext));
131 Collection<ParameterMarkerSegment> parameterMarkerSegments = ((AbstractSQLStatement) sqlStatementContext.getSqlStatement()).getParameterMarkerSegments();
132 Collection<MySQLPacket> result = new ArrayList<>(parameterMarkerSegments.size());
133 Collection<Integer> paramColumnDefinitionFlags = new ArrayList<>(parameterMarkerSegments.size());
134 for (int index = 0; index < parameterMarkerSegments.size(); index++) {
135 ShardingSphereColumn column = columnsOfParameterMarkers.isEmpty() ? null : columnsOfParameterMarkers.get(index);
136 if (null != column) {
137 int columnDefinitionFlag = calculateColumnDefinitionFlag(column);
138 result.add(createMySQLColumnDefinition41Packet(characterSet, columnDefinitionFlag, MySQLBinaryColumnType.valueOfJDBCType(column.getDataType())));
139 paramColumnDefinitionFlags.add(columnDefinitionFlag);
140 } else {
141 result.add(createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING));
142 paramColumnDefinitionFlags.add(0);
143 }
144 }
145 serverPreparedStatement.getParameterColumnDefinitionFlags().addAll(paramColumnDefinitionFlags);
146 return result;
147 }
148
149 private Collection<MySQLPacket> createProjectionColumnDefinition41Packets(final SelectStatementContext selectStatementContext, final int characterSet) {
150 Collection<Projection> projections = selectStatementContext.getProjectionsContext().getExpandProjections();
151 ShardingSphereSchema schema = getSchema(selectStatementContext);
152 Collection<MySQLPacket> result = new ArrayList<>(projections.size());
153 for (Projection each : projections) {
154
155 if (each instanceof ColumnProjection) {
156 result.add(Optional.ofNullable(schema.getTable(((ColumnProjection) each).getOriginalTable().getValue()))
157 .map(table -> table.getColumns().get(((ColumnProjection) each).getOriginalColumn().getValue()))
158 .map(column -> createMySQLColumnDefinition41Packet(characterSet, calculateColumnDefinitionFlag(column), MySQLBinaryColumnType.valueOfJDBCType(column.getDataType())))
159 .orElseGet(() -> createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING)));
160 } else {
161 result.add(createMySQLColumnDefinition41Packet(characterSet, 0, MySQLBinaryColumnType.VAR_STRING));
162 }
163 }
164 return result;
165 }
166
167 private ShardingSphereSchema getSchema(final SQLStatementContext sqlStatementContext) {
168 String databaseName = sqlStatementContext.getTablesContext().getDatabaseName().orElseGet(connectionSession::getDefaultDatabaseName);
169 ShardingSphereDatabase database = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabase(databaseName);
170 return sqlStatementContext.getTablesContext().getSchemaName().map(database::getSchema)
171 .orElseGet(() -> database.getSchema(new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName())));
172 }
173
174 private int calculateColumnDefinitionFlag(final ShardingSphereColumn column) {
175 int result = 0;
176 result |= column.isPrimaryKey() ? MySQLColumnDefinitionFlag.PRIMARY_KEY.getValue() : 0;
177 result |= column.isUnsigned() ? MySQLColumnDefinitionFlag.UNSIGNED.getValue() : 0;
178 return result;
179 }
180
181 private MySQLColumnDefinition41Packet createMySQLColumnDefinition41Packet(final int characterSet, final int columnDefinitionFlag, final MySQLBinaryColumnType columnType) {
182 return new MySQLColumnDefinition41Packet(characterSet, columnDefinitionFlag, "", "", "", "", "", 0, columnType, 0, false);
183 }
184 }