1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.shardingsphere.broadcast.rule;
19
20 import lombok.Getter;
21 import org.apache.shardingsphere.broadcast.api.config.BroadcastRuleConfiguration;
22 import org.apache.shardingsphere.broadcast.rule.attribute.BroadcastDataNodeRuleAttribute;
23 import org.apache.shardingsphere.broadcast.rule.attribute.BroadcastTableNamesRuleAttribute;
24 import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
25 import org.apache.shardingsphere.infra.rule.attribute.RuleAttributes;
26 import org.apache.shardingsphere.infra.rule.attribute.datasource.DataSourceMapperRuleAttribute;
27 import org.apache.shardingsphere.infra.rule.scope.DatabaseRule;
28
29 import javax.sql.DataSource;
30 import java.util.Collection;
31 import java.util.LinkedList;
32 import java.util.Map;
33 import java.util.Map.Entry;
34 import java.util.Optional;
35 import java.util.TreeSet;
36 import java.util.stream.Collectors;
37
38
39
40
41 @Getter
42 public final class BroadcastRule implements DatabaseRule {
43
44 private final BroadcastRuleConfiguration configuration;
45
46 private final String databaseName;
47
48 private final Collection<String> tables;
49
50 private final Collection<String> dataSourceNames;
51
52 private final RuleAttributes attributes;
53
54 public BroadcastRule(final BroadcastRuleConfiguration config, final String databaseName, final Map<String, DataSource> dataSources, final Collection<ShardingSphereRule> builtRules) {
55 configuration = config;
56 this.databaseName = databaseName;
57 dataSourceNames = getAggregatedDataSourceNames(dataSources, builtRules);
58 tables = createBroadcastTables(config.getTables());
59 attributes = new RuleAttributes(new BroadcastDataNodeRuleAttribute(dataSourceNames, tables), new BroadcastTableNamesRuleAttribute(tables));
60 }
61
62 private Collection<String> getAggregatedDataSourceNames(final Map<String, DataSource> dataSources, final Collection<ShardingSphereRule> builtRules) {
63 Collection<String> result = new LinkedList<>(dataSources.keySet());
64 for (ShardingSphereRule each : builtRules) {
65 Optional<DataSourceMapperRuleAttribute> ruleAttribute = each.getAttributes().findAttribute(DataSourceMapperRuleAttribute.class);
66 if (ruleAttribute.isPresent()) {
67 result = getAggregatedDataSourceNames(result, ruleAttribute.get());
68 }
69 }
70 return result;
71 }
72
73 private Collection<String> getAggregatedDataSourceNames(final Collection<String> dataSourceNames, final DataSourceMapperRuleAttribute ruleAttribute) {
74 Collection<String> result = new LinkedList<>();
75 for (Entry<String, Collection<String>> entry : ruleAttribute.getDataSourceMapper().entrySet()) {
76 for (String each : entry.getValue()) {
77 if (dataSourceNames.contains(each)) {
78 dataSourceNames.remove(each);
79 if (!result.contains(entry.getKey())) {
80 result.add(entry.getKey());
81 }
82 }
83 }
84 }
85 result.addAll(dataSourceNames);
86 return result;
87 }
88
89 private Collection<String> createBroadcastTables(final Collection<String> broadcastTables) {
90 Collection<String> result = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
91 result.addAll(broadcastTables);
92 return result;
93 }
94
95
96
97
98
99
100
101 public Collection<String> getBroadcastRuleTableNames(final Collection<String> logicTableNames) {
102 return logicTableNames.stream().filter(tables::contains).collect(Collectors.toSet());
103 }
104
105
106
107
108
109
110
111 public boolean isAllBroadcastTables(final Collection<String> logicTableNames) {
112 return !logicTableNames.isEmpty() && tables.containsAll(logicTableNames);
113 }
114 }