1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.sharding.route.engine.condition.engine;
19
20 import com.google.common.collect.Range;
21 import lombok.RequiredArgsConstructor;
22 import org.apache.shardingsphere.infra.binder.context.extractor.SQLStatementContextExtractor;
23 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
24 import org.apache.shardingsphere.infra.binder.context.available.WhereContextAvailable;
25 import org.apache.shardingsphere.sharding.exception.data.ShardingValueDataTypeException;
26 import org.apache.shardingsphere.sharding.route.engine.condition.AlwaysFalseShardingCondition;
27 import org.apache.shardingsphere.sharding.route.engine.condition.Column;
28 import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
29 import org.apache.shardingsphere.sharding.route.engine.condition.generator.ConditionValueGeneratorFactory;
30 import org.apache.shardingsphere.sharding.route.engine.condition.value.AlwaysFalseShardingConditionValue;
31 import org.apache.shardingsphere.sharding.route.engine.condition.value.ListShardingConditionValue;
32 import org.apache.shardingsphere.sharding.route.engine.condition.value.RangeShardingConditionValue;
33 import org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;
34 import org.apache.shardingsphere.sharding.rule.ShardingRule;
35 import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
36 import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
37 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
38 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
39 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate;
40 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
41 import org.apache.shardingsphere.sql.parser.statement.core.util.SafeNumberOperationUtils;
42 import org.apache.shardingsphere.timeservice.core.rule.TimestampServiceRule;
43
44 import java.util.ArrayList;
45 import java.util.Collection;
46 import java.util.Collections;
47 import java.util.HashMap;
48 import java.util.HashSet;
49 import java.util.LinkedList;
50 import java.util.List;
51 import java.util.Map;
52 import java.util.Map.Entry;
53 import java.util.Optional;
54 import java.util.Set;
55
56
57
58
59 @RequiredArgsConstructor
60 public final class WhereClauseShardingConditionEngine {
61
62 private final ShardingRule rule;
63
64 private final TimestampServiceRule timestampServiceRule;
65
66
67
68
69
70
71
72
73 public List<ShardingCondition> createShardingConditions(final SQLStatementContext sqlStatementContext, final List<Object> params) {
74 if (!(sqlStatementContext instanceof WhereContextAvailable)) {
75 return Collections.emptyList();
76 }
77 List<ShardingCondition> result = new ArrayList<>();
78 for (WhereSegment each : SQLStatementContextExtractor.getAllWhereSegments(sqlStatementContext)) {
79 result.addAll(createShardingConditions(each.getExpr(), params));
80 }
81 return result;
82 }
83
84 private Collection<ShardingCondition> createShardingConditions(final ExpressionSegment expression, final List<Object> params) {
85 Collection<AndPredicate> andPredicates = ExpressionExtractor.extractAndPredicates(expression);
86 Collection<ShardingCondition> result = new LinkedList<>();
87 for (AndPredicate each : andPredicates) {
88 Map<Column, Collection<ShardingConditionValue>> shardingConditionValues = createShardingConditionValueMap(each.getPredicates(), params);
89 if (shardingConditionValues.isEmpty()) {
90 return Collections.emptyList();
91 }
92 ShardingCondition shardingCondition = createShardingCondition(shardingConditionValues);
93
94 shardingCondition.setStartIndex(expression.getStartIndex());
95 result.add(shardingCondition);
96 }
97 return result;
98 }
99
100 private Map<Column, Collection<ShardingConditionValue>> createShardingConditionValueMap(final Collection<ExpressionSegment> predicates, final List<Object> params) {
101 Map<Column, Collection<ShardingConditionValue>> result = new HashMap<>(predicates.size(), 1F);
102 for (ExpressionSegment each : predicates) {
103 for (ColumnSegment columnSegment : ColumnExtractor.extract(each)) {
104 String tableName = columnSegment.getColumnBoundInfo().getOriginalTable().getValue();
105 Optional<String> shardingColumn = rule.findShardingColumn(columnSegment.getColumnBoundInfo().getOriginalColumn().getValue(), tableName);
106 if (!shardingColumn.isPresent()) {
107 continue;
108 }
109 Column column = new Column(shardingColumn.get(), tableName);
110 Optional<ShardingConditionValue> shardingConditionValue = ConditionValueGeneratorFactory.generate(each, column, params, timestampServiceRule);
111 if (!shardingConditionValue.isPresent()) {
112 continue;
113 }
114 result.computeIfAbsent(column, unused -> new LinkedList<>()).add(shardingConditionValue.get());
115 }
116 }
117 return result;
118 }
119
120 private ShardingCondition createShardingCondition(final Map<Column, Collection<ShardingConditionValue>> shardingConditionValues) {
121 ShardingCondition result = new ShardingCondition();
122 for (Entry<Column, Collection<ShardingConditionValue>> entry : shardingConditionValues.entrySet()) {
123 try {
124 ShardingConditionValue shardingConditionValue = mergeShardingConditionValues(entry.getKey(), entry.getValue());
125 if (shardingConditionValue instanceof AlwaysFalseShardingConditionValue) {
126 return new AlwaysFalseShardingCondition();
127 }
128 result.getValues().add(shardingConditionValue);
129 } catch (final ClassCastException ignored) {
130 throw new ShardingValueDataTypeException(entry.getKey());
131 }
132 }
133 return result;
134 }
135
136 @SuppressWarnings({"unchecked", "rawtypes"})
137 private ShardingConditionValue mergeShardingConditionValues(final Column column, final Collection<ShardingConditionValue> shardingConditionValues) {
138 Collection<Comparable<?>> listValue = null;
139 Range<Comparable<?>> rangeValue = null;
140 Set<Integer> parameterMarkerIndexes = new HashSet<>();
141 for (ShardingConditionValue each : shardingConditionValues) {
142 parameterMarkerIndexes.addAll(each.getParameterMarkerIndexes());
143 if (each instanceof ListShardingConditionValue) {
144 listValue = mergeListShardingValues(((ListShardingConditionValue) each).getValues(), listValue);
145 if (listValue.isEmpty()) {
146 return new AlwaysFalseShardingConditionValue();
147 }
148 } else if (each instanceof RangeShardingConditionValue) {
149 try {
150 rangeValue = mergeRangeShardingValues(((RangeShardingConditionValue) each).getValueRange(), rangeValue);
151 } catch (final IllegalArgumentException ex) {
152 return new AlwaysFalseShardingConditionValue();
153 }
154 }
155 }
156 if (null == listValue) {
157 return new RangeShardingConditionValue<>(column.getName(), column.getTableName(), rangeValue, new ArrayList<>(parameterMarkerIndexes));
158 }
159 if (null == rangeValue) {
160 return new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue, new ArrayList<>(parameterMarkerIndexes));
161 }
162 listValue = mergeListAndRangeShardingValues(listValue, rangeValue);
163 return listValue.isEmpty() ? new AlwaysFalseShardingConditionValue()
164 : new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue, new ArrayList<>(parameterMarkerIndexes));
165 }
166
167 private Collection<Comparable<?>> mergeListShardingValues(final Collection<Comparable<?>> value1, final Collection<Comparable<?>> value2) {
168 if (null == value2) {
169 return value1;
170 }
171 value1.retainAll(value2);
172 return value1;
173 }
174
175 private Range<Comparable<?>> mergeRangeShardingValues(final Range<Comparable<?>> value1, final Range<Comparable<?>> value2) {
176 return null == value2 ? value1 : SafeNumberOperationUtils.safeIntersection(value1, value2);
177 }
178
179 private Collection<Comparable<?>> mergeListAndRangeShardingValues(final Collection<Comparable<?>> listValue, final Range<Comparable<?>> rangeValue) {
180 Collection<Comparable<?>> result = new LinkedList<>();
181 for (Comparable<?> each : listValue) {
182 if (SafeNumberOperationUtils.safeContains(rangeValue, each)) {
183 result.add(each);
184 }
185 }
186 return result;
187 }
188 }