View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.sql.parser.sql.common.util;
19  
20  import lombok.AccessLevel;
21  import lombok.NoArgsConstructor;
22  import org.apache.shardingsphere.sql.parser.sql.common.enums.LogicalOperator;
23  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
24  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
25  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BetweenExpression;
26  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
27  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CaseWhenExpression;
28  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
29  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
30  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpression;
31  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ListExpression;
32  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.NotExpression;
33  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.TypeCastExpression;
34  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression;
35  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonTableExpressionSegment;
36  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
37  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubqueryExpressionSegment;
38  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
39  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.AggregationProjectionSegment;
40  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment;
41  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.IntervalExpressionProjection;
42  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
43  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
44  import org.apache.shardingsphere.sql.parser.sql.dialect.segment.mysql.match.MatchAgainstExpression;
45  import org.apache.shardingsphere.sql.parser.sql.dialect.segment.oracle.datetime.DatetimeExpression;
46  import org.apache.shardingsphere.sql.parser.sql.dialect.segment.oracle.join.OuterJoinExpression;
47  import org.apache.shardingsphere.sql.parser.sql.dialect.segment.oracle.multiset.MultisetExpression;
48  
49  import java.util.ArrayList;
50  import java.util.Collection;
51  import java.util.Collections;
52  import java.util.LinkedList;
53  import java.util.List;
54  import java.util.Optional;
55  
56  /**
57   * Expression extract utility class.
58   */
59  @NoArgsConstructor(access = AccessLevel.PRIVATE)
60  public final class ExpressionExtractUtils {
61      
62      /**
63       * Get and predicate collection.
64       * 
65       * @param expression expression segment
66       * @return and predicate collection
67       */
68      public static Collection<AndPredicate> getAndPredicates(final ExpressionSegment expression) {
69          Collection<AndPredicate> result = new LinkedList<>();
70          extractAndPredicates(result, expression);
71          return result;
72      }
73      
74      private static void extractAndPredicates(final Collection<AndPredicate> result, final ExpressionSegment expression) {
75          if (!(expression instanceof BinaryOperationExpression)) {
76              result.add(createAndPredicate(expression));
77              return;
78          }
79          BinaryOperationExpression binaryExpression = (BinaryOperationExpression) expression;
80          Optional<LogicalOperator> logicalOperator = LogicalOperator.valueFrom(binaryExpression.getOperator());
81          if (logicalOperator.isPresent() && LogicalOperator.OR == logicalOperator.get()) {
82              extractAndPredicates(result, binaryExpression.getLeft());
83              extractAndPredicates(result, binaryExpression.getRight());
84          } else if (logicalOperator.isPresent() && LogicalOperator.AND == logicalOperator.get()) {
85              Collection<AndPredicate> predicates = getAndPredicates(binaryExpression.getRight());
86              for (AndPredicate each : getAndPredicates(binaryExpression.getLeft())) {
87                  extractCombinedAndPredicates(result, each, predicates);
88              }
89          } else {
90              result.add(createAndPredicate(expression));
91          }
92      }
93      
94      private static void extractCombinedAndPredicates(final Collection<AndPredicate> result, final AndPredicate current, final Collection<AndPredicate> predicates) {
95          for (AndPredicate each : predicates) {
96              AndPredicate predicate = new AndPredicate();
97              predicate.getPredicates().addAll(current.getPredicates());
98              predicate.getPredicates().addAll(each.getPredicates());
99              result.add(predicate);
100         }
101     }
102     
103     private static AndPredicate createAndPredicate(final ExpressionSegment expression) {
104         AndPredicate result = new AndPredicate();
105         result.getPredicates().add(expression);
106         return result;
107     }
108     
109     /**
110      * Get parameter marker expression collection.
111      * 
112      * @param expressions expression collection
113      * @return parameter marker expression collection
114      */
115     public static List<ParameterMarkerExpressionSegment> getParameterMarkerExpressions(final Collection<ExpressionSegment> expressions) {
116         List<ParameterMarkerExpressionSegment> result = new ArrayList<>();
117         extractParameterMarkerExpressions(result, expressions);
118         return result;
119     }
120     
121     private static void extractParameterMarkerExpressions(final List<ParameterMarkerExpressionSegment> result, final Collection<ExpressionSegment> expressions) {
122         for (ExpressionSegment each : expressions) {
123             if (each instanceof ParameterMarkerExpressionSegment) {
124                 result.add((ParameterMarkerExpressionSegment) each);
125             }
126             // TODO support more expression type if necessary
127             if (each instanceof BinaryOperationExpression) {
128                 extractParameterMarkerExpressions(result, Collections.singleton(((BinaryOperationExpression) each).getLeft()));
129                 extractParameterMarkerExpressions(result, Collections.singleton(((BinaryOperationExpression) each).getRight()));
130             }
131             if (each instanceof FunctionSegment) {
132                 extractParameterMarkerExpressions(result, ((FunctionSegment) each).getParameters());
133             }
134             if (each instanceof TypeCastExpression) {
135                 extractParameterMarkerExpressions(result, Collections.singleton(((TypeCastExpression) each).getExpression()));
136             }
137             if (each instanceof InExpression) {
138                 extractParameterMarkerExpressions(result, ((InExpression) each).getExpressionList());
139             }
140         }
141     }
142     
143     /**
144      * Extract join conditions.
145      * 
146      * @param joinConditions join conditions
147      * @param whereSegments where segments
148      */
149     public static void extractJoinConditions(final Collection<BinaryOperationExpression> joinConditions, final Collection<WhereSegment> whereSegments) {
150         for (WhereSegment each : whereSegments) {
151             if (each.getExpr() instanceof BinaryOperationExpression && ((BinaryOperationExpression) each.getExpr()).getLeft() instanceof ColumnSegment
152                     && ((BinaryOperationExpression) each.getExpr()).getRight() instanceof ColumnSegment) {
153                 joinConditions.add((BinaryOperationExpression) each.getExpr());
154             }
155         }
156     }
157     
158     /**
159      * Extract columns.
160      *
161      * @param expression expression
162      * @param containsSubQuery contains sub query or not
163      * @return columns
164      */
165     public static Collection<ColumnSegment> extractColumns(final ExpressionSegment expression, final boolean containsSubQuery) {
166         if (expression instanceof ColumnSegment) {
167             return Collections.singletonList((ColumnSegment) expression);
168         }
169         Collection<ColumnSegment> result = new LinkedList<>();
170         if (expression instanceof AggregationProjectionSegment) {
171             for (ExpressionSegment each : ((AggregationProjectionSegment) expression).getParameters()) {
172                 result.addAll(extractColumns(each, containsSubQuery));
173             }
174         }
175         if (expression instanceof BetweenExpression) {
176             result.addAll(extractColumns(((BetweenExpression) expression).getLeft(), containsSubQuery));
177             result.addAll(extractColumns(((BetweenExpression) expression).getBetweenExpr(), containsSubQuery));
178             result.addAll(extractColumns(((BetweenExpression) expression).getAndExpr(), containsSubQuery));
179         }
180         if (expression instanceof BinaryOperationExpression) {
181             result.addAll(extractColumns(((BinaryOperationExpression) expression).getLeft(), containsSubQuery));
182             result.addAll(extractColumns(((BinaryOperationExpression) expression).getRight(), containsSubQuery));
183         }
184         if (expression instanceof CaseWhenExpression) {
185             result.addAll(extractColumns(((CaseWhenExpression) expression).getCaseExpr(), containsSubQuery));
186             result.addAll(extractColumns(((CaseWhenExpression) expression).getElseExpr(), containsSubQuery));
187             ((CaseWhenExpression) expression).getWhenExprs().forEach(each -> result.addAll(extractColumns(each, containsSubQuery)));
188             ((CaseWhenExpression) expression).getThenExprs().forEach(each -> result.addAll(extractColumns(each, containsSubQuery)));
189         }
190         if (expression instanceof OuterJoinExpression) {
191             result.add(((OuterJoinExpression) expression).getColumnName());
192         }
193         if (expression instanceof CommonTableExpressionSegment) {
194             result.addAll(((CommonTableExpressionSegment) expression).getColumns());
195         }
196         if (expression instanceof DatetimeExpression) {
197             result.addAll(extractColumns(((DatetimeExpression) expression).getLeft(), containsSubQuery));
198             result.addAll(extractColumns(((DatetimeExpression) expression).getRight(), containsSubQuery));
199         }
200         if (expression instanceof ExpressionProjectionSegment) {
201             result.addAll(extractColumns(((ExpressionProjectionSegment) expression).getExpr(), containsSubQuery));
202         }
203         if (expression instanceof FunctionSegment) {
204             for (ExpressionSegment each : ((FunctionSegment) expression).getParameters()) {
205                 result.addAll(extractColumns(each, containsSubQuery));
206             }
207         }
208         if (expression instanceof InExpression) {
209             result.addAll(extractColumns(((InExpression) expression).getLeft(), containsSubQuery));
210             result.addAll(extractColumns(((InExpression) expression).getRight(), containsSubQuery));
211         }
212         if (expression instanceof IntervalExpressionProjection) {
213             result.addAll(extractColumns(((IntervalExpressionProjection) expression).getLeft(), containsSubQuery));
214             result.addAll(extractColumns(((IntervalExpressionProjection) expression).getRight(), containsSubQuery));
215             result.addAll(extractColumns(((IntervalExpressionProjection) expression).getMinus(), containsSubQuery));
216         }
217         if (expression instanceof ListExpression) {
218             for (ExpressionSegment each : ((ListExpression) expression).getItems()) {
219                 result.addAll(extractColumns(each, containsSubQuery));
220             }
221         }
222         if (expression instanceof MatchAgainstExpression) {
223             result.add(((MatchAgainstExpression) expression).getColumnName());
224             result.addAll(extractColumns(((MatchAgainstExpression) expression).getExpr(), containsSubQuery));
225         }
226         if (expression instanceof MultisetExpression) {
227             result.addAll(extractColumns(((MultisetExpression) expression).getLeft(), containsSubQuery));
228             result.addAll(extractColumns(((MultisetExpression) expression).getRight(), containsSubQuery));
229         }
230         if (expression instanceof NotExpression) {
231             result.addAll(extractColumns(((NotExpression) expression).getExpression(), containsSubQuery));
232         }
233         if (expression instanceof ValuesExpression) {
234             for (InsertValuesSegment each : ((ValuesExpression) expression).getRowConstructorList()) {
235                 each.getValues().forEach(value -> result.addAll(extractColumns(value, containsSubQuery)));
236             }
237         }
238         if (expression instanceof SubquerySegment && containsSubQuery) {
239             ColumnExtractor.extractFromSelectStatement(result, ((SubquerySegment) expression).getSelect(), true);
240         }
241         if (expression instanceof SubqueryExpressionSegment && containsSubQuery) {
242             ColumnExtractor.extractFromSelectStatement(result, ((SubqueryExpressionSegment) expression).getSubquery().getSelect(), true);
243         }
244         return result;
245     }
246 }