View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.describe;
19  
20  import lombok.RequiredArgsConstructor;
21  import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
22  import org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
23  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLColumnDescription;
24  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLNoDataPacket;
25  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.PostgreSQLRowDescriptionPacket;
26  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.PostgreSQLColumnType;
27  import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.describe.PostgreSQLComDescribePacket;
28  import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
29  import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
30  import org.apache.shardingsphere.infra.connection.kernel.KernelProcessor;
31  import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
32  import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
33  import org.apache.shardingsphere.infra.exception.generic.UnsupportedSQLOperationException;
34  import org.apache.shardingsphere.infra.exception.postgresql.exception.metadata.ColumnNotFoundException;
35  import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
36  import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit;
37  import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
38  import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
39  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
40  import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;
41  import org.apache.shardingsphere.infra.session.query.QueryContext;
42  import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
43  import org.apache.shardingsphere.proxy.backend.connector.ProxyDatabaseConnectionManager;
44  import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
45  import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
46  import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
47  import org.apache.shardingsphere.proxy.frontend.postgresql.command.PortalContext;
48  import org.apache.shardingsphere.proxy.frontend.postgresql.command.query.extended.PostgreSQLServerPreparedStatement;
49  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.ReturningSegment;
50  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
51  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
52  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
53  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
54  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
55  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment;
56  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
57  import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment;
58  import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
59  import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.InsertStatementHandler;
60  
61  import java.sql.Connection;
62  import java.sql.ParameterMetaData;
63  import java.sql.PreparedStatement;
64  import java.sql.ResultSetMetaData;
65  import java.sql.SQLException;
66  import java.sql.Types;
67  import java.util.ArrayList;
68  import java.util.Collection;
69  import java.util.Collections;
70  import java.util.HashSet;
71  import java.util.LinkedList;
72  import java.util.List;
73  import java.util.ListIterator;
74  import java.util.Optional;
75  import java.util.stream.Collectors;
76  
77  /**
78   * Command describe for PostgreSQL.
79   */
80  @RequiredArgsConstructor
81  public final class PostgreSQLComDescribeExecutor implements CommandExecutor {
82      
83      private static final String ANONYMOUS_COLUMN_NAME = "?column?";
84      
85      private final PortalContext portalContext;
86      
87      private final PostgreSQLComDescribePacket packet;
88      
89      private final ConnectionSession connectionSession;
90      
91      @Override
92      public Collection<DatabasePacket> execute() throws SQLException {
93          switch (packet.getType()) {
94              case 'S':
95                  return describePreparedStatement();
96              case 'P':
97                  return Collections.singleton(portalContext.get(packet.getName()).describe());
98              default:
99                  throw new UnsupportedSQLOperationException("Unsupported describe type: " + packet.getType());
100         }
101     }
102     
103     private List<DatabasePacket> describePreparedStatement() throws SQLException {
104         List<DatabasePacket> result = new ArrayList<>(2);
105         PostgreSQLServerPreparedStatement preparedStatement = connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(packet.getName());
106         result.add(preparedStatement.describeParameters());
107         Optional<PostgreSQLPacket> rowDescription = preparedStatement.describeRows();
108         if (rowDescription.isPresent()) {
109             result.add(rowDescription.get());
110         } else {
111             tryDescribePreparedStatement(preparedStatement);
112             preparedStatement.describeRows().ifPresent(result::add);
113         }
114         return result;
115     }
116     
117     private void tryDescribePreparedStatement(final PostgreSQLServerPreparedStatement preparedStatement) throws SQLException {
118         if (preparedStatement.getSqlStatementContext().getSqlStatement() instanceof InsertStatement) {
119             describeInsertStatementByDatabaseMetaData(preparedStatement);
120         } else {
121             tryDescribePreparedStatementByJDBC(preparedStatement);
122         }
123     }
124     
125     private void describeInsertStatementByDatabaseMetaData(final PostgreSQLServerPreparedStatement preparedStatement) {
126         InsertStatement insertStatement = (InsertStatement) preparedStatement.getSqlStatementContext().getSqlStatement();
127         Collection<Integer> unspecifiedTypeParameterIndexes = getUnspecifiedTypeParameterIndexes(preparedStatement);
128         Optional<ReturningSegment> returningSegment = InsertStatementHandler.getReturningSegment(insertStatement);
129         if (insertStatement.getParameterMarkerSegments().isEmpty() && unspecifiedTypeParameterIndexes.isEmpty() && !returningSegment.isPresent()) {
130             return;
131         }
132         String logicTableName = insertStatement.getTable().getTableName().getIdentifier().getValue();
133         ShardingSphereTable table = getTableFromMetaData(connectionSession.getDatabaseName(), insertStatement, logicTableName);
134         List<String> columnNamesOfInsert = getColumnNamesOfInsertStatement(insertStatement, table);
135         preparedStatement.setRowDescription(returningSegment.<PostgreSQLPacket>map(returning -> describeReturning(returning, table))
136                 .orElseGet(PostgreSQLNoDataPacket::getInstance));
137         int parameterMarkerIndex = 0;
138         for (InsertValuesSegment each : insertStatement.getValues()) {
139             ListIterator<ExpressionSegment> listIterator = each.getValues().listIterator();
140             for (int columnIndex = listIterator.nextIndex(); listIterator.hasNext(); columnIndex = listIterator.nextIndex()) {
141                 ExpressionSegment value = listIterator.next();
142                 if (!(value instanceof ParameterMarkerExpressionSegment)) {
143                     continue;
144                 }
145                 if (!unspecifiedTypeParameterIndexes.contains(parameterMarkerIndex)) {
146                     parameterMarkerIndex++;
147                     continue;
148                 }
149                 String columnName = columnNamesOfInsert.get(columnIndex);
150                 ShardingSphereColumn column = table.getColumn(columnName);
151                 ShardingSpherePreconditions.checkNotNull(column, () -> new ColumnNotFoundException(logicTableName, columnName));
152                 preparedStatement.getParameterTypes().set(parameterMarkerIndex++, PostgreSQLColumnType.valueOfJDBCType(column.getDataType()));
153             }
154         }
155     }
156     
157     private Collection<Integer> getUnspecifiedTypeParameterIndexes(final PostgreSQLServerPreparedStatement preparedStatement) {
158         Collection<Integer> result = new HashSet<>();
159         ListIterator<PostgreSQLColumnType> parameterTypesListIterator = preparedStatement.getParameterTypes().listIterator();
160         for (int index = parameterTypesListIterator.nextIndex(); parameterTypesListIterator.hasNext(); index = parameterTypesListIterator.nextIndex()) {
161             if (PostgreSQLColumnType.UNSPECIFIED == parameterTypesListIterator.next()) {
162                 result.add(index);
163             }
164         }
165         return result;
166     }
167     
168     private ShardingSphereTable getTableFromMetaData(final String databaseName, final InsertStatement insertStatement, final String logicTableName) {
169         ShardingSphereDatabase database = ProxyContext.getInstance().getContextManager().getDatabase(databaseName);
170         String schemaName = insertStatement.getTable().getOwner().map(optional -> optional.getIdentifier()
171                 .getValue()).orElseGet(() -> new DatabaseTypeRegistry(database.getProtocolType()).getDefaultSchemaName(databaseName));
172         return database.getSchema(schemaName).getTable(logicTableName);
173     }
174     
175     private List<String> getColumnNamesOfInsertStatement(final InsertStatement insertStatement, final ShardingSphereTable table) {
176         return insertStatement.getColumns().isEmpty() ? table.getColumnNames() : insertStatement.getColumns().stream().map(each -> each.getIdentifier().getValue()).collect(Collectors.toList());
177     }
178     
179     private PostgreSQLRowDescriptionPacket describeReturning(final ReturningSegment returningSegment, final ShardingSphereTable table) {
180         Collection<PostgreSQLColumnDescription> result = new LinkedList<>();
181         for (ProjectionSegment each : returningSegment.getProjections().getProjections()) {
182             if (each instanceof ShorthandProjectionSegment) {
183                 table.getColumnValues().stream().map(column -> new PostgreSQLColumnDescription(column.getName(), 0, column.getDataType(), estimateColumnLength(column.getDataType()), ""))
184                         .forEach(result::add);
185             }
186             if (each instanceof ColumnProjectionSegment) {
187                 ColumnProjectionSegment segment = (ColumnProjectionSegment) each;
188                 String columnName = segment.getColumn().getIdentifier().getValue();
189                 ShardingSphereColumn column = table.containsColumn(columnName) ? table.getColumn(columnName) : generateDefaultColumn(segment);
190                 String alias = segment.getAliasName().orElseGet(column::getName);
191                 result.add(new PostgreSQLColumnDescription(alias, 0, column.getDataType(), estimateColumnLength(column.getDataType()), ""));
192             }
193             if (each instanceof ExpressionProjectionSegment) {
194                 result.add(convertExpressionToDescription((ExpressionProjectionSegment) each));
195             }
196         }
197         return new PostgreSQLRowDescriptionPacket(result);
198     }
199     
200     private ShardingSphereColumn generateDefaultColumn(final ColumnProjectionSegment segment) {
201         return new ShardingSphereColumn(segment.getColumn().getIdentifier().getValue(), Types.VARCHAR, false, false, false, true, false, false);
202     }
203     
204     private PostgreSQLColumnDescription convertExpressionToDescription(final ExpressionProjectionSegment expressionProjectionSegment) {
205         ExpressionSegment expressionSegment = expressionProjectionSegment.getExpr();
206         String columnName = expressionProjectionSegment.getAliasName().orElse(ANONYMOUS_COLUMN_NAME);
207         if (expressionSegment instanceof LiteralExpressionSegment) {
208             Object value = ((LiteralExpressionSegment) expressionSegment).getLiterals();
209             if (value instanceof String) {
210                 return new PostgreSQLColumnDescription(columnName, 0, Types.VARCHAR, estimateColumnLength(Types.VARCHAR), "");
211             }
212             if (value instanceof Integer) {
213                 return new PostgreSQLColumnDescription(columnName, 0, Types.INTEGER, estimateColumnLength(Types.INTEGER), "");
214             }
215             if (value instanceof Long) {
216                 return new PostgreSQLColumnDescription(columnName, 0, Types.BIGINT, estimateColumnLength(Types.BIGINT), "");
217             }
218             if (value instanceof Number) {
219                 return new PostgreSQLColumnDescription(columnName, 0, Types.NUMERIC, estimateColumnLength(Types.NUMERIC), "");
220             }
221         }
222         return new PostgreSQLColumnDescription(columnName, 0, Types.VARCHAR, estimateColumnLength(Types.VARCHAR), "");
223     }
224     
225     private int estimateColumnLength(final int jdbcType) {
226         switch (jdbcType) {
227             case Types.SMALLINT:
228                 return 2;
229             case Types.INTEGER:
230                 return 4;
231             case Types.BIGINT:
232                 return 8;
233             default:
234                 return -1;
235         }
236     }
237     
238     private void tryDescribePreparedStatementByJDBC(final PostgreSQLServerPreparedStatement logicPreparedStatement) throws SQLException {
239         MetaDataContexts metaDataContexts = ProxyContext.getInstance().getContextManager().getMetaDataContexts();
240         String databaseName = connectionSession.getDatabaseName();
241         SQLStatementContext sqlStatementContext = new SQLBindEngine(metaDataContexts.getMetaData(), databaseName, logicPreparedStatement.getHintValueContext())
242                 .bind(logicPreparedStatement.getSqlStatementContext().getSqlStatement(), Collections.emptyList());
243         QueryContext queryContext = new QueryContext(sqlStatementContext, logicPreparedStatement.getSql(), Collections.emptyList(), logicPreparedStatement.getHintValueContext());
244         ShardingSphereDatabase database = ProxyContext.getInstance().getContextManager().getDatabase(databaseName);
245         ExecutionContext executionContext = new KernelProcessor().generateExecutionContext(
246                 queryContext, database, metaDataContexts.getMetaData().getGlobalRuleMetaData(), metaDataContexts.getMetaData().getProps(), connectionSession.getConnectionContext());
247         ExecutionUnit executionUnitSample = executionContext.getExecutionUnits().iterator().next();
248         ProxyDatabaseConnectionManager databaseConnectionManager = connectionSession.getDatabaseConnectionManager();
249         Connection connection = databaseConnectionManager.getConnections(executionUnitSample.getDataSourceName(), 0, 1, ConnectionMode.CONNECTION_STRICTLY).iterator().next();
250         try (PreparedStatement actualPreparedStatement = connection.prepareStatement(executionUnitSample.getSqlUnit().getSql())) {
251             populateParameterTypes(logicPreparedStatement, actualPreparedStatement);
252             populateColumnTypes(logicPreparedStatement, actualPreparedStatement);
253         }
254     }
255     
256     private void populateParameterTypes(final PostgreSQLServerPreparedStatement logicPreparedStatement, final PreparedStatement actualPreparedStatement) throws SQLException {
257         if (0 == logicPreparedStatement.getSqlStatementContext().getSqlStatement().getParameterCount()
258                 || logicPreparedStatement.getParameterTypes().stream().noneMatch(each -> PostgreSQLColumnType.UNSPECIFIED == each)) {
259             return;
260         }
261         ParameterMetaData parameterMetaData = actualPreparedStatement.getParameterMetaData();
262         for (int i = 0; i < logicPreparedStatement.getSqlStatementContext().getSqlStatement().getParameterCount(); i++) {
263             if (PostgreSQLColumnType.UNSPECIFIED == logicPreparedStatement.getParameterTypes().get(i)) {
264                 logicPreparedStatement.getParameterTypes().set(i, PostgreSQLColumnType.valueOfJDBCType(parameterMetaData.getParameterType(i + 1), parameterMetaData.getParameterTypeName(i + 1)));
265             }
266         }
267     }
268     
269     private void populateColumnTypes(final PostgreSQLServerPreparedStatement logicPreparedStatement, final PreparedStatement actualPreparedStatement) throws SQLException {
270         if (logicPreparedStatement.describeRows().isPresent()) {
271             return;
272         }
273         ResultSetMetaData resultSetMetaData = actualPreparedStatement.getMetaData();
274         if (null == resultSetMetaData) {
275             logicPreparedStatement.setRowDescription(PostgreSQLNoDataPacket.getInstance());
276             return;
277         }
278         List<PostgreSQLColumnDescription> columnDescriptions = new ArrayList<>(resultSetMetaData.getColumnCount());
279         for (int columnIndex = 1; columnIndex <= resultSetMetaData.getColumnCount(); columnIndex++) {
280             String columnName = resultSetMetaData.getColumnName(columnIndex);
281             int columnType = resultSetMetaData.getColumnType(columnIndex);
282             int columnLength = resultSetMetaData.getColumnDisplaySize(columnIndex);
283             String columnTypeName = resultSetMetaData.getColumnTypeName(columnIndex);
284             columnDescriptions.add(new PostgreSQLColumnDescription(columnName, columnIndex, columnType, columnLength, columnTypeName));
285         }
286         logicPreparedStatement.setRowDescription(new PostgreSQLRowDescriptionPacket(columnDescriptions));
287     }
288 }