1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.prepare;
19
20 import lombok.AccessLevel;
21 import lombok.NoArgsConstructor;
22 import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
23 import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.column.ColumnNotFoundException;
24 import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
25 import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
26 import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;
27 import org.apache.shardingsphere.infra.metadata.identifier.ShardingSphereIdentifier;
28 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.ColumnAssignmentSegment;
29 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.InsertValuesSegment;
30 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
31 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
32 import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
33 import org.apache.shardingsphere.sql.parser.statement.core.statement.type.dml.InsertStatement;
34
35 import java.util.ArrayList;
36 import java.util.Collection;
37 import java.util.Collections;
38 import java.util.LinkedList;
39 import java.util.List;
40 import java.util.stream.Collectors;
41
42
43
44
45 @NoArgsConstructor(access = AccessLevel.PRIVATE)
46 public final class MySQLComStmtPrepareParameterMarkerExtractor {
47
48
49
50
51
52
53
54
55
56 public static List<ShardingSphereColumn> findColumnsOfParameterMarkers(final SQLStatement sqlStatement, final ShardingSphereSchema schema) {
57 return sqlStatement instanceof InsertStatement && ((InsertStatement) sqlStatement).getTable().isPresent()
58 ? findColumnsOfParameterMarkersForInsert((InsertStatement) sqlStatement, schema)
59 : Collections.emptyList();
60 }
61
62 private static List<ShardingSphereColumn> findColumnsOfParameterMarkersForInsert(final InsertStatement insertStatement, final ShardingSphereSchema schema) {
63 ShardingSphereTable table = schema.getTable(insertStatement.getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse(""));
64 List<String> columnNamesOfInsert = getColumnNamesOfInsertStatement(insertStatement, table);
65 List<ShardingSphereColumn> result = getParameterMarkerColumns(insertStatement, table, columnNamesOfInsert);
66 insertStatement.getOnDuplicateKeyColumns().ifPresent(optional -> result.addAll(getOnDuplicateKeyParameterMarkerColumns(optional.getColumns(), table)));
67 return result;
68 }
69
70 private static List<String> getColumnNamesOfInsertStatement(final InsertStatement insertStatement, final ShardingSphereTable table) {
71 return insertStatement.getColumns().isEmpty()
72 ? table.getColumnNames().stream().map(ShardingSphereIdentifier::getValue).collect(Collectors.toList())
73 : insertStatement.getColumns().stream().map(each -> each.getIdentifier().getValue()).collect(Collectors.toList());
74 }
75
76 private static List<ShardingSphereColumn> getParameterMarkerColumns(final InsertStatement insertStatement, final ShardingSphereTable table, final List<String> columnNamesOfInsert) {
77 List<ShardingSphereColumn> result = new ArrayList<>(insertStatement.getParameterMarkers().size());
78 for (InsertValuesSegment each : insertStatement.getValues()) {
79 result.addAll(getParameterMarkerColumns(table, columnNamesOfInsert, each));
80 }
81 return result;
82 }
83
84 private static List<ShardingSphereColumn> getParameterMarkerColumns(final ShardingSphereTable table, final List<String> columnNamesOfInsert, final InsertValuesSegment segment) {
85 List<ShardingSphereColumn> result = new LinkedList<>();
86 int index = 0;
87 for (ExpressionSegment each : segment.getValues()) {
88 if (each instanceof ParameterMarkerExpressionSegment) {
89 String columnName = columnNamesOfInsert.get(index);
90 ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(table.getName(), columnName));
91 result.add(table.getColumn(columnName));
92 }
93 index++;
94 }
95 return result;
96 }
97
98 private static List<ShardingSphereColumn> getOnDuplicateKeyParameterMarkerColumns(final Collection<ColumnAssignmentSegment> onDuplicateKeyColumns, final ShardingSphereTable table) {
99 List<ShardingSphereColumn> result = new LinkedList<>();
100 for (ColumnAssignmentSegment each : onDuplicateKeyColumns) {
101 if (each.getValue() instanceof ParameterMarkerExpressionSegment) {
102 String columnName = each.getColumns().iterator().next().getIdentifier().getValue();
103 ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(table.getName(), columnName));
104 result.add(table.getColumn(columnName));
105 }
106 }
107 return result;
108 }
109 }