View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.sharding.rewrite.token.generator.impl;
19  
20  import lombok.Setter;
21  import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
22  import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.RouteContextAware;
23  import org.apache.shardingsphere.sharding.rewrite.token.pojo.ShardingInsertValue;
24  import org.apache.shardingsphere.sharding.rewrite.token.pojo.ShardingInsertValuesToken;
25  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
26  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
27  import org.apache.shardingsphere.infra.binder.context.statement.type.dml.InsertStatementContext;
28  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.InsertValuesSegment;
29  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
30  import org.apache.shardingsphere.infra.datanode.DataNode;
31  import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.OptionalSQLTokenGenerator;
32  import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.InsertValuesToken;
33  import org.apache.shardingsphere.infra.route.context.RouteContext;
34  
35  import java.util.Collection;
36  import java.util.Collections;
37  import java.util.Iterator;
38  import java.util.List;
39  
40  /**
41   * Insert values token generator for sharding.
42   */
43  @HighFrequencyInvocation
44  @Setter
45  public final class ShardingInsertValuesTokenGenerator implements OptionalSQLTokenGenerator<InsertStatementContext>, RouteContextAware {
46      
47      private RouteContext routeContext;
48      
49      @Override
50      public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
51          return sqlStatementContext instanceof InsertStatementContext && !(((InsertStatementContext) sqlStatementContext).getSqlStatement()).getValues().isEmpty();
52      }
53      
54      @Override
55      public InsertValuesToken generateSQLToken(final InsertStatementContext insertStatementContext) {
56          Collection<InsertValuesSegment> insertValuesSegments = insertStatementContext.getSqlStatement().getValues();
57          InsertValuesToken result = new ShardingInsertValuesToken(getStartIndex(insertValuesSegments), getStopIndex(insertValuesSegments));
58          Iterator<Collection<DataNode>> dataNodesIterator = routeContext.getOriginalDataNodes().isEmpty() ? Collections.emptyIterator() : routeContext.getOriginalDataNodes().iterator();
59          for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
60              List<ExpressionSegment> expressionSegments = each.getValueExpressions();
61              Collection<DataNode> dataNodes = dataNodesIterator.hasNext() ? dataNodesIterator.next() : Collections.emptyList();
62              result.getInsertValues().add(new ShardingInsertValue(expressionSegments, dataNodes));
63          }
64          return result;
65      }
66      
67      private int getStartIndex(final Collection<InsertValuesSegment> segments) {
68          int result = segments.iterator().next().getStartIndex();
69          for (InsertValuesSegment each : segments) {
70              result = Math.min(result, each.getStartIndex());
71          }
72          return result;
73      }
74      
75      private int getStopIndex(final Collection<InsertValuesSegment> segments) {
76          int result = segments.iterator().next().getStopIndex();
77          for (InsertValuesSegment each : segments) {
78              result = Math.max(result, each.getStopIndex());
79          }
80          return result;
81      }
82  }