View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.parse;
19  
20  import lombok.RequiredArgsConstructor;
21  import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
22  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
23  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLComParsePacket;
24  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLParseCompletePacket;
25  import org.apache.shardingsphere.distsql.statement.DistSQLStatement;
26  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
27  import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
28  import org.apache.shardingsphere.infra.parser.SQLParserEngine;
29  import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
30  import org.apache.shardingsphere.parser.rule.SQLParserRule;
31  import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
32  import org.apache.shardingsphere.proxy.backend.distsql.DistSQLStatementContext;
33  import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
34  import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
35  import org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.PostgreSQLServerPreparedStatement;
36  import org.apache.shardingsphere.sql.parser.sql.common.enums.ParameterMarkerType;
37  import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
38  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
39  import org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
40  import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
41  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement;
42  
43  import java.util.ArrayList;
44  import java.util.Collection;
45  import java.util.Collections;
46  import java.util.Comparator;
47  import java.util.List;
48  
49  /**
50   * PostgreSQL command parse executor.
51   */
52  @RequiredArgsConstructor
53  public final class PostgreSQLComParseExecutor implements CommandExecutor {
54      
55      private final PostgreSQLComParsePacket packet;
56      
57      private final ConnectionSession connectionSession;
58      
59      @Override
60      public Collection<DatabasePacket> execute() {
61          SQLParserEngine sqlParserEngine = createShardingSphereSQLParserEngine(connectionSession.getDatabaseName());
62          String sql = packet.getSQL();
63          SQLStatement sqlStatement = sqlParserEngine.parse(sql, true);
64          String escapedSql = escape(sqlStatement, sql);
65          if (!escapedSql.equalsIgnoreCase(sql)) {
66              sqlStatement = sqlParserEngine.parse(escapedSql, true);
67              sql = escapedSql;
68          }
69          List<Integer> actualParameterMarkerIndexes = new ArrayList<>();
70          if (sqlStatement.getParameterCount() > 0) {
71              List<ParameterMarkerSegment> parameterMarkerSegments = new ArrayList<>(((AbstractSQLStatement) sqlStatement).getParameterMarkerSegments());
72              for (ParameterMarkerSegment each : parameterMarkerSegments) {
73                  actualParameterMarkerIndexes.add(each.getParameterIndex());
74              }
75              sql = convertSQLToJDBCStyle(parameterMarkerSegments, sql);
76              sqlStatement = sqlParserEngine.parse(sql, true);
77          }
78          List<PostgreSQLColumnType> paddedColumnTypes = paddingColumnTypes(sqlStatement.getParameterCount(), packet.readParameterTypes());
79          SQLStatementContext sqlStatementContext = sqlStatement instanceof DistSQLStatement ? new DistSQLStatementContext((DistSQLStatement) sqlStatement)
80                  : new SQLBindEngine(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData(), connectionSession.getDefaultDatabaseName(), packet.getHintValueContext())
81                          .bind(sqlStatement, Collections.emptyList());
82          PostgreSQLServerPreparedStatement serverPreparedStatement = new PostgreSQLServerPreparedStatement(sql, sqlStatementContext, packet.getHintValueContext(), paddedColumnTypes,
83                  actualParameterMarkerIndexes);
84          connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(packet.getStatementId(), serverPreparedStatement);
85          return Collections.singleton(PostgreSQLParseCompletePacket.getInstance());
86      }
87      
88      private SQLParserEngine createShardingSphereSQLParserEngine(final String databaseName) {
89          MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
90          SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
91          return sqlParserRule.getSQLParserEngine(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType());
92      }
93      
94      private String escape(final SQLStatement sqlStatement, final String sql) {
95          if (sqlStatement instanceof DMLStatement) {
96              return sql.replace("?", "??");
97          }
98          return sql;
99      }
100     
101     private String convertSQLToJDBCStyle(final List<ParameterMarkerSegment> parameterMarkerSegments, final String sql) {
102         parameterMarkerSegments.sort(Comparator.comparingInt(SQLSegment::getStopIndex));
103         StringBuilder result = new StringBuilder(sql);
104         for (int i = parameterMarkerSegments.size() - 1; i >= 0; i--) {
105             ParameterMarkerSegment each = parameterMarkerSegments.get(i);
106             result.replace(each.getStartIndex(), each.getStopIndex() + 1, ParameterMarkerType.QUESTION.getMarker());
107         }
108         return result.toString();
109     }
110     
111     private List<PostgreSQLColumnType> paddingColumnTypes(final int parameterCount, final List<PostgreSQLColumnType> specifiedColumnTypes) {
112         if (parameterCount == specifiedColumnTypes.size()) {
113             return specifiedColumnTypes;
114         }
115         List<PostgreSQLColumnType> result = new ArrayList<>(parameterCount);
116         result.addAll(specifiedColumnTypes);
117         int unspecifiedCount = parameterCount - specifiedColumnTypes.size();
118         for (int i = 0; i < unspecifiedCount; i++) {
119             result.add(PostgreSQLColumnType.UNSPECIFIED);
120         }
121         return result;
122     }
123 }