View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.sql.parser.sql.common.extractor;
19  
20  import lombok.Getter;
21  import org.apache.shardingsphere.sql.parser.sql.common.segment.ddl.routine.RoutineBodySegment;
22  import org.apache.shardingsphere.sql.parser.sql.common.segment.ddl.routine.ValidStatementSegment;
23  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
24  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
25  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
26  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BetweenExpression;
27  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
28  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExistsSubqueryExpression;
29  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
30  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
31  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpression;
32  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ListExpression;
33  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubqueryExpressionSegment;
34  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.AggregationProjectionSegment;
35  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
36  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment;
37  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
38  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
39  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.SubqueryProjectionSegment;
40  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.ColumnOrderByItemSegment;
41  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.OrderByItemSegment;
42  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.LockSegment;
43  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerAvailable;
44  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
45  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.DeleteMultiTableSegment;
46  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.JoinTableSegment;
47  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
48  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SubqueryTableSegment;
49  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableNameSegment;
50  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
51  import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
52  import org.apache.shardingsphere.sql.parser.sql.common.statement.ddl.CreateTableStatement;
53  import org.apache.shardingsphere.sql.parser.sql.common.statement.ddl.CreateViewStatement;
54  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DeleteStatement;
55  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
56  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
57  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
58  import org.apache.shardingsphere.sql.parser.sql.dialect.handler.ddl.CreateTableStatementHandler;
59  import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
60  import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler;
61  import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.match.MatchAgainstExpression;
62  
63  import java.util.Collection;
64  import java.util.LinkedList;
65  import java.util.Optional;
66  
67  @Getter
68  public final class TableExtractor {
69      
70      private final Collection<SimpleTableSegment> rewriteTables = new LinkedList<>();
71      
72      private final Collection<TableSegment> tableContext = new LinkedList<>();
73      
74      private final Collection<JoinTableSegment> joinTables = new LinkedList<>();
75      
76      /**
77       * Extract table that should be rewritten from select statement.
78       *
79       * @param selectStatement select statement
80       */
81      public void extractTablesFromSelect(final SelectStatement selectStatement) {
82          if (selectStatement.getCombine().isPresent()) {
83              CombineSegment combineSegment = selectStatement.getCombine().get();
84              extractTablesFromSelect(combineSegment.getLeft().getSelect());
85              extractTablesFromSelect(combineSegment.getRight().getSelect());
86          }
87          if (selectStatement.getFrom().isPresent() && !selectStatement.getCombine().isPresent()) {
88              extractTablesFromTableSegment(selectStatement.getFrom().get());
89          }
90          if (selectStatement.getWhere().isPresent()) {
91              extractTablesFromExpression(selectStatement.getWhere().get().getExpr());
92          }
93          if (null != selectStatement.getProjections() && !selectStatement.getCombine().isPresent()) {
94              extractTablesFromProjections(selectStatement.getProjections());
95          }
96          if (selectStatement.getGroupBy().isPresent()) {
97              extractTablesFromOrderByItems(selectStatement.getGroupBy().get().getGroupByItems());
98          }
99          if (selectStatement.getOrderBy().isPresent()) {
100             extractTablesFromOrderByItems(selectStatement.getOrderBy().get().getOrderByItems());
101         }
102         Optional<LockSegment> lockSegment = SelectStatementHandler.getLockSegment(selectStatement);
103         lockSegment.ifPresent(this::extractTablesFromLock);
104     }
105     
106     private void extractTablesFromTableSegment(final TableSegment tableSegment) {
107         if (tableSegment instanceof SimpleTableSegment) {
108             tableContext.add(tableSegment);
109             rewriteTables.add((SimpleTableSegment) tableSegment);
110         }
111         if (tableSegment instanceof SubqueryTableSegment) {
112             tableContext.add(tableSegment);
113             TableExtractor tableExtractor = new TableExtractor();
114             tableExtractor.extractTablesFromSelect(((SubqueryTableSegment) tableSegment).getSubquery().getSelect());
115             rewriteTables.addAll(tableExtractor.rewriteTables);
116             joinTables.addAll(tableExtractor.joinTables);
117         }
118         if (tableSegment instanceof JoinTableSegment) {
119             joinTables.add((JoinTableSegment) tableSegment);
120             extractTablesFromJoinTableSegment((JoinTableSegment) tableSegment);
121         }
122         if (tableSegment instanceof DeleteMultiTableSegment) {
123             DeleteMultiTableSegment deleteMultiTableSegment = (DeleteMultiTableSegment) tableSegment;
124             rewriteTables.addAll(deleteMultiTableSegment.getActualDeleteTables());
125             extractTablesFromTableSegment(deleteMultiTableSegment.getRelationTable());
126         }
127     }
128     
129     private void extractTablesFromJoinTableSegment(final JoinTableSegment tableSegment) {
130         extractTablesFromTableSegment(tableSegment.getLeft());
131         extractTablesFromTableSegment(tableSegment.getRight());
132         extractTablesFromExpression(tableSegment.getCondition());
133     }
134     
135     private void extractTablesFromExpression(final ExpressionSegment expressionSegment) {
136         if (expressionSegment instanceof ColumnSegment && ((ColumnSegment) expressionSegment).getOwner().isPresent() && needRewrite(((ColumnSegment) expressionSegment).getOwner().get())) {
137             OwnerSegment ownerSegment = ((ColumnSegment) expressionSegment).getOwner().get();
138             rewriteTables.add(new SimpleTableSegment(new TableNameSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(), ownerSegment.getIdentifier())));
139         }
140         if (expressionSegment instanceof ListExpression) {
141             for (ExpressionSegment each : ((ListExpression) expressionSegment).getItems()) {
142                 extractTablesFromExpression(each);
143             }
144         }
145         if (expressionSegment instanceof ExistsSubqueryExpression) {
146             extractTablesFromSelect(((ExistsSubqueryExpression) expressionSegment).getSubquery().getSelect());
147         }
148         if (expressionSegment instanceof BetweenExpression) {
149             extractTablesFromExpression(((BetweenExpression) expressionSegment).getLeft());
150             extractTablesFromExpression(((BetweenExpression) expressionSegment).getBetweenExpr());
151             extractTablesFromExpression(((BetweenExpression) expressionSegment).getAndExpr());
152         }
153         if (expressionSegment instanceof InExpression) {
154             extractTablesFromExpression(((InExpression) expressionSegment).getLeft());
155             extractTablesFromExpression(((InExpression) expressionSegment).getRight());
156         }
157         if (expressionSegment instanceof SubqueryExpressionSegment) {
158             extractTablesFromSelect(((SubqueryExpressionSegment) expressionSegment).getSubquery().getSelect());
159         }
160         if (expressionSegment instanceof BinaryOperationExpression) {
161             extractTablesFromExpression(((BinaryOperationExpression) expressionSegment).getLeft());
162             extractTablesFromExpression(((BinaryOperationExpression) expressionSegment).getRight());
163         }
164         if (expressionSegment instanceof MatchAgainstExpression) {
165             for (ColumnSegment each : ((MatchAgainstExpression) expressionSegment).getColumns()) {
166                 extractTablesFromExpression(each);
167             }
168         }
169         if (expressionSegment instanceof FunctionSegment) {
170             for (ExpressionSegment each : ((FunctionSegment) expressionSegment).getParameters()) {
171                 extractTablesFromExpression(each);
172             }
173         }
174     }
175     
176     private void extractTablesFromProjections(final ProjectionsSegment projections) {
177         for (ProjectionSegment each : projections.getProjections()) {
178             if (each instanceof SubqueryProjectionSegment) {
179                 extractTablesFromSelect(((SubqueryProjectionSegment) each).getSubquery().getSelect());
180             } else if (each instanceof OwnerAvailable) {
181                 if (((OwnerAvailable) each).getOwner().isPresent() && needRewrite(((OwnerAvailable) each).getOwner().get())) {
182                     OwnerSegment ownerSegment = ((OwnerAvailable) each).getOwner().get();
183                     rewriteTables.add(createSimpleTableSegment(ownerSegment));
184                 }
185             } else if (each instanceof ColumnProjectionSegment) {
186                 if (((ColumnProjectionSegment) each).getColumn().getOwner().isPresent() && needRewrite(((ColumnProjectionSegment) each).getColumn().getOwner().get())) {
187                     OwnerSegment ownerSegment = ((ColumnProjectionSegment) each).getColumn().getOwner().get();
188                     rewriteTables.add(createSimpleTableSegment(ownerSegment));
189                 }
190             } else if (each instanceof AggregationProjectionSegment) {
191                 ((AggregationProjectionSegment) each).getParameters().forEach(this::extractTablesFromExpression);
192             } else if (each instanceof ExpressionProjectionSegment) {
193                 extractTablesFromExpression(((ExpressionProjectionSegment) each).getExpr());
194             }
195         }
196     }
197     
198     private SimpleTableSegment createSimpleTableSegment(final OwnerSegment ownerSegment) {
199         SimpleTableSegment result = new SimpleTableSegment(new TableNameSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(), ownerSegment.getIdentifier()));
200         ownerSegment.getOwner().ifPresent(result::setOwner);
201         return result;
202     }
203     
204     private void extractTablesFromOrderByItems(final Collection<OrderByItemSegment> orderByItems) {
205         for (OrderByItemSegment each : orderByItems) {
206             if (each instanceof ColumnOrderByItemSegment) {
207                 Optional<OwnerSegment> owner = ((ColumnOrderByItemSegment) each).getColumn().getOwner();
208                 if (owner.isPresent() && needRewrite(owner.get())) {
209                     rewriteTables.add(new SimpleTableSegment(new TableNameSegment(owner.get().getStartIndex(), owner.get().getStopIndex(), owner.get().getIdentifier())));
210                 }
211             }
212         }
213     }
214     
215     private void extractTablesFromLock(final LockSegment lockSegment) {
216         rewriteTables.addAll(lockSegment.getTables());
217     }
218     
219     /**
220      * Extract table that should be rewritten from delete statement.
221      *
222      * @param deleteStatement delete statement
223      */
224     public void extractTablesFromDelete(final DeleteStatement deleteStatement) {
225         extractTablesFromTableSegment(deleteStatement.getTable());
226         if (deleteStatement.getWhere().isPresent()) {
227             extractTablesFromExpression(deleteStatement.getWhere().get().getExpr());
228         }
229     }
230     
231     /**
232      * Extract table that should be rewritten from insert statement.
233      *
234      * @param insertStatement insert statement
235      */
236     public void extractTablesFromInsert(final InsertStatement insertStatement) {
237         if (null != insertStatement.getTable()) {
238             extractTablesFromTableSegment(insertStatement.getTable());
239         }
240         if (!insertStatement.getColumns().isEmpty()) {
241             for (ColumnSegment each : insertStatement.getColumns()) {
242                 extractTablesFromExpression(each);
243             }
244         }
245         InsertStatementHandler.getOnDuplicateKeyColumnsSegment(insertStatement).ifPresent(optional -> extractTablesFromAssignmentItems(optional.getColumns()));
246         if (insertStatement.getInsertSelect().isPresent()) {
247             extractTablesFromSelect(insertStatement.getInsertSelect().get().getSelect());
248         }
249     }
250     
251     private void extractTablesFromAssignmentItems(final Collection<ColumnAssignmentSegment> assignmentItems) {
252         assignmentItems.forEach(each -> extractTablesFromColumnSegments(each.getColumns()));
253     }
254     
255     private void extractTablesFromColumnSegments(final Collection<ColumnSegment> columnSegments) {
256         columnSegments.forEach(each -> {
257             if (each.getOwner().isPresent() && needRewrite(each.getOwner().get())) {
258                 OwnerSegment ownerSegment = each.getOwner().get();
259                 rewriteTables.add(new SimpleTableSegment(new TableNameSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(), ownerSegment.getIdentifier())));
260             }
261         });
262     }
263     
264     /**
265      * Extract table that should be rewritten from update statement.
266      *
267      * @param updateStatement update statement.
268      */
269     public void extractTablesFromUpdate(final UpdateStatement updateStatement) {
270         extractTablesFromTableSegment(updateStatement.getTable());
271         updateStatement.getSetAssignment().getAssignments().forEach(each -> extractTablesFromExpression(each.getColumns().get(0)));
272         if (updateStatement.getWhere().isPresent()) {
273             extractTablesFromExpression(updateStatement.getWhere().get().getExpr());
274         }
275     }
276     
277     /**
278      * Check if the table needs to be overwritten.
279      *
280      * @param owner owner
281      * @return boolean
282      */
283     public boolean needRewrite(final OwnerSegment owner) {
284         for (TableSegment each : tableContext) {
285             if (owner.getIdentifier().getValue().equalsIgnoreCase(each.getAliasName().orElse(null))) {
286                 return false;
287             }
288         }
289         return true;
290     }
291     
292     /**
293      * Extract the tables that should exist from routine body segment.
294      *
295      * @param routineBody routine body segment
296      * @return the tables that should exist
297      */
298     public Collection<SimpleTableSegment> extractExistTableFromRoutineBody(final RoutineBodySegment routineBody) {
299         Collection<SimpleTableSegment> result = new LinkedList<>();
300         for (ValidStatementSegment each : routineBody.getValidStatements()) {
301             if (each.getAlterTable().isPresent()) {
302                 result.add(each.getAlterTable().get().getTable());
303             }
304             if (each.getDropTable().isPresent()) {
305                 result.addAll(each.getDropTable().get().getTables());
306             }
307             if (each.getTruncate().isPresent()) {
308                 result.addAll(each.getTruncate().get().getTables());
309             }
310             result.addAll(extractExistTableFromDMLStatement(each));
311         }
312         return result;
313     }
314     
315     private Collection<SimpleTableSegment> extractExistTableFromDMLStatement(final ValidStatementSegment validStatementSegment) {
316         if (validStatementSegment.getInsert().isPresent()) {
317             extractTablesFromInsert(validStatementSegment.getInsert().get());
318         } else if (validStatementSegment.getReplace().isPresent()) {
319             extractTablesFromInsert(validStatementSegment.getReplace().get());
320         } else if (validStatementSegment.getUpdate().isPresent()) {
321             extractTablesFromUpdate(validStatementSegment.getUpdate().get());
322         } else if (validStatementSegment.getDelete().isPresent()) {
323             extractTablesFromDelete(validStatementSegment.getDelete().get());
324         } else if (validStatementSegment.getSelect().isPresent()) {
325             extractTablesFromSelect(validStatementSegment.getSelect().get());
326         }
327         return rewriteTables;
328     }
329     
330     /**
331      * Extract the tables that should not exist from routine body segment.
332      *
333      * @param routineBody routine body segment
334      * @return the tables that should not exist
335      */
336     public Collection<SimpleTableSegment> extractNotExistTableFromRoutineBody(final RoutineBodySegment routineBody) {
337         Collection<SimpleTableSegment> result = new LinkedList<>();
338         for (ValidStatementSegment each : routineBody.getValidStatements()) {
339             Optional<CreateTableStatement> createTable = each.getCreateTable();
340             if (createTable.isPresent() && !CreateTableStatementHandler.ifNotExists(createTable.get())) {
341                 result.add(createTable.get().getTable());
342             }
343         }
344         return result;
345     }
346     
347     /**
348      * Extract table that should be rewritten from SQL statement.
349      *
350      * @param sqlStatement SQL statement
351      */
352     public void extractTablesFromSQLStatement(final SQLStatement sqlStatement) {
353         if (sqlStatement instanceof SelectStatement) {
354             extractTablesFromSelect((SelectStatement) sqlStatement);
355         } else if (sqlStatement instanceof InsertStatement) {
356             extractTablesFromInsert((InsertStatement) sqlStatement);
357         } else if (sqlStatement instanceof UpdateStatement) {
358             extractTablesFromUpdate((UpdateStatement) sqlStatement);
359         } else if (sqlStatement instanceof DeleteStatement) {
360             extractTablesFromDelete((DeleteStatement) sqlStatement);
361         }
362     }
363     
364     /**
365      * Extract table that should be rewritten from create view statement.
366      * 
367      * @param createViewStatement create view statement
368      */
369     public void extractTablesFromCreateViewStatement(final CreateViewStatement createViewStatement) {
370         tableContext.add(createViewStatement.getView());
371         rewriteTables.add(createViewStatement.getView());
372         extractTablesFromSelect(createViewStatement.getSelect());
373     }
374 }