1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.update;
19
20 import org.apache.calcite.sql.SqlNode;
21 import org.apache.calcite.sql.SqlNodeList;
22 import org.apache.calcite.sql.SqlOrderBy;
23 import org.apache.calcite.sql.SqlUpdate;
24 import org.apache.calcite.sql.parser.SqlParserPos;
25 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
26 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
27 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
28 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit.LimitSegment;
29 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
30 import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.UpdateStatementHandler;
31 import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.ExpressionConverter;
32 import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl.ColumnConverter;
33 import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.from.TableConverter;
34 import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.limit.PaginationValueSQLConverter;
35 import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.orderby.OrderByConverter;
36 import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.where.WhereConverter;
37 import org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.SQLStatementConverter;
38
39 import java.util.List;
40 import java.util.Optional;
41 import java.util.stream.Collectors;
42
43
44
45
46 public final class UpdateStatementConverter implements SQLStatementConverter<UpdateStatement, SqlNode> {
47
48 @Override
49 public SqlNode convert(final UpdateStatement updateStatement) {
50 SqlUpdate sqlUpdate = convertUpdate(updateStatement);
51 SqlNodeList orderBy = UpdateStatementHandler.getOrderBySegment(updateStatement).flatMap(OrderByConverter::convert).orElse(SqlNodeList.EMPTY);
52 Optional<LimitSegment> limit = UpdateStatementHandler.getLimitSegment(updateStatement);
53 if (limit.isPresent()) {
54 SqlNode offset = limit.get().getOffset().flatMap(PaginationValueSQLConverter::convert).orElse(null);
55 SqlNode rowCount = limit.get().getRowCount().flatMap(PaginationValueSQLConverter::convert).orElse(null);
56 return new SqlOrderBy(SqlParserPos.ZERO, sqlUpdate, orderBy, offset, rowCount);
57 }
58 return orderBy.isEmpty() ? sqlUpdate : new SqlOrderBy(SqlParserPos.ZERO, sqlUpdate, orderBy, null, null);
59 }
60
61 private SqlUpdate convertUpdate(final UpdateStatement updateStatement) {
62 SqlNode table = TableConverter.convert(updateStatement.getTable()).orElseThrow(IllegalStateException::new);
63 SqlNode condition = updateStatement.getWhere().flatMap(WhereConverter::convert).orElse(null);
64 SqlNodeList columns = new SqlNodeList(SqlParserPos.ZERO);
65 SqlNodeList expressions = new SqlNodeList(SqlParserPos.ZERO);
66 for (ColumnAssignmentSegment each : updateStatement.getAssignmentSegment().orElseThrow(IllegalStateException::new).getAssignments()) {
67 columns.addAll(convertColumn(each.getColumns()));
68 expressions.add(convertExpression(each.getValue()));
69 }
70 return new SqlUpdate(SqlParserPos.ZERO, table, columns, expressions, condition, null, null);
71 }
72
73 private List<SqlNode> convertColumn(final List<ColumnSegment> columnSegments) {
74 return columnSegments.stream().map(each -> ColumnConverter.convert(each).orElseThrow(IllegalStateException::new)).collect(Collectors.toList());
75 }
76
77 private SqlNode convertExpression(final ExpressionSegment expressionSegment) {
78 return ExpressionConverter.convert(expressionSegment).orElseThrow(IllegalStateException::new);
79 }
80 }