1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.sharding.route.engine.condition.engine;
19
20 import com.cedarsoftware.util.CaseInsensitiveSet;
21 import com.google.common.collect.Range;
22 import lombok.RequiredArgsConstructor;
23 import org.apache.shardingsphere.infra.binder.context.available.WhereContextAvailable;
24 import org.apache.shardingsphere.infra.binder.context.extractor.SQLStatementContextExtractor;
25 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
26 import org.apache.shardingsphere.infra.exception.ShardingSpherePreconditions;
27 import org.apache.shardingsphere.infra.exception.kernel.metadata.SchemaNotFoundException;
28 import org.apache.shardingsphere.infra.exception.kernel.metadata.TableNotFoundException;
29 import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
30 import org.apache.shardingsphere.infra.metadata.database.schema.HashColumn;
31 import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
32 import org.apache.shardingsphere.sharding.exception.data.ShardingValueDataTypeException;
33 import org.apache.shardingsphere.sharding.route.engine.condition.AlwaysFalseShardingCondition;
34 import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
35 import org.apache.shardingsphere.sharding.route.engine.condition.generator.ConditionValueGeneratorFactory;
36 import org.apache.shardingsphere.sharding.route.engine.condition.value.AlwaysFalseShardingConditionValue;
37 import org.apache.shardingsphere.sharding.route.engine.condition.value.ListShardingConditionValue;
38 import org.apache.shardingsphere.sharding.route.engine.condition.value.RangeShardingConditionValue;
39 import org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;
40 import org.apache.shardingsphere.sharding.rule.ShardingRule;
41 import org.apache.shardingsphere.sharding.util.ShardingValueTypeConvertUtils;
42 import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
43 import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
44 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
45 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
46 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate;
47 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
48 import org.apache.shardingsphere.sql.parser.statement.core.util.SafeNumberOperationUtils;
49 import org.apache.shardingsphere.timeservice.core.rule.TimestampServiceRule;
50
51 import java.util.ArrayList;
52 import java.util.Collection;
53 import java.util.Collections;
54 import java.util.HashMap;
55 import java.util.HashSet;
56 import java.util.LinkedList;
57 import java.util.List;
58 import java.util.Map;
59 import java.util.Map.Entry;
60 import java.util.Optional;
61 import java.util.Set;
62
63
64
65
66 @RequiredArgsConstructor
67 public final class WhereClauseShardingConditionEngine {
68
69 private final ShardingSphereDatabase database;
70
71 private final ShardingRule rule;
72
73 private final TimestampServiceRule timestampServiceRule;
74
75
76
77
78
79
80
81
82 public List<ShardingCondition> createShardingConditions(final SQLStatementContext sqlStatementContext, final List<Object> params) {
83 if (!(sqlStatementContext instanceof WhereContextAvailable)) {
84 return Collections.emptyList();
85 }
86 List<ShardingCondition> result = new ArrayList<>();
87 for (WhereSegment each : SQLStatementContextExtractor.getAllWhereSegments(sqlStatementContext)) {
88 result.addAll(createShardingConditions(each.getExpr(), params));
89 }
90 return result;
91 }
92
93 private Collection<ShardingCondition> createShardingConditions(final ExpressionSegment expression, final List<Object> params) {
94 Collection<AndPredicate> andPredicates = ExpressionExtractor.extractAndPredicates(expression);
95 Collection<ShardingCondition> result = new LinkedList<>();
96 for (AndPredicate each : andPredicates) {
97 Map<HashColumn, Collection<ShardingConditionValue>> shardingConditionValues = createShardingConditionValueMap(each.getPredicates(), params);
98 if (shardingConditionValues.isEmpty()) {
99 return Collections.emptyList();
100 }
101 ShardingCondition shardingCondition = createShardingCondition(shardingConditionValues);
102
103 shardingCondition.setStartIndex(expression.getStartIndex());
104 result.add(shardingCondition);
105 }
106 return result;
107 }
108
109 private Map<HashColumn, Collection<ShardingConditionValue>> createShardingConditionValueMap(final Collection<ExpressionSegment> predicates, final List<Object> params) {
110 Map<HashColumn, Collection<ShardingConditionValue>> result = new HashMap<>(predicates.size(), 1F);
111 for (ExpressionSegment each : predicates) {
112 for (ColumnSegment columnSegment : ColumnExtractor.extract(each)) {
113 String tableName = columnSegment.getColumnBoundInfo().getOriginalTable().getValue();
114 Optional<String> shardingColumn = rule.findShardingColumn(columnSegment.getColumnBoundInfo().getOriginalColumn().getValue(), tableName);
115 if (!shardingColumn.isPresent()) {
116 continue;
117 }
118 String schemaName = columnSegment.getColumnBoundInfo().getOriginalSchema().getValue();
119 ShardingSpherePreconditions.checkState(database.containsSchema(schemaName), () -> new SchemaNotFoundException(schemaName));
120 ShardingSphereSchema schema = database.getSchema(schemaName);
121 ShardingSpherePreconditions.checkState(schema.containsTable(tableName), () -> new TableNotFoundException(tableName));
122 HashColumn column = new HashColumn(shardingColumn.get(), tableName, schema.getTable(tableName).getColumn(shardingColumn.get()).isCaseSensitive());
123 Optional<ShardingConditionValue> shardingConditionValue = ConditionValueGeneratorFactory.generate(each, column, params, timestampServiceRule);
124 if (!shardingConditionValue.isPresent()) {
125 continue;
126 }
127 result.computeIfAbsent(column, unused -> new LinkedList<>()).add(shardingConditionValue.get());
128 }
129 }
130 return result;
131 }
132
133 private ShardingCondition createShardingCondition(final Map<HashColumn, Collection<ShardingConditionValue>> shardingConditionValues) {
134 ShardingCondition result = new ShardingCondition();
135 for (Entry<HashColumn, Collection<ShardingConditionValue>> entry : shardingConditionValues.entrySet()) {
136 try {
137 ShardingConditionValue shardingConditionValue = mergeShardingConditionValues(entry.getKey(), entry.getValue());
138 if (shardingConditionValue instanceof AlwaysFalseShardingConditionValue) {
139 return new AlwaysFalseShardingCondition();
140 }
141 result.getValues().add(shardingConditionValue);
142 } catch (final ClassCastException ignored) {
143 throw new ShardingValueDataTypeException(entry.getKey());
144 }
145 }
146 return result;
147 }
148
149 @SuppressWarnings({"unchecked", "rawtypes"})
150 private ShardingConditionValue mergeShardingConditionValues(final HashColumn column, final Collection<ShardingConditionValue> shardingConditionValues) {
151 Collection<Comparable<?>> listValue = null;
152 Range<Comparable<?>> rangeValue = null;
153 Set<Integer> parameterMarkerIndexes = new HashSet<>();
154 for (ShardingConditionValue each : shardingConditionValues) {
155 parameterMarkerIndexes.addAll(each.getParameterMarkerIndexes());
156 if (each instanceof ListShardingConditionValue) {
157 listValue = mergeListShardingValues(column, ((ListShardingConditionValue) each).getValues(), listValue);
158 if (listValue.isEmpty()) {
159 return new AlwaysFalseShardingConditionValue();
160 }
161 } else if (each instanceof RangeShardingConditionValue) {
162 try {
163 rangeValue = mergeRangeShardingValues(((RangeShardingConditionValue) each).getValueRange(), rangeValue);
164 } catch (final IllegalArgumentException ex) {
165 return new AlwaysFalseShardingConditionValue();
166 }
167 }
168 }
169 if (null == listValue) {
170 return new RangeShardingConditionValue<>(column.getName(), column.getTableName(), rangeValue, new ArrayList<>(parameterMarkerIndexes));
171 }
172 if (null == rangeValue) {
173 return new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue, new ArrayList<>(parameterMarkerIndexes));
174 }
175 listValue = mergeListAndRangeShardingValues(listValue, rangeValue);
176 return listValue.isEmpty() ? new AlwaysFalseShardingConditionValue()
177 : new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue, new ArrayList<>(parameterMarkerIndexes));
178 }
179
180 private Collection<Comparable<?>> mergeListShardingValues(final HashColumn column, final Collection<Comparable<?>> value1, final Collection<Comparable<?>> value2) {
181 if (null == value2) {
182 return value1;
183 }
184 Collection<Comparable<?>> convertedValue2 = value2;
185 if (!value1.isEmpty() && !value2.isEmpty() && isDifferentType(value1, value2)) {
186 convertedValue2 = ShardingValueTypeConvertUtils.convertCollectionType(value2, value1.iterator().next().getClass());
187 }
188 if (column.isCaseSensitive()) {
189 value1.retainAll(convertedValue2);
190 return value1;
191 }
192 Collection<Comparable<?>> caseInSensitiveValue1 = new CaseInsensitiveSet<>(value1);
193 Collection<Comparable<?>> caseInSensitiveValue2 = new CaseInsensitiveSet<>(convertedValue2);
194 caseInSensitiveValue1.retainAll(caseInSensitiveValue2);
195 return caseInSensitiveValue1;
196 }
197
198 private boolean isDifferentType(final Collection<Comparable<?>> value1, final Collection<Comparable<?>> value2) {
199 return value1.iterator().next().getClass() != value2.iterator().next().getClass();
200 }
201
202 private Range<Comparable<?>> mergeRangeShardingValues(final Range<Comparable<?>> value1, final Range<Comparable<?>> value2) {
203 return null == value2 ? value1 : SafeNumberOperationUtils.safeIntersection(value1, value2);
204 }
205
206 private Collection<Comparable<?>> mergeListAndRangeShardingValues(final Collection<Comparable<?>> listValue, final Range<Comparable<?>> rangeValue) {
207 Collection<Comparable<?>> result = new LinkedList<>();
208 for (Comparable<?> each : listValue) {
209 if (SafeNumberOperationUtils.safeContains(rangeValue, each)) {
210 result.add(each);
211 }
212 }
213 return result;
214 }
215 }