1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.infra.binder.statement.dml;
19
20 import com.cedarsoftware.util.CaseInsensitiveMap;
21 import lombok.SneakyThrows;
22 import org.apache.shardingsphere.infra.binder.enums.SegmentType;
23 import org.apache.shardingsphere.infra.binder.segment.column.InsertColumnsSegmentBinder;
24 import org.apache.shardingsphere.infra.binder.segment.expression.ExpressionSegmentBinder;
25 import org.apache.shardingsphere.infra.binder.segment.expression.impl.ColumnSegmentBinder;
26 import org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinder;
27 import org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinderContext;
28 import org.apache.shardingsphere.infra.binder.segment.parameter.ParameterMarkerSegmentBinder;
29 import org.apache.shardingsphere.infra.binder.segment.where.WhereSegmentBinder;
30 import org.apache.shardingsphere.infra.binder.statement.SQLStatementBinder;
31 import org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
32 import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
33 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
34 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
35 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
36 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
37 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.InsertColumnsSegment;
38 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
39 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
40 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
41 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
42 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
43 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
44 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
45 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
46 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
47 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
48 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
49 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.MergeStatement;
50 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
51 import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
52 import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.UpdateStatementHandler;
53
54 import java.util.ArrayList;
55 import java.util.Collection;
56 import java.util.Collections;
57 import java.util.LinkedHashMap;
58 import java.util.LinkedList;
59 import java.util.List;
60 import java.util.Map;
61
62
63
64
65 public final class MergeStatementBinder implements SQLStatementBinder<MergeStatement> {
66
67 @Override
68 public MergeStatement bind(final MergeStatement sqlStatement, final ShardingSphereMetaData metaData, final String defaultDatabaseName) {
69 return bind(sqlStatement, metaData, defaultDatabaseName, Collections.emptyMap());
70 }
71
72 @SneakyThrows
73 private MergeStatement bind(final MergeStatement sqlStatement, final ShardingSphereMetaData metaData, final String defaultDatabaseName,
74 final Map<String, TableSegmentBinderContext> externalTableBinderContexts) {
75 MergeStatement result = sqlStatement.getClass().getDeclaredConstructor().newInstance();
76 SQLStatementBinderContext statementBinderContext = new SQLStatementBinderContext(metaData, defaultDatabaseName, sqlStatement.getDatabaseType(), sqlStatement.getVariableNames());
77 statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts);
78 Map<String, TableSegmentBinderContext> targetTableBinderContexts = new CaseInsensitiveMap<>();
79 TableSegment boundedTargetTableSegment = TableSegmentBinder.bind(sqlStatement.getTarget(), statementBinderContext, targetTableBinderContexts, Collections.emptyMap());
80 Map<String, TableSegmentBinderContext> sourceTableBinderContexts = new CaseInsensitiveMap<>();
81 TableSegment boundedSourceTableSegment = TableSegmentBinder.bind(sqlStatement.getSource(), statementBinderContext, sourceTableBinderContexts, Collections.emptyMap());
82 result.setTarget(boundedTargetTableSegment);
83 result.setSource(boundedSourceTableSegment);
84 Map<String, TableSegmentBinderContext> tableBinderContexts = new LinkedHashMap<>();
85 tableBinderContexts.putAll(sourceTableBinderContexts);
86 tableBinderContexts.putAll(targetTableBinderContexts);
87 if (null != sqlStatement.getExpression()) {
88 ExpressionWithParamsSegment expression = new ExpressionWithParamsSegment(sqlStatement.getExpression().getStartIndex(), sqlStatement.getExpression().getStopIndex(),
89 ExpressionSegmentBinder.bind(sqlStatement.getExpression().getExpr(), SegmentType.JOIN_ON, statementBinderContext, tableBinderContexts, Collections.emptyMap()));
90 expression.getParameterMarkerSegments().addAll(sqlStatement.getExpression().getParameterMarkerSegments());
91 result.setExpression(expression);
92 }
93 sqlStatement.getInsert().ifPresent(
94 optional -> result.setInsert(bindMergeInsert(optional, (SimpleTableSegment) boundedTargetTableSegment, statementBinderContext, targetTableBinderContexts, sourceTableBinderContexts)));
95 sqlStatement.getUpdate().ifPresent(
96 optional -> result.setUpdate(bindMergeUpdate(optional, (SimpleTableSegment) boundedTargetTableSegment, statementBinderContext, targetTableBinderContexts, sourceTableBinderContexts)));
97 addParameterMarkerSegments(result);
98 result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
99 return result;
100 }
101
102 private void addParameterMarkerSegments(final MergeStatement mergeStatement) {
103
104 mergeStatement.addParameterMarkerSegments(getSourceSubqueryTableProjectionParameterMarkers(mergeStatement.getSource()));
105 mergeStatement.getInsert().ifPresent(optional -> mergeStatement.addParameterMarkerSegments(optional.getParameterMarkerSegments()));
106 mergeStatement.getUpdate().ifPresent(optional -> mergeStatement.addParameterMarkerSegments(optional.getParameterMarkerSegments()));
107 }
108
109 private Collection<ParameterMarkerSegment> getSourceSubqueryTableProjectionParameterMarkers(final TableSegment tableSegment) {
110 if (!(tableSegment instanceof SubqueryTableSegment)) {
111 return Collections.emptyList();
112 }
113 SubqueryTableSegment subqueryTable = (SubqueryTableSegment) tableSegment;
114 Collection<ParameterMarkerSegment> result = new LinkedList<>();
115 for (ProjectionSegment each : subqueryTable.getSubquery().getSelect().getProjections().getProjections()) {
116 if (each instanceof ParameterMarkerExpressionSegment) {
117 result.add((ParameterMarkerSegment) each);
118 }
119 }
120 return result;
121 }
122
123 @SneakyThrows
124 private InsertStatement bindMergeInsert(final InsertStatement sqlStatement, final SimpleTableSegment tableSegment, final SQLStatementBinderContext statementBinderContext,
125 final Map<String, TableSegmentBinderContext> targetTableBinderContexts, final Map<String, TableSegmentBinderContext> sourceTableBinderContexts) {
126 SQLStatementBinderContext insertStatementBinderContext = new SQLStatementBinderContext(statementBinderContext.getMetaData(), statementBinderContext.getDefaultDatabaseName(),
127 statementBinderContext.getDatabaseType(), statementBinderContext.getVariableNames());
128 insertStatementBinderContext.getExternalTableBinderContexts().putAll(statementBinderContext.getExternalTableBinderContexts());
129 insertStatementBinderContext.getExternalTableBinderContexts().putAll(sourceTableBinderContexts);
130 InsertStatement result = sqlStatement.getClass().getDeclaredConstructor().newInstance();
131 result.setTable(tableSegment);
132 sqlStatement.getInsertColumns()
133 .ifPresent(optional -> result.setInsertColumns(InsertColumnsSegmentBinder.bind(sqlStatement.getInsertColumns().get(), statementBinderContext, targetTableBinderContexts)));
134 sqlStatement.getInsertSelect().ifPresent(result::setInsertSelect);
135 Collection<InsertValuesSegment> insertValues = new LinkedList<>();
136 Map<ParameterMarkerSegment, ColumnSegmentBoundedInfo> parameterMarkerSegmentBoundedInfos = new LinkedHashMap<>();
137 List<ColumnSegment> columnSegments = new ArrayList<>(result.getInsertColumns().map(InsertColumnsSegment::getColumns)
138 .orElseGet(() -> getVisibleColumns(targetTableBinderContexts.values().iterator().next().getProjectionSegments())));
139 for (InsertValuesSegment each : sqlStatement.getValues()) {
140 List<ExpressionSegment> values = new LinkedList<>();
141 int index = 0;
142 for (ExpressionSegment expression : each.getValues()) {
143 values.add(ExpressionSegmentBinder.bind(expression, SegmentType.VALUES, insertStatementBinderContext, targetTableBinderContexts, sourceTableBinderContexts));
144 if (expression instanceof ParameterMarkerSegment) {
145 parameterMarkerSegmentBoundedInfos.put((ParameterMarkerSegment) expression, columnSegments.get(index).getColumnBoundedInfo());
146 }
147 index++;
148 }
149 insertValues.add(new InsertValuesSegment(each.getStartIndex(), each.getStopIndex(), values));
150 }
151 result.getValues().addAll(insertValues);
152 InsertStatementHandler.getOnDuplicateKeyColumnsSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setOnDuplicateKeyColumnsSegment(result, optional));
153 InsertStatementHandler.getSetAssignmentSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setSetAssignmentSegment(result, optional));
154 InsertStatementHandler.getWithSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setWithSegment(result, optional));
155 InsertStatementHandler.getOutputSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setOutputSegment(result, optional));
156 InsertStatementHandler.getMultiTableInsertType(sqlStatement).ifPresent(optional -> InsertStatementHandler.setMultiTableInsertType(result, optional));
157 InsertStatementHandler.getMultiTableInsertIntoSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setMultiTableInsertIntoSegment(result, optional));
158 InsertStatementHandler.getMultiTableConditionalIntoSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setMultiTableConditionalIntoSegment(result, optional));
159 InsertStatementHandler.getReturningSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setReturningSegment(result, optional));
160 InsertStatementHandler.getWhereSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setWhereSegment(result,
161 WhereSegmentBinder.bind(optional, insertStatementBinderContext, targetTableBinderContexts, sourceTableBinderContexts)));
162 result.addParameterMarkerSegments(ParameterMarkerSegmentBinder.bind(sqlStatement.getParameterMarkerSegments(), parameterMarkerSegmentBoundedInfos));
163 result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
164 return result;
165 }
166
167 private Collection<ColumnSegment> getVisibleColumns(final Collection<ProjectionSegment> projectionSegments) {
168 Collection<ColumnSegment> result = new LinkedList<>();
169 for (ProjectionSegment each : projectionSegments) {
170 if (each instanceof ColumnProjectionSegment && each.isVisible()) {
171 result.add(((ColumnProjectionSegment) each).getColumn());
172 }
173 }
174 return result;
175 }
176
177 @SneakyThrows
178 private UpdateStatement bindMergeUpdate(final UpdateStatement sqlStatement, final SimpleTableSegment tableSegment, final SQLStatementBinderContext statementBinderContext,
179 final Map<String, TableSegmentBinderContext> targetTableBinderContexts, final Map<String, TableSegmentBinderContext> sourceTableBinderContexts) {
180 UpdateStatement result = sqlStatement.getClass().getDeclaredConstructor().newInstance();
181 result.setTable(tableSegment);
182 Collection<ColumnAssignmentSegment> assignments = new LinkedList<>();
183 SQLStatementBinderContext updateStatementBinderContext = new SQLStatementBinderContext(statementBinderContext.getMetaData(), statementBinderContext.getDefaultDatabaseName(),
184 statementBinderContext.getDatabaseType(), statementBinderContext.getVariableNames());
185 updateStatementBinderContext.getExternalTableBinderContexts().putAll(statementBinderContext.getExternalTableBinderContexts());
186 updateStatementBinderContext.getExternalTableBinderContexts().putAll(sourceTableBinderContexts);
187 Map<ParameterMarkerSegment, ColumnSegmentBoundedInfo> parameterMarkerSegmentBoundedInfos = new LinkedHashMap<>();
188 for (ColumnAssignmentSegment each : sqlStatement.getSetAssignment().getAssignments()) {
189 List<ColumnSegment> columnSegments = new ArrayList<>(each.getColumns().size());
190 each.getColumns().forEach(column -> columnSegments.add(
191 ColumnSegmentBinder.bind(column, SegmentType.SET_ASSIGNMENT, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap())));
192 ExpressionSegment expression = ExpressionSegmentBinder.bind(each.getValue(), SegmentType.SET_ASSIGNMENT, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap());
193 ColumnAssignmentSegment columnAssignmentSegment = new ColumnAssignmentSegment(each.getStartIndex(), each.getStopIndex(), columnSegments, expression);
194 assignments.add(columnAssignmentSegment);
195 if (expression instanceof ParameterMarkerSegment) {
196 parameterMarkerSegmentBoundedInfos.put((ParameterMarkerSegment) expression, columnAssignmentSegment.getColumns().get(0).getColumnBoundedInfo());
197 }
198 }
199 SetAssignmentSegment setAssignmentSegment = new SetAssignmentSegment(sqlStatement.getSetAssignment().getStartIndex(), sqlStatement.getSetAssignment().getStopIndex(), assignments);
200 result.setSetAssignment(setAssignmentSegment);
201 sqlStatement.getWhere().ifPresent(optional -> result.setWhere(WhereSegmentBinder.bind(optional, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap())));
202 UpdateStatementHandler.getDeleteWhereSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setDeleteWhereSegment(result,
203 WhereSegmentBinder.bind(optional, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap())));
204 UpdateStatementHandler.getOrderBySegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setOrderBySegment(result, optional));
205 UpdateStatementHandler.getLimitSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setLimitSegment(result, optional));
206 UpdateStatementHandler.getWithSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setWithSegment(result, optional));
207 result.addParameterMarkerSegments(ParameterMarkerSegmentBinder.bind(sqlStatement.getParameterMarkerSegments(), parameterMarkerSegmentBoundedInfos));
208 result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
209 return result;
210 }
211 }