1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.data.pipeline.postgresql.ddlgenerator;
19
20 import lombok.SneakyThrows;
21
22 import java.sql.Array;
23 import java.sql.Connection;
24 import java.sql.SQLException;
25 import java.util.Arrays;
26 import java.util.Collection;
27 import java.util.Collections;
28 import java.util.HashSet;
29 import java.util.LinkedHashMap;
30 import java.util.LinkedList;
31 import java.util.Map;
32 import java.util.regex.Matcher;
33 import java.util.regex.Pattern;
34 import java.util.stream.Collectors;
35
36
37
38
39 public final class PostgreSQLColumnPropertiesAppender extends AbstractPostgreSQLDDLAdapter {
40
41 private static final Collection<String> TIME_TYPE_NAMES = new HashSet<>(Arrays.asList(
42 "time", "timetz", "time without time zone", "time with time zone", "timestamp", "timestamptz", "timestamp without time zone", "timestamp with time zone"));
43
44 private static final Collection<String> BIT_TYPE_NAMES = new HashSet<>(Arrays.asList("bit", "bit varying", "varbit"));
45
46 private static final Pattern LENGTH_PRECISION_PATTERN = Pattern.compile("(\\d+),(\\d+)");
47
48 private static final Pattern LENGTH_PATTERN = Pattern.compile("(\\d+)");
49
50 private static final Pattern BRACKETS_PATTERN = Pattern.compile("(\\(\\d+\\))");
51
52 private static final String ATT_OPTION_SPLIT = "=";
53
54 public PostgreSQLColumnPropertiesAppender(final Connection connection, final int majorVersion, final int minorVersion) {
55 super(connection, majorVersion, minorVersion);
56 }
57
58
59
60
61
62
63 @SneakyThrows(SQLException.class)
64 public void append(final Map<String, Object> context) {
65 Collection<Map<String, Object>> typeAndInheritedColumns = getTypeAndInheritedColumns(context);
66 Collection<Map<String, Object>> allColumns = executeByTemplate(context, "component/columns/%s/properties.ftl");
67 for (Map<String, Object> each : allColumns) {
68 for (Map<String, Object> column : typeAndInheritedColumns) {
69 if (each.get("name").equals(column.get("name"))) {
70 each.put(getInheritedFromTableOrType(context), column.get("inheritedfrom"));
71 }
72 }
73 }
74 if (!allColumns.isEmpty()) {
75 Map<String, Collection<String>> editTypes = getEditTypes(allColumns);
76 for (Map<String, Object> each : allColumns) {
77 columnFormatter(each, editTypes.getOrDefault(each.get("atttypid").toString(), new LinkedList<>()));
78 }
79 }
80 context.put("columns", allColumns);
81 }
82
83 private Collection<Map<String, Object>> getTypeAndInheritedColumns(final Map<String, Object> context) throws SQLException {
84 if (null != context.get("typoid")) {
85 return getColumnFromType(context);
86 }
87 if (null == context.get("coll_inherits")) {
88 return Collections.emptyList();
89 }
90 Collection<String> collInherits = toCollection((Array) context.get("coll_inherits"));
91 context.put("coll_inherits", collInherits);
92 return collInherits.isEmpty() ? Collections.emptyList() : getColumnFromInherits(collInherits);
93 }
94
95 private Collection<Map<String, Object>> getColumnFromType(final Map<String, Object> context) {
96 Map<String, Object> params = new LinkedHashMap<>();
97 params.put("tid", context.get("typoid"));
98 return executeByTemplate(params, "component/table/%s/get_columns_for_table.ftl");
99 }
100
101 private Collection<String> toCollection(final Array array) throws SQLException {
102 return Arrays.stream((String[]) array.getArray()).collect(Collectors.toList());
103 }
104
105 private Collection<Map<String, Object>> getColumnFromInherits(final Collection<String> collInherits) {
106 Collection<Map<String, Object>> result = new LinkedList<>();
107 for (Map<String, Object> each : executeByTemplate(new LinkedHashMap<>(), "component/table/%s/get_inherits.ftl")) {
108 if (collInherits.contains((String) each.get("inherits"))) {
109 Map<String, Object> params = new LinkedHashMap<>();
110 params.put("tid", each.get("oid"));
111 result.addAll(executeByTemplate(params, "table/%s/get_columns_for_table.ftl"));
112 }
113 }
114 return result;
115 }
116
117 @SuppressWarnings("unchecked")
118 private String getInheritedFromTableOrType(final Map<String, Object> context) {
119 String result = "inheritedfrom";
120 if (null != context.get("typoid")) {
121 result += "type";
122 } else if (null != context.get("coll_inherits") && !((Collection<String>) context.get("coll_inherits")).isEmpty()) {
123 result += "table";
124 }
125 return result;
126 }
127
128 private Map<String, Collection<String>> getEditTypes(final Collection<Map<String, Object>> allColumns) throws SQLException {
129 Map<String, Collection<String>> result = new LinkedHashMap<>();
130 Map<String, Object> params = new LinkedHashMap<>();
131 params.put("type_ids", allColumns.stream().map(each -> each.get("atttypid").toString()).collect(Collectors.joining(",")));
132 for (Map<String, Object> each : executeByTemplate(params, "component/columns/%s/edit_mode_types_multi.ftl")) {
133 result.put(each.get("main_oid").toString(), toCollectionAndSort((Array) each.get("edit_types")));
134 }
135 return result;
136 }
137
138 private Collection<String> toCollectionAndSort(final Array editTypes) throws SQLException {
139 return Arrays.stream((String[]) editTypes.getArray()).sorted(String::compareTo).collect(Collectors.toList());
140 }
141
142 private void columnFormatter(final Map<String, Object> column, final Collection<String> editTypes) throws SQLException {
143 handlePrimaryColumn(column);
144 fetchLengthPrecision(column);
145 formatColumnVariables(column);
146 formatSecurityLabels(column);
147 editTypes.add(column.get("cltype").toString());
148 column.put("edit_types", editTypes.stream().sorted().collect(Collectors.toList()));
149 column.put("cltype", parseTypeName(column.get("cltype").toString()));
150 }
151
152 private void handlePrimaryColumn(final Map<String, Object> column) {
153 if (null == column.get("attnum") || null == column.get("indkey")) {
154 return;
155 }
156 if (Arrays.stream(column.get("indkey").toString().split(" ")).collect(Collectors.toList()).contains(column.get("attnum").toString())) {
157 column.put("is_pk", true);
158 column.put("is_primary_key", true);
159 } else {
160 column.put("is_pk", false);
161 column.put("is_primary_key", false);
162 }
163 }
164
165 private void fetchLengthPrecision(final Map<String, Object> column) {
166 String fullType = getFullDataType(column);
167 if (column.containsKey("elemoid")) {
168 handleLengthPrecision((Long) column.get("elemoid"), column, fullType);
169 }
170 }
171
172 private void handleLengthPrecision(final Long elemoid, final Map<String, Object> column, final String fullType) {
173 switch (PostgreSQLColumnType.valueOf(elemoid)) {
174 case NUMERIC:
175 setColumnPrecision(column, fullType);
176 break;
177 case DATE:
178 case VARCHAR:
179 setColumnLength(column, fullType);
180 break;
181 default:
182 break;
183 }
184 }
185
186 private void setColumnPrecision(final Map<String, Object> column, final String fullType) {
187 Matcher matcher = LENGTH_PRECISION_PATTERN.matcher(fullType);
188 if (matcher.find()) {
189 column.put("attlen", matcher.group(1));
190 column.put("attprecision", matcher.group(2));
191 }
192 }
193
194 private static void setColumnLength(final Map<String, Object> column, final String fullType) {
195 Matcher matcher = LENGTH_PATTERN.matcher(fullType);
196 if (matcher.find()) {
197 column.put("attlen", matcher.group(1));
198 column.put("attprecision", null);
199 }
200 }
201
202 private String getFullDataType(final Map<String, Object> column) {
203 String namespace = (String) column.get("typnspname");
204 String typeName = (String) column.get("typname");
205 Integer numdims = (Integer) column.get("attndims");
206 String schema = null == namespace ? "" : namespace;
207 String name = checkSchemaInName(typeName, schema);
208 if (name.startsWith("_")) {
209 if (null == numdims || 0 == numdims) {
210 numdims = 1;
211 }
212 name = name.substring(1);
213 }
214 if (name.endsWith("[]")) {
215 if (null == numdims || 0 == numdims) {
216 numdims = 1;
217 }
218 name = name.substring(0, name.length() - 2);
219 }
220 if (name.startsWith("\"") && name.endsWith("\"")) {
221 name = name.substring(1, name.length() - 1);
222 }
223 Integer typmod = (Integer) column.get("atttypmod");
224 String length = -1 == typmod ? "" : checkTypmod(typmod, name);
225 return getFullTypeValue(name, schema, length, numdims == 1 ? "[]" : "");
226 }
227
228 private String checkSchemaInName(final String typname, final String schema) {
229 if (typname.contains(schema + "\".")) {
230 return typname.substring(schema.length() + 3);
231 }
232 if (typname.contains(schema + ".")) {
233 return typname.substring(schema.length() + 1);
234 }
235 return typname;
236 }
237
238 private String getFullTypeValue(final String name, final String schema, final String length, final String array) {
239 if ("char".equals(name) && "pg_catalog".equals(schema)) {
240 return "\"char\"" + array;
241 }
242 if ("time with time zone".equals(name)) {
243 return "time" + length + " with time zone" + array;
244 }
245 if ("time without time zone".equals(name)) {
246 return "time" + length + " without time zone" + array;
247 }
248 if ("timestamp with time zone".equals(name)) {
249 return "timestamp" + length + " with time zone" + array;
250 }
251 if ("timestamp without time zone".equals(name)) {
252 return "timestamp" + length + " without time zone" + array;
253 }
254 return name + length + array;
255 }
256
257 private String checkTypmod(final Integer typmod, final String name) {
258 String result = "(";
259 if ("numeric".equals(name)) {
260 int len = (typmod - 4) >> 16;
261 int prec = (typmod - 4) & 0xffff;
262 result += String.valueOf(len);
263 result += "," + prec;
264 } else if (TIME_TYPE_NAMES.contains(name) || BIT_TYPE_NAMES.contains(name)) {
265 int len = typmod;
266 result += String.valueOf(len);
267 } else if ("interval".equals(name)) {
268 int len = typmod & 0xffff;
269 result += len > 6 ? "" : String.valueOf(len);
270 } else if ("date".equals(name)) {
271 result = "";
272 } else {
273 int len = typmod - 4;
274 result += String.valueOf(len);
275 }
276 if (!result.isEmpty()) {
277 result += ")";
278 }
279 return result;
280 }
281
282 private void formatColumnVariables(final Map<String, Object> column) throws SQLException {
283 if (null == column.get("attoptions")) {
284 return;
285 }
286 Collection<Map<String, String>> attOptions = new LinkedList<>();
287 Collection<String> columnVariables = Arrays.stream((String[]) ((Array) column.get("attoptions")).getArray()).collect(Collectors.toList());
288 for (String each : columnVariables) {
289 Map<String, String> columnVariable = new LinkedHashMap<>();
290 columnVariable.put("name", each.substring(0, each.indexOf(ATT_OPTION_SPLIT)));
291 columnVariable.put("value", each.substring(each.indexOf(ATT_OPTION_SPLIT) + 1));
292 attOptions.add(columnVariable);
293 }
294 column.put("attoptions", attOptions);
295 }
296
297 private String parseTypeName(final String name) {
298 String result = name;
299 boolean isArray = false;
300 if (result.endsWith("[]")) {
301 isArray = true;
302 result = result.substring(0, result.lastIndexOf("[]"));
303 }
304 int idx = result.indexOf('(');
305 if (idx > 0 && result.endsWith(")")) {
306 result = result.substring(0, idx);
307 } else if (idx > 0 && result.startsWith("time")) {
308 int endIdx = result.indexOf(')');
309 if (1 != endIdx) {
310 Matcher matcher = BRACKETS_PATTERN.matcher(result);
311 StringBuffer buffer = new StringBuffer();
312 while (matcher.find()) {
313 matcher.appendReplacement(buffer, "");
314 }
315 matcher.appendTail(buffer);
316 result = buffer.toString();
317 }
318 } else if (result.startsWith("interval")) {
319 result = "interval";
320 }
321 if (isArray) {
322 result += "[]";
323 }
324 return result;
325 }
326 }