1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.parse;
19
20 import lombok.RequiredArgsConstructor;
21 import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
22 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
23 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLComParsePacket;
24 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLParseCompletePacket;
25 import org.apache.shardingsphere.distsql.statement.DistSQLStatement;
26 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
27 import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
28 import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
29 import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
30 import org.apache.shardingsphere.infra.parser.SQLParserEngine;
31 import org.apache.shardingsphere.parser.rule.SQLParserRule;
32 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
33 import org.apache.shardingsphere.proxy.backend.distsql.DistSQLStatementContext;
34 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
35 import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
36 import org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.PostgreSQLServerPreparedStatement;
37 import org.apache.shardingsphere.sql.parser.statement.core.enums.ParameterMarkerType;
38 import org.apache.shardingsphere.sql.parser.statement.core.segment.SQLSegment;
39 import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.ParameterMarkerSegment;
40 import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
41 import org.apache.shardingsphere.sql.parser.statement.core.statement.type.dml.DMLStatement;
42
43 import java.util.ArrayList;
44 import java.util.Collection;
45 import java.util.Collections;
46 import java.util.Comparator;
47 import java.util.List;
48
49
50
51
52 @RequiredArgsConstructor
53 public final class PostgreSQLComParseExecutor implements CommandExecutor {
54
55 private final PostgreSQLComParsePacket packet;
56
57 private final ConnectionSession connectionSession;
58
59 @Override
60 public Collection<DatabasePacket> execute() {
61 ShardingSphereMetaData metaData = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData();
62 DatabaseType databaseType = metaData.getDatabase(connectionSession.getUsedDatabaseName()).getProtocolType();
63 SQLParserEngine sqlParserEngine = metaData.getGlobalRuleMetaData().getSingleRule(SQLParserRule.class).getSQLParserEngine(databaseType);
64 String sql = packet.getSQL();
65 SQLStatement sqlStatement = sqlParserEngine.parse(sql, true);
66 String escapedSql = escape(sqlStatement, sql);
67 if (!escapedSql.equalsIgnoreCase(sql)) {
68 sqlStatement = sqlParserEngine.parse(escapedSql, true);
69 sql = escapedSql;
70 }
71 List<Integer> actualParameterMarkerIndexes = new ArrayList<>(sqlStatement.getParameterMarkers().size());
72 if (sqlStatement.getParameterCount() > 0) {
73 List<ParameterMarkerSegment> parameterMarkerSegments = new ArrayList<>(sqlStatement.getParameterMarkers());
74 for (ParameterMarkerSegment each : parameterMarkerSegments) {
75 actualParameterMarkerIndexes.add(each.getParameterIndex());
76 }
77 sql = convertSQLToJDBCStyle(parameterMarkerSegments, sql);
78 sqlStatement = sqlParserEngine.parse(sql, true);
79 }
80 List<PostgreSQLColumnType> paddedColumnTypes = paddingColumnTypes(sqlStatement.getParameterCount(), packet.readParameterTypes());
81 SQLStatementContext sqlStatementContext = sqlStatement instanceof DistSQLStatement
82 ? new DistSQLStatementContext((DistSQLStatement) sqlStatement)
83 : new SQLBindEngine(metaData, connectionSession.getCurrentDatabaseName(), packet.getHintValueContext()).bind(databaseType, sqlStatement, Collections.emptyList());
84 PostgreSQLServerPreparedStatement serverPreparedStatement = new PostgreSQLServerPreparedStatement(
85 sql, sqlStatementContext, packet.getHintValueContext(), paddedColumnTypes, actualParameterMarkerIndexes);
86 connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(packet.getStatementId(), serverPreparedStatement);
87 return Collections.singleton(PostgreSQLParseCompletePacket.getInstance());
88 }
89
90 private String escape(final SQLStatement sqlStatement, final String sql) {
91 return sqlStatement instanceof DMLStatement ? sql.replace("?", "??") : sql;
92 }
93
94 private String convertSQLToJDBCStyle(final List<ParameterMarkerSegment> parameterMarkerSegments, final String sql) {
95 parameterMarkerSegments.sort(Comparator.comparingInt(SQLSegment::getStopIndex));
96 StringBuilder result = new StringBuilder(sql);
97 for (int i = parameterMarkerSegments.size() - 1; i >= 0; i--) {
98 ParameterMarkerSegment each = parameterMarkerSegments.get(i);
99 result.replace(each.getStartIndex(), each.getStopIndex() + 1, ParameterMarkerType.QUESTION.getMarker());
100 }
101 return result.toString();
102 }
103
104 private List<PostgreSQLColumnType> paddingColumnTypes(final int parameterCount, final List<PostgreSQLColumnType> specifiedColumnTypes) {
105 if (parameterCount == specifiedColumnTypes.size()) {
106 return specifiedColumnTypes;
107 }
108 List<PostgreSQLColumnType> result = new ArrayList<>(parameterCount);
109 result.addAll(specifiedColumnTypes);
110 int unspecifiedCount = parameterCount - specifiedColumnTypes.size();
111 for (int i = 0; i < unspecifiedCount; i++) {
112 result.add(PostgreSQLColumnType.UNSPECIFIED);
113 }
114 return result;
115 }
116 }