View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.encrypt.rewrite.token.generator.insert;
19  
20  import com.google.common.base.Preconditions;
21  import lombok.RequiredArgsConstructor;
22  import lombok.Setter;
23  import org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptInsertValuesToken;
24  import org.apache.shardingsphere.encrypt.rule.EncryptRule;
25  import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
26  import org.apache.shardingsphere.encrypt.rule.column.item.AssistedQueryColumnItem;
27  import org.apache.shardingsphere.encrypt.rule.column.item.LikeQueryColumnItem;
28  import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
29  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
30  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedLiteralExpressionSegment;
31  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedParameterMarkerExpressionSegment;
32  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
33  import org.apache.shardingsphere.infra.binder.context.statement.type.dml.InsertStatementContext;
34  import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
35  import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
36  import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.OptionalSQLTokenGenerator;
37  import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.PreviousSQLTokensAware;
38  import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
39  import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.InsertValue;
40  import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.InsertValuesToken;
41  import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.UseDefaultInsertColumnsToken;
42  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.InsertValuesSegment;
43  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
44  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
45  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment;
46  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
47  import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
48  
49  import java.util.Collection;
50  import java.util.Iterator;
51  import java.util.LinkedList;
52  import java.util.List;
53  import java.util.Optional;
54  
55  /**
56   * Insert values token generator for encrypt.
57   */
58  @RequiredArgsConstructor
59  @Setter
60  public final class EncryptInsertValuesTokenGenerator implements OptionalSQLTokenGenerator<InsertStatementContext>, PreviousSQLTokensAware {
61      
62      private final EncryptRule rule;
63      
64      private final ShardingSphereDatabase database;
65      
66      private List<SQLToken> previousSQLTokens;
67      
68      @Override
69      public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
70          return sqlStatementContext instanceof InsertStatementContext && !(((InsertStatementContext) sqlStatementContext).getSqlStatement()).getValues().isEmpty();
71      }
72      
73      @Override
74      public InsertValuesToken generateSQLToken(final InsertStatementContext insertStatementContext) {
75          Optional<SQLToken> insertValuesToken = findPreviousSQLToken(InsertValuesToken.class);
76          if (insertValuesToken.isPresent()) {
77              processPreviousSQLToken(insertStatementContext, (InsertValuesToken) insertValuesToken.get());
78              return (InsertValuesToken) insertValuesToken.get();
79          }
80          return generateNewSQLToken(insertStatementContext);
81      }
82      
83      private Optional<SQLToken> findPreviousSQLToken(final Class<?> sqlToken) {
84          for (SQLToken each : previousSQLTokens) {
85              if (sqlToken.isAssignableFrom(each.getClass())) {
86                  return Optional.of(each);
87              }
88          }
89          return Optional.empty();
90      }
91      
92      private void processPreviousSQLToken(final InsertStatementContext insertStatementContext, final InsertValuesToken insertValuesToken) {
93          String tableName = insertStatementContext.getSqlStatement().getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
94          EncryptTable encryptTable = rule.getEncryptTable(tableName);
95          int count = 0;
96          String schemaName = insertStatementContext.getTablesContext().getSchemaName()
97                  .orElseGet(() -> new DatabaseTypeRegistry(insertStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName()));
98          for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
99              encryptToken(insertValuesToken.getInsertValues().get(count), schemaName, encryptTable, insertStatementContext, each);
100             count++;
101         }
102     }
103     
104     private InsertValuesToken generateNewSQLToken(final InsertStatementContext insertStatementContext) {
105         String tableName = insertStatementContext.getSqlStatement().getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
106         Collection<InsertValuesSegment> insertValuesSegments = insertStatementContext.getSqlStatement().getValues();
107         InsertValuesToken result = new EncryptInsertValuesToken(getStartIndex(insertValuesSegments), getStopIndex(insertValuesSegments));
108         EncryptTable encryptTable = rule.getEncryptTable(tableName);
109         String schemaName = insertStatementContext.getTablesContext().getSchemaName()
110                 .orElseGet(() -> new DatabaseTypeRegistry(insertStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName()));
111         for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
112             InsertValue insertValueToken = new InsertValue(new LinkedList<>(each.getValueExpressions()));
113             encryptToken(insertValueToken, schemaName, encryptTable, insertStatementContext, each);
114             result.getInsertValues().add(insertValueToken);
115         }
116         return result;
117     }
118     
119     private int getStartIndex(final Collection<InsertValuesSegment> segments) {
120         int result = segments.iterator().next().getStartIndex();
121         for (InsertValuesSegment each : segments) {
122             result = Math.min(result, each.getStartIndex());
123         }
124         return result;
125     }
126     
127     private int getStopIndex(final Collection<InsertValuesSegment> segments) {
128         int result = segments.iterator().next().getStopIndex();
129         for (InsertValuesSegment each : segments) {
130             result = Math.max(result, each.getStopIndex());
131         }
132         return result;
133     }
134     
135     private void encryptToken(final InsertValue insertValueToken, final String schemaName, final EncryptTable encryptTable,
136                               final InsertStatementContext insertStatementContext, final InsertValueContext insertValueContext) {
137         String tableName = encryptTable.getTable();
138         Optional<SQLToken> useDefaultInsertColumnsToken = findPreviousSQLToken(UseDefaultInsertColumnsToken.class);
139         Iterator<String> descendingColumnNames = insertStatementContext.getDescendingColumnNames();
140         while (descendingColumnNames.hasNext()) {
141             String columnName = descendingColumnNames.next();
142             if (!encryptTable.isEncryptColumn(columnName)) {
143                 continue;
144             }
145             EncryptColumn encryptColumn = rule.getEncryptTable(tableName).getEncryptColumn(columnName);
146             int columnIndex = useDefaultInsertColumnsToken
147                     .map(optional -> ((UseDefaultInsertColumnsToken) optional).getColumns().indexOf(columnName)).orElseGet(() -> insertStatementContext.getColumnNames().indexOf(columnName));
148             Object originalValue = insertValueContext.getLiteralValue(columnIndex).orElse(null);
149             ExpressionSegment valueExpression = insertValueContext.getValueExpressions().get(columnIndex);
150             setCipherColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, originalValue);
151             int indexDelta = 1;
152             if (encryptColumn.getAssistedQuery().isPresent()) {
153                 addAssistedQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
154                 indexDelta++;
155             }
156             if (encryptColumn.getLikeQuery().isPresent()) {
157                 addLikeQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
158             }
159         }
160     }
161     
162     private void setCipherColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn,
163                                  final InsertValue insertValueToken, final ExpressionSegment valueExpression, final int columnIndex, final Object originalValue) {
164         if (valueExpression instanceof LiteralExpressionSegment) {
165             insertValueToken.getValues().set(columnIndex, new LiteralExpressionSegment(
166                     valueExpression.getStartIndex(), valueExpression.getStopIndex(),
167                     encryptColumn.getCipher().encrypt(database.getName(), schemaName, tableName, encryptColumn.getName(), originalValue)));
168         }
169     }
170     
171     private void addAssistedQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
172                                         final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
173         Optional<AssistedQueryColumnItem> assistedQueryColumnItem = encryptColumn.getAssistedQuery();
174         Preconditions.checkState(assistedQueryColumnItem.isPresent());
175         Object derivedValue = assistedQueryColumnItem.get().encrypt(database.getName(), schemaName, tableName, encryptColumn.getName(), originalValue);
176         addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, assistedQueryColumnItem.get().getName());
177     }
178     
179     private void addLikeQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
180                                     final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
181         Optional<LikeQueryColumnItem> likeQueryColumnItem = encryptColumn.getLikeQuery();
182         Preconditions.checkState(likeQueryColumnItem.isPresent());
183         Object derivedValue = likeQueryColumnItem.get().encrypt(database.getName(), schemaName, tableName, encryptColumn.getName(), originalValue);
184         addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, likeQueryColumnItem.get().getName());
185     }
186     
187     private void addDerivedColumn(final InsertValue insertValueToken, final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object derivedValue,
188                                   final String derivedColumnName) {
189         ExpressionSegment derivedExpression;
190         if (valueExpression instanceof LiteralExpressionSegment) {
191             derivedExpression = new DerivedLiteralExpressionSegment(derivedValue);
192         } else if (valueExpression instanceof ParameterMarkerExpressionSegment) {
193             derivedExpression = new DerivedParameterMarkerExpressionSegment(getParameterIndexCount(insertValueToken));
194         } else if (valueExpression instanceof ColumnSegment) {
195             derivedExpression = createColumnSegment((ColumnSegment) valueExpression, derivedColumnName);
196         } else {
197             derivedExpression = valueExpression;
198         }
199         insertValueToken.getValues().add(columnIndex + indexDelta, derivedExpression);
200     }
201     
202     private ColumnSegment createColumnSegment(final ColumnSegment originalColumn, final String columnName) {
203         ColumnSegment result = new ColumnSegment(originalColumn.getStartIndex(), originalColumn.getStopIndex(), new IdentifierValue(columnName, originalColumn.getIdentifier().getQuoteCharacter()));
204         result.setNestedObjectAttributes(originalColumn.getNestedObjectAttributes());
205         originalColumn.getOwner().ifPresent(result::setOwner);
206         result.setColumnBoundInfo(originalColumn.getColumnBoundInfo());
207         result.setOtherUsingColumnBoundInfo(originalColumn.getOtherUsingColumnBoundInfo());
208         result.setVariable(originalColumn.isVariable());
209         return result;
210     }
211     
212     private int getParameterIndexCount(final InsertValue insertValueToken) {
213         int result = 0;
214         for (ExpressionSegment each : insertValueToken.getValues()) {
215             if (each instanceof ParameterMarkerExpressionSegment) {
216                 result++;
217             }
218         }
219         return result;
220     }
221 }