1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.sql.parser.postgresql.visitor.statement.type;
19
20 import org.antlr.v4.runtime.misc.Interval;
21 import org.apache.shardingsphere.sql.parser.api.ASTNode;
22 import org.apache.shardingsphere.sql.parser.api.visitor.statement.type.DMLStatementVisitor;
23 import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.CallArgumentContext;
24 import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.CallContext;
25 import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.CheckpointContext;
26 import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.CopyContext;
27 import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.CopyWithTableBinaryContext;
28 import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.CopyWithTableOrQueryBinaryCsvContext;
29 import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.CopyWithTableOrQueryContext;
30 import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.DoStatementContext;
31 import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.PreparableStmtContext;
32 import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.ReturningClauseContext;
33 import org.apache.shardingsphere.sql.parser.postgresql.visitor.statement.PostgreSQLStatementVisitor;
34 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.ReturningSegment;
35 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
36 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
37 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonExpressionSegment;
38 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
39 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.prepare.PrepareStatementQuerySegment;
40 import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
41 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DeleteStatement;
42 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
43 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
44 import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
45 import org.apache.shardingsphere.sql.parser.sql.common.value.collection.CollectionValue;
46 import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
47 import org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLCallStatement;
48 import org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLCheckpointStatement;
49 import org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLCopyStatement;
50 import org.apache.shardingsphere.sql.parser.sql.dialect.statement.postgresql.dml.PostgreSQLDoStatement;
51
52 import java.util.Collection;
53 import java.util.LinkedList;
54
55
56
57
58 public final class PostgreSQLDMLStatementVisitor extends PostgreSQLStatementVisitor implements DMLStatementVisitor {
59
60 @Override
61 public ASTNode visitCall(final CallContext ctx) {
62 PostgreSQLCallStatement result = new PostgreSQLCallStatement();
63 result.setProcedureName(((IdentifierValue) visit(ctx.identifier())).getValue());
64 if (null != ctx.callArguments()) {
65 Collection<ExpressionSegment> params = new LinkedList<>();
66 for (CallArgumentContext each : ctx.callArguments().callArgument()) {
67 params.add((ExpressionSegment) visit(each));
68 }
69 result.getParameters().addAll(params);
70 }
71 return result;
72 }
73
74 @Override
75 public ASTNode visitCallArgument(final CallArgumentContext ctx) {
76 if (null == ctx.positionalNotation()) {
77 String text = ctx.namedNotation().start.getInputStream().getText(new Interval(ctx.namedNotation().start.getStartIndex(), ctx.namedNotation().stop.getStopIndex()));
78 return new CommonExpressionSegment(ctx.namedNotation().getStart().getStartIndex(), ctx.namedNotation().getStop().getStopIndex(), text);
79 }
80 return visit(ctx.positionalNotation().aExpr());
81 }
82
83 @Override
84 public ASTNode visitDoStatement(final DoStatementContext ctx) {
85 return new PostgreSQLDoStatement();
86 }
87
88 @Override
89 public ASTNode visitCopy(final CopyContext ctx) {
90 if (null != ctx.copyWithTableOrQuery()) {
91 return visit(ctx.copyWithTableOrQuery());
92 }
93 if (null != ctx.copyWithTableOrQueryBinaryCsv()) {
94 return visit(ctx.copyWithTableOrQueryBinaryCsv());
95 }
96 return visit(ctx.copyWithTableBinary());
97 }
98
99 @Override
100 public ASTNode visitCopyWithTableOrQuery(final CopyWithTableOrQueryContext ctx) {
101 PostgreSQLCopyStatement result = new PostgreSQLCopyStatement();
102 if (null != ctx.qualifiedName()) {
103 result.setTableSegment((SimpleTableSegment) visit(ctx.qualifiedName()));
104 if (null != ctx.columnNames()) {
105 result.getColumns().addAll(((CollectionValue<ColumnSegment>) visit(ctx.columnNames())).getValue());
106 }
107 }
108 if (null != ctx.preparableStmt()) {
109 result.setPrepareStatementQuerySegment(extractPrepareStatementQuerySegmentFromPreparableStmt(ctx.preparableStmt()));
110 }
111 return result;
112 }
113
114 private PrepareStatementQuerySegment extractPrepareStatementQuerySegmentFromPreparableStmt(final PreparableStmtContext ctx) {
115 PrepareStatementQuerySegment result = new PrepareStatementQuerySegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex());
116 if (null != ctx.select()) {
117 result.setSelect((SelectStatement) visit(ctx.select()));
118 } else if (null != ctx.insert()) {
119 result.setInsert((InsertStatement) visit(ctx.insert()));
120 } else if (null != ctx.update()) {
121 result.setUpdate((UpdateStatement) visit(ctx.update()));
122 } else {
123 result.setDelete((DeleteStatement) visit(ctx.delete()));
124 }
125 return result;
126 }
127
128 @Override
129 public ASTNode visitCopyWithTableOrQueryBinaryCsv(final CopyWithTableOrQueryBinaryCsvContext ctx) {
130 PostgreSQLCopyStatement result = new PostgreSQLCopyStatement();
131 if (null != ctx.qualifiedName()) {
132 result.setTableSegment((SimpleTableSegment) visit(ctx.qualifiedName()));
133 if (null != ctx.columnNames()) {
134 result.getColumns().addAll(((CollectionValue<ColumnSegment>) visit(ctx.columnNames())).getValue());
135 }
136 }
137 if (null != ctx.preparableStmt()) {
138 result.setPrepareStatementQuerySegment(extractPrepareStatementQuerySegmentFromPreparableStmt(ctx.preparableStmt()));
139 }
140 return result;
141 }
142
143 @Override
144 public ASTNode visitCopyWithTableBinary(final CopyWithTableBinaryContext ctx) {
145 PostgreSQLCopyStatement result = new PostgreSQLCopyStatement();
146 if (null != ctx.qualifiedName()) {
147 result.setTableSegment((SimpleTableSegment) visit(ctx.qualifiedName()));
148 }
149 return result;
150 }
151
152 @Override
153 public ASTNode visitCheckpoint(final CheckpointContext ctx) {
154 return new PostgreSQLCheckpointStatement();
155 }
156
157 @Override
158 public ASTNode visitReturningClause(final ReturningClauseContext ctx) {
159 return new ReturningSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), (ProjectionsSegment) visit(ctx.targetList()));
160 }
161 }