View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *     http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.shardingsphere.data.pipeline.postgresql.ddlgenerator;
19  
20  import lombok.SneakyThrows;
21  
22  import java.sql.Array;
23  import java.sql.Connection;
24  import java.sql.SQLException;
25  import java.util.Arrays;
26  import java.util.Collection;
27  import java.util.Collections;
28  import java.util.HashSet;
29  import java.util.LinkedHashMap;
30  import java.util.LinkedList;
31  import java.util.Map;
32  import java.util.regex.Matcher;
33  import java.util.regex.Pattern;
34  import java.util.stream.Collectors;
35  
36  /**
37   * Column properties appender for PostgreSQL.
38   */
39  public final class PostgreSQLColumnPropertiesAppender extends AbstractPostgreSQLDDLAdapter {
40      
41      private static final Collection<String> TIME_TYPE_NAMES = new HashSet<>(Arrays.asList(
42              "time", "timetz", "time without time zone", "time with time zone", "timestamp", "timestamptz", "timestamp without time zone", "timestamp with time zone"));
43      
44      private static final Collection<String> BIT_TYPE_NAMES = new HashSet<>(Arrays.asList("bit", "bit varying", "varbit"));
45      
46      private static final Pattern LENGTH_PRECISION_PATTERN = Pattern.compile("(\\d+),(\\d+)");
47      
48      private static final Pattern LENGTH_PATTERN = Pattern.compile("(\\d+)");
49      
50      private static final Pattern BRACKETS_PATTERN = Pattern.compile("(\\(\\d+\\))");
51      
52      private static final String ATT_OPTION_SPLIT = "=";
53      
54      public PostgreSQLColumnPropertiesAppender(final Connection connection, final int majorVersion, final int minorVersion) {
55          super(connection, majorVersion, minorVersion);
56      }
57      
58      /**
59       * Append column properties.
60       *
61       * @param context create table SQL context
62       */
63      @SneakyThrows(SQLException.class)
64      public void append(final Map<String, Object> context) {
65          Collection<Map<String, Object>> typeAndInheritedColumns = getTypeAndInheritedColumns(context);
66          Collection<Map<String, Object>> allColumns = executeByTemplate(context, "component/columns/%s/properties.ftl");
67          for (Map<String, Object> each : allColumns) {
68              for (Map<String, Object> column : typeAndInheritedColumns) {
69                  if (each.get("name").equals(column.get("name"))) {
70                      each.put(getInheritedFromTableOrType(context), column.get("inheritedfrom"));
71                  }
72              }
73          }
74          if (!allColumns.isEmpty()) {
75              Map<String, Collection<String>> editTypes = getEditTypes(allColumns);
76              for (Map<String, Object> each : allColumns) {
77                  columnFormatter(each, editTypes.getOrDefault(each.get("atttypid").toString(), new LinkedList<>()));
78              }
79          }
80          context.put("columns", allColumns);
81      }
82      
83      private Collection<Map<String, Object>> getTypeAndInheritedColumns(final Map<String, Object> context) throws SQLException {
84          if (null != context.get("typoid")) {
85              return getColumnFromType(context);
86          }
87          if (null == context.get("coll_inherits")) {
88              return Collections.emptyList();
89          }
90          Collection<String> collInherits = toCollection((Array) context.get("coll_inherits"));
91          context.put("coll_inherits", collInherits);
92          return collInherits.isEmpty() ? Collections.emptyList() : getColumnFromInherits(collInherits);
93      }
94      
95      private Collection<Map<String, Object>> getColumnFromType(final Map<String, Object> context) {
96          Map<String, Object> params = new LinkedHashMap<>();
97          params.put("tid", context.get("typoid"));
98          return executeByTemplate(params, "component/table/%s/get_columns_for_table.ftl");
99      }
100     
101     private Collection<String> toCollection(final Array array) throws SQLException {
102         return Arrays.stream((String[]) array.getArray()).collect(Collectors.toList());
103     }
104     
105     private Collection<Map<String, Object>> getColumnFromInherits(final Collection<String> collInherits) {
106         Collection<Map<String, Object>> result = new LinkedList<>();
107         for (Map<String, Object> each : executeByTemplate(new LinkedHashMap<>(), "component/table/%s/get_inherits.ftl")) {
108             if (collInherits.contains((String) each.get("inherits"))) {
109                 Map<String, Object> params = new LinkedHashMap<>();
110                 params.put("tid", each.get("oid"));
111                 result.addAll(executeByTemplate(params, "table/%s/get_columns_for_table.ftl"));
112             }
113         }
114         return result;
115     }
116     
117     @SuppressWarnings("unchecked")
118     private String getInheritedFromTableOrType(final Map<String, Object> context) {
119         String result = "inheritedfrom";
120         if (null != context.get("typoid")) {
121             result += "type";
122         } else if (null != context.get("coll_inherits") && !((Collection<String>) context.get("coll_inherits")).isEmpty()) {
123             result += "table";
124         }
125         return result;
126     }
127     
128     private Map<String, Collection<String>> getEditTypes(final Collection<Map<String, Object>> allColumns) throws SQLException {
129         Map<String, Collection<String>> result = new LinkedHashMap<>();
130         Map<String, Object> params = new LinkedHashMap<>();
131         params.put("type_ids", allColumns.stream().map(each -> each.get("atttypid").toString()).collect(Collectors.joining(",")));
132         for (Map<String, Object> each : executeByTemplate(params, "component/columns/%s/edit_mode_types_multi.ftl")) {
133             result.put(each.get("main_oid").toString(), toCollectionAndSort((Array) each.get("edit_types")));
134         }
135         return result;
136     }
137     
138     private Collection<String> toCollectionAndSort(final Array editTypes) throws SQLException {
139         return Arrays.stream((String[]) editTypes.getArray()).sorted(String::compareTo).collect(Collectors.toList());
140     }
141     
142     private void columnFormatter(final Map<String, Object> column, final Collection<String> editTypes) throws SQLException {
143         handlePrimaryColumn(column);
144         fetchLengthPrecision(column);
145         formatColumnVariables(column);
146         formatSecurityLabels(column);
147         editTypes.add(column.get("cltype").toString());
148         column.put("edit_types", editTypes.stream().sorted().collect(Collectors.toList()));
149         column.put("cltype", parseTypeName(column.get("cltype").toString()));
150     }
151     
152     private void handlePrimaryColumn(final Map<String, Object> column) {
153         if (null == column.get("attnum") || null == column.get("indkey")) {
154             return;
155         }
156         if (Arrays.stream(column.get("indkey").toString().split(" ")).collect(Collectors.toList()).contains(column.get("attnum").toString())) {
157             column.put("is_pk", true);
158             column.put("is_primary_key", true);
159         } else {
160             column.put("is_pk", false);
161             column.put("is_primary_key", false);
162         }
163     }
164     
165     private void fetchLengthPrecision(final Map<String, Object> column) {
166         String fullType = getFullDataType(column);
167         if (column.containsKey("elemoid")) {
168             handleLengthPrecision((Long) column.get("elemoid"), column, fullType);
169         }
170     }
171     
172     private void handleLengthPrecision(final Long elemoid, final Map<String, Object> column, final String fullType) {
173         switch (PostgreSQLColumnType.valueOf(elemoid)) {
174             case NUMERIC:
175                 setColumnPrecision(column, fullType);
176                 break;
177             case DATE:
178             case VARCHAR:
179                 setColumnLength(column, fullType);
180                 break;
181             default:
182                 break;
183         }
184     }
185     
186     private void setColumnPrecision(final Map<String, Object> column, final String fullType) {
187         Matcher matcher = LENGTH_PRECISION_PATTERN.matcher(fullType);
188         if (matcher.find()) {
189             column.put("attlen", matcher.group(1));
190             column.put("attprecision", matcher.group(2));
191         }
192     }
193     
194     private static void setColumnLength(final Map<String, Object> column, final String fullType) {
195         Matcher matcher = LENGTH_PATTERN.matcher(fullType);
196         if (matcher.find()) {
197             column.put("attlen", matcher.group(1));
198             column.put("attprecision", null);
199         }
200     }
201     
202     private String getFullDataType(final Map<String, Object> column) {
203         String namespace = (String) column.get("typnspname");
204         String typeName = (String) column.get("typname");
205         Integer numdims = (Integer) column.get("attndims");
206         String schema = null == namespace ? "" : namespace;
207         String name = checkSchemaInName(typeName, schema);
208         if (name.startsWith("_")) {
209             if (null == numdims || 0 == numdims) {
210                 numdims = 1;
211             }
212             name = name.substring(1);
213         }
214         if (name.endsWith("[]")) {
215             if (null == numdims || 0 == numdims) {
216                 numdims = 1;
217             }
218             name = name.substring(0, name.length() - 2);
219         }
220         if (name.startsWith("\"") && name.endsWith("\"")) {
221             name = name.substring(1, name.length() - 1);
222         }
223         Integer typmod = (Integer) column.get("atttypmod");
224         String length = -1 == typmod ? "" : checkTypmod(typmod, name);
225         return getFullTypeValue(name, schema, length, numdims == 1 ? "[]" : "");
226     }
227     
228     private String checkSchemaInName(final String typname, final String schema) {
229         if (typname.contains(schema + "\".")) {
230             return typname.substring(schema.length() + 3);
231         }
232         if (typname.contains(schema + ".")) {
233             return typname.substring(schema.length() + 1);
234         }
235         return typname;
236     }
237     
238     private String getFullTypeValue(final String name, final String schema, final String length, final String array) {
239         if ("char".equals(name) && "pg_catalog".equals(schema)) {
240             return "\"char\"" + array;
241         }
242         if ("time with time zone".equals(name)) {
243             return "time" + length + " with time zone" + array;
244         }
245         if ("time without time zone".equals(name)) {
246             return "time" + length + " without time zone" + array;
247         }
248         if ("timestamp with time zone".equals(name)) {
249             return "timestamp" + length + " with time zone" + array;
250         }
251         if ("timestamp without time zone".equals(name)) {
252             return "timestamp" + length + " without time zone" + array;
253         }
254         return name + length + array;
255     }
256     
257     private String checkTypmod(final Integer typmod, final String name) {
258         String result = "(";
259         if ("numeric".equals(name)) {
260             int len = (typmod - 4) >> 16;
261             int prec = (typmod - 4) & 0xffff;
262             result += String.valueOf(len);
263             result += "," + prec;
264         } else if (TIME_TYPE_NAMES.contains(name) || BIT_TYPE_NAMES.contains(name)) {
265             int len = typmod;
266             result += String.valueOf(len);
267         } else if ("interval".equals(name)) {
268             int len = typmod & 0xffff;
269             result += len > 6 ? "" : String.valueOf(len);
270         } else if ("date".equals(name)) {
271             result = "";
272         } else {
273             int len = typmod - 4;
274             result += String.valueOf(len);
275         }
276         if (!result.isEmpty()) {
277             result += ")";
278         }
279         return result;
280     }
281     
282     private void formatColumnVariables(final Map<String, Object> column) throws SQLException {
283         if (null == column.get("attoptions")) {
284             return;
285         }
286         Collection<Map<String, String>> attOptions = new LinkedList<>();
287         Collection<String> columnVariables = Arrays.stream((String[]) ((Array) column.get("attoptions")).getArray()).collect(Collectors.toList());
288         for (String each : columnVariables) {
289             Map<String, String> columnVariable = new LinkedHashMap<>();
290             columnVariable.put("name", each.substring(0, each.indexOf(ATT_OPTION_SPLIT)));
291             columnVariable.put("value", each.substring(each.indexOf(ATT_OPTION_SPLIT) + 1));
292             attOptions.add(columnVariable);
293         }
294         column.put("attoptions", attOptions);
295     }
296     
297     private String parseTypeName(final String name) {
298         String result = name;
299         boolean isArray = false;
300         if (result.endsWith("[]")) {
301             isArray = true;
302             result = result.substring(0, result.lastIndexOf("[]"));
303         }
304         int idx = result.indexOf('(');
305         if (idx > 0 && result.endsWith(")")) {
306             result = result.substring(0, idx);
307         } else if (idx > 0 && result.startsWith("time")) {
308             int endIdx = result.indexOf(')');
309             if (1 != endIdx) {
310                 Matcher matcher = BRACKETS_PATTERN.matcher(result);
311                 StringBuffer buffer = new StringBuffer();
312                 while (matcher.find()) {
313                     matcher.appendReplacement(buffer, "");
314                 }
315                 matcher.appendTail(buffer);
316                 result = buffer.toString();
317             }
318         } else if (result.startsWith("interval")) {
319             result = "interval";
320         }
321         if (isArray) {
322             result += "[]";
323         }
324         return result;
325     }
326 }