1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.sharding.route.engine.condition.engine;
19
20 import com.google.common.collect.Range;
21 import lombok.RequiredArgsConstructor;
22 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
23 import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
24 import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
25 import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
26 import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
27 import org.apache.shardingsphere.sharding.exception.data.ShardingValueDataTypeException;
28 import org.apache.shardingsphere.sharding.route.engine.condition.AlwaysFalseShardingCondition;
29 import org.apache.shardingsphere.sharding.route.engine.condition.Column;
30 import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
31 import org.apache.shardingsphere.sharding.route.engine.condition.generator.ConditionValueGeneratorFactory;
32 import org.apache.shardingsphere.sharding.route.engine.condition.value.AlwaysFalseShardingConditionValue;
33 import org.apache.shardingsphere.sharding.route.engine.condition.value.ListShardingConditionValue;
34 import org.apache.shardingsphere.sharding.route.engine.condition.value.RangeShardingConditionValue;
35 import org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;
36 import org.apache.shardingsphere.sharding.rule.ShardingRule;
37 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
38 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
39 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
40 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
41 import org.apache.shardingsphere.sql.parser.sql.common.util.ColumnExtractor;
42 import org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtils;
43 import org.apache.shardingsphere.sql.parser.sql.common.util.SafeNumberOperationUtils;
44 import org.apache.shardingsphere.timeservice.core.rule.TimestampServiceRule;
45
46 import java.util.ArrayList;
47 import java.util.Collection;
48 import java.util.Collections;
49 import java.util.HashMap;
50 import java.util.HashSet;
51 import java.util.LinkedList;
52 import java.util.List;
53 import java.util.Map;
54 import java.util.Map.Entry;
55 import java.util.Optional;
56 import java.util.Set;
57
58
59
60
61 @RequiredArgsConstructor
62 public final class WhereClauseShardingConditionEngine {
63
64 private final ShardingSphereDatabase database;
65
66 private final ShardingRule shardingRule;
67
68 private final TimestampServiceRule timestampServiceRule;
69
70
71
72
73
74
75
76
77 public List<ShardingCondition> createShardingConditions(final SQLStatementContext sqlStatementContext, final List<Object> params) {
78 if (!(sqlStatementContext instanceof WhereAvailable)) {
79 return Collections.emptyList();
80 }
81 Collection<ColumnSegment> columnSegments = ((WhereAvailable) sqlStatementContext).getColumnSegments();
82 String defaultSchemaName = new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName());
83 ShardingSphereSchema schema = sqlStatementContext.getTablesContext().getSchemaName()
84 .map(database::getSchema).orElseGet(() -> database.getSchema(defaultSchemaName));
85 Map<String, String> columnExpressionTableNames = sqlStatementContext.getTablesContext().findTableNamesByColumnSegment(columnSegments, schema);
86 List<ShardingCondition> result = new ArrayList<>();
87 for (WhereSegment each : ((WhereAvailable) sqlStatementContext).getWhereSegments()) {
88 result.addAll(createShardingConditions(each.getExpr(), params, columnExpressionTableNames));
89 }
90 return result;
91 }
92
93 private Collection<ShardingCondition> createShardingConditions(final ExpressionSegment expression, final List<Object> params, final Map<String, String> columnExpressionTableNames) {
94 Collection<AndPredicate> andPredicates = ExpressionExtractUtils.getAndPredicates(expression);
95 Collection<ShardingCondition> result = new LinkedList<>();
96 for (AndPredicate each : andPredicates) {
97 Map<Column, Collection<ShardingConditionValue>> shardingConditionValues = createShardingConditionValueMap(each.getPredicates(), params, columnExpressionTableNames);
98 if (shardingConditionValues.isEmpty()) {
99 return Collections.emptyList();
100 }
101 ShardingCondition shardingCondition = createShardingCondition(shardingConditionValues);
102
103 shardingCondition.setStartIndex(expression.getStartIndex());
104 result.add(shardingCondition);
105 }
106 return result;
107 }
108
109 private Map<Column, Collection<ShardingConditionValue>> createShardingConditionValueMap(final Collection<ExpressionSegment> predicates,
110 final List<Object> params, final Map<String, String> columnTableNames) {
111 Map<Column, Collection<ShardingConditionValue>> result = new HashMap<>(predicates.size(), 1F);
112 for (ExpressionSegment each : predicates) {
113 for (ColumnSegment columnSegment : ColumnExtractor.extract(each)) {
114 Optional<String> tableName = Optional.ofNullable(columnTableNames.get(columnSegment.getExpression()));
115 Optional<String> shardingColumn = tableName.flatMap(optional -> shardingRule.findShardingColumn(columnSegment.getIdentifier().getValue(), optional));
116 if (!tableName.isPresent() || !shardingColumn.isPresent()) {
117 continue;
118 }
119 Column column = new Column(shardingColumn.get(), tableName.get());
120 Optional<ShardingConditionValue> shardingConditionValue = ConditionValueGeneratorFactory.generate(each, column, params, timestampServiceRule);
121 if (!shardingConditionValue.isPresent()) {
122 continue;
123 }
124 result.computeIfAbsent(column, unused -> new LinkedList<>()).add(shardingConditionValue.get());
125 }
126 }
127 return result;
128 }
129
130 private ShardingCondition createShardingCondition(final Map<Column, Collection<ShardingConditionValue>> shardingConditionValues) {
131 ShardingCondition result = new ShardingCondition();
132 for (Entry<Column, Collection<ShardingConditionValue>> entry : shardingConditionValues.entrySet()) {
133 try {
134 ShardingConditionValue shardingConditionValue = mergeShardingConditionValues(entry.getKey(), entry.getValue());
135 if (shardingConditionValue instanceof AlwaysFalseShardingConditionValue) {
136 return new AlwaysFalseShardingCondition();
137 }
138 result.getValues().add(shardingConditionValue);
139 } catch (final ClassCastException ignored) {
140 throw new ShardingValueDataTypeException(entry.getKey());
141 }
142 }
143 return result;
144 }
145
146 @SuppressWarnings({"unchecked", "rawtypes"})
147 private ShardingConditionValue mergeShardingConditionValues(final Column column, final Collection<ShardingConditionValue> shardingConditionValues) {
148 Collection<Comparable<?>> listValue = null;
149 Range<Comparable<?>> rangeValue = null;
150 Set<Integer> parameterMarkerIndexes = new HashSet<>();
151 for (ShardingConditionValue each : shardingConditionValues) {
152 parameterMarkerIndexes.addAll(each.getParameterMarkerIndexes());
153 if (each instanceof ListShardingConditionValue) {
154 listValue = mergeListShardingValues(((ListShardingConditionValue) each).getValues(), listValue);
155 if (listValue.isEmpty()) {
156 return new AlwaysFalseShardingConditionValue();
157 }
158 } else if (each instanceof RangeShardingConditionValue) {
159 try {
160 rangeValue = mergeRangeShardingValues(((RangeShardingConditionValue) each).getValueRange(), rangeValue);
161 } catch (final IllegalArgumentException ex) {
162 return new AlwaysFalseShardingConditionValue();
163 }
164 }
165 }
166 if (null == listValue) {
167 return new RangeShardingConditionValue<>(column.getName(), column.getTableName(), rangeValue, new ArrayList<>(parameterMarkerIndexes));
168 }
169 if (null == rangeValue) {
170 return new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue, new ArrayList<>(parameterMarkerIndexes));
171 }
172 listValue = mergeListAndRangeShardingValues(listValue, rangeValue);
173 return listValue.isEmpty() ? new AlwaysFalseShardingConditionValue()
174 : new ListShardingConditionValue<>(column.getName(), column.getTableName(), listValue, new ArrayList<>(parameterMarkerIndexes));
175 }
176
177 private Collection<Comparable<?>> mergeListShardingValues(final Collection<Comparable<?>> value1, final Collection<Comparable<?>> value2) {
178 if (null == value2) {
179 return value1;
180 }
181 value1.retainAll(value2);
182 return value1;
183 }
184
185 private Range<Comparable<?>> mergeRangeShardingValues(final Range<Comparable<?>> value1, final Range<Comparable<?>> value2) {
186 return null == value2 ? value1 : SafeNumberOperationUtils.safeIntersection(value1, value2);
187 }
188
189 private Collection<Comparable<?>> mergeListAndRangeShardingValues(final Collection<Comparable<?>> listValue, final Range<Comparable<?>> rangeValue) {
190 Collection<Comparable<?>> result = new LinkedList<>();
191 for (Comparable<?> each : listValue) {
192 if (SafeNumberOperationUtils.safeContains(rangeValue, each)) {
193 result.add(each);
194 }
195 }
196 return result;
197 }
198 }