1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.merge;
19
20 import org.apache.calcite.sql.SqlMerge;
21 import org.apache.calcite.sql.SqlNode;
22 import org.apache.calcite.sql.SqlNodeList;
23 import org.apache.calcite.sql.SqlUpdate;
24 import org.apache.calcite.sql.parser.SqlParserPos;
25 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
26 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
27 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
28 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.MergeStatement;
29 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
30 import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.ExpressionConverter;
31 import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl.ColumnConverter;
32 import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.from.TableConverter;
33 import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.where.WhereConverter;
34 import org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.SQLStatementConverter;
35
36 import java.util.List;
37 import java.util.stream.Collectors;
38
39
40
41
42 public final class MergeStatementConverter implements SQLStatementConverter<MergeStatement, SqlNode> {
43
44 @Override
45 public SqlNode convert(final MergeStatement mergeStatement) {
46 SqlNode targetTable = TableConverter.convert(mergeStatement.getTarget()).orElseThrow(IllegalStateException::new);
47 SqlNode condition = ExpressionConverter.convert(mergeStatement.getExpression().getExpr()).orElseThrow(IllegalStateException::new);
48 SqlNode sourceTable = TableConverter.convert(mergeStatement.getSource()).orElseThrow(IllegalStateException::new);
49 SqlUpdate sqlUpdate = mergeStatement.getUpdate().map(this::convertUpdate).orElse(null);
50 return new SqlMerge(SqlParserPos.ZERO, targetTable, condition, sourceTable, sqlUpdate, null, null, null);
51 }
52
53 private SqlUpdate convertUpdate(final UpdateStatement updateStatement) {
54 SqlNode table = TableConverter.convert(updateStatement.getTable()).orElse(SqlNodeList.EMPTY);
55 SqlNode condition = updateStatement.getWhere().flatMap(WhereConverter::convert).orElse(null);
56 SqlNodeList columns = new SqlNodeList(SqlParserPos.ZERO);
57 SqlNodeList expressions = new SqlNodeList(SqlParserPos.ZERO);
58 for (ColumnAssignmentSegment each : updateStatement.getAssignmentSegment().orElseThrow(IllegalStateException::new).getAssignments()) {
59 columns.addAll(convertColumn(each.getColumns()));
60 expressions.add(convertExpression(each.getValue()));
61 }
62 return new SqlUpdate(SqlParserPos.ZERO, table, columns, expressions, condition, null, null);
63 }
64
65 private List<SqlNode> convertColumn(final List<ColumnSegment> columnSegments) {
66 return columnSegments.stream().map(each -> ColumnConverter.convert(each).orElseThrow(IllegalStateException::new)).collect(Collectors.toList());
67 }
68
69 private SqlNode convertExpression(final ExpressionSegment expressionSegment) {
70 return ExpressionConverter.convert(expressionSegment).orElseThrow(IllegalStateException::new);
71 }
72 }