View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.encrypt.merge.dal.show;
19  
20  import org.apache.shardingsphere.encrypt.exception.syntax.UnsupportedEncryptSQLException;
21  import org.apache.shardingsphere.encrypt.rule.EncryptRule;
22  import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
23  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
24  import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
25  import org.apache.shardingsphere.infra.merge.result.MergedResult;
26  import org.apache.shardingsphere.infra.merge.result.impl.decorator.DecoratorMergedResult;
27  import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
28  import org.apache.shardingsphere.infra.parser.SQLParserEngine;
29  import org.apache.shardingsphere.parser.rule.SQLParserRule;
30  import org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.column.ColumnDefinitionSegment;
31  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
32  import org.apache.shardingsphere.sql.parser.statement.core.statement.attribute.type.TableInResultSetSQLStatementAttribute;
33  import org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.table.CreateTableStatement;
34  
35  import java.sql.SQLException;
36  import java.util.ArrayList;
37  import java.util.List;
38  import java.util.Optional;
39  
40  /**
41   * Encrypt show create table merged result.
42   */
43  public final class EncryptShowCreateTableMergedResult extends DecoratorMergedResult {
44      
45      private static final String COMMA = ", ";
46      
47      private final EncryptRule rule;
48      
49      private final String tableName;
50      
51      private final int tableNameResultSetIndex;
52      
53      private final SQLParserEngine sqlParserEngine;
54      
55      public EncryptShowCreateTableMergedResult(final RuleMetaData globalRuleMetaData, final MergedResult mergedResult, final SQLStatementContext sqlStatementContext, final EncryptRule rule) {
56          super(mergedResult);
57          ShardingSpherePreconditions.checkState(1 == sqlStatementContext.getTablesContext().getSimpleTables().size(),
58                  () -> new UnsupportedEncryptSQLException("SHOW CREATE TABLE FOR MULTI TABLES"));
59          this.rule = rule;
60          tableName = sqlStatementContext.getTablesContext().getSimpleTables().iterator().next().getTableName().getIdentifier().getValue();
61          TableInResultSetSQLStatementAttribute attribute = sqlStatementContext.getSqlStatement().getAttributes().getAttribute(TableInResultSetSQLStatementAttribute.class);
62          tableNameResultSetIndex = attribute.getNameResultSetIndex();
63          sqlParserEngine = globalRuleMetaData.getSingleRule(SQLParserRule.class).getSQLParserEngine(sqlStatementContext.getSqlStatement().getDatabaseType());
64      }
65      
66      @Override
67      public Object getValue(final int columnIndex, final Class<?> type) throws SQLException {
68          if (tableNameResultSetIndex != columnIndex) {
69              return getMergedResult().getValue(columnIndex, type);
70          }
71          String createTableSQL = getMergedResult().getValue(tableNameResultSetIndex, type).toString();
72          Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
73          if (!encryptTable.isPresent() || !createTableSQL.contains("(")) {
74              return createTableSQL;
75          }
76          CreateTableStatement createTableStatement = (CreateTableStatement) sqlParserEngine.parse(createTableSQL, false);
77          List<ColumnDefinitionSegment> columnDefinitions = new ArrayList<>(createTableStatement.getColumnDefinitions());
78          StringBuilder result = new StringBuilder(createTableSQL.substring(0, columnDefinitions.get(0).getStartIndex()));
79          for (ColumnDefinitionSegment each : columnDefinitions) {
80              findLogicColumnDefinition(each, encryptTable.get(), createTableSQL).ifPresent(optional -> result.append(optional).append(COMMA));
81          }
82          // TODO decorate encrypt column index when we support index rewrite
83          result.delete(result.length() - COMMA.length(), result.length()).append(createTableSQL.substring(columnDefinitions.get(columnDefinitions.size() - 1).getStopIndex() + 1));
84          return result.toString();
85      }
86      
87      private Optional<String> findLogicColumnDefinition(final ColumnDefinitionSegment columnDefinition, final EncryptTable encryptTable, final String createTableSQL) {
88          ColumnSegment columnSegment = columnDefinition.getColumnName();
89          String columnName = columnSegment.getIdentifier().getValue();
90          if (encryptTable.isCipherColumn(columnName)) {
91              String logicColumn = encryptTable.getLogicColumnByCipherColumn(columnName);
92              return Optional.of(createTableSQL.substring(columnDefinition.getStartIndex(), columnSegment.getStartIndex())
93                      + columnSegment.getIdentifier().getQuoteCharacter().wrap(logicColumn) + createTableSQL.substring(columnSegment.getStopIndex() + 1, columnDefinition.getStopIndex() + 1));
94          }
95          if (encryptTable.isDerivedColumn(columnName)) {
96              return Optional.empty();
97          }
98          return Optional.of(createTableSQL.substring(columnDefinition.getStartIndex(), columnDefinition.getStopIndex() + 1));
99      }
100 }