View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.infra.binder.context.statement.dml;
19  
20  import com.google.common.base.Preconditions;
21  import lombok.Getter;
22  import lombok.Setter;
23  import org.apache.shardingsphere.infra.binder.context.aware.ParameterAware;
24  import org.apache.shardingsphere.infra.binder.context.segment.select.groupby.GroupByContext;
25  import org.apache.shardingsphere.infra.binder.context.segment.select.groupby.engine.GroupByContextEngine;
26  import org.apache.shardingsphere.infra.binder.context.segment.select.orderby.OrderByContext;
27  import org.apache.shardingsphere.infra.binder.context.segment.select.orderby.OrderByItem;
28  import org.apache.shardingsphere.infra.binder.context.segment.select.orderby.engine.OrderByContextEngine;
29  import org.apache.shardingsphere.infra.binder.context.segment.select.pagination.PaginationContext;
30  import org.apache.shardingsphere.infra.binder.context.segment.select.pagination.engine.PaginationContextEngine;
31  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
32  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.ProjectionsContext;
33  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.engine.ProjectionsContextEngine;
34  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.AggregationDistinctProjection;
35  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.AggregationProjection;
36  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
37  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ParameterMarkerProjection;
38  import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.SubqueryProjection;
39  import org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
40  import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
41  import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
42  import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
43  import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
44  import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.NoDatabaseSelectedException;
45  import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.database.UnknownDatabaseException;
46  import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
47  import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
48  import org.apache.shardingsphere.infra.rule.attribute.table.TableMapperRuleAttribute;
49  import org.apache.shardingsphere.sql.parser.sql.common.enums.ParameterMarkerType;
50  import org.apache.shardingsphere.sql.parser.sql.common.enums.SubqueryType;
51  import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor;
52  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
53  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
54  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
55  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
56  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
57  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ColumnOrderByItemSegment;
58  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ExpressionOrderByItemSegment;
59  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.IndexOrderByItemSegment;
60  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.OrderByItemSegment;
61  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.TextOrderByItemSegment;
62  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
63  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.JoinTableSegment;
64  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
65  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
66  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
67  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
68  import org.apache.shardingsphere.sql.parser.sql.common.util.ColumnExtractor;
69  import org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtils;
70  import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils;
71  import org.apache.shardingsphere.sql.parser.sql.common.util.SubqueryExtractUtils;
72  import org.apache.shardingsphere.sql.parser.sql.common.util.WhereExtractUtils;
73  import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
74  
75  import java.util.Collection;
76  import java.util.Collections;
77  import java.util.HashMap;
78  import java.util.LinkedList;
79  import java.util.List;
80  import java.util.Map;
81  import java.util.Optional;
82  import java.util.stream.Collectors;
83  
84  /**
85   * Select SQL statement context.
86   */
87  @Getter
88  @Setter
89  public final class SelectStatementContext extends CommonSQLStatementContext implements TableAvailable, WhereAvailable, ParameterAware {
90      
91      private final TablesContext tablesContext;
92      
93      private final ProjectionsContext projectionsContext;
94      
95      private final GroupByContext groupByContext;
96      
97      private final OrderByContext orderByContext;
98      
99      private final Map<Integer, SelectStatementContext> subqueryContexts;
100     
101     private final Collection<WhereSegment> whereSegments = new LinkedList<>();
102     
103     private final Collection<ColumnSegment> columnSegments = new LinkedList<>();
104     
105     private final Collection<BinaryOperationExpression> joinConditions = new LinkedList<>();
106     
107     private final boolean containsEnhancedTable;
108     
109     private SubqueryType subqueryType;
110     
111     private boolean needAggregateRewrite;
112     
113     private PaginationContext paginationContext;
114     
115     public SelectStatementContext(final ShardingSphereMetaData metaData, final List<Object> params, final SelectStatement sqlStatement, final String defaultDatabaseName) {
116         super(sqlStatement);
117         extractWhereSegments(whereSegments, sqlStatement);
118         ColumnExtractor.extractColumnSegments(columnSegments, whereSegments);
119         ExpressionExtractUtils.extractJoinConditions(joinConditions, whereSegments);
120         subqueryContexts = createSubqueryContexts(metaData, params, defaultDatabaseName);
121         tablesContext = new TablesContext(getAllTableSegments(), subqueryContexts, getDatabaseType());
122         groupByContext = new GroupByContextEngine().createGroupByContext(sqlStatement);
123         orderByContext = new OrderByContextEngine().createOrderBy(sqlStatement, groupByContext);
124         projectionsContext = new ProjectionsContextEngine(getDatabaseType()).createProjectionsContext(getSqlStatement().getProjections(), groupByContext, orderByContext);
125         paginationContext = new PaginationContextEngine(getDatabaseType()).createPaginationContext(sqlStatement, projectionsContext, params, whereSegments);
126         String databaseName = tablesContext.getDatabaseName().orElse(defaultDatabaseName);
127         containsEnhancedTable = isContainsEnhancedTable(metaData, databaseName, getTablesContext().getTableNames());
128     }
129     
130     private boolean isContainsEnhancedTable(final ShardingSphereMetaData metaData, final String databaseName, final Collection<String> tableNames) {
131         for (TableMapperRuleAttribute each : getTableMapperRuleAttributes(metaData, databaseName)) {
132             for (String tableName : tableNames) {
133                 if (each.getEnhancedTableNames().contains(tableName)) {
134                     return true;
135                 }
136             }
137         }
138         return false;
139     }
140     
141     private Collection<TableMapperRuleAttribute> getTableMapperRuleAttributes(final ShardingSphereMetaData metaData, final String databaseName) {
142         if (null == databaseName) {
143             ShardingSpherePreconditions.checkMustEmpty(tablesContext.getSimpleTableSegments(), NoDatabaseSelectedException::new);
144             return Collections.emptyList();
145         }
146         ShardingSphereDatabase database = metaData.getDatabase(databaseName);
147         ShardingSpherePreconditions.checkNotNull(database, () -> new UnknownDatabaseException(databaseName));
148         return database.getRuleMetaData().getAttributes(TableMapperRuleAttribute.class);
149     }
150     
151     private Map<Integer, SelectStatementContext> createSubqueryContexts(final ShardingSphereMetaData metaData, final List<Object> params, final String defaultDatabaseName) {
152         Collection<SubquerySegment> subquerySegments = SubqueryExtractUtils.getSubquerySegments(getSqlStatement());
153         Map<Integer, SelectStatementContext> result = new HashMap<>(subquerySegments.size(), 1F);
154         for (SubquerySegment each : subquerySegments) {
155             SelectStatementContext subqueryContext = new SelectStatementContext(metaData, params, each.getSelect(), defaultDatabaseName);
156             subqueryContext.setSubqueryType(each.getSubqueryType());
157             result.put(each.getStartIndex(), subqueryContext);
158         }
159         return result;
160     }
161     
162     /**
163      * Judge whether contains join query or not.
164      *
165      * @return whether contains join query or not
166      */
167     public boolean isContainsJoinQuery() {
168         return getSqlStatement().getFrom().isPresent() && getSqlStatement().getFrom().get() instanceof JoinTableSegment;
169     }
170     
171     /**
172      * Judge whether contains subquery or not.
173      *
174      * @return whether contains subquery or not
175      */
176     public boolean isContainsSubquery() {
177         return !subqueryContexts.isEmpty();
178     }
179     
180     /**
181      * Judge whether contains having or not.
182      *
183      * @return whether contains having or not
184      */
185     public boolean isContainsHaving() {
186         return getSqlStatement().getHaving().isPresent();
187     }
188     
189     /**
190      * Judge whether contains combine or not.
191      *
192      * @return whether contains combine or not
193      */
194     public boolean isContainsCombine() {
195         return getSqlStatement().getCombine().isPresent();
196     }
197     
198     /**
199      * Judge whether contains dollar parameter marker or not.
200      * 
201      * @return whether contains dollar parameter marker or not
202      */
203     public boolean isContainsDollarParameterMarker() {
204         for (Projection each : projectionsContext.getProjections()) {
205             if (each instanceof ParameterMarkerProjection && ParameterMarkerType.DOLLAR == ((ParameterMarkerProjection) each).getParameterMarkerType()) {
206                 return true;
207             }
208         }
209         for (ParameterMarkerExpressionSegment each : getParameterMarkerExpressions()) {
210             if (ParameterMarkerType.DOLLAR == each.getParameterMarkerType()) {
211                 return true;
212             }
213         }
214         return false;
215     }
216     
217     private Collection<ParameterMarkerExpressionSegment> getParameterMarkerExpressions() {
218         Collection<ExpressionSegment> expressions = new LinkedList<>();
219         for (WhereSegment each : whereSegments) {
220             expressions.add(each.getExpr());
221         }
222         return ExpressionExtractUtils.getParameterMarkerExpressions(expressions);
223     }
224     
225     /**
226      * Judge whether contains partial distinct aggregation.
227      * 
228      * @return whether contains partial distinct aggregation
229      */
230     public boolean isContainsPartialDistinctAggregation() {
231         Collection<Projection> aggregationProjections = projectionsContext.getProjections().stream().filter(AggregationProjection.class::isInstance).collect(Collectors.toList());
232         Collection<AggregationDistinctProjection> aggregationDistinctProjections = projectionsContext.getAggregationDistinctProjections();
233         return aggregationProjections.size() > 1 && !aggregationDistinctProjections.isEmpty() && aggregationProjections.size() != aggregationDistinctProjections.size();
234     }
235     
236     /**
237      * Set indexes.
238      *
239      * @param columnLabelIndexMap map for column label and index
240      */
241     public void setIndexes(final Map<String, Integer> columnLabelIndexMap) {
242         setIndexForAggregationProjection(columnLabelIndexMap);
243         setIndexForOrderItem(columnLabelIndexMap, orderByContext.getItems());
244         setIndexForOrderItem(columnLabelIndexMap, groupByContext.getItems());
245     }
246     
247     private void setIndexForAggregationProjection(final Map<String, Integer> columnLabelIndexMap) {
248         for (AggregationProjection each : projectionsContext.getAggregationProjections()) {
249             String columnLabel = SQLUtils.getExactlyValue(each.getAlias().map(IdentifierValue::getValue).orElse(each.getColumnName()));
250             Preconditions.checkState(columnLabelIndexMap.containsKey(columnLabel), "Can't find index: %s, please add alias for aggregate selections", each);
251             each.setIndex(columnLabelIndexMap.get(columnLabel));
252             for (AggregationProjection derived : each.getDerivedAggregationProjections()) {
253                 String derivedColumnLabel = SQLUtils.getExactlyValue(derived.getAlias().map(IdentifierValue::getValue).orElse(each.getColumnName()));
254                 Preconditions.checkState(columnLabelIndexMap.containsKey(derivedColumnLabel), "Can't find index: %s", derived);
255                 derived.setIndex(columnLabelIndexMap.get(derivedColumnLabel));
256             }
257         }
258     }
259     
260     private void setIndexForOrderItem(final Map<String, Integer> columnLabelIndexMap, final Collection<OrderByItem> orderByItems) {
261         for (OrderByItem each : orderByItems) {
262             if (each.getSegment() instanceof IndexOrderByItemSegment) {
263                 each.setIndex(((IndexOrderByItemSegment) each.getSegment()).getColumnIndex());
264                 continue;
265             }
266             if (each.getSegment() instanceof ColumnOrderByItemSegment && ((ColumnOrderByItemSegment) each.getSegment()).getColumn().getOwner().isPresent()) {
267                 Optional<Integer> itemIndex = projectionsContext.findProjectionIndex(((ColumnOrderByItemSegment) each.getSegment()).getText());
268                 if (itemIndex.isPresent()) {
269                     each.setIndex(itemIndex.get());
270                     continue;
271                 }
272             }
273             String columnLabel = getAlias(each.getSegment()).orElseGet(() -> getOrderItemText((TextOrderByItemSegment) each.getSegment()));
274             Preconditions.checkState(columnLabelIndexMap.containsKey(columnLabel), "Can't find index: %s", each);
275             if (columnLabelIndexMap.containsKey(columnLabel)) {
276                 each.setIndex(columnLabelIndexMap.get(columnLabel));
277             }
278         }
279     }
280     
281     private Optional<String> getAlias(final OrderByItemSegment orderByItem) {
282         if (projectionsContext.isUnqualifiedShorthandProjection()) {
283             return Optional.empty();
284         }
285         String rawName = SQLUtils.getExactlyValue(((TextOrderByItemSegment) orderByItem).getText());
286         for (Projection each : projectionsContext.getProjections()) {
287             Optional<String> result = each.getAlias().map(IdentifierValue::getValue);
288             if (SQLUtils.getExactlyExpression(rawName).equalsIgnoreCase(SQLUtils.getExactlyExpression(SQLUtils.getExactlyValue(each.getExpression())))) {
289                 return result;
290             }
291             if (rawName.equalsIgnoreCase(result.orElse(null))) {
292                 return Optional.of(rawName);
293             }
294             if (isSameColumnName(each, rawName)) {
295                 return result;
296             }
297         }
298         return Optional.empty();
299     }
300     
301     private boolean isSameColumnName(final Projection projection, final String name) {
302         return projection instanceof ColumnProjection && name.equalsIgnoreCase(((ColumnProjection) projection).getName().getValue());
303     }
304     
305     private String getOrderItemText(final TextOrderByItemSegment orderByItemSegment) {
306         if (orderByItemSegment instanceof ColumnOrderByItemSegment) {
307             return SQLUtils.getExactlyValue(((ColumnOrderByItemSegment) orderByItemSegment).getColumn().getIdentifier().getValue());
308         }
309         return SQLUtils.getExactlyValue(((ExpressionOrderByItemSegment) orderByItemSegment).getExpression());
310     }
311     
312     /**
313      * Judge group by and order by sequence is same or not.
314      *
315      * @return group by and order by sequence is same or not
316      */
317     public boolean isSameGroupByAndOrderByItems() {
318         return !groupByContext.getItems().isEmpty() && groupByContext.getItems().equals(orderByContext.getItems());
319     }
320     
321     /**
322      * Find column projection.
323      * 
324      * @param columnIndex column index
325      * @return find column projection
326      */
327     public Optional<ColumnProjection> findColumnProjection(final int columnIndex) {
328         List<Projection> expandProjections = projectionsContext.getExpandProjections();
329         if (expandProjections.size() < columnIndex) {
330             return Optional.empty();
331         }
332         Projection projection = expandProjections.get(columnIndex - 1);
333         if (projection instanceof ColumnProjection) {
334             return Optional.of((ColumnProjection) projection);
335         }
336         if (projection instanceof SubqueryProjection && ((SubqueryProjection) projection).getProjection() instanceof ColumnProjection) {
337             return Optional.of((ColumnProjection) ((SubqueryProjection) projection).getProjection());
338         }
339         return Optional.empty();
340     }
341     
342     @Override
343     public SelectStatement getSqlStatement() {
344         return (SelectStatement) super.getSqlStatement();
345     }
346     
347     @Override
348     public Collection<SimpleTableSegment> getAllTables() {
349         return tablesContext.getSimpleTableSegments();
350     }
351     
352     @Override
353     public Collection<WhereSegment> getWhereSegments() {
354         return whereSegments;
355     }
356     
357     @Override
358     public Collection<ColumnSegment> getColumnSegments() {
359         return columnSegments;
360     }
361     
362     @Override
363     public Collection<BinaryOperationExpression> getJoinConditions() {
364         return joinConditions;
365     }
366     
367     private void extractWhereSegments(final Collection<WhereSegment> whereSegments, final SelectStatement selectStatement) {
368         selectStatement.getWhere().ifPresent(whereSegments::add);
369         whereSegments.addAll(WhereExtractUtils.getSubqueryWhereSegments(selectStatement));
370         whereSegments.addAll(WhereExtractUtils.getJoinWhereSegments(selectStatement));
371     }
372     
373     private Collection<TableSegment> getAllTableSegments() {
374         TableExtractor tableExtractor = new TableExtractor();
375         tableExtractor.extractTablesFromSelect(getSqlStatement());
376         Collection<TableSegment> result = new LinkedList<>(tableExtractor.getRewriteTables());
377         for (TableSegment each : tableExtractor.getTableContext()) {
378             if (each instanceof SubqueryTableSegment) {
379                 result.add(each);
380             }
381         }
382         return result;
383     }
384     
385     /**
386      * Judge whether sql statement contains table subquery segment or not.
387      *
388      * @return whether sql statement contains table subquery segment or not
389      */
390     public boolean containsTableSubquery() {
391         return getSqlStatement().getFrom().isPresent() && getSqlStatement().getFrom().get() instanceof SubqueryTableSegment;
392     }
393     
394     @Override
395     public void setUpParameters(final List<Object> params) {
396         paginationContext = new PaginationContextEngine(getDatabaseType()).createPaginationContext(getSqlStatement(), projectionsContext, params, whereSegments);
397     }
398 }