View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.encrypt.rewrite.token.generator.insert;
19  
20  import com.google.common.base.Preconditions;
21  import lombok.Setter;
22  import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseNameAware;
23  import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptRuleAware;
24  import org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptInsertValuesToken;
25  import org.apache.shardingsphere.encrypt.rule.EncryptRule;
26  import org.apache.shardingsphere.encrypt.rule.EncryptTable;
27  import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
28  import org.apache.shardingsphere.encrypt.rule.column.item.AssistedQueryColumnItem;
29  import org.apache.shardingsphere.encrypt.rule.column.item.LikeQueryColumnItem;
30  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
31  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedLiteralExpressionSegment;
32  import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedParameterMarkerExpressionSegment;
33  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
34  import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
35  import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
36  import org.apache.shardingsphere.infra.rewrite.sql.token.generator.OptionalSQLTokenGenerator;
37  import org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.PreviousSQLTokensAware;
38  import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
39  import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.InsertValue;
40  import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.InsertValuesToken;
41  import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.UseDefaultInsertColumnsToken;
42  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
43  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
44  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
45  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
46  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
47  import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
48  
49  import java.util.Collection;
50  import java.util.Iterator;
51  import java.util.List;
52  import java.util.Optional;
53  
54  /**
55   * Insert values token generator for encrypt.
56   */
57  @Setter
58  public final class EncryptInsertValuesTokenGenerator implements OptionalSQLTokenGenerator<InsertStatementContext>, PreviousSQLTokensAware, EncryptRuleAware, DatabaseNameAware {
59      
60      private List<SQLToken> previousSQLTokens;
61      
62      private EncryptRule encryptRule;
63      
64      private String databaseName;
65      
66      @Override
67      public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
68          return sqlStatementContext instanceof InsertStatementContext && !(((InsertStatementContext) sqlStatementContext).getSqlStatement()).getValues().isEmpty();
69      }
70      
71      @Override
72      public InsertValuesToken generateSQLToken(final InsertStatementContext insertStatementContext) {
73          Optional<SQLToken> insertValuesToken = findPreviousSQLToken(InsertValuesToken.class);
74          if (insertValuesToken.isPresent()) {
75              processPreviousSQLToken(insertStatementContext, (InsertValuesToken) insertValuesToken.get());
76              return (InsertValuesToken) insertValuesToken.get();
77          }
78          return generateNewSQLToken(insertStatementContext);
79      }
80      
81      private Optional<SQLToken> findPreviousSQLToken(final Class<?> sqlToken) {
82          for (SQLToken each : previousSQLTokens) {
83              if (sqlToken.isAssignableFrom(each.getClass())) {
84                  return Optional.of(each);
85              }
86          }
87          return Optional.empty();
88      }
89      
90      private void processPreviousSQLToken(final InsertStatementContext insertStatementContext, final InsertValuesToken insertValuesToken) {
91          String tableName = insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
92          EncryptTable encryptTable = encryptRule.getEncryptTable(tableName);
93          int count = 0;
94          String schemaName = insertStatementContext.getTablesContext().getSchemaName()
95                  .orElseGet(() -> new DatabaseTypeRegistry(insertStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName));
96          for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
97              encryptToken(insertValuesToken.getInsertValues().get(count), schemaName, encryptTable, insertStatementContext, each);
98              count++;
99          }
100     }
101     
102     private InsertValuesToken generateNewSQLToken(final InsertStatementContext insertStatementContext) {
103         String tableName = insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
104         Collection<InsertValuesSegment> insertValuesSegments = insertStatementContext.getSqlStatement().getValues();
105         InsertValuesToken result = new EncryptInsertValuesToken(getStartIndex(insertValuesSegments), getStopIndex(insertValuesSegments));
106         EncryptTable encryptTable = encryptRule.getEncryptTable(tableName);
107         String schemaName = insertStatementContext.getTablesContext().getSchemaName()
108                 .orElseGet(() -> new DatabaseTypeRegistry(insertStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName));
109         for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
110             InsertValue insertValueToken = new InsertValue(each.getValueExpressions());
111             encryptToken(insertValueToken, schemaName, encryptTable, insertStatementContext, each);
112             result.getInsertValues().add(insertValueToken);
113         }
114         return result;
115     }
116     
117     private int getStartIndex(final Collection<InsertValuesSegment> segments) {
118         int result = segments.iterator().next().getStartIndex();
119         for (InsertValuesSegment each : segments) {
120             result = Math.min(result, each.getStartIndex());
121         }
122         return result;
123     }
124     
125     private int getStopIndex(final Collection<InsertValuesSegment> segments) {
126         int result = segments.iterator().next().getStopIndex();
127         for (InsertValuesSegment each : segments) {
128             result = Math.max(result, each.getStopIndex());
129         }
130         return result;
131     }
132     
133     private void encryptToken(final InsertValue insertValueToken, final String schemaName, final EncryptTable encryptTable,
134                               final InsertStatementContext insertStatementContext, final InsertValueContext insertValueContext) {
135         String tableName = encryptTable.getTable();
136         Optional<SQLToken> useDefaultInsertColumnsToken = findPreviousSQLToken(UseDefaultInsertColumnsToken.class);
137         Iterator<String> descendingColumnNames = insertStatementContext.getDescendingColumnNames();
138         while (descendingColumnNames.hasNext()) {
139             String columnName = descendingColumnNames.next();
140             if (!encryptTable.isEncryptColumn(columnName)) {
141                 continue;
142             }
143             EncryptColumn encryptColumn = encryptRule.getEncryptTable(tableName).getEncryptColumn(columnName);
144             int columnIndex = useDefaultInsertColumnsToken
145                     .map(optional -> ((UseDefaultInsertColumnsToken) optional).getColumns().indexOf(columnName)).orElseGet(() -> insertStatementContext.getColumnNames().indexOf(columnName));
146             Object originalValue = insertValueContext.getLiteralValue(columnIndex).orElse(null);
147             ExpressionSegment valueExpression = insertValueContext.getValueExpressions().get(columnIndex);
148             setCipherColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, originalValue);
149             int indexDelta = 1;
150             if (encryptColumn.getAssistedQuery().isPresent()) {
151                 addAssistedQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
152                 indexDelta++;
153             }
154             if (encryptColumn.getLikeQuery().isPresent()) {
155                 addLikeQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
156             }
157         }
158     }
159     
160     private void setCipherColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn,
161                                  final InsertValue insertValueToken, final ExpressionSegment valueExpression, final int columnIndex, final Object originalValue) {
162         if (valueExpression instanceof LiteralExpressionSegment) {
163             insertValueToken.getValues().set(columnIndex, new LiteralExpressionSegment(
164                     valueExpression.getStartIndex(), valueExpression.getStopIndex(), encryptColumn.getCipher().encrypt(databaseName, schemaName, tableName, encryptColumn.getName(), originalValue)));
165         }
166     }
167     
168     private void addAssistedQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
169                                         final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
170         Optional<AssistedQueryColumnItem> assistedQueryColumnItem = encryptColumn.getAssistedQuery();
171         Preconditions.checkState(assistedQueryColumnItem.isPresent());
172         Object derivedValue = assistedQueryColumnItem.get().encrypt(databaseName, schemaName, tableName, encryptColumn.getName(), originalValue);
173         addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, assistedQueryColumnItem.get().getName());
174     }
175     
176     private void addLikeQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
177                                     final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
178         Optional<LikeQueryColumnItem> likeQueryColumnItem = encryptColumn.getLikeQuery();
179         Preconditions.checkState(likeQueryColumnItem.isPresent());
180         Object derivedValue = likeQueryColumnItem.get().encrypt(databaseName, schemaName, tableName, encryptColumn.getName(), originalValue);
181         addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, likeQueryColumnItem.get().getName());
182     }
183     
184     private void addDerivedColumn(final InsertValue insertValueToken, final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object derivedValue,
185                                   final String derivedColumnName) {
186         ExpressionSegment derivedExpression;
187         if (valueExpression instanceof LiteralExpressionSegment) {
188             derivedExpression = new DerivedLiteralExpressionSegment(derivedValue);
189         } else if (valueExpression instanceof ParameterMarkerExpressionSegment) {
190             derivedExpression = new DerivedParameterMarkerExpressionSegment(getParameterIndexCount(insertValueToken));
191         } else if (valueExpression instanceof ColumnSegment) {
192             derivedExpression = createColumnSegment((ColumnSegment) valueExpression, derivedColumnName);
193         } else {
194             derivedExpression = valueExpression;
195         }
196         insertValueToken.getValues().add(columnIndex + indexDelta, derivedExpression);
197     }
198     
199     private ColumnSegment createColumnSegment(final ColumnSegment originalColumn, final String columnName) {
200         ColumnSegment result = new ColumnSegment(originalColumn.getStartIndex(), originalColumn.getStopIndex(), new IdentifierValue(columnName, originalColumn.getIdentifier().getQuoteCharacter()));
201         result.setNestedObjectAttributes(originalColumn.getNestedObjectAttributes());
202         originalColumn.getOwner().ifPresent(result::setOwner);
203         result.setColumnBoundedInfo(originalColumn.getColumnBoundedInfo());
204         result.setOtherUsingColumnBoundedInfo(originalColumn.getOtherUsingColumnBoundedInfo());
205         result.setVariable(originalColumn.isVariable());
206         return result;
207     }
208     
209     private int getParameterIndexCount(final InsertValue insertValueToken) {
210         int result = 0;
211         for (ExpressionSegment each : insertValueToken.getValues()) {
212             if (each instanceof ParameterMarkerExpressionSegment) {
213                 result++;
214             }
215         }
216         return result;
217     }
218 }