Showing
4 changed files
with
272 additions
and
9 deletions
... | ... | @@ -18,14 +18,19 @@ package org.thingsboard.server.controller.plugin; |
18 | 18 | import lombok.extern.slf4j.Slf4j; |
19 | 19 | import org.springframework.beans.factory.BeanCreationNotAllowedException; |
20 | 20 | import org.springframework.beans.factory.annotation.Autowired; |
21 | -import org.springframework.context.annotation.Lazy; | |
21 | +import org.springframework.beans.factory.annotation.Value; | |
22 | +import org.springframework.scheduling.annotation.Scheduled; | |
22 | 23 | import org.springframework.stereotype.Service; |
23 | 24 | import org.springframework.web.socket.CloseStatus; |
24 | 25 | import org.springframework.web.socket.TextMessage; |
25 | 26 | import org.springframework.web.socket.WebSocketSession; |
26 | 27 | import org.springframework.web.socket.handler.TextWebSocketHandler; |
28 | +import org.thingsboard.server.common.data.id.CustomerId; | |
29 | +import org.thingsboard.server.common.data.id.TenantId; | |
30 | +import org.thingsboard.server.common.data.id.UserId; | |
27 | 31 | import org.thingsboard.server.config.WebSocketConfiguration; |
28 | 32 | import org.thingsboard.server.service.security.model.SecurityUser; |
33 | +import org.thingsboard.server.service.security.model.UserPrincipal; | |
29 | 34 | import org.thingsboard.server.service.telemetry.SessionEvent; |
30 | 35 | import org.thingsboard.server.service.telemetry.TelemetryWebSocketMsgEndpoint; |
31 | 36 | import org.thingsboard.server.service.telemetry.TelemetryWebSocketService; |
... | ... | @@ -34,6 +39,7 @@ import org.thingsboard.server.service.telemetry.TelemetryWebSocketSessionRef; |
34 | 39 | import java.io.IOException; |
35 | 40 | import java.net.URI; |
36 | 41 | import java.security.InvalidParameterException; |
42 | +import java.util.Set; | |
37 | 43 | import java.util.UUID; |
38 | 44 | import java.util.concurrent.ConcurrentHashMap; |
39 | 45 | import java.util.concurrent.ConcurrentMap; |
... | ... | @@ -48,12 +54,26 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr |
48 | 54 | @Autowired |
49 | 55 | private TelemetryWebSocketService webSocketService; |
50 | 56 | |
57 | + @Value("${server.ws.limits.max_sessions_per_tenant:0}") | |
58 | + private int maxSessionsPerTenant; | |
59 | + @Value("${server.ws.limits.max_sessions_per_customer:0}") | |
60 | + private int maxSessionsPerCustomer; | |
61 | + @Value("${server.ws.limits.max_sessions_per_regular_user:0}") | |
62 | + private int maxSessionsPerRegularUser; | |
63 | + @Value("${server.ws.limits.max_sessions_per_public_user:0}") | |
64 | + private int maxSessionsPerPublicUser; | |
65 | + | |
66 | + private ConcurrentMap<TenantId, Set<String>> tenantSessionsMap = new ConcurrentHashMap<>(); | |
67 | + private ConcurrentMap<CustomerId, Set<String>> customerSessionsMap = new ConcurrentHashMap<>(); | |
68 | + private ConcurrentMap<UserId, Set<String>> regularUserSessionsMap = new ConcurrentHashMap<>(); | |
69 | + private ConcurrentMap<UserId, Set<String>> publicUserSessionsMap = new ConcurrentHashMap<>(); | |
70 | + | |
51 | 71 | @Override |
52 | 72 | public void handleTextMessage(WebSocketSession session, TextMessage message) { |
53 | 73 | try { |
54 | 74 | SessionMetaData sessionMd = internalSessionMap.get(session.getId()); |
55 | 75 | if (sessionMd != null) { |
56 | - log.info("[{}][{}] Processing {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message); | |
76 | + log.info("[{}][{}] Processing {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message.getPayload()); | |
57 | 77 | webSocketService.handleWebSocketMsg(sessionMd.sessionRef, message.getPayload()); |
58 | 78 | } else { |
59 | 79 | log.warn("[{}] Failed to find session", session.getId()); |
... | ... | @@ -71,12 +91,15 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr |
71 | 91 | String internalSessionId = session.getId(); |
72 | 92 | TelemetryWebSocketSessionRef sessionRef = toRef(session); |
73 | 93 | String externalSessionId = sessionRef.getSessionId(); |
94 | + if (!checkLimits(session, sessionRef)) { | |
95 | + return; | |
96 | + } | |
74 | 97 | internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef)); |
75 | 98 | externalSessionMap.put(externalSessionId, internalSessionId); |
76 | 99 | processInWebSocketService(sessionRef, SessionEvent.onEstablished()); |
77 | 100 | log.info("[{}][{}][{}] Session is opened", sessionRef.getSecurityCtx().getTenantId(), externalSessionId, session.getId()); |
78 | 101 | } catch (InvalidParameterException e) { |
79 | - log.warn("[[{}] Failed to start session", session.getId(), e); | |
102 | + log.warn("[{}] Failed to start session", session.getId(), e); | |
80 | 103 | session.close(CloseStatus.BAD_DATA.withReason(e.getMessage())); |
81 | 104 | } catch (Exception e) { |
82 | 105 | log.warn("[{}] Failed to start session", session.getId(), e); |
... | ... | @@ -101,6 +124,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr |
101 | 124 | super.afterConnectionClosed(session, closeStatus); |
102 | 125 | SessionMetaData sessionMd = internalSessionMap.remove(session.getId()); |
103 | 126 | if (sessionMd != null) { |
127 | + cleanupLimits(session, sessionMd.sessionRef); | |
104 | 128 | externalSessionMap.remove(sessionMd.sessionRef.getSessionId()); |
105 | 129 | processInWebSocketService(sessionMd.sessionRef, SessionEvent.onClosed()); |
106 | 130 | } |
... | ... | @@ -136,7 +160,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr |
136 | 160 | private final WebSocketSession session; |
137 | 161 | private final TelemetryWebSocketSessionRef sessionRef; |
138 | 162 | |
139 | - public SessionMetaData(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) { | |
163 | + SessionMetaData(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) { | |
140 | 164 | super(); |
141 | 165 | this.session = session; |
142 | 166 | this.sessionRef = sessionRef; |
... | ... | @@ -162,15 +186,21 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr |
162 | 186 | } |
163 | 187 | } |
164 | 188 | |
189 | + | |
165 | 190 | @Override |
166 | 191 | public void close(TelemetryWebSocketSessionRef sessionRef) throws IOException { |
192 | + close(sessionRef, CloseStatus.NORMAL); | |
193 | + } | |
194 | + | |
195 | + @Override | |
196 | + public void close(TelemetryWebSocketSessionRef sessionRef, CloseStatus reason) throws IOException { | |
167 | 197 | String externalId = sessionRef.getSessionId(); |
168 | 198 | log.debug("[{}] Processing close request", externalId); |
169 | 199 | String internalId = externalSessionMap.get(externalId); |
170 | 200 | if (internalId != null) { |
171 | 201 | SessionMetaData sessionMd = internalSessionMap.get(internalId); |
172 | 202 | if (sessionMd != null) { |
173 | - sessionMd.session.close(CloseStatus.NORMAL); | |
203 | + sessionMd.session.close(reason); | |
174 | 204 | } else { |
175 | 205 | log.warn("[{}][{}] Failed to find session by internal id", externalId, internalId); |
176 | 206 | } |
... | ... | @@ -179,4 +209,94 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr |
179 | 209 | } |
180 | 210 | } |
181 | 211 | |
212 | + private boolean checkLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) throws Exception { | |
213 | + String sessionId = session.getId(); | |
214 | + if (maxSessionsPerTenant > 0) { | |
215 | + Set<String> tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet()); | |
216 | + synchronized (tenantSessions) { | |
217 | + if (tenantSessions.size() < maxSessionsPerTenant) { | |
218 | + tenantSessions.add(sessionId); | |
219 | + } else { | |
220 | + log.info("[{}][{}][{}] Failed to start session. Max tenant sessions limit reached" | |
221 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); | |
222 | + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max tenant sessions limit reached!")); | |
223 | + return false; | |
224 | + } | |
225 | + } | |
226 | + } | |
227 | + | |
228 | + if (sessionRef.getSecurityCtx().isCustomerUser()) { | |
229 | + if (maxSessionsPerCustomer > 0) { | |
230 | + Set<String> customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet()); | |
231 | + synchronized (customerSessions) { | |
232 | + if (customerSessions.size() < maxSessionsPerCustomer) { | |
233 | + customerSessions.add(sessionId); | |
234 | + } else { | |
235 | + log.info("[{}][{}][{}] Failed to start session. Max customer sessions limit reached" | |
236 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); | |
237 | + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max customer sessions limit reached")); | |
238 | + return false; | |
239 | + } | |
240 | + } | |
241 | + } | |
242 | + if (maxSessionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | |
243 | + Set<String> regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | |
244 | + synchronized (regularUserSessions) { | |
245 | + if (regularUserSessions.size() < maxSessionsPerRegularUser) { | |
246 | + regularUserSessions.add(sessionId); | |
247 | + } else { | |
248 | + log.info("[{}][{}][{}] Failed to start session. Max user sessions limit reached" | |
249 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); | |
250 | + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max regular user sessions limit reached")); | |
251 | + return false; | |
252 | + } | |
253 | + } | |
254 | + } | |
255 | + if (maxSessionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | |
256 | + Set<String> publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | |
257 | + synchronized (publicUserSessions) { | |
258 | + if (publicUserSessions.size() < maxSessionsPerPublicUser) { | |
259 | + publicUserSessions.add(sessionId); | |
260 | + } else { | |
261 | + log.info("[{}][{}][{}] Failed to start session. Max user sessions limit reached" | |
262 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); | |
263 | + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max public user sessions limit reached")); | |
264 | + return false; | |
265 | + } | |
266 | + } | |
267 | + } | |
268 | + } | |
269 | + return true; | |
270 | + } | |
271 | + | |
272 | + private void cleanupLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) { | |
273 | + String sessionId = session.getId(); | |
274 | + if (maxSessionsPerTenant > 0) { | |
275 | + Set<String> tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet()); | |
276 | + synchronized (tenantSessions) { | |
277 | + tenantSessions.remove(sessionId); | |
278 | + } | |
279 | + } | |
280 | + if (sessionRef.getSecurityCtx().isCustomerUser()) { | |
281 | + if (maxSessionsPerCustomer > 0) { | |
282 | + Set<String> customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet()); | |
283 | + synchronized (customerSessions) { | |
284 | + customerSessions.remove(sessionId); | |
285 | + } | |
286 | + } | |
287 | + if (maxSessionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | |
288 | + Set<String> regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | |
289 | + synchronized (regularUserSessions) { | |
290 | + regularUserSessions.remove(sessionId); | |
291 | + } | |
292 | + } | |
293 | + if (maxSessionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | |
294 | + Set<String> publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | |
295 | + synchronized (publicUserSessions) { | |
296 | + publicUserSessions.remove(sessionId); | |
297 | + } | |
298 | + } | |
299 | + } | |
300 | + } | |
301 | + | |
182 | 302 | } | ... | ... |
... | ... | @@ -23,12 +23,17 @@ import com.google.common.util.concurrent.Futures; |
23 | 23 | import com.google.common.util.concurrent.ListenableFuture; |
24 | 24 | import lombok.extern.slf4j.Slf4j; |
25 | 25 | import org.springframework.beans.factory.annotation.Autowired; |
26 | +import org.springframework.beans.factory.annotation.Value; | |
26 | 27 | import org.springframework.stereotype.Service; |
27 | 28 | import org.springframework.util.StringUtils; |
29 | +import org.springframework.web.socket.CloseStatus; | |
30 | +import org.springframework.web.socket.WebSocketSession; | |
28 | 31 | import org.thingsboard.server.common.data.DataConstants; |
32 | +import org.thingsboard.server.common.data.id.CustomerId; | |
29 | 33 | import org.thingsboard.server.common.data.id.EntityId; |
30 | 34 | import org.thingsboard.server.common.data.id.EntityIdFactory; |
31 | 35 | import org.thingsboard.server.common.data.id.TenantId; |
36 | +import org.thingsboard.server.common.data.id.UserId; | |
32 | 37 | import org.thingsboard.server.common.data.kv.Aggregation; |
33 | 38 | import org.thingsboard.server.common.data.kv.AttributeKvEntry; |
34 | 39 | import org.thingsboard.server.common.data.kv.BaseReadTsKvQuery; |
... | ... | @@ -42,6 +47,7 @@ import org.thingsboard.server.service.security.AccessValidator; |
42 | 47 | import org.thingsboard.server.service.security.ValidationCallback; |
43 | 48 | import org.thingsboard.server.service.security.ValidationResult; |
44 | 49 | import org.thingsboard.server.service.security.ValidationResultCode; |
50 | +import org.thingsboard.server.service.security.model.UserPrincipal; | |
45 | 51 | import org.thingsboard.server.service.telemetry.cmd.AttributesSubscriptionCmd; |
46 | 52 | import org.thingsboard.server.service.telemetry.cmd.GetHistoryCmd; |
47 | 53 | import org.thingsboard.server.service.telemetry.cmd.SubscriptionCmd; |
... | ... | @@ -64,6 +70,7 @@ import java.util.ArrayList; |
64 | 70 | import java.util.Collections; |
65 | 71 | import java.util.HashMap; |
66 | 72 | import java.util.HashSet; |
73 | +import java.util.Iterator; | |
67 | 74 | import java.util.List; |
68 | 75 | import java.util.Map; |
69 | 76 | import java.util.Optional; |
... | ... | @@ -112,11 +119,25 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi |
112 | 119 | @Autowired |
113 | 120 | private TimeseriesService tsService; |
114 | 121 | |
122 | + @Value("${server.ws.limits.max_subscriptions_per_tenant:0}") | |
123 | + private int maxSubscriptionsPerTenant; | |
124 | + @Value("${server.ws.limits.max_subscriptions_per_customer:0}") | |
125 | + private int maxSubscriptionsPerCustomer; | |
126 | + @Value("${server.ws.limits.max_subscriptions_per_regular_user:0}") | |
127 | + private int maxSubscriptionsPerRegularUser; | |
128 | + @Value("${server.ws.limits.max_subscriptions_per_public_user:0}") | |
129 | + private int maxSubscriptionsPerPublicUser; | |
130 | + | |
131 | + private ConcurrentMap<TenantId, Set<String>> tenantSubscriptionsMap = new ConcurrentHashMap<>(); | |
132 | + private ConcurrentMap<CustomerId, Set<String>> customerSubscriptionsMap = new ConcurrentHashMap<>(); | |
133 | + private ConcurrentMap<UserId, Set<String>> regularUserSubscriptionsMap = new ConcurrentHashMap<>(); | |
134 | + private ConcurrentMap<UserId, Set<String>> publicUserSubscriptionsMap = new ConcurrentHashMap<>(); | |
135 | + | |
115 | 136 | private ExecutorService executor; |
116 | 137 | |
117 | 138 | @PostConstruct |
118 | 139 | public void initExecutor() { |
119 | - executor = new ThreadPoolExecutor(0, 50, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue<>()); | |
140 | + executor = new ThreadPoolExecutor(0, 50, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue<>()); | |
120 | 141 | } |
121 | 142 | |
122 | 143 | @PreDestroy |
... | ... | @@ -140,6 +161,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi |
140 | 161 | case CLOSED: |
141 | 162 | wsSessionsMap.remove(sessionId); |
142 | 163 | subscriptionManager.cleanupLocalWsSessionSubscriptions(sessionRef, sessionId); |
164 | + processSessionClose(sessionRef); | |
143 | 165 | break; |
144 | 166 | } |
145 | 167 | } |
... | ... | @@ -154,10 +176,18 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi |
154 | 176 | TelemetryPluginCmdsWrapper cmdsWrapper = jsonMapper.readValue(msg, TelemetryPluginCmdsWrapper.class); |
155 | 177 | if (cmdsWrapper != null) { |
156 | 178 | if (cmdsWrapper.getAttrSubCmds() != null) { |
157 | - cmdsWrapper.getAttrSubCmds().forEach(cmd -> handleWsAttributesSubscriptionCmd(sessionRef, cmd)); | |
179 | + cmdsWrapper.getAttrSubCmds().forEach(cmd -> { | |
180 | + if (processSubscription(sessionRef, cmd)) { | |
181 | + handleWsAttributesSubscriptionCmd(sessionRef, cmd); | |
182 | + } | |
183 | + }); | |
158 | 184 | } |
159 | 185 | if (cmdsWrapper.getTsSubCmds() != null) { |
160 | - cmdsWrapper.getTsSubCmds().forEach(cmd -> handleWsTimeseriesSubscriptionCmd(sessionRef, cmd)); | |
186 | + cmdsWrapper.getTsSubCmds().forEach(cmd -> { | |
187 | + if (processSubscription(sessionRef, cmd)) { | |
188 | + handleWsTimeseriesSubscriptionCmd(sessionRef, cmd); | |
189 | + } | |
190 | + }); | |
161 | 191 | } |
162 | 192 | if (cmdsWrapper.getHistoryCmds() != null) { |
163 | 193 | cmdsWrapper.getHistoryCmds().forEach(cmd -> handleWsHistoryCmd(sessionRef, cmd)); |
... | ... | @@ -178,6 +208,105 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi |
178 | 208 | } |
179 | 209 | } |
180 | 210 | |
211 | + private void processSessionClose(TelemetryWebSocketSessionRef sessionRef) { | |
212 | + String sessionId = "[" + sessionRef.getSessionId() + "]"; | |
213 | + if (maxSubscriptionsPerTenant > 0) { | |
214 | + Set<String> tenantSubscriptions = tenantSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet()); | |
215 | + synchronized (tenantSubscriptions) { | |
216 | + tenantSubscriptions.removeIf(subId -> subId.startsWith(sessionId)); | |
217 | + } | |
218 | + } | |
219 | + if (sessionRef.getSecurityCtx().isCustomerUser()) { | |
220 | + if (maxSubscriptionsPerCustomer > 0) { | |
221 | + Set<String> customerSessions = customerSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet()); | |
222 | + synchronized (customerSessions) { | |
223 | + customerSessions.removeIf(subId -> subId.startsWith(sessionId)); | |
224 | + } | |
225 | + } | |
226 | + if (maxSubscriptionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | |
227 | + Set<String> regularUserSessions = regularUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | |
228 | + synchronized (regularUserSessions) { | |
229 | + regularUserSessions.removeIf(subId -> subId.startsWith(sessionId)); | |
230 | + } | |
231 | + } | |
232 | + if (maxSubscriptionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | |
233 | + Set<String> publicUserSessions = publicUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | |
234 | + synchronized (publicUserSessions) { | |
235 | + publicUserSessions.removeIf(subId -> subId.startsWith(sessionId)); | |
236 | + } | |
237 | + } | |
238 | + } | |
239 | + } | |
240 | + | |
241 | + private boolean processSubscription(TelemetryWebSocketSessionRef sessionRef, SubscriptionCmd cmd) { | |
242 | + String subId = "[" + sessionRef.getSessionId() + "]:[" + cmd.getCmdId() + "]"; | |
243 | + try { | |
244 | + if (maxSubscriptionsPerTenant > 0) { | |
245 | + Set<String> tenantSubscriptions = tenantSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet()); | |
246 | + synchronized (tenantSubscriptions) { | |
247 | + if (cmd.isUnsubscribe()) { | |
248 | + tenantSubscriptions.remove(subId); | |
249 | + } else if (tenantSubscriptions.size() < maxSubscriptionsPerTenant) { | |
250 | + tenantSubscriptions.add(subId); | |
251 | + } else { | |
252 | + log.info("[{}][{}][{}] Failed to start subscription. Max tenant subscriptions limit reached" | |
253 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId); | |
254 | + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max tenant subscriptions limit reached!")); | |
255 | + return false; | |
256 | + } | |
257 | + } | |
258 | + } | |
259 | + | |
260 | + if (sessionRef.getSecurityCtx().isCustomerUser()) { | |
261 | + if (maxSubscriptionsPerCustomer > 0) { | |
262 | + Set<String> customerSessions = customerSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet()); | |
263 | + synchronized (customerSessions) { | |
264 | + if (cmd.isUnsubscribe()) { | |
265 | + customerSessions.remove(subId); | |
266 | + } else if (customerSessions.size() < maxSubscriptionsPerCustomer) { | |
267 | + customerSessions.add(subId); | |
268 | + } else { | |
269 | + log.info("[{}][{}][{}] Failed to start subscription. Max customer sessions limit reached" | |
270 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId); | |
271 | + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max customer subscriptions limit reached")); | |
272 | + return false; | |
273 | + } | |
274 | + } | |
275 | + } | |
276 | + if (maxSubscriptionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | |
277 | + Set<String> regularUserSessions = regularUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | |
278 | + synchronized (regularUserSessions) { | |
279 | + if (regularUserSessions.size() < maxSubscriptionsPerRegularUser) { | |
280 | + regularUserSessions.add(subId); | |
281 | + } else { | |
282 | + log.info("[{}][{}][{}] Failed to start subscription. Max user sessions limit reached" | |
283 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId); | |
284 | + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max regular user subscriptions limit reached")); | |
285 | + return false; | |
286 | + } | |
287 | + } | |
288 | + } | |
289 | + if (maxSubscriptionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | |
290 | + Set<String> publicUserSessions = publicUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | |
291 | + synchronized (publicUserSessions) { | |
292 | + if (publicUserSessions.size() < maxSubscriptionsPerPublicUser) { | |
293 | + publicUserSessions.add(subId); | |
294 | + } else { | |
295 | + log.info("[{}][{}][{}] Failed to start subscription. Max user sessions limit reached" | |
296 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId); | |
297 | + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max public user subscriptions limit reached")); | |
298 | + return false; | |
299 | + } | |
300 | + } | |
301 | + } | |
302 | + } | |
303 | + } catch (IOException e) { | |
304 | + log.warn("[{}] Failed to send session close: {}", sessionRef.getSessionId(), e); | |
305 | + return false; | |
306 | + } | |
307 | + return true; | |
308 | + } | |
309 | + | |
181 | 310 | private void handleWsAttributesSubscriptionCmd(TelemetryWebSocketSessionRef sessionRef, AttributesSubscriptionCmd cmd) { |
182 | 311 | String sessionId = sessionRef.getSessionId(); |
183 | 312 | log.debug("[{}] Processing: {}", sessionId, cmd); |
... | ... | @@ -220,7 +349,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi |
220 | 349 | public void onFailure(Throwable e) { |
221 | 350 | log.error(FAILED_TO_FETCH_ATTRIBUTES, e); |
222 | 351 | SubscriptionUpdate update; |
223 | - if (UnauthorizedException.class.isInstance(e)) { | |
352 | + if (e instanceof UnauthorizedException) { | |
224 | 353 | update = new SubscriptionUpdate(cmd.getCmdId(), SubscriptionErrorCode.UNAUTHORIZED, |
225 | 354 | SubscriptionErrorCode.UNAUTHORIZED.getDefaultMsg()); |
226 | 355 | } else { | ... | ... |
... | ... | @@ -15,6 +15,8 @@ |
15 | 15 | */ |
16 | 16 | package org.thingsboard.server.service.telemetry; |
17 | 17 | |
18 | +import org.springframework.web.socket.CloseStatus; | |
19 | + | |
18 | 20 | import java.io.IOException; |
19 | 21 | |
20 | 22 | /** |
... | ... | @@ -26,4 +28,5 @@ public interface TelemetryWebSocketMsgEndpoint { |
26 | 28 | |
27 | 29 | void close(TelemetryWebSocketSessionRef sessionRef) throws IOException; |
28 | 30 | |
31 | + void close(TelemetryWebSocketSessionRef sessionRef, CloseStatus withReason) throws IOException; | |
29 | 32 | } | ... | ... |
... | ... | @@ -32,6 +32,17 @@ server: |
32 | 32 | # Alias that identifies the key in the key store |
33 | 33 | key-alias: "${SSL_KEY_ALIAS:tomcat}" |
34 | 34 | log_controller_error_stack_trace: "${HTTP_LOG_CONTROLLER_ERROR_STACK_TRACE:true}" |
35 | + ws: | |
36 | + limits: | |
37 | + # Limit the amount of sessions and subscriptions available on each server. Put values to zero to disable particular limitation | |
38 | + max_sessions_per_tenant: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_TENANT:0}" | |
39 | + max_sessions_per_customer: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_CUSTOMER:0}" | |
40 | + max_sessions_per_regular_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_REGULAR_USER:0}" | |
41 | + max_sessions_per_public_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_PUBLIC_USER:0}" | |
42 | + max_subscriptions_per_tenant: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_TENANT:0}" | |
43 | + max_subscriptions_per_customer: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_CUSTOMER:0}" | |
44 | + max_subscriptions_per_regular_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_REGULAR_USER:0}" | |
45 | + max_subscriptions_per_public_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_PUBLIC_USER:0}" | |
35 | 46 | |
36 | 47 | # Zookeeper connection parameters. Used for service discovery. |
37 | 48 | zk: | ... | ... |