1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.encrypt.rewrite.token.generator.insert;
19
20 import com.google.common.base.Preconditions;
21 import lombok.RequiredArgsConstructor;
22 import lombok.Setter;
23 import org.apache.shardingsphere.encrypt.rewrite.token.pojo.EncryptInsertValuesToken;
24 import org.apache.shardingsphere.encrypt.rule.EncryptRule;
25 import org.apache.shardingsphere.encrypt.rule.column.EncryptColumn;
26 import org.apache.shardingsphere.encrypt.rule.column.item.AssistedQueryColumnItem;
27 import org.apache.shardingsphere.encrypt.rule.column.item.LikeQueryColumnItem;
28 import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
29 import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
30 import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedLiteralExpressionSegment;
31 import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedParameterMarkerExpressionSegment;
32 import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
33 import org.apache.shardingsphere.infra.binder.context.statement.type.dml.InsertStatementContext;
34 import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
35 import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
36 import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.OptionalSQLTokenGenerator;
37 import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.PreviousSQLTokensAware;
38 import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
39 import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.InsertValue;
40 import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.InsertValuesToken;
41 import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.UseDefaultInsertColumnsToken;
42 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.InsertValuesSegment;
43 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
44 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
45 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment;
46 import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
47 import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
48
49 import java.util.Collection;
50 import java.util.Iterator;
51 import java.util.LinkedList;
52 import java.util.List;
53 import java.util.Optional;
54
55
56
57
58 @RequiredArgsConstructor
59 @Setter
60 public final class EncryptInsertValuesTokenGenerator implements OptionalSQLTokenGenerator<InsertStatementContext>, PreviousSQLTokensAware {
61
62 private final EncryptRule rule;
63
64 private final ShardingSphereDatabase database;
65
66 private List<SQLToken> previousSQLTokens;
67
68 @Override
69 public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
70 return sqlStatementContext instanceof InsertStatementContext && !(((InsertStatementContext) sqlStatementContext).getSqlStatement()).getValues().isEmpty();
71 }
72
73 @Override
74 public InsertValuesToken generateSQLToken(final InsertStatementContext insertStatementContext) {
75 Optional<SQLToken> insertValuesToken = findPreviousSQLToken(InsertValuesToken.class);
76 if (insertValuesToken.isPresent()) {
77 processPreviousSQLToken(insertStatementContext, (InsertValuesToken) insertValuesToken.get());
78 return (InsertValuesToken) insertValuesToken.get();
79 }
80 return generateNewSQLToken(insertStatementContext);
81 }
82
83 private Optional<SQLToken> findPreviousSQLToken(final Class<?> sqlToken) {
84 for (SQLToken each : previousSQLTokens) {
85 if (sqlToken.isAssignableFrom(each.getClass())) {
86 return Optional.of(each);
87 }
88 }
89 return Optional.empty();
90 }
91
92 private void processPreviousSQLToken(final InsertStatementContext insertStatementContext, final InsertValuesToken insertValuesToken) {
93 String tableName = insertStatementContext.getSqlStatement().getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
94 EncryptTable encryptTable = rule.getEncryptTable(tableName);
95 int count = 0;
96 String schemaName = insertStatementContext.getTablesContext().getSchemaName()
97 .orElseGet(() -> new DatabaseTypeRegistry(insertStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName()));
98 for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
99 encryptToken(insertValuesToken.getInsertValues().get(count), schemaName, encryptTable, insertStatementContext, each);
100 count++;
101 }
102 }
103
104 private InsertValuesToken generateNewSQLToken(final InsertStatementContext insertStatementContext) {
105 String tableName = insertStatementContext.getSqlStatement().getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
106 Collection<InsertValuesSegment> insertValuesSegments = insertStatementContext.getSqlStatement().getValues();
107 InsertValuesToken result = new EncryptInsertValuesToken(getStartIndex(insertValuesSegments), getStopIndex(insertValuesSegments));
108 EncryptTable encryptTable = rule.getEncryptTable(tableName);
109 String schemaName = insertStatementContext.getTablesContext().getSchemaName()
110 .orElseGet(() -> new DatabaseTypeRegistry(insertStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName()));
111 for (InsertValueContext each : insertStatementContext.getInsertValueContexts()) {
112 InsertValue insertValueToken = new InsertValue(new LinkedList<>(each.getValueExpressions()));
113 encryptToken(insertValueToken, schemaName, encryptTable, insertStatementContext, each);
114 result.getInsertValues().add(insertValueToken);
115 }
116 return result;
117 }
118
119 private int getStartIndex(final Collection<InsertValuesSegment> segments) {
120 int result = segments.iterator().next().getStartIndex();
121 for (InsertValuesSegment each : segments) {
122 result = Math.min(result, each.getStartIndex());
123 }
124 return result;
125 }
126
127 private int getStopIndex(final Collection<InsertValuesSegment> segments) {
128 int result = segments.iterator().next().getStopIndex();
129 for (InsertValuesSegment each : segments) {
130 result = Math.max(result, each.getStopIndex());
131 }
132 return result;
133 }
134
135 private void encryptToken(final InsertValue insertValueToken, final String schemaName, final EncryptTable encryptTable,
136 final InsertStatementContext insertStatementContext, final InsertValueContext insertValueContext) {
137 String tableName = encryptTable.getTable();
138 Optional<SQLToken> useDefaultInsertColumnsToken = findPreviousSQLToken(UseDefaultInsertColumnsToken.class);
139 Iterator<String> descendingColumnNames = insertStatementContext.getDescendingColumnNames();
140 while (descendingColumnNames.hasNext()) {
141 String columnName = descendingColumnNames.next();
142 if (!encryptTable.isEncryptColumn(columnName)) {
143 continue;
144 }
145 EncryptColumn encryptColumn = rule.getEncryptTable(tableName).getEncryptColumn(columnName);
146 int columnIndex = useDefaultInsertColumnsToken
147 .map(optional -> ((UseDefaultInsertColumnsToken) optional).getColumns().indexOf(columnName)).orElseGet(() -> insertStatementContext.getColumnNames().indexOf(columnName));
148 Object originalValue = insertValueContext.getLiteralValue(columnIndex).orElse(null);
149 ExpressionSegment valueExpression = insertValueContext.getValueExpressions().get(columnIndex);
150 setCipherColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, originalValue);
151 int indexDelta = 1;
152 if (encryptColumn.getAssistedQuery().isPresent()) {
153 addAssistedQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
154 indexDelta++;
155 }
156 if (encryptColumn.getLikeQuery().isPresent()) {
157 addLikeQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
158 }
159 }
160 }
161
162 private void setCipherColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn,
163 final InsertValue insertValueToken, final ExpressionSegment valueExpression, final int columnIndex, final Object originalValue) {
164 if (valueExpression instanceof LiteralExpressionSegment) {
165 insertValueToken.getValues().set(columnIndex, new LiteralExpressionSegment(
166 valueExpression.getStartIndex(), valueExpression.getStopIndex(),
167 encryptColumn.getCipher().encrypt(database.getName(), schemaName, tableName, encryptColumn.getName(), originalValue)));
168 }
169 }
170
171 private void addAssistedQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
172 final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
173 Optional<AssistedQueryColumnItem> assistedQueryColumnItem = encryptColumn.getAssistedQuery();
174 Preconditions.checkState(assistedQueryColumnItem.isPresent());
175 Object derivedValue = assistedQueryColumnItem.get().encrypt(database.getName(), schemaName, tableName, encryptColumn.getName(), originalValue);
176 addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, assistedQueryColumnItem.get().getName());
177 }
178
179 private void addLikeQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
180 final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
181 Optional<LikeQueryColumnItem> likeQueryColumnItem = encryptColumn.getLikeQuery();
182 Preconditions.checkState(likeQueryColumnItem.isPresent());
183 Object derivedValue = likeQueryColumnItem.get().encrypt(database.getName(), schemaName, tableName, encryptColumn.getName(), originalValue);
184 addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, likeQueryColumnItem.get().getName());
185 }
186
187 private void addDerivedColumn(final InsertValue insertValueToken, final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object derivedValue,
188 final String derivedColumnName) {
189 ExpressionSegment derivedExpression;
190 if (valueExpression instanceof LiteralExpressionSegment) {
191 derivedExpression = new DerivedLiteralExpressionSegment(derivedValue);
192 } else if (valueExpression instanceof ParameterMarkerExpressionSegment) {
193 derivedExpression = new DerivedParameterMarkerExpressionSegment(getParameterIndexCount(insertValueToken));
194 } else if (valueExpression instanceof ColumnSegment) {
195 derivedExpression = createColumnSegment((ColumnSegment) valueExpression, derivedColumnName);
196 } else {
197 derivedExpression = valueExpression;
198 }
199 insertValueToken.getValues().add(columnIndex + indexDelta, derivedExpression);
200 }
201
202 private ColumnSegment createColumnSegment(final ColumnSegment originalColumn, final String columnName) {
203 ColumnSegment result = new ColumnSegment(originalColumn.getStartIndex(), originalColumn.getStopIndex(), new IdentifierValue(columnName, originalColumn.getIdentifier().getQuoteCharacter()));
204 result.setNestedObjectAttributes(originalColumn.getNestedObjectAttributes());
205 originalColumn.getOwner().ifPresent(result::setOwner);
206 result.setColumnBoundInfo(originalColumn.getColumnBoundInfo());
207 result.setOtherUsingColumnBoundInfo(originalColumn.getOtherUsingColumnBoundInfo());
208 result.setVariable(originalColumn.isVariable());
209 return result;
210 }
211
212 private int getParameterIndexCount(final InsertValue insertValueToken) {
213 int result = 0;
214 for (ExpressionSegment each : insertValueToken.getValues()) {
215 if (each instanceof ParameterMarkerExpressionSegment) {
216 result++;
217 }
218 }
219 return result;
220 }
221 }