View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.sharding.route.engine.condition.engine;
19  
20  import com.google.common.collect.Range;
21  import lombok.RequiredArgsConstructor;
22  import org.apache.shardingsphere.infra.binder.context.extractor.SQLStatementContextExtractor;
23  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
24  import org.apache.shardingsphere.infra.binder.context.available.WhereContextAvailable;
25  import org.apache.shardingsphere.sharding.exception.data.ShardingValueDataTypeException;
26  import org.apache.shardingsphere.sharding.route.engine.condition.AlwaysFalseShardingCondition;
27  import org.apache.shardingsphere.sharding.route.engine.condition.Column;
28  import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
29  import org.apache.shardingsphere.sharding.route.engine.condition.generator.ConditionValueGeneratorFactory;
30  import org.apache.shardingsphere.sharding.route.engine.condition.value.AlwaysFalseShardingConditionValue;
31  import org.apache.shardingsphere.sharding.route.engine.condition.value.ListShardingConditionValue;
32  import org.apache.shardingsphere.sharding.route.engine.condition.value.RangeShardingConditionValue;
33  import org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;
34  import org.apache.shardingsphere.sharding.rule.ShardingRule;
35  import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
36  import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
37  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
38  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
39  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate;
40  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
41  import org.apache.shardingsphere.sql.parser.statement.core.util.SafeNumberOperationUtils;
42  import org.apache.shardingsphere.timeservice.core.rule.TimestampServiceRule;
43  
44  import java.util.ArrayList;
45  import java.util.Collection;
46  import java.util.Collections;
47  import java.util.HashMap;
48  import java.util.HashSet;
49  import java.util.LinkedList;
50  import java.util.List;
51  import java.util.Map;
52  import java.util.Map.Entry;
53  import java.util.Optional;
54  import java.util.Set;
55  
56  /**
57   * Sharding condition engine for where clause.
58   */
59  @RequiredArgsConstructor
60  public final class WhereClauseShardingConditionEngine {
61      
62      private final ShardingRule rule;
63      
64      private final TimestampServiceRule timestampServiceRule;
65      
66      /**
67       * Create sharding conditions.
68       *
69       * @param sqlStatementContext SQL statement context
70       * @param params SQL parameters
71       * @return sharding conditions
72       */
73      public List<ShardingCondition> createShardingConditions(final SQLStatementContext sqlStatementContext, final List<Object> params) {
74          if (!(sqlStatementContext instanceof WhereContextAvailable)) {
75              return Collections.emptyList();
76          }
77          List<ShardingCondition> result = new ArrayList<>();
78          for (WhereSegment each : SQLStatementContextExtractor.getAllWhereSegments(sqlStatementContext)) {
79              result.addAll(createShardingConditions(each.getExpr(), params));
80          }
81          return result;
82      }
83      
84      private Collection<ShardingCondition> createShardingConditions(final ExpressionSegment expression, final List<Object> params) {
85          Collection<AndPredicate> andPredicates = ExpressionExtractor.extractAndPredicates(expression);
86          Collection<ShardingCondition> result = new LinkedList<>();
87          for (AndPredicate each : andPredicates) {
88              Map<Column, Collection<ShardingConditionValue>> shardingConditionValues = createShardingConditionValueMap(each.getPredicates(), params);
89              if (shardingConditionValues.isEmpty()) {
90                  return Collections.emptyList();
91              }
92              ShardingCondition shardingCondition = createShardingCondition(shardingConditionValues);
93              // TODO remove startIndex when federation has perfect support for subquery
94              shardingCondition.setStartIndex(expression.getStartIndex());
95              result.add(shardingCondition);
96          }
97          return result;
98      }
99      
100     private Map<Column, Collection<ShardingConditionValue>> createShardingConditionValueMap(final Collection<ExpressionSegment> predicates, final List<Object> params) {
101         Map<Column, Collection<ShardingConditionValue>> result = new HashMap<>(predicates.size(), 1F);
102         for (ExpressionSegment each : predicates) {
103             for (ColumnSegment columnSegment : ColumnExtractor.extract(each)) {
104                 String tableName = columnSegment.getColumnBoundInfo().getOriginalTable().getValue();
105                 Optional<String> shardingColumn = rule.findShardingColumn(columnSegment.getColumnBoundInfo().getOriginalColumn().getValue(), tableName);
106                 if (!shardingColumn.isPresent()) {
107                     continue;
108                 }
109                 Column column = new Column(shardingColumn.get(), tableName);
110                 Optional<ShardingConditionValue> shardingConditionValue = ConditionValueGeneratorFactory.generate(each, column, params, timestampServiceRule);
111                 if (!shardingConditionValue.isPresent()) {
112                     continue;
113                 }
114                 result.computeIfAbsent(column, unused -> new LinkedList<>()).add(shardingConditionValue.get());
115             }
116         }
117         return result;
118     }
119     
120     private ShardingCondition createShardingCondition(final Map<Column, Collection<ShardingConditionValue>> shardingConditionValues) {
121         ShardingCondition result = new ShardingCondition();
122         for (Entry<Column, Collection<ShardingConditionValue>> entry : shardingConditionValues.entrySet()) {
123             try {
124                 ShardingConditionValue shardingConditionValue = mergeShardingConditionValues(entry.getKey(), entry.getValue());
125                 if (shardingConditionValue instanceof AlwaysFalseShardingConditionValue) {
126                     return new AlwaysFalseShardingCondition();
127                 }
128                 result.getValues().add(shardingConditionValue);
129             } catch (final ClassCastException ignored) {
130                 throw new ShardingValueDataTypeException(entry.getKey());
131             }
132         }
133         return result;
134     }
135     
136     @SuppressWarnings({"unchecked", "rawtypes"})
137     private ShardingConditionValue mergeShardingConditionValues(final Column column, final Collection<ShardingConditionValue> shardingConditionValues) {
138         Collection<Comparable<?>> listValue = null;
139         Range<Comparable<?>> rangeValue = null;
140         Set<Integer> parameterMarkerIndexes = new HashSet<>();
141         for (ShardingConditionValue each : shardingConditionValues) {
142             parameterMarkerIndexes.addAll(each.getParameterMarkerIndexes());
143             if (each instanceof ListShardingConditionValue) {
144                 listValue = mergeListShardingValues(((ListShardingConditionValue) each).getValues(), listValue);
145                 if (listValue.isEmpty()) {
146                     return new AlwaysFalseShardingConditionValue();
147                 }
148             } else if (each instanceof RangeShardingConditionValue) {
149                 try {
150                     rangeValue = mergeRangeShardingValues(((RangeShardingConditionValue) each).getValueRange(), rangeValue);
151                 } catch (final IllegalArgumentException ex) {
152                     return new AlwaysFalseShardingConditionValue();
153                 }
154             }
155         }
156         if (null == listValue) {
157             return new RangeShardingConditionValue<>(column.getName(), column.getTableName(), rangeValue, new ArrayList<>(parameterMarkerIndexes));
158         }
159         if (null == rangeValue) {
160             return new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue, new ArrayList<>(parameterMarkerIndexes));
161         }
162         listValue = mergeListAndRangeShardingValues(listValue, rangeValue);
163         return listValue.isEmpty() ? new AlwaysFalseShardingConditionValue()
164                 : new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue, new ArrayList<>(parameterMarkerIndexes));
165     }
166     
167     private Collection<Comparable<?>> mergeListShardingValues(final Collection<Comparable<?>> value1, final Collection<Comparable<?>> value2) {
168         if (null == value2) {
169             return value1;
170         }
171         value1.retainAll(value2);
172         return value1;
173     }
174     
175     private Range<Comparable<?>> mergeRangeShardingValues(final Range<Comparable<?>> value1, final Range<Comparable<?>> value2) {
176         return null == value2 ? value1 : SafeNumberOperationUtils.safeIntersection(value1, value2);
177     }
178     
179     private Collection<Comparable<?>> mergeListAndRangeShardingValues(final Collection<Comparable<?>> listValue, final Range<Comparable<?>> rangeValue) {
180         Collection<Comparable<?>> result = new LinkedList<>();
181         for (Comparable<?> each : listValue) {
182             if (SafeNumberOperationUtils.safeContains(rangeValue, each)) {
183                 result.add(each);
184             }
185         }
186         return result;
187     }
188 }