1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.infra.binder.context.statement.dml;
19
20 import com.google.common.base.Preconditions;
21 import lombok.Getter;
22 import lombok.Setter;
23 import org.apache.shardingsphere.infra.binder.context.aware.ParameterAware;
24 import org.apache.shardingsphere.infra.binder.context.segment.select.groupby.GroupByContext;
25 import org.apache.shardingsphere.infra.binder.context.segment.select.groupby.engine.GroupByContextEngine;
26 import org.apache.shardingsphere.infra.binder.context.segment.select.orderby.OrderByContext;
27 import org.apache.shardingsphere.infra.binder.context.segment.select.orderby.OrderByItem;
28 import org.apache.shardingsphere.infra.binder.context.segment.select.orderby.engine.OrderByContextEngine;
29 import org.apache.shardingsphere.infra.binder.context.segment.select.pagination.PaginationContext;
30 import org.apache.shardingsphere.infra.binder.context.segment.select.pagination.engine.PaginationContextEngine;
31 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
32 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.ProjectionsContext;
33 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.engine.ProjectionsContextEngine;
34 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.AggregationDistinctProjection;
35 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.AggregationProjection;
36 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
37 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ParameterMarkerProjection;
38 import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.SubqueryProjection;
39 import org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
40 import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
41 import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
42 import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
43 import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
44 import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.NoDatabaseSelectedException;
45 import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.UnknownDatabaseException;
46 import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
47 import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
48 import org.apache.shardingsphere.infra.rule.attribute.table.TableMapperRuleAttribute;
49 import org.apache.shardingsphere.sql.parser.sql.common.enums.ParameterMarkerType;
50 import org.apache.shardingsphere.sql.parser.sql.common.enums.SubqueryType;
51 import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor;
52 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
53 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
54 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
55 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
56 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
57 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ColumnOrderByItemSegment;
58 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ExpressionOrderByItemSegment;
59 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.IndexOrderByItemSegment;
60 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.OrderByItemSegment;
61 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.TextOrderByItemSegment;
62 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
63 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.JoinTableSegment;
64 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
65 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
66 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
67 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
68 import org.apache.shardingsphere.sql.parser.sql.common.util.ColumnExtractor;
69 import org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtils;
70 import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils;
71 import org.apache.shardingsphere.sql.parser.sql.common.util.SubqueryExtractUtils;
72 import org.apache.shardingsphere.sql.parser.sql.common.util.WhereExtractUtils;
73 import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
74
75 import java.util.Collection;
76 import java.util.Collections;
77 import java.util.HashMap;
78 import java.util.LinkedList;
79 import java.util.List;
80 import java.util.Map;
81 import java.util.Optional;
82 import java.util.stream.Collectors;
83
84
85
86
87 @Getter
88 @Setter
89 public final class SelectStatementContext extends CommonSQLStatementContext implements TableAvailable, WhereAvailable, ParameterAware {
90
91 private final TablesContext tablesContext;
92
93 private final ProjectionsContext projectionsContext;
94
95 private final GroupByContext groupByContext;
96
97 private final OrderByContext orderByContext;
98
99 private final Map<Integer, SelectStatementContext> subqueryContexts;
100
101 private final Collection<WhereSegment> whereSegments = new LinkedList<>();
102
103 private final Collection<ColumnSegment> columnSegments = new LinkedList<>();
104
105 private final Collection<BinaryOperationExpression> joinConditions = new LinkedList<>();
106
107 private final boolean containsEnhancedTable;
108
109 private SubqueryType subqueryType;
110
111 private boolean needAggregateRewrite;
112
113 private PaginationContext paginationContext;
114
115 public SelectStatementContext(final ShardingSphereMetaData metaData, final List<Object> params, final SelectStatement sqlStatement, final String defaultDatabaseName) {
116 super(sqlStatement);
117 extractWhereSegments(whereSegments, sqlStatement);
118 ColumnExtractor.extractColumnSegments(columnSegments, whereSegments);
119 ExpressionExtractUtils.extractJoinConditions(joinConditions, whereSegments);
120 subqueryContexts = createSubqueryContexts(metaData, params, defaultDatabaseName);
121 tablesContext = new TablesContext(getAllTableSegments(), subqueryContexts, getDatabaseType());
122 groupByContext = new GroupByContextEngine().createGroupByContext(sqlStatement);
123 orderByContext = new OrderByContextEngine().createOrderBy(sqlStatement, groupByContext);
124 projectionsContext = new ProjectionsContextEngine(getDatabaseType()).createProjectionsContext(getSqlStatement().getProjections(), groupByContext, orderByContext);
125 paginationContext = new PaginationContextEngine(getDatabaseType()).createPaginationContext(sqlStatement, projectionsContext, params, whereSegments);
126 String databaseName = tablesContext.getDatabaseName().orElse(defaultDatabaseName);
127 containsEnhancedTable = isContainsEnhancedTable(metaData, databaseName, getTablesContext().getTableNames());
128 }
129
130 private boolean isContainsEnhancedTable(final ShardingSphereMetaData metaData, final String databaseName, final Collection<String> tableNames) {
131 for (TableMapperRuleAttribute each : getTableMapperRuleAttributes(metaData, databaseName)) {
132 for (String tableName : tableNames) {
133 if (each.getEnhancedTableNames().contains(tableName)) {
134 return true;
135 }
136 }
137 }
138 return false;
139 }
140
141 private Collection<TableMapperRuleAttribute> getTableMapperRuleAttributes(final ShardingSphereMetaData metaData, final String databaseName) {
142 if (null == databaseName) {
143 ShardingSpherePreconditions.checkMustEmpty(tablesContext.getSimpleTableSegments(), NoDatabaseSelectedException::new);
144 return Collections.emptyList();
145 }
146 ShardingSphereDatabase database = metaData.getDatabase(databaseName);
147 ShardingSpherePreconditions.checkNotNull(database, () -> new UnknownDatabaseException(databaseName));
148 return database.getRuleMetaData().getAttributes(TableMapperRuleAttribute.class);
149 }
150
151 private Map<Integer, SelectStatementContext> createSubqueryContexts(final ShardingSphereMetaData metaData, final List<Object> params, final String defaultDatabaseName) {
152 Collection<SubquerySegment> subquerySegments = SubqueryExtractUtils.getSubquerySegments(getSqlStatement());
153 Map<Integer, SelectStatementContext> result = new HashMap<>(subquerySegments.size(), 1F);
154 for (SubquerySegment each : subquerySegments) {
155 SelectStatementContext subqueryContext = new SelectStatementContext(metaData, params, each.getSelect(), defaultDatabaseName);
156 subqueryContext.setSubqueryType(each.getSubqueryType());
157 result.put(each.getStartIndex(), subqueryContext);
158 }
159 return result;
160 }
161
162
163
164
165
166
167 public boolean isContainsJoinQuery() {
168 return getSqlStatement().getFrom().isPresent() && getSqlStatement().getFrom().get() instanceof JoinTableSegment;
169 }
170
171
172
173
174
175
176 public boolean isContainsSubquery() {
177 return !subqueryContexts.isEmpty();
178 }
179
180
181
182
183
184
185 public boolean isContainsHaving() {
186 return getSqlStatement().getHaving().isPresent();
187 }
188
189
190
191
192
193
194 public boolean isContainsCombine() {
195 return getSqlStatement().getCombine().isPresent();
196 }
197
198
199
200
201
202
203 public boolean isContainsDollarParameterMarker() {
204 for (Projection each : projectionsContext.getProjections()) {
205 if (each instanceof ParameterMarkerProjection && ParameterMarkerType.DOLLAR == ((ParameterMarkerProjection) each).getParameterMarkerType()) {
206 return true;
207 }
208 }
209 for (ParameterMarkerExpressionSegment each : getParameterMarkerExpressions()) {
210 if (ParameterMarkerType.DOLLAR == each.getParameterMarkerType()) {
211 return true;
212 }
213 }
214 return false;
215 }
216
217 private Collection<ParameterMarkerExpressionSegment> getParameterMarkerExpressions() {
218 Collection<ExpressionSegment> expressions = new LinkedList<>();
219 for (WhereSegment each : whereSegments) {
220 expressions.add(each.getExpr());
221 }
222 return ExpressionExtractUtils.getParameterMarkerExpressions(expressions);
223 }
224
225
226
227
228
229
230 public boolean isContainsPartialDistinctAggregation() {
231 Collection<Projection> aggregationProjections = projectionsContext.getProjections().stream().filter(AggregationProjection.class::isInstance).collect(Collectors.toList());
232 Collection<AggregationDistinctProjection> aggregationDistinctProjections = projectionsContext.getAggregationDistinctProjections();
233 return aggregationProjections.size() > 1 && !aggregationDistinctProjections.isEmpty() && aggregationProjections.size() != aggregationDistinctProjections.size();
234 }
235
236
237
238
239
240
241 public void setIndexes(final Map<String, Integer> columnLabelIndexMap) {
242 setIndexForAggregationProjection(columnLabelIndexMap);
243 setIndexForOrderItem(columnLabelIndexMap, orderByContext.getItems());
244 setIndexForOrderItem(columnLabelIndexMap, groupByContext.getItems());
245 }
246
247 private void setIndexForAggregationProjection(final Map<String, Integer> columnLabelIndexMap) {
248 for (AggregationProjection each : projectionsContext.getAggregationProjections()) {
249 String columnLabel = SQLUtils.getExactlyValue(each.getAlias().map(IdentifierValue::getValue).orElse(each.getColumnName()));
250 Preconditions.checkState(columnLabelIndexMap.containsKey(columnLabel), "Can't find index: %s, please add alias for aggregate selections", each);
251 each.setIndex(columnLabelIndexMap.get(columnLabel));
252 for (AggregationProjection derived : each.getDerivedAggregationProjections()) {
253 String derivedColumnLabel = SQLUtils.getExactlyValue(derived.getAlias().map(IdentifierValue::getValue).orElse(each.getColumnName()));
254 Preconditions.checkState(columnLabelIndexMap.containsKey(derivedColumnLabel), "Can't find index: %s", derived);
255 derived.setIndex(columnLabelIndexMap.get(derivedColumnLabel));
256 }
257 }
258 }
259
260 private void setIndexForOrderItem(final Map<String, Integer> columnLabelIndexMap, final Collection<OrderByItem> orderByItems) {
261 for (OrderByItem each : orderByItems) {
262 if (each.getSegment() instanceof IndexOrderByItemSegment) {
263 each.setIndex(((IndexOrderByItemSegment) each.getSegment()).getColumnIndex());
264 continue;
265 }
266 if (each.getSegment() instanceof ColumnOrderByItemSegment && ((ColumnOrderByItemSegment) each.getSegment()).getColumn().getOwner().isPresent()) {
267 Optional<Integer> itemIndex = projectionsContext.findProjectionIndex(((ColumnOrderByItemSegment) each.getSegment()).getText());
268 if (itemIndex.isPresent()) {
269 each.setIndex(itemIndex.get());
270 continue;
271 }
272 }
273 String columnLabel = getAlias(each.getSegment()).orElseGet(() -> getOrderItemText((TextOrderByItemSegment) each.getSegment()));
274 Preconditions.checkState(columnLabelIndexMap.containsKey(columnLabel), "Can't find index: %s", each);
275 if (columnLabelIndexMap.containsKey(columnLabel)) {
276 each.setIndex(columnLabelIndexMap.get(columnLabel));
277 }
278 }
279 }
280
281 private Optional<String> getAlias(final OrderByItemSegment orderByItem) {
282 if (projectionsContext.isUnqualifiedShorthandProjection()) {
283 return Optional.empty();
284 }
285 String rawName = SQLUtils.getExactlyValue(((TextOrderByItemSegment) orderByItem).getText());
286 for (Projection each : projectionsContext.getProjections()) {
287 Optional<String> result = each.getAlias().map(IdentifierValue::getValue);
288 if (SQLUtils.getExactlyExpression(rawName).equalsIgnoreCase(SQLUtils.getExactlyExpression(SQLUtils.getExactlyValue(each.getExpression())))) {
289 return result;
290 }
291 if (rawName.equalsIgnoreCase(result.orElse(null))) {
292 return Optional.of(rawName);
293 }
294 if (isSameColumnName(each, rawName)) {
295 return result;
296 }
297 }
298 return Optional.empty();
299 }
300
301 private boolean isSameColumnName(final Projection projection, final String name) {
302 return projection instanceof ColumnProjection && name.equalsIgnoreCase(((ColumnProjection) projection).getName().getValue());
303 }
304
305 private String getOrderItemText(final TextOrderByItemSegment orderByItemSegment) {
306 if (orderByItemSegment instanceof ColumnOrderByItemSegment) {
307 return SQLUtils.getExactlyValue(((ColumnOrderByItemSegment) orderByItemSegment).getColumn().getIdentifier().getValue());
308 }
309 return SQLUtils.getExactlyValue(((ExpressionOrderByItemSegment) orderByItemSegment).getExpression());
310 }
311
312
313
314
315
316
317 public boolean isSameGroupByAndOrderByItems() {
318 return !groupByContext.getItems().isEmpty() && groupByContext.getItems().equals(orderByContext.getItems());
319 }
320
321
322
323
324
325
326
327 public Optional<ColumnProjection> findColumnProjection(final int columnIndex) {
328 List<Projection> expandProjections = projectionsContext.getExpandProjections();
329 if (expandProjections.size() < columnIndex) {
330 return Optional.empty();
331 }
332 Projection projection = expandProjections.get(columnIndex - 1);
333 if (projection instanceof ColumnProjection) {
334 return Optional.of((ColumnProjection) projection);
335 }
336 if (projection instanceof SubqueryProjection && ((SubqueryProjection) projection).getProjection() instanceof ColumnProjection) {
337 return Optional.of((ColumnProjection) ((SubqueryProjection) projection).getProjection());
338 }
339 return Optional.empty();
340 }
341
342 @Override
343 public SelectStatement getSqlStatement() {
344 return (SelectStatement) super.getSqlStatement();
345 }
346
347 @Override
348 public Collection<SimpleTableSegment> getAllTables() {
349 return tablesContext.getSimpleTableSegments();
350 }
351
352 @Override
353 public Collection<WhereSegment> getWhereSegments() {
354 return whereSegments;
355 }
356
357 @Override
358 public Collection<ColumnSegment> getColumnSegments() {
359 return columnSegments;
360 }
361
362 @Override
363 public Collection<BinaryOperationExpression> getJoinConditions() {
364 return joinConditions;
365 }
366
367 private void extractWhereSegments(final Collection<WhereSegment> whereSegments, final SelectStatement selectStatement) {
368 selectStatement.getWhere().ifPresent(whereSegments::add);
369 whereSegments.addAll(WhereExtractUtils.getSubqueryWhereSegments(selectStatement));
370 whereSegments.addAll(WhereExtractUtils.getJoinWhereSegments(selectStatement));
371 }
372
373 private Collection<TableSegment> getAllTableSegments() {
374 TableExtractor tableExtractor = new TableExtractor();
375 tableExtractor.extractTablesFromSelect(getSqlStatement());
376 Collection<TableSegment> result = new LinkedList<>(tableExtractor.getRewriteTables());
377 for (TableSegment each : tableExtractor.getTableContext()) {
378 if (each instanceof SubqueryTableSegment) {
379 result.add(each);
380 }
381 }
382 return result;
383 }
384
385
386
387
388
389
390 public boolean containsTableSubquery() {
391 return getSqlStatement().getFrom().isPresent() && getSqlStatement().getFrom().get() instanceof SubqueryTableSegment;
392 }
393
394 @Override
395 public void setUpParameters(final List<Object> params) {
396 paginationContext = new PaginationContextEngine(getDatabaseType()).createPaginationContext(getSqlStatement(), projectionsContext, params, whereSegments);
397 }
398 }