View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.prepare;
19  
20  import lombok.AccessLevel;
21  import lombok.NoArgsConstructor;
22  import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
23  import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.column.ColumnNotFoundException;
24  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
25  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
26  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;
27  import org.apache.shardingsphere.infra.metadata.identifier.ShardingSphereIdentifier;
28  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.ColumnAssignmentSegment;
29  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.InsertValuesSegment;
30  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
31  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
32  import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
33  import org.apache.shardingsphere.sql.parser.statement.core.statement.type.dml.InsertStatement;
34  
35  import java.util.ArrayList;
36  import java.util.Collection;
37  import java.util.Collections;
38  import java.util.LinkedList;
39  import java.util.List;
40  import java.util.stream.Collectors;
41  
42  /**
43   * Parameter marker extractor for MySQL COM_STMT_PREPARE.
44   */
45  @NoArgsConstructor(access = AccessLevel.PRIVATE)
46  public final class MySQLComStmtPrepareParameterMarkerExtractor {
47      
48      /**
49       * TODO Support more statements and syntax.
50       * Find corresponding columns of parameter markers.
51       *
52       * @param sqlStatement SQL statement
53       * @param schema schema
54       * @return corresponding columns of parameter markers
55       */
56      public static List<ShardingSphereColumn> findColumnsOfParameterMarkers(final SQLStatement sqlStatement, final ShardingSphereSchema schema) {
57          return sqlStatement instanceof InsertStatement && ((InsertStatement) sqlStatement).getTable().isPresent()
58                  ? findColumnsOfParameterMarkersForInsert((InsertStatement) sqlStatement, schema)
59                  : Collections.emptyList();
60      }
61      
62      private static List<ShardingSphereColumn> findColumnsOfParameterMarkersForInsert(final InsertStatement insertStatement, final ShardingSphereSchema schema) {
63          ShardingSphereTable table = schema.getTable(insertStatement.getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse(""));
64          List<String> columnNamesOfInsert = getColumnNamesOfInsertStatement(insertStatement, table);
65          List<ShardingSphereColumn> result = getParameterMarkerColumns(insertStatement, table, columnNamesOfInsert);
66          insertStatement.getOnDuplicateKeyColumns().ifPresent(optional -> result.addAll(getOnDuplicateKeyParameterMarkerColumns(optional.getColumns(), table)));
67          return result;
68      }
69      
70      private static List<String> getColumnNamesOfInsertStatement(final InsertStatement insertStatement, final ShardingSphereTable table) {
71          return insertStatement.getColumns().isEmpty()
72                  ? table.getColumnNames().stream().map(ShardingSphereIdentifier::getValue).collect(Collectors.toList())
73                  : insertStatement.getColumns().stream().map(each -> each.getIdentifier().getValue()).collect(Collectors.toList());
74      }
75      
76      private static List<ShardingSphereColumn> getParameterMarkerColumns(final InsertStatement insertStatement, final ShardingSphereTable table, final List<String> columnNamesOfInsert) {
77          List<ShardingSphereColumn> result = new ArrayList<>(insertStatement.getParameterMarkers().size());
78          for (InsertValuesSegment each : insertStatement.getValues()) {
79              result.addAll(getParameterMarkerColumns(table, columnNamesOfInsert, each));
80          }
81          return result;
82      }
83      
84      private static List<ShardingSphereColumn> getParameterMarkerColumns(final ShardingSphereTable table, final List<String> columnNamesOfInsert, final InsertValuesSegment segment) {
85          List<ShardingSphereColumn> result = new LinkedList<>();
86          int index = 0;
87          for (ExpressionSegment each : segment.getValues()) {
88              if (each instanceof ParameterMarkerExpressionSegment) {
89                  String columnName = columnNamesOfInsert.get(index);
90                  ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(table.getName(), columnName));
91                  result.add(table.getColumn(columnName));
92              }
93              index++;
94          }
95          return result;
96      }
97      
98      private static List<ShardingSphereColumn> getOnDuplicateKeyParameterMarkerColumns(final Collection<ColumnAssignmentSegment> onDuplicateKeyColumns, final ShardingSphereTable table) {
99          List<ShardingSphereColumn> result = new LinkedList<>();
100         for (ColumnAssignmentSegment each : onDuplicateKeyColumns) {
101             if (each.getValue() instanceof ParameterMarkerExpressionSegment) {
102                 String columnName = each.getColumns().iterator().next().getIdentifier().getValue();
103                 ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(table.getName(), columnName));
104                 result.add(table.getColumn(columnName));
105             }
106         }
107         return result;
108     }
109 }