1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.proxy.backend.session;
19
20 import java.util.ArrayList;
21 import java.util.Collections;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Map.Entry;
25 import java.util.StringJoiner;
26 import java.util.concurrent.ConcurrentHashMap;
27
28
29
30
31 public final class RequiredSessionVariableRecorder {
32
33 private static final String DEFAULT = "DEFAULT";
34
35 private static final String NULL = "NULL";
36
37 private final Map<String, String> sessionVariables = new ConcurrentHashMap<>();
38
39
40
41
42
43
44
45 public void setVariable(final String variableName, final String variableValue) {
46 sessionVariables.put(variableName, variableValue);
47 }
48
49
50
51
52
53
54 public boolean isEmpty() {
55 return sessionVariables.isEmpty();
56 }
57
58
59
60
61
62
63
64 public List<String> toSetSQLs(final String databaseType) {
65 if (sessionVariables.isEmpty()) {
66 return Collections.emptyList();
67 }
68
69 switch (databaseType) {
70 case "MySQL":
71 return Collections.singletonList(aggregateToMySQLSetSQL());
72 case "PostgreSQL":
73 return convertToPostgreSQLSetSQLs();
74 default:
75 return Collections.emptyList();
76 }
77 }
78
79 private String aggregateToMySQLSetSQL() {
80 StringJoiner result = new StringJoiner(",", "SET ", "");
81 for (Entry<String, String> entry : sessionVariables.entrySet()) {
82 result.add(entry.getKey() + "=" + entry.getValue());
83 }
84 return result.toString();
85 }
86
87 private List<String> convertToPostgreSQLSetSQLs() {
88 List<String> result = new ArrayList<>(sessionVariables.size());
89 for (Entry<String, String> entry : sessionVariables.entrySet()) {
90 result.add("SET " + entry.getKey() + "=" + entry.getValue());
91 }
92 return result;
93 }
94
95
96
97
98
99
100
101 public List<String> toResetSQLs(final String databaseType) {
102 if (sessionVariables.isEmpty()) {
103 return Collections.emptyList();
104 }
105
106 switch (databaseType) {
107 case "MySQL":
108 return Collections.singletonList(aggregateToMySQLSetDefaultSQLs());
109 case "PostgreSQL":
110 return Collections.singletonList("RESET ALL");
111 default:
112 return Collections.emptyList();
113 }
114 }
115
116 private String aggregateToMySQLSetDefaultSQLs() {
117 StringJoiner result = new StringJoiner(",", "SET ", "");
118 for (String each : sessionVariables.keySet()) {
119 if (each.startsWith("@")) {
120 result.add(each + "=" + NULL);
121 } else {
122 result.add(each + "=" + DEFAULT);
123 }
124 }
125 return result.toString();
126 }
127
128
129
130
131 public void removeVariablesWithDefaultValue() {
132 sessionVariables.entrySet().removeIf(entry -> DEFAULT.equalsIgnoreCase(entry.getValue()));
133 }
134 }