1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.parse;
19
20 import lombok.RequiredArgsConstructor;
21 import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
22 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
23 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLComParsePacket;
24 import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLParseCompletePacket;
25 import org.apache.shardingsphere.distsql.statement.DistSQLStatement;
26 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
27 import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
28 import org.apache.shardingsphere.infra.parser.SQLParserEngine;
29 import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
30 import org.apache.shardingsphere.parser.rule.SQLParserRule;
31 import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
32 import org.apache.shardingsphere.proxy.backend.distsql.DistSQLStatementContext;
33 import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
34 import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
35 import org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.PostgreSQLServerPreparedStatement;
36 import org.apache.shardingsphere.sql.parser.sql.common.enums.ParameterMarkerType;
37 import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
38 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
39 import org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
40 import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
41 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement;
42
43 import java.util.ArrayList;
44 import java.util.Collection;
45 import java.util.Collections;
46 import java.util.Comparator;
47 import java.util.List;
48
49
50
51
52 @RequiredArgsConstructor
53 public final class PostgreSQLComParseExecutor implements CommandExecutor {
54
55 private final PostgreSQLComParsePacket packet;
56
57 private final ConnectionSession connectionSession;
58
59 @Override
60 public Collection<DatabasePacket> execute() {
61 SQLParserEngine sqlParserEngine = createShardingSphereSQLParserEngine(connectionSession.getDatabaseName());
62 String sql = packet.getSQL();
63 SQLStatement sqlStatement = sqlParserEngine.parse(sql, true);
64 String escapedSql = escape(sqlStatement, sql);
65 if (!escapedSql.equalsIgnoreCase(sql)) {
66 sqlStatement = sqlParserEngine.parse(escapedSql, true);
67 sql = escapedSql;
68 }
69 List<Integer> actualParameterMarkerIndexes = new ArrayList<>();
70 if (sqlStatement.getParameterCount() > 0) {
71 List<ParameterMarkerSegment> parameterMarkerSegments = new ArrayList<>(((AbstractSQLStatement) sqlStatement).getParameterMarkerSegments());
72 for (ParameterMarkerSegment each : parameterMarkerSegments) {
73 actualParameterMarkerIndexes.add(each.getParameterIndex());
74 }
75 sql = convertSQLToJDBCStyle(parameterMarkerSegments, sql);
76 sqlStatement = sqlParserEngine.parse(sql, true);
77 }
78 List<PostgreSQLColumnType> paddedColumnTypes = paddingColumnTypes(sqlStatement.getParameterCount(), packet.readParameterTypes());
79 SQLStatementContext sqlStatementContext = sqlStatement instanceof DistSQLStatement ? new DistSQLStatementContext((DistSQLStatement) sqlStatement)
80 : new SQLBindEngine(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData(), connectionSession.getDefaultDatabaseName(), packet.getHintValueContext())
81 .bind(sqlStatement, Collections.emptyList());
82 PostgreSQLServerPreparedStatement serverPreparedStatement = new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, packet.getHintValueContext(), paddedColumnTypes,
83 actualParameterMarkerIndexes);
84 connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(packet.getStatementId(), serverPreparedStatement);
85 return Collections.singleton(PostgreSQLParseCompletePacket.getInstance());
86 }
87
88 private SQLParserEngine createShardingSphereSQLParserEngine(final String databaseName) {
89 MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
90 SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
91 return sqlParserRule.getSQLParserEngine(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType());
92 }
93
94 private String escape(final SQLStatement sqlStatement, final String sql) {
95 if (sqlStatement instanceof DMLStatement) {
96 return sql.replace("?", "??");
97 }
98 return sql;
99 }
100
101 private String convertSQLToJDBCStyle(final List<ParameterMarkerSegment> parameterMarkerSegments, final String sql) {
102 parameterMarkerSegments.sort(Comparator.comparingInt(SQLSegment::getStopIndex));
103 StringBuilder result = new StringBuilder(sql);
104 for (int i = parameterMarkerSegments.size() - 1; i >= 0; i--) {
105 ParameterMarkerSegment each = parameterMarkerSegments.get(i);
106 result.replace(each.getStartIndex(), each.getStopIndex() + 1, ParameterMarkerType.QUESTION.getMarker());
107 }
108 return result.toString();
109 }
110
111 private List<PostgreSQLColumnType> paddingColumnTypes(final int parameterCount, final List<PostgreSQLColumnType> specifiedColumnTypes) {
112 if (parameterCount == specifiedColumnTypes.size()) {
113 return specifiedColumnTypes;
114 }
115 List<PostgreSQLColumnType> result = new ArrayList<>(parameterCount);
116 result.addAll(specifiedColumnTypes);
117 int unspecifiedCount = parameterCount - specifiedColumnTypes.size();
118 for (int i = 0; i < unspecifiedCount; i++) {
119 result.add(PostgreSQLColumnType.UNSPECIFIED);
120 }
121 return result;
122 }
123 }