View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.parse;
19  
20  import lombok.RequiredArgsConstructor;
21  import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
22  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
23  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLComParsePacket;
24  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.parse.PostgreSQLParseCompletePacket;
25  import org.apache.shardingsphere.distsql.statement.DistSQLStatement;
26  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
27  import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
28  import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
29  import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
30  import org.apache.shardingsphere.infra.parser.SQLParserEngine;
31  import org.apache.shardingsphere.parser.rule.SQLParserRule;
32  import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
33  import org.apache.shardingsphere.proxy.backend.distsql.DistSQLStatementContext;
34  import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
35  import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
36  import org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.PostgreSQLServerPreparedStatement;
37  import org.apache.shardingsphere.sql.parser.statement.core.enums.ParameterMarkerType;
38  import org.apache.shardingsphere.sql.parser.statement.core.segment.SQLSegment;
39  import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.ParameterMarkerSegment;
40  import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
41  import org.apache.shardingsphere.sql.parser.statement.core.statement.type.dml.DMLStatement;
42  
43  import java.util.ArrayList;
44  import java.util.Collection;
45  import java.util.Collections;
46  import java.util.Comparator;
47  import java.util.List;
48  
49  /**
50   * PostgreSQL command parse executor.
51   */
52  @RequiredArgsConstructor
53  public final class PostgreSQLComParseExecutor implements CommandExecutor {
54      
55      private final PostgreSQLComParsePacket packet;
56      
57      private final ConnectionSession connectionSession;
58      
59      @Override
60      public Collection<DatabasePacket> execute() {
61          ShardingSphereMetaData metaData = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData();
62          DatabaseType databaseType = metaData.getDatabase(connectionSession.getUsedDatabaseName()).getProtocolType();
63          SQLParserEngine sqlParserEngine = metaData.getGlobalRuleMetaData().getSingleRule(SQLParserRule.class).getSQLParserEngine(databaseType);
64          String sql = packet.getSQL();
65          SQLStatement sqlStatement = sqlParserEngine.parse(sql, true);
66          String escapedSql = escape(sqlStatement, sql);
67          if (!escapedSql.equalsIgnoreCase(sql)) {
68              sqlStatement = sqlParserEngine.parse(escapedSql, true);
69              sql = escapedSql;
70          }
71          List<Integer> actualParameterMarkerIndexes = new ArrayList<>(sqlStatement.getParameterMarkers().size());
72          if (sqlStatement.getParameterCount() > 0) {
73              List<ParameterMarkerSegment> parameterMarkerSegments = new ArrayList<>(sqlStatement.getParameterMarkers());
74              for (ParameterMarkerSegment each : parameterMarkerSegments) {
75                  actualParameterMarkerIndexes.add(each.getParameterIndex());
76              }
77              sql = convertSQLToJDBCStyle(parameterMarkerSegments, sql);
78              sqlStatement = sqlParserEngine.parse(sql, true);
79          }
80          List<PostgreSQLColumnType> paddedColumnTypes = paddingColumnTypes(sqlStatement.getParameterCount(), packet.readParameterTypes());
81          SQLStatementContext sqlStatementContext = sqlStatement instanceof DistSQLStatement
82                  ? new DistSQLStatementContext((DistSQLStatement) sqlStatement)
83                  : new SQLBindEngine(metaData, connectionSession.getCurrentDatabaseName(), packet.getHintValueContext()).bind(databaseType, sqlStatement, Collections.emptyList());
84          PostgreSQLServerPreparedStatement serverPreparedStatement = new PostgreSQLServerPreparedStatement(
85                  sql, sqlStatementContext, packet.getHintValueContext(), paddedColumnTypes, actualParameterMarkerIndexes);
86          connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(packet.getStatementId(), serverPreparedStatement);
87          return Collections.singleton(PostgreSQLParseCompletePacket.getInstance());
88      }
89      
90      private String escape(final SQLStatement sqlStatement, final String sql) {
91          return sqlStatement instanceof DMLStatement ? sql.replace("?", "??") : sql;
92      }
93      
94      private String convertSQLToJDBCStyle(final List<ParameterMarkerSegment> parameterMarkerSegments, final String sql) {
95          parameterMarkerSegments.sort(Comparator.comparingInt(SQLSegment::getStopIndex));
96          StringBuilder result = new StringBuilder(sql);
97          for (int i = parameterMarkerSegments.size() - 1; i >= 0; i--) {
98              ParameterMarkerSegment each = parameterMarkerSegments.get(i);
99              result.replace(each.getStartIndex(), each.getStopIndex() + 1, ParameterMarkerType.QUESTION.getMarker());
100         }
101         return result.toString();
102     }
103     
104     private List<PostgreSQLColumnType> paddingColumnTypes(final int parameterCount, final List<PostgreSQLColumnType> specifiedColumnTypes) {
105         if (parameterCount == specifiedColumnTypes.size()) {
106             return specifiedColumnTypes;
107         }
108         List<PostgreSQLColumnType> result = new ArrayList<>(parameterCount);
109         result.addAll(specifiedColumnTypes);
110         int unspecifiedCount = parameterCount - specifiedColumnTypes.size();
111         for (int i = 0; i < unspecifiedCount; i++) {
112             result.add(PostgreSQLColumnType.UNSPECIFIED);
113         }
114         return result;
115     }
116 }