View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.infra.binder.statement.dml;
19  
20  import com.cedarsoftware.util.CaseInsensitiveMap;
21  import lombok.SneakyThrows;
22  import org.apache.shardingsphere.infra.binder.enums.SegmentType;
23  import org.apache.shardingsphere.infra.binder.segment.column.InsertColumnsSegmentBinder;
24  import org.apache.shardingsphere.infra.binder.segment.expression.ExpressionSegmentBinder;
25  import org.apache.shardingsphere.infra.binder.segment.expression.impl.ColumnSegmentBinder;
26  import org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinder;
27  import org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinderContext;
28  import org.apache.shardingsphere.infra.binder.segment.parameter.ParameterMarkerSegmentBinder;
29  import org.apache.shardingsphere.infra.binder.segment.where.WhereSegmentBinder;
30  import org.apache.shardingsphere.infra.binder.statement.SQLStatementBinder;
31  import org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
32  import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
33  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
34  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
35  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
36  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
37  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.InsertColumnsSegment;
38  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
39  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
40  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
41  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
42  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
43  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
44  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
45  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
46  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
47  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
48  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
49  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.MergeStatement;
50  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
51  import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
52  import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.UpdateStatementHandler;
53  
54  import java.util.ArrayList;
55  import java.util.Collection;
56  import java.util.Collections;
57  import java.util.LinkedHashMap;
58  import java.util.LinkedList;
59  import java.util.List;
60  import java.util.Map;
61  
62  /**
63   * Merge statement binder.
64   */
65  public final class MergeStatementBinder implements SQLStatementBinder<MergeStatement> {
66      
67      @Override
68      public MergeStatement bind(final MergeStatement sqlStatement, final ShardingSphereMetaData metaData, final String defaultDatabaseName) {
69          return bind(sqlStatement, metaData, defaultDatabaseName, Collections.emptyMap());
70      }
71      
72      @SneakyThrows
73      private MergeStatement bind(final MergeStatement sqlStatement, final ShardingSphereMetaData metaData, final String defaultDatabaseName,
74                                  final Map<String, TableSegmentBinderContext> externalTableBinderContexts) {
75          MergeStatement result = sqlStatement.getClass().getDeclaredConstructor().newInstance();
76          SQLStatementBinderContext statementBinderContext = new SQLStatementBinderContext(metaData, defaultDatabaseName, sqlStatement.getDatabaseType(), sqlStatement.getVariableNames());
77          statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts);
78          Map<String, TableSegmentBinderContext> targetTableBinderContexts = new CaseInsensitiveMap<>();
79          TableSegment boundedTargetTableSegment = TableSegmentBinder.bind(sqlStatement.getTarget(), statementBinderContext, targetTableBinderContexts, Collections.emptyMap());
80          Map<String, TableSegmentBinderContext> sourceTableBinderContexts = new CaseInsensitiveMap<>();
81          TableSegment boundedSourceTableSegment = TableSegmentBinder.bind(sqlStatement.getSource(), statementBinderContext, sourceTableBinderContexts, Collections.emptyMap());
82          result.setTarget(boundedTargetTableSegment);
83          result.setSource(boundedSourceTableSegment);
84          Map<String, TableSegmentBinderContext> tableBinderContexts = new LinkedHashMap<>();
85          tableBinderContexts.putAll(sourceTableBinderContexts);
86          tableBinderContexts.putAll(targetTableBinderContexts);
87          if (null != sqlStatement.getExpression()) {
88              ExpressionWithParamsSegment expression = new ExpressionWithParamsSegment(sqlStatement.getExpression().getStartIndex(), sqlStatement.getExpression().getStopIndex(),
89                      ExpressionSegmentBinder.bind(sqlStatement.getExpression().getExpr(), SegmentType.JOIN_ON, statementBinderContext, tableBinderContexts, Collections.emptyMap()));
90              expression.getParameterMarkerSegments().addAll(sqlStatement.getExpression().getParameterMarkerSegments());
91              result.setExpression(expression);
92          }
93          sqlStatement.getInsert().ifPresent(
94                  optional -> result.setInsert(bindMergeInsert(optional, (SimpleTableSegment) boundedTargetTableSegment, statementBinderContext, targetTableBinderContexts, sourceTableBinderContexts)));
95          sqlStatement.getUpdate().ifPresent(
96                  optional -> result.setUpdate(bindMergeUpdate(optional, (SimpleTableSegment) boundedTargetTableSegment, statementBinderContext, targetTableBinderContexts, sourceTableBinderContexts)));
97          addParameterMarkerSegments(result);
98          result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
99          return result;
100     }
101     
102     private void addParameterMarkerSegments(final MergeStatement mergeStatement) {
103         // TODO bind parameter marker segments for merge statement
104         mergeStatement.addParameterMarkerSegments(getSourceSubqueryTableProjectionParameterMarkers(mergeStatement.getSource()));
105         mergeStatement.getInsert().ifPresent(optional -> mergeStatement.addParameterMarkerSegments(optional.getParameterMarkerSegments()));
106         mergeStatement.getUpdate().ifPresent(optional -> mergeStatement.addParameterMarkerSegments(optional.getParameterMarkerSegments()));
107     }
108     
109     private Collection<ParameterMarkerSegment> getSourceSubqueryTableProjectionParameterMarkers(final TableSegment tableSegment) {
110         if (!(tableSegment instanceof SubqueryTableSegment)) {
111             return Collections.emptyList();
112         }
113         SubqueryTableSegment subqueryTable = (SubqueryTableSegment) tableSegment;
114         Collection<ParameterMarkerSegment> result = new LinkedList<>();
115         for (ProjectionSegment each : subqueryTable.getSubquery().getSelect().getProjections().getProjections()) {
116             if (each instanceof ParameterMarkerExpressionSegment) {
117                 result.add((ParameterMarkerSegment) each);
118             }
119         }
120         return result;
121     }
122     
123     @SneakyThrows
124     private InsertStatement bindMergeInsert(final InsertStatement sqlStatement, final SimpleTableSegment tableSegment, final SQLStatementBinderContext statementBinderContext,
125                                             final Map<String, TableSegmentBinderContext> targetTableBinderContexts, final Map<String, TableSegmentBinderContext> sourceTableBinderContexts) {
126         SQLStatementBinderContext insertStatementBinderContext = new SQLStatementBinderContext(statementBinderContext.getMetaData(), statementBinderContext.getDefaultDatabaseName(),
127                 statementBinderContext.getDatabaseType(), statementBinderContext.getVariableNames());
128         insertStatementBinderContext.getExternalTableBinderContexts().putAll(statementBinderContext.getExternalTableBinderContexts());
129         insertStatementBinderContext.getExternalTableBinderContexts().putAll(sourceTableBinderContexts);
130         InsertStatement result = sqlStatement.getClass().getDeclaredConstructor().newInstance();
131         result.setTable(tableSegment);
132         sqlStatement.getInsertColumns()
133                 .ifPresent(optional -> result.setInsertColumns(InsertColumnsSegmentBinder.bind(sqlStatement.getInsertColumns().get(), statementBinderContext, targetTableBinderContexts)));
134         sqlStatement.getInsertSelect().ifPresent(result::setInsertSelect);
135         Collection<InsertValuesSegment> insertValues = new LinkedList<>();
136         Map<ParameterMarkerSegment, ColumnSegmentBoundedInfo> parameterMarkerSegmentBoundedInfos = new LinkedHashMap<>();
137         List<ColumnSegment> columnSegments = new ArrayList<>(result.getInsertColumns().map(InsertColumnsSegment::getColumns)
138                 .orElseGet(() -> getVisibleColumns(targetTableBinderContexts.values().iterator().next().getProjectionSegments())));
139         for (InsertValuesSegment each : sqlStatement.getValues()) {
140             List<ExpressionSegment> values = new LinkedList<>();
141             int index = 0;
142             for (ExpressionSegment expression : each.getValues()) {
143                 values.add(ExpressionSegmentBinder.bind(expression, SegmentType.VALUES, insertStatementBinderContext, targetTableBinderContexts, sourceTableBinderContexts));
144                 if (expression instanceof ParameterMarkerSegment) {
145                     parameterMarkerSegmentBoundedInfos.put((ParameterMarkerSegment) expression, columnSegments.get(index).getColumnBoundedInfo());
146                 }
147                 index++;
148             }
149             insertValues.add(new InsertValuesSegment(each.getStartIndex(), each.getStopIndex(), values));
150         }
151         result.getValues().addAll(insertValues);
152         InsertStatementHandler.getOnDuplicateKeyColumnsSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setOnDuplicateKeyColumnsSegment(result, optional));
153         InsertStatementHandler.getSetAssignmentSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setSetAssignmentSegment(result, optional));
154         InsertStatementHandler.getWithSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setWithSegment(result, optional));
155         InsertStatementHandler.getOutputSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setOutputSegment(result, optional));
156         InsertStatementHandler.getMultiTableInsertType(sqlStatement).ifPresent(optional -> InsertStatementHandler.setMultiTableInsertType(result, optional));
157         InsertStatementHandler.getMultiTableInsertIntoSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setMultiTableInsertIntoSegment(result, optional));
158         InsertStatementHandler.getMultiTableConditionalIntoSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setMultiTableConditionalIntoSegment(result, optional));
159         InsertStatementHandler.getReturningSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setReturningSegment(result, optional));
160         InsertStatementHandler.getWhereSegment(sqlStatement).ifPresent(optional -> InsertStatementHandler.setWhereSegment(result,
161                 WhereSegmentBinder.bind(optional, insertStatementBinderContext, targetTableBinderContexts, sourceTableBinderContexts)));
162         result.addParameterMarkerSegments(ParameterMarkerSegmentBinder.bind(sqlStatement.getParameterMarkerSegments(), parameterMarkerSegmentBoundedInfos));
163         result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
164         return result;
165     }
166     
167     private Collection<ColumnSegment> getVisibleColumns(final Collection<ProjectionSegment> projectionSegments) {
168         Collection<ColumnSegment> result = new LinkedList<>();
169         for (ProjectionSegment each : projectionSegments) {
170             if (each instanceof ColumnProjectionSegment && each.isVisible()) {
171                 result.add(((ColumnProjectionSegment) each).getColumn());
172             }
173         }
174         return result;
175     }
176     
177     @SneakyThrows
178     private UpdateStatement bindMergeUpdate(final UpdateStatement sqlStatement, final SimpleTableSegment tableSegment, final SQLStatementBinderContext statementBinderContext,
179                                             final Map<String, TableSegmentBinderContext> targetTableBinderContexts, final Map<String, TableSegmentBinderContext> sourceTableBinderContexts) {
180         UpdateStatement result = sqlStatement.getClass().getDeclaredConstructor().newInstance();
181         result.setTable(tableSegment);
182         Collection<ColumnAssignmentSegment> assignments = new LinkedList<>();
183         SQLStatementBinderContext updateStatementBinderContext = new SQLStatementBinderContext(statementBinderContext.getMetaData(), statementBinderContext.getDefaultDatabaseName(),
184                 statementBinderContext.getDatabaseType(), statementBinderContext.getVariableNames());
185         updateStatementBinderContext.getExternalTableBinderContexts().putAll(statementBinderContext.getExternalTableBinderContexts());
186         updateStatementBinderContext.getExternalTableBinderContexts().putAll(sourceTableBinderContexts);
187         Map<ParameterMarkerSegment, ColumnSegmentBoundedInfo> parameterMarkerSegmentBoundedInfos = new LinkedHashMap<>();
188         for (ColumnAssignmentSegment each : sqlStatement.getSetAssignment().getAssignments()) {
189             List<ColumnSegment> columnSegments = new ArrayList<>(each.getColumns().size());
190             each.getColumns().forEach(column -> columnSegments.add(
191                     ColumnSegmentBinder.bind(column, SegmentType.SET_ASSIGNMENT, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap())));
192             ExpressionSegment expression = ExpressionSegmentBinder.bind(each.getValue(), SegmentType.SET_ASSIGNMENT, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap());
193             ColumnAssignmentSegment columnAssignmentSegment = new ColumnAssignmentSegment(each.getStartIndex(), each.getStopIndex(), columnSegments, expression);
194             assignments.add(columnAssignmentSegment);
195             if (expression instanceof ParameterMarkerSegment) {
196                 parameterMarkerSegmentBoundedInfos.put((ParameterMarkerSegment) expression, columnAssignmentSegment.getColumns().get(0).getColumnBoundedInfo());
197             }
198         }
199         SetAssignmentSegment setAssignmentSegment = new SetAssignmentSegment(sqlStatement.getSetAssignment().getStartIndex(), sqlStatement.getSetAssignment().getStopIndex(), assignments);
200         result.setSetAssignment(setAssignmentSegment);
201         sqlStatement.getWhere().ifPresent(optional -> result.setWhere(WhereSegmentBinder.bind(optional, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap())));
202         UpdateStatementHandler.getDeleteWhereSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setDeleteWhereSegment(result,
203                 WhereSegmentBinder.bind(optional, updateStatementBinderContext, targetTableBinderContexts, Collections.emptyMap())));
204         UpdateStatementHandler.getOrderBySegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setOrderBySegment(result, optional));
205         UpdateStatementHandler.getLimitSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setLimitSegment(result, optional));
206         UpdateStatementHandler.getWithSegment(sqlStatement).ifPresent(optional -> UpdateStatementHandler.setWithSegment(result, optional));
207         result.addParameterMarkerSegments(ParameterMarkerSegmentBinder.bind(sqlStatement.getParameterMarkerSegments(), parameterMarkerSegmentBoundedInfos));
208         result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
209         return result;
210     }
211 }