View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.infra.binder.context.statement.dml;
19  
20  import lombok.Getter;
21  import org.apache.shardingsphere.infra.binder.context.aware.ParameterAware;
22  import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.GeneratedKeyContext;
23  import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.engine.GeneratedKeyContextEngine;
24  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertSelectContext;
25  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
26  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.OnDuplicateUpdateContext;
27  import org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
28  import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
29  import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
30  import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
31  import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
32  import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.NoDatabaseSelectedException;
33  import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.UnknownDatabaseException;
34  import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
35  import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
36  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
37  import org.apache.shardingsphere.sql.parser.sql.common.enums.SubqueryType;
38  import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor;
39  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
40  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
41  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
42  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
43  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDuplicateKeyColumnsSegment;
44  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
45  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
46  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
47  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
48  import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
49  
50  import java.util.ArrayList;
51  import java.util.Collection;
52  import java.util.Collections;
53  import java.util.Iterator;
54  import java.util.LinkedList;
55  import java.util.List;
56  import java.util.Optional;
57  import java.util.concurrent.atomic.AtomicInteger;
58  
59  /**
60   * Insert SQL statement context.
61   */
62  @Getter
63  public final class InsertStatementContext extends CommonSQLStatementContext implements TableAvailable, ParameterAware {
64      
65      private final TablesContext tablesContext;
66      
67      private final List<String> columnNames;
68      
69      private final ShardingSphereMetaData metaData;
70      
71      private final String defaultDatabaseName;
72      
73      private final List<String> insertColumnNames;
74      
75      private final List<List<ExpressionSegment>> valueExpressions;
76      
77      private List<InsertValueContext> insertValueContexts;
78      
79      private InsertSelectContext insertSelectContext;
80      
81      private OnDuplicateUpdateContext onDuplicateKeyUpdateValueContext;
82      
83      private GeneratedKeyContext generatedKeyContext;
84      
85      public InsertStatementContext(final ShardingSphereMetaData metaData, final List<Object> params, final InsertStatement sqlStatement, final String defaultDatabaseName) {
86          super(sqlStatement);
87          this.metaData = metaData;
88          this.defaultDatabaseName = defaultDatabaseName;
89          insertColumnNames = getInsertColumnNames();
90          valueExpressions = getAllValueExpressions(sqlStatement);
91          AtomicInteger parametersOffset = new AtomicInteger(0);
92          insertValueContexts = getInsertValueContexts(params, parametersOffset, valueExpressions);
93          insertSelectContext = getInsertSelectContext(metaData, params, parametersOffset, defaultDatabaseName).orElse(null);
94          onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null);
95          tablesContext = new TablesContext(getAllSimpleTableSegments(), getDatabaseType());
96          ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName);
97          columnNames = containsInsertColumns() ? insertColumnNames
98                  : Optional.ofNullable(sqlStatement.getTable()).map(optional -> schema.getVisibleColumnNames(optional.getTableName().getIdentifier().getValue())).orElseGet(Collections::emptyList);
99          generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement, schema).createGenerateKeyContext(insertColumnNames, insertValueContexts, params).orElse(null);
100     }
101     
102     private ShardingSphereSchema getSchema(final ShardingSphereMetaData metaData, final String defaultDatabaseName) {
103         String databaseName = tablesContext.getDatabaseName().orElse(defaultDatabaseName);
104         ShardingSpherePreconditions.checkNotNull(databaseName, NoDatabaseSelectedException::new);
105         ShardingSphereDatabase database = metaData.getDatabase(databaseName);
106         ShardingSpherePreconditions.checkNotNull(database, () -> new UnknownDatabaseException(databaseName));
107         String defaultSchema = new DatabaseTypeRegistry(getDatabaseType()).getDefaultSchemaName(databaseName);
108         return tablesContext.getSchemaName().map(database::getSchema).orElseGet(() -> database.getSchema(defaultSchema));
109     }
110     
111     private Collection<SimpleTableSegment> getAllSimpleTableSegments() {
112         TableExtractor tableExtractor = new TableExtractor();
113         tableExtractor.extractTablesFromInsert(getSqlStatement());
114         return tableExtractor.getRewriteTables();
115     }
116     
117     private List<InsertValueContext> getInsertValueContexts(final List<Object> params, final AtomicInteger paramsOffset, final List<List<ExpressionSegment>> valueExpressions) {
118         List<InsertValueContext> result = new LinkedList<>();
119         for (Collection<ExpressionSegment> each : valueExpressions) {
120             InsertValueContext insertValueContext = new InsertValueContext(each, params, paramsOffset.get());
121             result.add(insertValueContext);
122             paramsOffset.addAndGet(insertValueContext.getParameterCount());
123         }
124         return result;
125     }
126     
127     private Optional<InsertSelectContext> getInsertSelectContext(final ShardingSphereMetaData metaData, final List<Object> params,
128                                                                  final AtomicInteger paramsOffset, final String defaultDatabaseName) {
129         if (!getSqlStatement().getInsertSelect().isPresent()) {
130             return Optional.empty();
131         }
132         SubquerySegment insertSelectSegment = getSqlStatement().getInsertSelect().get();
133         SelectStatementContext selectStatementContext = new SelectStatementContext(metaData, params, insertSelectSegment.getSelect(), defaultDatabaseName);
134         selectStatementContext.setSubqueryType(SubqueryType.INSERT_SELECT_SUBQUERY);
135         InsertSelectContext insertSelectContext = new InsertSelectContext(selectStatementContext, params, paramsOffset.get());
136         paramsOffset.addAndGet(insertSelectContext.getParameterCount());
137         return Optional.of(insertSelectContext);
138     }
139     
140     private Optional<OnDuplicateUpdateContext> getOnDuplicateKeyUpdateValueContext(final List<Object> params, final AtomicInteger parametersOffset) {
141         Optional<OnDuplicateKeyColumnsSegment> onDuplicateKeyColumnsSegment = InsertStatementHandler.getOnDuplicateKeyColumnsSegment(getSqlStatement());
142         if (!onDuplicateKeyColumnsSegment.isPresent()) {
143             return Optional.empty();
144         }
145         Collection<ColumnAssignmentSegment> onDuplicateKeyColumns = onDuplicateKeyColumnsSegment.get().getColumns();
146         OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(onDuplicateKeyColumns, params, parametersOffset.get());
147         parametersOffset.addAndGet(onDuplicateUpdateContext.getParameterCount());
148         return Optional.of(onDuplicateUpdateContext);
149     }
150     
151     /**
152      * Get column names for descending order.
153      *
154      * @return column names for descending order
155      */
156     public Iterator<String> getDescendingColumnNames() {
157         return new LinkedList<>(columnNames).descendingIterator();
158     }
159     
160     /**
161      * Get grouped parameters.
162      *
163      * @return grouped parameters
164      */
165     public List<List<Object>> getGroupedParameters() {
166         List<List<Object>> result = new LinkedList<>();
167         for (InsertValueContext each : insertValueContexts) {
168             result.add(each.getParameters());
169         }
170         if (null != insertSelectContext && !insertSelectContext.getParameters().isEmpty()) {
171             result.add(insertSelectContext.getParameters());
172         }
173         return result;
174     }
175     
176     /**
177      * Get on duplicate key update parameters.
178      *
179      * @return on duplicate key update parameters
180      */
181     public List<Object> getOnDuplicateKeyUpdateParameters() {
182         return null == onDuplicateKeyUpdateValueContext ? new ArrayList<>() : onDuplicateKeyUpdateValueContext.getParameters();
183     }
184     
185     /**
186      * Get generated key context.
187      *
188      * @return generated key context
189      */
190     public Optional<GeneratedKeyContext> getGeneratedKeyContext() {
191         return Optional.ofNullable(generatedKeyContext);
192     }
193     
194     /**
195      * Judge whether contains insert columns.
196      *
197      * @return contains insert columns or not
198      */
199     public boolean containsInsertColumns() {
200         InsertStatement insertStatement = getSqlStatement();
201         return !insertStatement.getColumns().isEmpty() || InsertStatementHandler.getSetAssignmentSegment(insertStatement).isPresent();
202     }
203     
204     /**
205      * Get value list count.
206      *
207      * @return value list count
208      */
209     public int getValueListCount() {
210         InsertStatement insertStatement = getSqlStatement();
211         return InsertStatementHandler.getSetAssignmentSegment(insertStatement).isPresent() ? 1 : insertStatement.getValues().size();
212     }
213     
214     /**
215      * Get insert column names.
216      *
217      * @return column names collection
218      */
219     public List<String> getInsertColumnNames() {
220         InsertStatement insertStatement = getSqlStatement();
221         return InsertStatementHandler.getSetAssignmentSegment(insertStatement).map(this::getColumnNamesForSetAssignment).orElseGet(() -> getColumnNamesForInsertColumns(insertStatement.getColumns()));
222     }
223     
224     private List<String> getColumnNamesForSetAssignment(final SetAssignmentSegment setAssignment) {
225         List<String> result = new LinkedList<>();
226         for (ColumnAssignmentSegment each : setAssignment.getAssignments()) {
227             result.add(each.getColumns().get(0).getIdentifier().getValue().toLowerCase());
228         }
229         return result;
230     }
231     
232     private List<String> getColumnNamesForInsertColumns(final Collection<ColumnSegment> columns) {
233         List<String> result = new LinkedList<>();
234         for (ColumnSegment each : columns) {
235             result.add(each.getIdentifier().getValue().toLowerCase());
236         }
237         return result;
238     }
239     
240     private List<List<ExpressionSegment>> getAllValueExpressions(final InsertStatement insertStatement) {
241         Optional<SetAssignmentSegment> setAssignment = InsertStatementHandler.getSetAssignmentSegment(insertStatement);
242         return setAssignment
243                 .map(optional -> Collections.singletonList(getAllValueExpressionsFromSetAssignment(optional))).orElseGet(() -> getAllValueExpressionsFromValues(insertStatement.getValues()));
244     }
245     
246     private List<ExpressionSegment> getAllValueExpressionsFromSetAssignment(final SetAssignmentSegment setAssignment) {
247         List<ExpressionSegment> result = new ArrayList<>(setAssignment.getAssignments().size());
248         for (ColumnAssignmentSegment each : setAssignment.getAssignments()) {
249             result.add(each.getValue());
250         }
251         return result;
252     }
253     
254     private List<List<ExpressionSegment>> getAllValueExpressionsFromValues(final Collection<InsertValuesSegment> values) {
255         List<List<ExpressionSegment>> result = new ArrayList<>(values.size());
256         for (InsertValuesSegment each : values) {
257             result.add(each.getValues());
258         }
259         return result;
260     }
261     
262     @Override
263     public InsertStatement getSqlStatement() {
264         return (InsertStatement) super.getSqlStatement();
265     }
266     
267     @Override
268     public Collection<SimpleTableSegment> getAllTables() {
269         return tablesContext.getSimpleTableSegments();
270     }
271     
272     @Override
273     public void setUpParameters(final List<Object> params) {
274         AtomicInteger parametersOffset = new AtomicInteger(0);
275         insertValueContexts = getInsertValueContexts(params, parametersOffset, valueExpressions);
276         insertSelectContext = getInsertSelectContext(metaData, params, parametersOffset, defaultDatabaseName).orElse(null);
277         onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null);
278         ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName);
279         generatedKeyContext = new GeneratedKeyContextEngine(getSqlStatement(), schema).createGenerateKeyContext(insertColumnNames, insertValueContexts, params).orElse(null);
280     }
281 }