1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.encrypt.rewrite.token.generator.insert;
19
20 import com.google.common.base.Preconditions;
21 import lombok.Setter;
22 import org.apache.shardingsphere.encrypt.rewrite.aware.DatabaseNameAware;
23 import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptRuleAware;
24 import org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptInsertValuesToken;
25 import org.apache.shardingsphere.encrypt.rule.EncryptRule;
26 import org.apache.shardingsphere.encrypt.rule.EncryptTable;
27 import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
28 import org.apache.shardingsphere.encrypt.rule.column.item.AssistedQueryColumnItem;
29 import org.apache.shardingsphere.encrypt.rule.column.item.LikeQueryColumnItem;
30 import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
31 import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedLiteralExpressionSegment;
32 import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedParameterMarkerExpressionSegment;
33 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
34 import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
35 import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
36 import org.apache.shardingsphere.infra.rewrite.sql.token.generator.OptionalSQLTokenGenerator;
37 import org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.PreviousSQLTokensAware;
38 import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.SQLToken;
39 import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.InsertValue;
40 import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.InsertValuesToken;
41 import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.UseDefaultInsertColumnsToken;
42 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
43 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
44 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
45 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
46 import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
47 import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
48
49 import java.util.Collection;
50 import java.util.Iterator;
51 import java.util.List;
52 import java.util.Optional;
53
54
55
56
57 @Setter
58 public final class EncryptInsertValuesTokenGenerator implements OptionalSQLTokenGenerator<InsertStatementContext>, PreviousSQLTokensAware, EncryptRuleAware, DatabaseNameAware {
59
60 private List<SQLToken> previousSQLTokens;
61
62 private EncryptRule encryptRule;
63
64 private String databaseName;
65
66 @Override
67 public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
68 return sqlStatementContext instanceof InsertStatementContext && !(((InsertStatementContext) sqlStatementContext).getSqlStatement()).getValues().isEmpty();
69 }
70
71 @Override
72 public InsertValuesToken generateSQLToken(final InsertStatementContext insertStatementContext) {
73 Optional<SQLToken> insertValuesToken = findPreviousSQLToken(InsertValuesToken.class);
74 if (insertValuesToken.isPresent()) {
75 processPreviousSQLToken(insertStatementContext, (InsertValuesToken) insertValuesToken.get());
76 return (InsertValuesToken) insertValuesToken.get();
77 }
78 return generateNewSQLToken(insertStatementContext);
79 }
80
81 private Optional<SQLToken> findPreviousSQLToken(final Class<?> sqlToken) {
82 for (SQLToken each : previousSQLTokens) {
83 if (sqlToken.isAssignableFrom(each.getClass())) {
84 return Optional.of(each);
85 }
86 }
87 return Optional.empty();
88 }
89
90 private void processPreviousSQLToken(final InsertStatementContext insertStatementContext, final InsertValuesToken insertValuesToken) {
91 String tableName = insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
92 EncryptTable encryptTable = encryptRule.getEncryptTable(tableName);
93 int count = 0;
94 String schemaName = insertStatementContext.getTablesContext().getSchemaName()
95 .orElseGet(() -> new DatabaseTypeRegistry(insertStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName));
96 for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
97 encryptToken(insertValuesToken.getInsertValues().get(count), schemaName, encryptTable, insertStatementContext, each);
98 count++;
99 }
100 }
101
102 private InsertValuesToken generateNewSQLToken(final InsertStatementContext insertStatementContext) {
103 String tableName = insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
104 Collection<InsertValuesSegment> insertValuesSegments = insertStatementContext.getSqlStatement().getValues();
105 InsertValuesToken result = new EncryptInsertValuesToken(getStartIndex(insertValuesSegments), getStopIndex(insertValuesSegments));
106 EncryptTable encryptTable = encryptRule.getEncryptTable(tableName);
107 String schemaName = insertStatementContext.getTablesContext().getSchemaName()
108 .orElseGet(() -> new DatabaseTypeRegistry(insertStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName));
109 for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
110 InsertValue insertValueToken = new InsertValue(each.getValueExpressions());
111 encryptToken(insertValueToken, schemaName, encryptTable, insertStatementContext, each);
112 result.getInsertValues().add(insertValueToken);
113 }
114 return result;
115 }
116
117 private int getStartIndex(final Collection<InsertValuesSegment> segments) {
118 int result = segments.iterator().next().getStartIndex();
119 for (InsertValuesSegment each : segments) {
120 result = Math.min(result, each.getStartIndex());
121 }
122 return result;
123 }
124
125 private int getStopIndex(final Collection<InsertValuesSegment> segments) {
126 int result = segments.iterator().next().getStopIndex();
127 for (InsertValuesSegment each : segments) {
128 result = Math.max(result, each.getStopIndex());
129 }
130 return result;
131 }
132
133 private void encryptToken(final InsertValue insertValueToken, final String schemaName, final EncryptTable encryptTable,
134 final InsertStatementContext insertStatementContext, final InsertValueContext insertValueContext) {
135 String tableName = encryptTable.getTable();
136 Optional<SQLToken> useDefaultInsertColumnsToken = findPreviousSQLToken(UseDefaultInsertColumnsToken.class);
137 Iterator<String> descendingColumnNames = insertStatementContext.getDescendingColumnNames();
138 while (descendingColumnNames.hasNext()) {
139 String columnName = descendingColumnNames.next();
140 if (!encryptTable.isEncryptColumn(columnName)) {
141 continue;
142 }
143 EncryptColumn encryptColumn = encryptRule.getEncryptTable(tableName).getEncryptColumn(columnName);
144 int columnIndex = useDefaultInsertColumnsToken
145 .map(optional -> ((UseDefaultInsertColumnsToken) optional).getColumns().indexOf(columnName)).orElseGet(() -> insertStatementContext.getColumnNames().indexOf(columnName));
146 Object originalValue = insertValueContext.getLiteralValue(columnIndex).orElse(null);
147 ExpressionSegment valueExpression = insertValueContext.getValueExpressions().get(columnIndex);
148 setCipherColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, originalValue);
149 int indexDelta = 1;
150 if (encryptColumn.getAssistedQuery().isPresent()) {
151 addAssistedQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
152 indexDelta++;
153 }
154 if (encryptColumn.getLikeQuery().isPresent()) {
155 addLikeQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
156 }
157 }
158 }
159
160 private void setCipherColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn,
161 final InsertValue insertValueToken, final ExpressionSegment valueExpression, final int columnIndex, final Object originalValue) {
162 if (valueExpression instanceof LiteralExpressionSegment) {
163 insertValueToken.getValues().set(columnIndex, new LiteralExpressionSegment(
164 valueExpression.getStartIndex(), valueExpression.getStopIndex(), encryptColumn.getCipher().encrypt(databaseName, schemaName, tableName, encryptColumn.getName(), originalValue)));
165 }
166 }
167
168 private void addAssistedQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
169 final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
170 Optional<AssistedQueryColumnItem> assistedQueryColumnItem = encryptColumn.getAssistedQuery();
171 Preconditions.checkState(assistedQueryColumnItem.isPresent());
172 Object derivedValue = assistedQueryColumnItem.get().encrypt(databaseName, schemaName, tableName, encryptColumn.getName(), originalValue);
173 addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, assistedQueryColumnItem.get().getName());
174 }
175
176 private void addLikeQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
177 final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
178 Optional<LikeQueryColumnItem> likeQueryColumnItem = encryptColumn.getLikeQuery();
179 Preconditions.checkState(likeQueryColumnItem.isPresent());
180 Object derivedValue = likeQueryColumnItem.get().encrypt(databaseName, schemaName, tableName, encryptColumn.getName(), originalValue);
181 addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, likeQueryColumnItem.get().getName());
182 }
183
184 private void addDerivedColumn(final InsertValue insertValueToken, final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object derivedValue,
185 final String derivedColumnName) {
186 ExpressionSegment derivedExpression;
187 if (valueExpression instanceof LiteralExpressionSegment) {
188 derivedExpression = new DerivedLiteralExpressionSegment(derivedValue);
189 } else if (valueExpression instanceof ParameterMarkerExpressionSegment) {
190 derivedExpression = new DerivedParameterMarkerExpressionSegment(getParameterIndexCount(insertValueToken));
191 } else if (valueExpression instanceof ColumnSegment) {
192 derivedExpression = createColumnSegment((ColumnSegment) valueExpression, derivedColumnName);
193 } else {
194 derivedExpression = valueExpression;
195 }
196 insertValueToken.getValues().add(columnIndex + indexDelta, derivedExpression);
197 }
198
199 private ColumnSegment createColumnSegment(final ColumnSegment originalColumn, final String columnName) {
200 ColumnSegment result = new ColumnSegment(originalColumn.getStartIndex(), originalColumn.getStopIndex(), new IdentifierValue(columnName, originalColumn.getIdentifier().getQuoteCharacter()));
201 result.setNestedObjectAttributes(originalColumn.getNestedObjectAttributes());
202 originalColumn.getOwner().ifPresent(result::setOwner);
203 result.setColumnBoundedInfo(originalColumn.getColumnBoundedInfo());
204 result.setOtherUsingColumnBoundedInfo(originalColumn.getOtherUsingColumnBoundedInfo());
205 result.setVariable(originalColumn.isVariable());
206 return result;
207 }
208
209 private int getParameterIndexCount(final InsertValue insertValueToken) {
210 int result = 0;
211 for (ExpressionSegment each : insertValueToken.getValues()) {
212 if (each instanceof ParameterMarkerExpressionSegment) {
213 result++;
214 }
215 }
216 return result;
217 }
218 }