1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.sharding.rewrite.token.generator.impl;
19
20 import lombok.Setter;
21 import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
22 import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.RouteContextAware;
23 import org.apache.shardingsphere.sharding.rewrite.token.pojo.ShardingInsertValue;
24 import org.apache.shardingsphere.sharding.rewrite.token.pojo.ShardingInsertValuesToken;
25 import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
26 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
27 import org.apache.shardingsphere.infra.binder.context.statement.type.dml.InsertStatementContext;
28 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.InsertValuesSegment;
29 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
30 import org.apache.shardingsphere.infra.datanode.DataNode;
31 import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.OptionalSQLTokenGenerator;
32 import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.InsertValuesToken;
33 import org.apache.shardingsphere.infra.route.context.RouteContext;
34
35 import java.util.Collection;
36 import java.util.Collections;
37 import java.util.Iterator;
38 import java.util.List;
39
40
41
42
43 @HighFrequencyInvocation
44 @Setter
45 public final class ShardingInsertValuesTokenGenerator implements OptionalSQLTokenGenerator<InsertStatementContext>, RouteContextAware {
46
47 private RouteContext routeContext;
48
49 @Override
50 public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
51 return sqlStatementContext instanceof InsertStatementContext && !(((InsertStatementContext) sqlStatementContext).getSqlStatement()).getValues().isEmpty();
52 }
53
54 @Override
55 public InsertValuesToken generateSQLToken(final InsertStatementContext insertStatementContext) {
56 Collection<InsertValuesSegment> insertValuesSegments = insertStatementContext.getSqlStatement().getValues();
57 InsertValuesToken result = new ShardingInsertValuesToken(getStartIndex(insertValuesSegments), getStopIndex(insertValuesSegments));
58 Iterator<Collection<DataNode>> dataNodesIterator = routeContext.getOriginalDataNodes().isEmpty() ? Collections.emptyIterator() : routeContext.getOriginalDataNodes().iterator();
59 for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
60 List<ExpressionSegment> expressionSegments = each.getValueExpressions();
61 Collection<DataNode> dataNodes = dataNodesIterator.hasNext() ? dataNodesIterator.next() : Collections.emptyList();
62 result.getInsertValues().add(new ShardingInsertValue(expressionSegments, dataNodes));
63 }
64 return result;
65 }
66
67 private int getStartIndex(final Collection<InsertValuesSegment> segments) {
68 int result = segments.iterator().next().getStartIndex();
69 for (InsertValuesSegment each : segments) {
70 result = Math.min(result, each.getStartIndex());
71 }
72 return result;
73 }
74
75 private int getStopIndex(final Collection<InsertValuesSegment> segments) {
76 int result = segments.iterator().next().getStopIndex();
77 for (InsertValuesSegment each : segments) {
78 result = Math.max(result, each.getStopIndex());
79 }
80 return result;
81 }
82 }