View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.sharding.route.engine.condition.engine;
19  
20  import com.google.common.collect.Range;
21  import lombok.RequiredArgsConstructor;
22  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
23  import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
24  import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
25  import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
26  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
27  import org.apache.shardingsphere.sharding.exception.data.ShardingValueDataTypeException;
28  import org.apache.shardingsphere.sharding.route.engine.condition.AlwaysFalseShardingCondition;
29  import org.apache.shardingsphere.sharding.route.engine.condition.Column;
30  import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
31  import org.apache.shardingsphere.sharding.route.engine.condition.generator.ConditionValueGeneratorFactory;
32  import org.apache.shardingsphere.sharding.route.engine.condition.value.AlwaysFalseShardingConditionValue;
33  import org.apache.shardingsphere.sharding.route.engine.condition.value.ListShardingConditionValue;
34  import org.apache.shardingsphere.sharding.route.engine.condition.value.RangeShardingConditionValue;
35  import org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;
36  import org.apache.shardingsphere.sharding.rule.ShardingRule;
37  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
38  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
39  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
40  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
41  import org.apache.shardingsphere.sql.parser.sql.common.util.ColumnExtractor;
42  import org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtils;
43  import org.apache.shardingsphere.sql.parser.sql.common.util.SafeNumberOperationUtils;
44  import org.apache.shardingsphere.timeservice.core.rule.TimestampServiceRule;
45  
46  import java.util.ArrayList;
47  import java.util.Collection;
48  import java.util.Collections;
49  import java.util.HashMap;
50  import java.util.HashSet;
51  import java.util.LinkedList;
52  import java.util.List;
53  import java.util.Map;
54  import java.util.Map.Entry;
55  import java.util.Optional;
56  import java.util.Set;
57  
58  /**
59   * Sharding condition engine for where clause.
60   */
61  @RequiredArgsConstructor
62  public final class WhereClauseShardingConditionEngine {
63      
64      private final ShardingSphereDatabase database;
65      
66      private final ShardingRule shardingRule;
67      
68      private final TimestampServiceRule timestampServiceRule;
69      
70      /**
71       * Create sharding conditions.
72       *
73       * @param sqlStatementContext SQL statement context
74       * @param params SQL parameters
75       * @return sharding conditions
76       */
77      public List<ShardingCondition> createShardingConditions(final SQLStatementContext sqlStatementContext, final List<Object> params) {
78          if (!(sqlStatementContext instanceof WhereAvailable)) {
79              return Collections.emptyList();
80          }
81          Collection<ColumnSegment> columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
82          String defaultSchemaName = new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName());
83          ShardingSphereSchema schema = sqlStatementContext.getTablesContext().getSchemaName()
84                  .map(database::getSchema).orElseGet(() -> database.getSchema(defaultSchemaName));
85          Map<String, String> columnExpressionTableNames = sqlStatementContext.getTablesContext().findTableNamesByColumnSegment(columnSegments, schema);
86          List<ShardingCondition> result = new ArrayList<>();
87          for (WhereSegment each : ((WhereAvailable) sqlStatementContext).getWhereSegments()) {
88              result.addAll(createShardingConditions(each.getExpr(), params, columnExpressionTableNames));
89          }
90          return result;
91      }
92      
93      private Collection<ShardingCondition> createShardingConditions(final ExpressionSegment expression, final List<Object> params, final Map<String, String> columnExpressionTableNames) {
94          Collection<AndPredicate> andPredicates = ExpressionExtractUtils.getAndPredicates(expression);
95          Collection<ShardingCondition> result = new LinkedList<>();
96          for (AndPredicate each : andPredicates) {
97              Map<Column, Collection<ShardingConditionValue>> shardingConditionValues = createShardingConditionValueMap(each.getPredicates(), params, columnExpressionTableNames);
98              if (shardingConditionValues.isEmpty()) {
99                  return Collections.emptyList();
100             }
101             ShardingCondition shardingCondition = createShardingCondition(shardingConditionValues);
102             // TODO remove startIndex when federation has perfect support for subquery
103             shardingCondition.setStartIndex(expression.getStartIndex());
104             result.add(shardingCondition);
105         }
106         return result;
107     }
108     
109     private Map<Column, Collection<ShardingConditionValue>> createShardingConditionValueMap(final Collection<ExpressionSegment> predicates,
110                                                                                             final List<Object> params, final Map<String, String> columnTableNames) {
111         Map<Column, Collection<ShardingConditionValue>> result = new HashMap<>(predicates.size(), 1F);
112         for (ExpressionSegment each : predicates) {
113             for (ColumnSegment columnSegment : ColumnExtractor.extract(each)) {
114                 Optional<String> tableName = Optional.ofNullable(columnTableNames.get(columnSegment.getExpression()));
115                 Optional<String> shardingColumn = tableName.flatMap(optional -> shardingRule.findShardingColumn(columnSegment.getIdentifier().getValue(), optional));
116                 if (!tableName.isPresent() || !shardingColumn.isPresent()) {
117                     continue;
118                 }
119                 Column column = new Column(shardingColumn.get(), tableName.get());
120                 Optional<ShardingConditionValue> shardingConditionValue = ConditionValueGeneratorFactory.generate(each, column, params, timestampServiceRule);
121                 if (!shardingConditionValue.isPresent()) {
122                     continue;
123                 }
124                 result.computeIfAbsent(column, unused -> new LinkedList<>()).add(shardingConditionValue.get());
125             }
126         }
127         return result;
128     }
129     
130     private ShardingCondition createShardingCondition(final Map<Column, Collection<ShardingConditionValue>> shardingConditionValues) {
131         ShardingCondition result = new ShardingCondition();
132         for (Entry<Column, Collection<ShardingConditionValue>> entry : shardingConditionValues.entrySet()) {
133             try {
134                 ShardingConditionValue shardingConditionValue = mergeShardingConditionValues(entry.getKey(), entry.getValue());
135                 if (shardingConditionValue instanceof AlwaysFalseShardingConditionValue) {
136                     return new AlwaysFalseShardingCondition();
137                 }
138                 result.getValues().add(shardingConditionValue);
139             } catch (final ClassCastException ignored) {
140                 throw new ShardingValueDataTypeException(entry.getKey());
141             }
142         }
143         return result;
144     }
145     
146     @SuppressWarnings({"unchecked", "rawtypes"})
147     private ShardingConditionValue mergeShardingConditionValues(final Column column, final Collection<ShardingConditionValue> shardingConditionValues) {
148         Collection<Comparable<?>> listValue = null;
149         Range<Comparable<?>> rangeValue = null;
150         Set<Integer> parameterMarkerIndexes = new HashSet<>();
151         for (ShardingConditionValue each : shardingConditionValues) {
152             parameterMarkerIndexes.addAll(each.getParameterMarkerIndexes());
153             if (each instanceof ListShardingConditionValue) {
154                 listValue = mergeListShardingValues(((ListShardingConditionValue) each).getValues(), listValue);
155                 if (listValue.isEmpty()) {
156                     return new AlwaysFalseShardingConditionValue();
157                 }
158             } else if (each instanceof RangeShardingConditionValue) {
159                 try {
160                     rangeValue = mergeRangeShardingValues(((RangeShardingConditionValue) each).getValueRange(), rangeValue);
161                 } catch (final IllegalArgumentException ex) {
162                     return new AlwaysFalseShardingConditionValue();
163                 }
164             }
165         }
166         if (null == listValue) {
167             return new RangeShardingConditionValue<>(column.getName(), column.getTableName(), rangeValue, new ArrayList<>(parameterMarkerIndexes));
168         }
169         if (null == rangeValue) {
170             return new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue, new ArrayList<>(parameterMarkerIndexes));
171         }
172         listValue = mergeListAndRangeShardingValues(listValue, rangeValue);
173         return listValue.isEmpty() ? new AlwaysFalseShardingConditionValue()
174                 : new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue, new ArrayList<>(parameterMarkerIndexes));
175     }
176     
177     private Collection<Comparable<?>> mergeListShardingValues(final Collection<Comparable<?>> value1, final Collection<Comparable<?>> value2) {
178         if (null == value2) {
179             return value1;
180         }
181         value1.retainAll(value2);
182         return value1;
183     }
184     
185     private Range<Comparable<?>> mergeRangeShardingValues(final Range<Comparable<?>> value1, final Range<Comparable<?>> value2) {
186         return null == value2 ? value1 : SafeNumberOperationUtils.safeIntersection(value1, value2);
187     }
188     
189     private Collection<Comparable<?>> mergeListAndRangeShardingValues(final Collection<Comparable<?>> listValue, final Range<Comparable<?>> rangeValue) {
190         Collection<Comparable<?>> result = new LinkedList<>();
191         for (Comparable<?> each : listValue) {
192             if (SafeNumberOperationUtils.safeContains(rangeValue, each)) {
193                 result.add(each);
194             }
195         }
196         return result;
197     }
198 }