View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.describe;
19  
20  import lombok.RequiredArgsConstructor;
21  import org.apache.shardingsphere.database.connector.core.type.DatabaseTypeRegistry;
22  import org.apache.shardingsphere.database.exception.core.exception.syntax.column.ColumnNotFoundException;
23  import org.apache.shardingsphere.database.protocol.packet.DatabasePacket;
24  import org.apache.shardingsphere.database.protocol.postgresql.packet.PostgreSQLPacket;
25  import org.apache.shardingsphere.database.protocol.postgresql.packet.command.query.PostgreSQLColumnDescription;
26  import org.apache.shardingsphere.database.protocol.postgresql.packet.command.query.PostgreSQLNoDataPacket;
27  import org.apache.shardingsphere.database.protocol.postgresql.packet.command.query.PostgreSQLRowDescriptionPacket;
28  import org.apache.shardingsphere.database.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
29  import org.apache.shardingsphere.database.protocol.postgresql.packet.command.query.extended.describe.PostgreSQLComDescribePacket;
30  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
31  import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
32  import org.apache.shardingsphere.infra.connection.kernel.KernelProcessor;
33  import org.apache.shardingsphere.infra.exception.ShardingSpherePreconditions;
34  import org.apache.shardingsphere.infra.exception.generic.UnsupportedSQLOperationException;
35  import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
36  import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit;
37  import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
38  import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
39  import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
40  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
41  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;
42  import org.apache.shardingsphere.infra.metadata.identifier.ShardingSphereIdentifier;
43  import org.apache.shardingsphere.infra.session.query.QueryContext;
44  import org.apache.shardingsphere.proxy.backend.connector.ProxyDatabaseConnectionManager;
45  import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
46  import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
47  import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
48  import org.apache.shardingsphere.proxy.frontend.postgresql.command.PortalContext;
49  import org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.PostgreSQLServerPreparedStatement;
50  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.ReturningSegment;
51  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.InsertValuesSegment;
52  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
53  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment;
54  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
55  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ColumnProjectionSegment;
56  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ExpressionProjectionSegment;
57  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionSegment;
58  import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ShorthandProjectionSegment;
59  import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
60  import org.apache.shardingsphere.sql.parser.statement.core.statement.type.dml.InsertStatement;
61  
62  import java.sql.Connection;
63  import java.sql.ParameterMetaData;
64  import java.sql.PreparedStatement;
65  import java.sql.ResultSetMetaData;
66  import java.sql.SQLException;
67  import java.sql.Types;
68  import java.util.ArrayList;
69  import java.util.Collection;
70  import java.util.Collections;
71  import java.util.HashSet;
72  import java.util.LinkedList;
73  import java.util.List;
74  import java.util.ListIterator;
75  import java.util.Optional;
76  import java.util.stream.Collectors;
77  
78  /**
79   * Command describe for PostgreSQL.
80   */
81  @RequiredArgsConstructor
82  public final class PostgreSQLComDescribeExecutor implements CommandExecutor {
83      
84      private static final String ANONYMOUS_COLUMN_NAME = "?column?";
85      
86      private final PortalContext portalContext;
87      
88      private final PostgreSQLComDescribePacket packet;
89      
90      private final ConnectionSession connectionSession;
91      
92      @Override
93      public Collection<DatabasePacket> execute() throws SQLException {
94          switch (packet.getType()) {
95              case 'S':
96                  return describePreparedStatement();
97              case 'P':
98                  return Collections.singleton(portalContext.get(packet.getName()).describe());
99              default:
100                 throw new UnsupportedSQLOperationException("Unsupported describe type: " + packet.getType());
101         }
102     }
103     
104     private List<DatabasePacket> describePreparedStatement() throws SQLException {
105         List<DatabasePacket> result = new ArrayList<>(2);
106         PostgreSQLServerPreparedStatement preparedStatement = connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(packet.getName());
107         result.add(preparedStatement.describeParameters());
108         Optional<PostgreSQLPacket> rowDescription = preparedStatement.describeRows();
109         if (rowDescription.isPresent()) {
110             result.add(rowDescription.get());
111         } else {
112             tryDescribePreparedStatement(preparedStatement);
113             preparedStatement.describeRows().ifPresent(result::add);
114         }
115         return result;
116     }
117     
118     private void tryDescribePreparedStatement(final PostgreSQLServerPreparedStatement preparedStatement) throws SQLException {
119         if (preparedStatement.getSqlStatementContext().getSqlStatement() instanceof InsertStatement) {
120             describeInsertStatementByDatabaseMetaData(preparedStatement);
121         } else {
122             tryDescribePreparedStatementByJDBC(preparedStatement);
123         }
124     }
125     
126     private void describeInsertStatementByDatabaseMetaData(final PostgreSQLServerPreparedStatement preparedStatement) {
127         InsertStatement insertStatement = (InsertStatement) preparedStatement.getSqlStatementContext().getSqlStatement();
128         Collection<Integer> unspecifiedTypeParameterIndexes = getUnspecifiedTypeParameterIndexes(preparedStatement);
129         Optional<ReturningSegment> returningSegment = insertStatement.getReturning();
130         if (insertStatement.getParameterMarkers().isEmpty() && unspecifiedTypeParameterIndexes.isEmpty() && !returningSegment.isPresent()) {
131             return;
132         }
133         String logicTableName = insertStatement.getTable().map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
134         ShardingSphereTable table = getTableFromMetaData(connectionSession.getUsedDatabaseName(), insertStatement, logicTableName);
135         List<ShardingSphereIdentifier> columnNamesOfInsert = getColumnNamesOfInsertStatement(insertStatement, table);
136         preparedStatement.setRowDescription(returningSegment.<PostgreSQLPacket>map(returning -> describeReturning(returning, table)).orElseGet(PostgreSQLNoDataPacket::getInstance));
137         int parameterMarkerIndex = 0;
138         for (InsertValuesSegment each : insertStatement.getValues()) {
139             for (int i = 0; i < each.getValues().size(); i++) {
140                 ExpressionSegment value = each.getValues().get(i);
141                 if (!(value instanceof ParameterMarkerExpressionSegment)) {
142                     continue;
143                 }
144                 if (!unspecifiedTypeParameterIndexes.contains(parameterMarkerIndex)) {
145                     parameterMarkerIndex++;
146                     continue;
147                 }
148                 String columnName = columnNamesOfInsert.get(i).toString();
149                 ShardingSpherePreconditions.checkState(table.containsColumn(columnName), () -> new ColumnNotFoundException(logicTableName, columnName));
150                 preparedStatement.getParameterTypes().set(parameterMarkerIndex++, PostgreSQLColumnType.valueOfJDBCType(table.getColumn(columnName).getDataType()));
151             }
152         }
153     }
154     
155     private Collection<Integer> getUnspecifiedTypeParameterIndexes(final PostgreSQLServerPreparedStatement preparedStatement) {
156         Collection<Integer> result = new HashSet<>();
157         ListIterator<PostgreSQLColumnType> parameterTypesListIterator = preparedStatement.getParameterTypes().listIterator();
158         for (int index = parameterTypesListIterator.nextIndex(); parameterTypesListIterator.hasNext(); index = parameterTypesListIterator.nextIndex()) {
159             if (PostgreSQLColumnType.UNSPECIFIED == parameterTypesListIterator.next()) {
160                 result.add(index);
161             }
162         }
163         return result;
164     }
165     
166     private ShardingSphereTable getTableFromMetaData(final String databaseName, final InsertStatement insertStatement, final String logicTableName) {
167         ShardingSphereDatabase database = ProxyContext.getInstance().getContextManager().getDatabase(databaseName);
168         String schemaName = insertStatement.getTable().flatMap(SimpleTableSegment::getOwner).map(optional -> optional.getIdentifier()
169                 .getValue()).orElseGet(() -> new DatabaseTypeRegistry(database.getProtocolType()).getDefaultSchemaName(databaseName));
170         return database.getSchema(schemaName).getTable(logicTableName);
171     }
172     
173     private List<ShardingSphereIdentifier> getColumnNamesOfInsertStatement(final InsertStatement insertStatement, final ShardingSphereTable table) {
174         return insertStatement.getColumns().isEmpty()
175                 ? table.getColumnNames()
176                 : insertStatement.getColumns().stream().map(each -> new ShardingSphereIdentifier(each.getIdentifier().getValue())).collect(Collectors.toList());
177     }
178     
179     private PostgreSQLRowDescriptionPacket describeReturning(final ReturningSegment returningSegment, final ShardingSphereTable table) {
180         Collection<PostgreSQLColumnDescription> result = new LinkedList<>();
181         for (ProjectionSegment each : returningSegment.getProjections().getProjections()) {
182             if (each instanceof ShorthandProjectionSegment) {
183                 table.getAllColumns().stream()
184                         .map(column -> new PostgreSQLColumnDescription(column.getName(), 0, column.getDataType(), estimateColumnLength(column.getDataType()), "")).forEach(result::add);
185             }
186             if (each instanceof ColumnProjectionSegment) {
187                 ColumnProjectionSegment segment = (ColumnProjectionSegment) each;
188                 String columnName = segment.getColumn().getIdentifier().getValue();
189                 ShardingSphereColumn column = table.containsColumn(columnName) ? table.getColumn(columnName) : generateDefaultColumn(segment);
190                 String alias = segment.getAliasName().orElseGet(column::getName);
191                 result.add(new PostgreSQLColumnDescription(alias, 0, column.getDataType(), estimateColumnLength(column.getDataType()), ""));
192             }
193             if (each instanceof ExpressionProjectionSegment) {
194                 result.add(convertExpressionToDescription((ExpressionProjectionSegment) each));
195             }
196         }
197         return new PostgreSQLRowDescriptionPacket(result);
198     }
199     
200     private ShardingSphereColumn generateDefaultColumn(final ColumnProjectionSegment segment) {
201         return new ShardingSphereColumn(segment.getColumn().getIdentifier().getValue(), Types.VARCHAR, false, false, false, true, false, false);
202     }
203     
204     private PostgreSQLColumnDescription convertExpressionToDescription(final ExpressionProjectionSegment expressionProjectionSegment) {
205         ExpressionSegment expressionSegment = expressionProjectionSegment.getExpr();
206         String columnName = expressionProjectionSegment.getAliasName().orElse(ANONYMOUS_COLUMN_NAME);
207         if (expressionSegment instanceof LiteralExpressionSegment) {
208             Object value = ((LiteralExpressionSegment) expressionSegment).getLiterals();
209             if (value instanceof String) {
210                 return new PostgreSQLColumnDescription(columnName, 0, Types.VARCHAR, estimateColumnLength(Types.VARCHAR), "");
211             }
212             if (value instanceof Integer) {
213                 return new PostgreSQLColumnDescription(columnName, 0, Types.INTEGER, estimateColumnLength(Types.INTEGER), "");
214             }
215             if (value instanceof Long) {
216                 return new PostgreSQLColumnDescription(columnName, 0, Types.BIGINT, estimateColumnLength(Types.BIGINT), "");
217             }
218             if (value instanceof Number) {
219                 return new PostgreSQLColumnDescription(columnName, 0, Types.NUMERIC, estimateColumnLength(Types.NUMERIC), "");
220             }
221         }
222         return new PostgreSQLColumnDescription(columnName, 0, Types.VARCHAR, estimateColumnLength(Types.VARCHAR), "");
223     }
224     
225     private int estimateColumnLength(final int jdbcType) {
226         switch (jdbcType) {
227             case Types.SMALLINT:
228                 return 2;
229             case Types.INTEGER:
230                 return 4;
231             case Types.BIGINT:
232                 return 8;
233             default:
234                 return -1;
235         }
236     }
237     
238     private void tryDescribePreparedStatementByJDBC(final PostgreSQLServerPreparedStatement logicPreparedStatement) throws SQLException {
239         ShardingSphereMetaData metaData = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData();
240         SQLStatementContext sqlStatementContext = new SQLBindEngine(metaData, connectionSession.getCurrentDatabaseName(), logicPreparedStatement.getHintValueContext())
241                 .bind(logicPreparedStatement.getSqlStatementContext().getSqlStatement());
242         QueryContext queryContext = new QueryContext(sqlStatementContext, logicPreparedStatement.getSql(), Collections.emptyList(), logicPreparedStatement.getHintValueContext(),
243                 connectionSession.getConnectionContext(), metaData);
244         ExecutionContext executionContext =
245                 new KernelProcessor().generateExecutionContext(queryContext, metaData.getGlobalRuleMetaData(), metaData.getProps());
246         ExecutionUnit executionUnitSample = executionContext.getExecutionUnits().iterator().next();
247         ProxyDatabaseConnectionManager databaseConnectionManager = connectionSession.getDatabaseConnectionManager();
248         Connection connection = databaseConnectionManager.getConnections(
249                 connectionSession.getUsedDatabaseName(), executionUnitSample.getDataSourceName(), 0, 1, ConnectionMode.CONNECTION_STRICTLY).iterator().next();
250         try (PreparedStatement actualPreparedStatement = connection.prepareStatement(executionUnitSample.getSqlUnit().getSql())) {
251             populateParameterTypes(logicPreparedStatement, actualPreparedStatement);
252             populateColumnTypes(logicPreparedStatement, actualPreparedStatement);
253         }
254     }
255     
256     private void populateParameterTypes(final PostgreSQLServerPreparedStatement logicPreparedStatement, final PreparedStatement actualPreparedStatement) throws SQLException {
257         if (0 == logicPreparedStatement.getSqlStatementContext().getSqlStatement().getParameterCount()
258                 || logicPreparedStatement.getParameterTypes().stream().noneMatch(each -> PostgreSQLColumnType.UNSPECIFIED == each)) {
259             return;
260         }
261         ParameterMetaData parameterMetaData = actualPreparedStatement.getParameterMetaData();
262         for (int i = 0; i < logicPreparedStatement.getSqlStatementContext().getSqlStatement().getParameterCount(); i++) {
263             if (PostgreSQLColumnType.UNSPECIFIED == logicPreparedStatement.getParameterTypes().get(i)) {
264                 logicPreparedStatement.getParameterTypes().set(i, PostgreSQLColumnType.valueOfJDBCType(parameterMetaData.getParameterType(i + 1), parameterMetaData.getParameterTypeName(i + 1)));
265             }
266         }
267     }
268     
269     private void populateColumnTypes(final PostgreSQLServerPreparedStatement logicPreparedStatement, final PreparedStatement actualPreparedStatement) throws SQLException {
270         if (logicPreparedStatement.describeRows().isPresent()) {
271             return;
272         }
273         ResultSetMetaData resultSetMetaData = actualPreparedStatement.getMetaData();
274         if (null == resultSetMetaData) {
275             logicPreparedStatement.setRowDescription(PostgreSQLNoDataPacket.getInstance());
276             return;
277         }
278         List<PostgreSQLColumnDescription> columnDescriptions = new ArrayList<>(resultSetMetaData.getColumnCount());
279         for (int columnIndex = 1; columnIndex <= resultSetMetaData.getColumnCount(); columnIndex++) {
280             String columnName = resultSetMetaData.getColumnName(columnIndex);
281             int columnType = resultSetMetaData.getColumnType(columnIndex);
282             int columnLength = resultSetMetaData.getColumnDisplaySize(columnIndex);
283             String columnTypeName = resultSetMetaData.getColumnTypeName(columnIndex);
284             columnDescriptions.add(new PostgreSQLColumnDescription(columnName, columnIndex, columnType, columnLength, columnTypeName));
285         }
286         logicPreparedStatement.setRowDescription(new PostgreSQLRowDescriptionPacket(columnDescriptions));
287     }
288 }