1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.shadow.route.engine.dml;
19
20 import org.apache.shardingsphere.infra.binder.context.statement.dml.UpdateStatementContext;
21 import org.apache.shardingsphere.shadow.api.shadow.ShadowOperationType;
22 import org.apache.shardingsphere.shadow.condition.ShadowColumnCondition;
23 import org.apache.shardingsphere.shadow.route.engine.util.ShadowExtractor;
24 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
25 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
26 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
27 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
28 import org.apache.shardingsphere.sql.parser.sql.common.util.ColumnExtractor;
29 import org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtils;
30
31 import java.util.Collection;
32 import java.util.LinkedList;
33 import java.util.List;
34
35
36
37
38 public final class ShadowUpdateStatementRoutingEngine extends AbstractShadowDMLStatementRouteEngine {
39
40 private final UpdateStatementContext sqlStatementContext;
41
42 private final List<Object> parameters;
43
44 public ShadowUpdateStatementRoutingEngine(final UpdateStatementContext sqlStatementContext, final List<Object> parameters) {
45 super(sqlStatementContext, ShadowOperationType.UPDATE);
46 this.sqlStatementContext = sqlStatementContext;
47 this.parameters = parameters;
48 }
49
50 @Override
51 protected Collection<ShadowColumnCondition> getShadowColumnConditions(final String shadowColumnName) {
52 Collection<ShadowColumnCondition> result = new LinkedList<>();
53 for (ExpressionSegment each : getWhereSegment()) {
54 Collection<ColumnSegment> columns = ColumnExtractor.extract(each);
55 if (1 != columns.size()) {
56 continue;
57 }
58 ShadowExtractor.extractValues(each, parameters).map(values -> new ShadowColumnCondition(getSingleTableName(), shadowColumnName, values)).ifPresent(result::add);
59 }
60 return result;
61 }
62
63 private Collection<ExpressionSegment> getWhereSegment() {
64 Collection<ExpressionSegment> result = new LinkedList<>();
65 for (WhereSegment each : sqlStatementContext.getWhereSegments()) {
66 for (AndPredicate predicate : ExpressionExtractUtils.getAndPredicates(each.getExpr())) {
67 result.addAll(predicate.getPredicates());
68 }
69 }
70 return result;
71 }
72 }