1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.infra.binder.context.statement.dml;
19
20 import lombok.Getter;
21 import org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
22 import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
23 import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
24 import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
25 import org.apache.shardingsphere.sql.parser.sql.common.extractor.TableExtractor;
26 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
27 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
28 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
29 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
30 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
31 import org.apache.shardingsphere.sql.parser.sql.common.util.ColumnExtractor;
32 import org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionExtractUtils;
33
34 import java.util.Collection;
35 import java.util.LinkedList;
36
37
38
39
40 @Getter
41 public final class UpdateStatementContext extends CommonSQLStatementContext implements TableAvailable, WhereAvailable {
42
43 private final TablesContext tablesContext;
44
45 private final Collection<WhereSegment> whereSegments = new LinkedList<>();
46
47 private final Collection<ColumnSegment> columnSegments = new LinkedList<>();
48
49 private final Collection<BinaryOperationExpression> joinConditions = new LinkedList<>();
50
51 public UpdateStatementContext(final UpdateStatement sqlStatement) {
52 super(sqlStatement);
53 tablesContext = new TablesContext(getAllSimpleTableSegments(), getDatabaseType());
54 getSqlStatement().getWhere().ifPresent(whereSegments::add);
55 ColumnExtractor.extractColumnSegments(columnSegments, whereSegments);
56 ExpressionExtractUtils.extractJoinConditions(joinConditions, whereSegments);
57 }
58
59 private Collection<SimpleTableSegment> getAllSimpleTableSegments() {
60 TableExtractor tableExtractor = new TableExtractor();
61 tableExtractor.extractTablesFromUpdate(getSqlStatement());
62 return tableExtractor.getRewriteTables();
63 }
64
65 @Override
66 public UpdateStatement getSqlStatement() {
67 return (UpdateStatement) super.getSqlStatement();
68 }
69
70 @Override
71 public Collection<SimpleTableSegment> getAllTables() {
72 return tablesContext.getSimpleTableSegments();
73 }
74
75 @Override
76 public Collection<WhereSegment> getWhereSegments() {
77 return whereSegments;
78 }
79
80 @Override
81 public Collection<ColumnSegment> getColumnSegments() {
82 return columnSegments;
83 }
84 }