1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.infra.binder.context.statement.dml;
19
20 import lombok.Getter;
21 import org.apache.shardingsphere.infra.binder.context.aware.ParameterAware;
22 import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.GeneratedKeyContext;
23 import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.engine.GeneratedKeyContextEngine;
24 import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertSelectContext;
25 import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
26 import org.apache.shardingsphere.infra.binder.context.segment.insert.values.OnDuplicateUpdateContext;
27 import org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
28 import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
29 import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
30 import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
31 import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
32 import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.NoDatabaseSelectedException;
33 import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.UnknownDatabaseException;
34 import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
35 import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
36 import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
37 import org.apache.shardingsphere.sql.parser.sql.common.enums.SubqueryType;
38 import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor;
39 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
40 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
41 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
42 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
43 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDuplicateKeyColumnsSegment;
44 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
45 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
46 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
47 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
48 import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
49
50 import java.util.ArrayList;
51 import java.util.Collection;
52 import java.util.Collections;
53 import java.util.Iterator;
54 import java.util.LinkedList;
55 import java.util.List;
56 import java.util.Optional;
57 import java.util.concurrent.atomic.AtomicInteger;
58
59
60
61
62 @Getter
63 public final class InsertStatementContext extends CommonSQLStatementContext implements TableAvailable, ParameterAware {
64
65 private final TablesContext tablesContext;
66
67 private final List<String> columnNames;
68
69 private final ShardingSphereMetaData metaData;
70
71 private final String defaultDatabaseName;
72
73 private final List<String> insertColumnNames;
74
75 private final List<List<ExpressionSegment>> valueExpressions;
76
77 private List<InsertValueContext> insertValueContexts;
78
79 private InsertSelectContext insertSelectContext;
80
81 private OnDuplicateUpdateContext onDuplicateKeyUpdateValueContext;
82
83 private GeneratedKeyContext generatedKeyContext;
84
85 public InsertStatementContext(final ShardingSphereMetaData metaData, final List<Object> params, final InsertStatement sqlStatement, final String defaultDatabaseName) {
86 super(sqlStatement);
87 this.metaData = metaData;
88 this.defaultDatabaseName = defaultDatabaseName;
89 insertColumnNames = getInsertColumnNames();
90 valueExpressions = getAllValueExpressions(sqlStatement);
91 AtomicInteger parametersOffset = new AtomicInteger(0);
92 insertValueContexts = getInsertValueContexts(params, parametersOffset, valueExpressions);
93 insertSelectContext = getInsertSelectContext(metaData, params, parametersOffset, defaultDatabaseName).orElse(null);
94 onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null);
95 tablesContext = new TablesContext(getAllSimpleTableSegments(), getDatabaseType());
96 ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName);
97 columnNames = containsInsertColumns() ? insertColumnNames
98 : Optional.ofNullable(sqlStatement.getTable()).map(optional -> schema.getVisibleColumnNames(optional.getTableName().getIdentifier().getValue())).orElseGet(Collections::emptyList);
99 generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement, schema).createGenerateKeyContext(insertColumnNames, insertValueContexts, params).orElse(null);
100 }
101
102 private ShardingSphereSchema getSchema(final ShardingSphereMetaData metaData, final String defaultDatabaseName) {
103 String databaseName = tablesContext.getDatabaseName().orElse(defaultDatabaseName);
104 ShardingSpherePreconditions.checkNotNull(databaseName, NoDatabaseSelectedException::new);
105 ShardingSphereDatabase database = metaData.getDatabase(databaseName);
106 ShardingSpherePreconditions.checkNotNull(database, () -> new UnknownDatabaseException(databaseName));
107 String defaultSchema = new DatabaseTypeRegistry(getDatabaseType()).getDefaultSchemaName(databaseName);
108 return tablesContext.getSchemaName().map(database::getSchema).orElseGet(() -> database.getSchema(defaultSchema));
109 }
110
111 private Collection<SimpleTableSegment> getAllSimpleTableSegments() {
112 TableExtractor tableExtractor = new TableExtractor();
113 tableExtractor.extractTablesFromInsert(getSqlStatement());
114 return tableExtractor.getRewriteTables();
115 }
116
117 private List<InsertValueContext> getInsertValueContexts(final List<Object> params, final AtomicInteger paramsOffset, final List<List<ExpressionSegment>> valueExpressions) {
118 List<InsertValueContext> result = new LinkedList<>();
119 for (Collection<ExpressionSegment> each : valueExpressions) {
120 InsertValueContext insertValueContext = new InsertValueContext(each, params, paramsOffset.get());
121 result.add(insertValueContext);
122 paramsOffset.addAndGet(insertValueContext.getParameterCount());
123 }
124 return result;
125 }
126
127 private Optional<InsertSelectContext> getInsertSelectContext(final ShardingSphereMetaData metaData, final List<Object> params,
128 final AtomicInteger paramsOffset, final String defaultDatabaseName) {
129 if (!getSqlStatement().getInsertSelect().isPresent()) {
130 return Optional.empty();
131 }
132 SubquerySegment insertSelectSegment = getSqlStatement().getInsertSelect().get();
133 SelectStatementContext selectStatementContext = new SelectStatementContext(metaData, params, insertSelectSegment.getSelect(), defaultDatabaseName);
134 selectStatementContext.setSubqueryType(SubqueryType.INSERT_SELECT_SUBQUERY);
135 InsertSelectContext insertSelectContext = new InsertSelectContext(selectStatementContext, params, paramsOffset.get());
136 paramsOffset.addAndGet(insertSelectContext.getParameterCount());
137 return Optional.of(insertSelectContext);
138 }
139
140 private Optional<OnDuplicateUpdateContext> getOnDuplicateKeyUpdateValueContext(final List<Object> params, final AtomicInteger parametersOffset) {
141 Optional<OnDuplicateKeyColumnsSegment> onDuplicateKeyColumnsSegment = InsertStatementHandler.getOnDuplicateKeyColumnsSegment(getSqlStatement());
142 if (!onDuplicateKeyColumnsSegment.isPresent()) {
143 return Optional.empty();
144 }
145 Collection<ColumnAssignmentSegment> onDuplicateKeyColumns = onDuplicateKeyColumnsSegment.get().getColumns();
146 OnDuplicateUpdateContext onDuplicateUpdateContext = new OnDuplicateUpdateContext(onDuplicateKeyColumns, params, parametersOffset.get());
147 parametersOffset.addAndGet(onDuplicateUpdateContext.getParameterCount());
148 return Optional.of(onDuplicateUpdateContext);
149 }
150
151
152
153
154
155
156 public Iterator<String> getDescendingColumnNames() {
157 return new LinkedList<>(columnNames).descendingIterator();
158 }
159
160
161
162
163
164
165 public List<List<Object>> getGroupedParameters() {
166 List<List<Object>> result = new LinkedList<>();
167 for (InsertValueContext each : insertValueContexts) {
168 result.add(each.getParameters());
169 }
170 if (null != insertSelectContext && !insertSelectContext.getParameters().isEmpty()) {
171 result.add(insertSelectContext.getParameters());
172 }
173 return result;
174 }
175
176
177
178
179
180
181 public List<Object> getOnDuplicateKeyUpdateParameters() {
182 return null == onDuplicateKeyUpdateValueContext ? new ArrayList<>() : onDuplicateKeyUpdateValueContext.getParameters();
183 }
184
185
186
187
188
189
190 public Optional<GeneratedKeyContext> getGeneratedKeyContext() {
191 return Optional.ofNullable(generatedKeyContext);
192 }
193
194
195
196
197
198
199 public boolean containsInsertColumns() {
200 InsertStatement insertStatement = getSqlStatement();
201 return !insertStatement.getColumns().isEmpty() || InsertStatementHandler.getSetAssignmentSegment(insertStatement).isPresent();
202 }
203
204
205
206
207
208
209 public int getValueListCount() {
210 InsertStatement insertStatement = getSqlStatement();
211 return InsertStatementHandler.getSetAssignmentSegment(insertStatement).isPresent() ? 1 : insertStatement.getValues().size();
212 }
213
214
215
216
217
218
219 public List<String> getInsertColumnNames() {
220 InsertStatement insertStatement = getSqlStatement();
221 return InsertStatementHandler.getSetAssignmentSegment(insertStatement).map(this::getColumnNamesForSetAssignment).orElseGet(() -> getColumnNamesForInsertColumns(insertStatement.getColumns()));
222 }
223
224 private List<String> getColumnNamesForSetAssignment(final SetAssignmentSegment setAssignment) {
225 List<String> result = new LinkedList<>();
226 for (ColumnAssignmentSegment each : setAssignment.getAssignments()) {
227 result.add(each.getColumns().get(0).getIdentifier().getValue().toLowerCase());
228 }
229 return result;
230 }
231
232 private List<String> getColumnNamesForInsertColumns(final Collection<ColumnSegment> columns) {
233 List<String> result = new LinkedList<>();
234 for (ColumnSegment each : columns) {
235 result.add(each.getIdentifier().getValue().toLowerCase());
236 }
237 return result;
238 }
239
240 private List<List<ExpressionSegment>> getAllValueExpressions(final InsertStatement insertStatement) {
241 Optional<SetAssignmentSegment> setAssignment = InsertStatementHandler.getSetAssignmentSegment(insertStatement);
242 return setAssignment
243 .map(optional -> Collections.singletonList(getAllValueExpressionsFromSetAssignment(optional))).orElseGet(() -> getAllValueExpressionsFromValues(insertStatement.getValues()));
244 }
245
246 private List<ExpressionSegment> getAllValueExpressionsFromSetAssignment(final SetAssignmentSegment setAssignment) {
247 List<ExpressionSegment> result = new ArrayList<>(setAssignment.getAssignments().size());
248 for (ColumnAssignmentSegment each : setAssignment.getAssignments()) {
249 result.add(each.getValue());
250 }
251 return result;
252 }
253
254 private List<List<ExpressionSegment>> getAllValueExpressionsFromValues(final Collection<InsertValuesSegment> values) {
255 List<List<ExpressionSegment>> result = new ArrayList<>(values.size());
256 for (InsertValuesSegment each : values) {
257 result.add(each.getValues());
258 }
259 return result;
260 }
261
262 @Override
263 public InsertStatement getSqlStatement() {
264 return (InsertStatement) super.getSqlStatement();
265 }
266
267 @Override
268 public Collection<SimpleTableSegment> getAllTables() {
269 return tablesContext.getSimpleTableSegments();
270 }
271
272 @Override
273 public void setUpParameters(final List<Object> params) {
274 AtomicInteger parametersOffset = new AtomicInteger(0);
275 insertValueContexts = getInsertValueContexts(params, parametersOffset, valueExpressions);
276 insertSelectContext = getInsertSelectContext(metaData, params, parametersOffset, defaultDatabaseName).orElse(null);
277 onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null);
278 ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName);
279 generatedKeyContext = new GeneratedKeyContextEngine(getSqlStatement(), schema).createGenerateKeyContext(insertColumnNames, insertValueContexts, params).orElse(null);
280 }
281 }