View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.sharding.rewrite.token.generator.impl;
19  
20  import lombok.Setter;
21  import org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.RouteContextAware;
22  import org.apache.shardingsphere.sharding.rewrite.token.pojo.ShardingInsertValue;
23  import org.apache.shardingsphere.sharding.rewrite.token.pojo.ShardingInsertValuesToken;
24  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
25  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
26  import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
27  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
28  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
29  import org.apache.shardingsphere.infra.datanode.DataNode;
30  import org.apache.shardingsphere.infra.rewrite.sql.token.generator.OptionalSQLTokenGenerator;
31  import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.InsertValuesToken;
32  import org.apache.shardingsphere.infra.route.context.RouteContext;
33  
34  import java.util.Collection;
35  import java.util.Collections;
36  import java.util.Iterator;
37  import java.util.List;
38  
39  /**
40   * Insert values token generator for sharding.
41   */
42  @Setter
43  public final class ShardingInsertValuesTokenGenerator implements OptionalSQLTokenGenerator<InsertStatementContext>, RouteContextAware {
44      
45      private RouteContext routeContext;
46      
47      @Override
48      public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
49          return sqlStatementContext instanceof InsertStatementContext && !(((InsertStatementContext) sqlStatementContext).getSqlStatement()).getValues().isEmpty();
50      }
51      
52      @Override
53      public InsertValuesToken generateSQLToken(final InsertStatementContext insertStatementContext) {
54          Collection<InsertValuesSegment> insertValuesSegments = insertStatementContext.getSqlStatement().getValues();
55          InsertValuesToken result = new ShardingInsertValuesToken(getStartIndex(insertValuesSegments), getStopIndex(insertValuesSegments));
56          Iterator<Collection<DataNode>> originalDataNodesIterator = null == routeContext || routeContext.getOriginalDataNodes().isEmpty()
57                  ? null
58                  : routeContext.getOriginalDataNodes().iterator();
59          for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
60              List<ExpressionSegment> expressionSegments = each.getValueExpressions();
61              Collection<DataNode> dataNodes = null == originalDataNodesIterator ? Collections.emptyList() : originalDataNodesIterator.next();
62              result.getInsertValues().add(new ShardingInsertValue(expressionSegments, dataNodes));
63          }
64          return result;
65      }
66      
67      private int getStartIndex(final Collection<InsertValuesSegment> segments) {
68          int result = segments.iterator().next().getStartIndex();
69          for (InsertValuesSegment each : segments) {
70              result = Math.min(result, each.getStartIndex());
71          }
72          return result;
73      }
74      
75      private int getStopIndex(final Collection<InsertValuesSegment> segments) {
76          int result = segments.iterator().next().getStopIndex();
77          for (InsertValuesSegment each : segments) {
78              result = Math.max(result, each.getStopIndex());
79          }
80          return result;
81      }
82  }