Commit 1b48704c0d5ae7aa82c42407ef21a5fcafa6bcf8

Authored by Andrew Shvayka
1 parent 7e1466ff

WS rate limites

... ... @@ -18,14 +18,19 @@ package org.thingsboard.server.controller.plugin;
18 18 import lombok.extern.slf4j.Slf4j;
19 19 import org.springframework.beans.factory.BeanCreationNotAllowedException;
20 20 import org.springframework.beans.factory.annotation.Autowired;
21   -import org.springframework.context.annotation.Lazy;
  21 +import org.springframework.beans.factory.annotation.Value;
  22 +import org.springframework.scheduling.annotation.Scheduled;
22 23 import org.springframework.stereotype.Service;
23 24 import org.springframework.web.socket.CloseStatus;
24 25 import org.springframework.web.socket.TextMessage;
25 26 import org.springframework.web.socket.WebSocketSession;
26 27 import org.springframework.web.socket.handler.TextWebSocketHandler;
  28 +import org.thingsboard.server.common.data.id.CustomerId;
  29 +import org.thingsboard.server.common.data.id.TenantId;
  30 +import org.thingsboard.server.common.data.id.UserId;
27 31 import org.thingsboard.server.config.WebSocketConfiguration;
28 32 import org.thingsboard.server.service.security.model.SecurityUser;
  33 +import org.thingsboard.server.service.security.model.UserPrincipal;
29 34 import org.thingsboard.server.service.telemetry.SessionEvent;
30 35 import org.thingsboard.server.service.telemetry.TelemetryWebSocketMsgEndpoint;
31 36 import org.thingsboard.server.service.telemetry.TelemetryWebSocketService;
... ... @@ -34,6 +39,7 @@ import org.thingsboard.server.service.telemetry.TelemetryWebSocketSessionRef;
34 39 import java.io.IOException;
35 40 import java.net.URI;
36 41 import java.security.InvalidParameterException;
  42 +import java.util.Set;
37 43 import java.util.UUID;
38 44 import java.util.concurrent.ConcurrentHashMap;
39 45 import java.util.concurrent.ConcurrentMap;
... ... @@ -48,12 +54,26 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
48 54 @Autowired
49 55 private TelemetryWebSocketService webSocketService;
50 56
  57 + @Value("${server.ws.limits.max_sessions_per_tenant:0}")
  58 + private int maxSessionsPerTenant;
  59 + @Value("${server.ws.limits.max_sessions_per_customer:0}")
  60 + private int maxSessionsPerCustomer;
  61 + @Value("${server.ws.limits.max_sessions_per_regular_user:0}")
  62 + private int maxSessionsPerRegularUser;
  63 + @Value("${server.ws.limits.max_sessions_per_public_user:0}")
  64 + private int maxSessionsPerPublicUser;
  65 +
  66 + private ConcurrentMap<TenantId, Set<String>> tenantSessionsMap = new ConcurrentHashMap<>();
  67 + private ConcurrentMap<CustomerId, Set<String>> customerSessionsMap = new ConcurrentHashMap<>();
  68 + private ConcurrentMap<UserId, Set<String>> regularUserSessionsMap = new ConcurrentHashMap<>();
  69 + private ConcurrentMap<UserId, Set<String>> publicUserSessionsMap = new ConcurrentHashMap<>();
  70 +
51 71 @Override
52 72 public void handleTextMessage(WebSocketSession session, TextMessage message) {
53 73 try {
54 74 SessionMetaData sessionMd = internalSessionMap.get(session.getId());
55 75 if (sessionMd != null) {
56   - log.info("[{}][{}] Processing {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message);
  76 + log.info("[{}][{}] Processing {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message.getPayload());
57 77 webSocketService.handleWebSocketMsg(sessionMd.sessionRef, message.getPayload());
58 78 } else {
59 79 log.warn("[{}] Failed to find session", session.getId());
... ... @@ -71,12 +91,15 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
71 91 String internalSessionId = session.getId();
72 92 TelemetryWebSocketSessionRef sessionRef = toRef(session);
73 93 String externalSessionId = sessionRef.getSessionId();
  94 + if (!checkLimits(session, sessionRef)) {
  95 + return;
  96 + }
74 97 internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef));
75 98 externalSessionMap.put(externalSessionId, internalSessionId);
76 99 processInWebSocketService(sessionRef, SessionEvent.onEstablished());
77 100 log.info("[{}][{}][{}] Session is opened", sessionRef.getSecurityCtx().getTenantId(), externalSessionId, session.getId());
78 101 } catch (InvalidParameterException e) {
79   - log.warn("[[{}] Failed to start session", session.getId(), e);
  102 + log.warn("[{}] Failed to start session", session.getId(), e);
80 103 session.close(CloseStatus.BAD_DATA.withReason(e.getMessage()));
81 104 } catch (Exception e) {
82 105 log.warn("[{}] Failed to start session", session.getId(), e);
... ... @@ -101,6 +124,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
101 124 super.afterConnectionClosed(session, closeStatus);
102 125 SessionMetaData sessionMd = internalSessionMap.remove(session.getId());
103 126 if (sessionMd != null) {
  127 + cleanupLimits(session, sessionMd.sessionRef);
104 128 externalSessionMap.remove(sessionMd.sessionRef.getSessionId());
105 129 processInWebSocketService(sessionMd.sessionRef, SessionEvent.onClosed());
106 130 }
... ... @@ -136,7 +160,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
136 160 private final WebSocketSession session;
137 161 private final TelemetryWebSocketSessionRef sessionRef;
138 162
139   - public SessionMetaData(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) {
  163 + SessionMetaData(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) {
140 164 super();
141 165 this.session = session;
142 166 this.sessionRef = sessionRef;
... ... @@ -162,15 +186,21 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
162 186 }
163 187 }
164 188
  189 +
165 190 @Override
166 191 public void close(TelemetryWebSocketSessionRef sessionRef) throws IOException {
  192 + close(sessionRef, CloseStatus.NORMAL);
  193 + }
  194 +
  195 + @Override
  196 + public void close(TelemetryWebSocketSessionRef sessionRef, CloseStatus reason) throws IOException {
167 197 String externalId = sessionRef.getSessionId();
168 198 log.debug("[{}] Processing close request", externalId);
169 199 String internalId = externalSessionMap.get(externalId);
170 200 if (internalId != null) {
171 201 SessionMetaData sessionMd = internalSessionMap.get(internalId);
172 202 if (sessionMd != null) {
173   - sessionMd.session.close(CloseStatus.NORMAL);
  203 + sessionMd.session.close(reason);
174 204 } else {
175 205 log.warn("[{}][{}] Failed to find session by internal id", externalId, internalId);
176 206 }
... ... @@ -179,4 +209,94 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
179 209 }
180 210 }
181 211
  212 + private boolean checkLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) throws Exception {
  213 + String sessionId = session.getId();
  214 + if (maxSessionsPerTenant > 0) {
  215 + Set<String> tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet());
  216 + synchronized (tenantSessions) {
  217 + if (tenantSessions.size() < maxSessionsPerTenant) {
  218 + tenantSessions.add(sessionId);
  219 + } else {
  220 + log.info("[{}][{}][{}] Failed to start session. Max tenant sessions limit reached"
  221 + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
  222 + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max tenant sessions limit reached!"));
  223 + return false;
  224 + }
  225 + }
  226 + }
  227 +
  228 + if (sessionRef.getSecurityCtx().isCustomerUser()) {
  229 + if (maxSessionsPerCustomer > 0) {
  230 + Set<String> customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet());
  231 + synchronized (customerSessions) {
  232 + if (customerSessions.size() < maxSessionsPerCustomer) {
  233 + customerSessions.add(sessionId);
  234 + } else {
  235 + log.info("[{}][{}][{}] Failed to start session. Max customer sessions limit reached"
  236 + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
  237 + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max customer sessions limit reached"));
  238 + return false;
  239 + }
  240 + }
  241 + }
  242 + if (maxSessionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
  243 + Set<String> regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
  244 + synchronized (regularUserSessions) {
  245 + if (regularUserSessions.size() < maxSessionsPerRegularUser) {
  246 + regularUserSessions.add(sessionId);
  247 + } else {
  248 + log.info("[{}][{}][{}] Failed to start session. Max user sessions limit reached"
  249 + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
  250 + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max regular user sessions limit reached"));
  251 + return false;
  252 + }
  253 + }
  254 + }
  255 + if (maxSessionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
  256 + Set<String> publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
  257 + synchronized (publicUserSessions) {
  258 + if (publicUserSessions.size() < maxSessionsPerPublicUser) {
  259 + publicUserSessions.add(sessionId);
  260 + } else {
  261 + log.info("[{}][{}][{}] Failed to start session. Max user sessions limit reached"
  262 + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
  263 + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max public user sessions limit reached"));
  264 + return false;
  265 + }
  266 + }
  267 + }
  268 + }
  269 + return true;
  270 + }
  271 +
  272 + private void cleanupLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) {
  273 + String sessionId = session.getId();
  274 + if (maxSessionsPerTenant > 0) {
  275 + Set<String> tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet());
  276 + synchronized (tenantSessions) {
  277 + tenantSessions.remove(sessionId);
  278 + }
  279 + }
  280 + if (sessionRef.getSecurityCtx().isCustomerUser()) {
  281 + if (maxSessionsPerCustomer > 0) {
  282 + Set<String> customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet());
  283 + synchronized (customerSessions) {
  284 + customerSessions.remove(sessionId);
  285 + }
  286 + }
  287 + if (maxSessionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
  288 + Set<String> regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
  289 + synchronized (regularUserSessions) {
  290 + regularUserSessions.remove(sessionId);
  291 + }
  292 + }
  293 + if (maxSessionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
  294 + Set<String> publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
  295 + synchronized (publicUserSessions) {
  296 + publicUserSessions.remove(sessionId);
  297 + }
  298 + }
  299 + }
  300 + }
  301 +
182 302 }
... ...
... ... @@ -23,12 +23,17 @@ import com.google.common.util.concurrent.Futures;
23 23 import com.google.common.util.concurrent.ListenableFuture;
24 24 import lombok.extern.slf4j.Slf4j;
25 25 import org.springframework.beans.factory.annotation.Autowired;
  26 +import org.springframework.beans.factory.annotation.Value;
26 27 import org.springframework.stereotype.Service;
27 28 import org.springframework.util.StringUtils;
  29 +import org.springframework.web.socket.CloseStatus;
  30 +import org.springframework.web.socket.WebSocketSession;
28 31 import org.thingsboard.server.common.data.DataConstants;
  32 +import org.thingsboard.server.common.data.id.CustomerId;
29 33 import org.thingsboard.server.common.data.id.EntityId;
30 34 import org.thingsboard.server.common.data.id.EntityIdFactory;
31 35 import org.thingsboard.server.common.data.id.TenantId;
  36 +import org.thingsboard.server.common.data.id.UserId;
32 37 import org.thingsboard.server.common.data.kv.Aggregation;
33 38 import org.thingsboard.server.common.data.kv.AttributeKvEntry;
34 39 import org.thingsboard.server.common.data.kv.BaseReadTsKvQuery;
... ... @@ -42,6 +47,7 @@ import org.thingsboard.server.service.security.AccessValidator;
42 47 import org.thingsboard.server.service.security.ValidationCallback;
43 48 import org.thingsboard.server.service.security.ValidationResult;
44 49 import org.thingsboard.server.service.security.ValidationResultCode;
  50 +import org.thingsboard.server.service.security.model.UserPrincipal;
45 51 import org.thingsboard.server.service.telemetry.cmd.AttributesSubscriptionCmd;
46 52 import org.thingsboard.server.service.telemetry.cmd.GetHistoryCmd;
47 53 import org.thingsboard.server.service.telemetry.cmd.SubscriptionCmd;
... ... @@ -64,6 +70,7 @@ import java.util.ArrayList;
64 70 import java.util.Collections;
65 71 import java.util.HashMap;
66 72 import java.util.HashSet;
  73 +import java.util.Iterator;
67 74 import java.util.List;
68 75 import java.util.Map;
69 76 import java.util.Optional;
... ... @@ -112,11 +119,25 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
112 119 @Autowired
113 120 private TimeseriesService tsService;
114 121
  122 + @Value("${server.ws.limits.max_subscriptions_per_tenant:0}")
  123 + private int maxSubscriptionsPerTenant;
  124 + @Value("${server.ws.limits.max_subscriptions_per_customer:0}")
  125 + private int maxSubscriptionsPerCustomer;
  126 + @Value("${server.ws.limits.max_subscriptions_per_regular_user:0}")
  127 + private int maxSubscriptionsPerRegularUser;
  128 + @Value("${server.ws.limits.max_subscriptions_per_public_user:0}")
  129 + private int maxSubscriptionsPerPublicUser;
  130 +
  131 + private ConcurrentMap<TenantId, Set<String>> tenantSubscriptionsMap = new ConcurrentHashMap<>();
  132 + private ConcurrentMap<CustomerId, Set<String>> customerSubscriptionsMap = new ConcurrentHashMap<>();
  133 + private ConcurrentMap<UserId, Set<String>> regularUserSubscriptionsMap = new ConcurrentHashMap<>();
  134 + private ConcurrentMap<UserId, Set<String>> publicUserSubscriptionsMap = new ConcurrentHashMap<>();
  135 +
115 136 private ExecutorService executor;
116 137
117 138 @PostConstruct
118 139 public void initExecutor() {
119   - executor = new ThreadPoolExecutor(0, 50, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue<>());
  140 + executor = new ThreadPoolExecutor(0, 50, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue<>());
120 141 }
121 142
122 143 @PreDestroy
... ... @@ -140,6 +161,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
140 161 case CLOSED:
141 162 wsSessionsMap.remove(sessionId);
142 163 subscriptionManager.cleanupLocalWsSessionSubscriptions(sessionRef, sessionId);
  164 + processSessionClose(sessionRef);
143 165 break;
144 166 }
145 167 }
... ... @@ -154,10 +176,18 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
154 176 TelemetryPluginCmdsWrapper cmdsWrapper = jsonMapper.readValue(msg, TelemetryPluginCmdsWrapper.class);
155 177 if (cmdsWrapper != null) {
156 178 if (cmdsWrapper.getAttrSubCmds() != null) {
157   - cmdsWrapper.getAttrSubCmds().forEach(cmd -> handleWsAttributesSubscriptionCmd(sessionRef, cmd));
  179 + cmdsWrapper.getAttrSubCmds().forEach(cmd -> {
  180 + if (processSubscription(sessionRef, cmd)) {
  181 + handleWsAttributesSubscriptionCmd(sessionRef, cmd);
  182 + }
  183 + });
158 184 }
159 185 if (cmdsWrapper.getTsSubCmds() != null) {
160   - cmdsWrapper.getTsSubCmds().forEach(cmd -> handleWsTimeseriesSubscriptionCmd(sessionRef, cmd));
  186 + cmdsWrapper.getTsSubCmds().forEach(cmd -> {
  187 + if (processSubscription(sessionRef, cmd)) {
  188 + handleWsTimeseriesSubscriptionCmd(sessionRef, cmd);
  189 + }
  190 + });
161 191 }
162 192 if (cmdsWrapper.getHistoryCmds() != null) {
163 193 cmdsWrapper.getHistoryCmds().forEach(cmd -> handleWsHistoryCmd(sessionRef, cmd));
... ... @@ -178,6 +208,105 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
178 208 }
179 209 }
180 210
  211 + private void processSessionClose(TelemetryWebSocketSessionRef sessionRef) {
  212 + String sessionId = "[" + sessionRef.getSessionId() + "]";
  213 + if (maxSubscriptionsPerTenant > 0) {
  214 + Set<String> tenantSubscriptions = tenantSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet());
  215 + synchronized (tenantSubscriptions) {
  216 + tenantSubscriptions.removeIf(subId -> subId.startsWith(sessionId));
  217 + }
  218 + }
  219 + if (sessionRef.getSecurityCtx().isCustomerUser()) {
  220 + if (maxSubscriptionsPerCustomer > 0) {
  221 + Set<String> customerSessions = customerSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet());
  222 + synchronized (customerSessions) {
  223 + customerSessions.removeIf(subId -> subId.startsWith(sessionId));
  224 + }
  225 + }
  226 + if (maxSubscriptionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
  227 + Set<String> regularUserSessions = regularUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
  228 + synchronized (regularUserSessions) {
  229 + regularUserSessions.removeIf(subId -> subId.startsWith(sessionId));
  230 + }
  231 + }
  232 + if (maxSubscriptionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
  233 + Set<String> publicUserSessions = publicUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
  234 + synchronized (publicUserSessions) {
  235 + publicUserSessions.removeIf(subId -> subId.startsWith(sessionId));
  236 + }
  237 + }
  238 + }
  239 + }
  240 +
  241 + private boolean processSubscription(TelemetryWebSocketSessionRef sessionRef, SubscriptionCmd cmd) {
  242 + String subId = "[" + sessionRef.getSessionId() + "]:[" + cmd.getCmdId() + "]";
  243 + try {
  244 + if (maxSubscriptionsPerTenant > 0) {
  245 + Set<String> tenantSubscriptions = tenantSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet());
  246 + synchronized (tenantSubscriptions) {
  247 + if (cmd.isUnsubscribe()) {
  248 + tenantSubscriptions.remove(subId);
  249 + } else if (tenantSubscriptions.size() < maxSubscriptionsPerTenant) {
  250 + tenantSubscriptions.add(subId);
  251 + } else {
  252 + log.info("[{}][{}][{}] Failed to start subscription. Max tenant subscriptions limit reached"
  253 + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId);
  254 + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max tenant subscriptions limit reached!"));
  255 + return false;
  256 + }
  257 + }
  258 + }
  259 +
  260 + if (sessionRef.getSecurityCtx().isCustomerUser()) {
  261 + if (maxSubscriptionsPerCustomer > 0) {
  262 + Set<String> customerSessions = customerSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet());
  263 + synchronized (customerSessions) {
  264 + if (cmd.isUnsubscribe()) {
  265 + customerSessions.remove(subId);
  266 + } else if (customerSessions.size() < maxSubscriptionsPerCustomer) {
  267 + customerSessions.add(subId);
  268 + } else {
  269 + log.info("[{}][{}][{}] Failed to start subscription. Max customer sessions limit reached"
  270 + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId);
  271 + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max customer subscriptions limit reached"));
  272 + return false;
  273 + }
  274 + }
  275 + }
  276 + if (maxSubscriptionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
  277 + Set<String> regularUserSessions = regularUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
  278 + synchronized (regularUserSessions) {
  279 + if (regularUserSessions.size() < maxSubscriptionsPerRegularUser) {
  280 + regularUserSessions.add(subId);
  281 + } else {
  282 + log.info("[{}][{}][{}] Failed to start subscription. Max user sessions limit reached"
  283 + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId);
  284 + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max regular user subscriptions limit reached"));
  285 + return false;
  286 + }
  287 + }
  288 + }
  289 + if (maxSubscriptionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
  290 + Set<String> publicUserSessions = publicUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
  291 + synchronized (publicUserSessions) {
  292 + if (publicUserSessions.size() < maxSubscriptionsPerPublicUser) {
  293 + publicUserSessions.add(subId);
  294 + } else {
  295 + log.info("[{}][{}][{}] Failed to start subscription. Max user sessions limit reached"
  296 + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId);
  297 + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max public user subscriptions limit reached"));
  298 + return false;
  299 + }
  300 + }
  301 + }
  302 + }
  303 + } catch (IOException e) {
  304 + log.warn("[{}] Failed to send session close: {}", sessionRef.getSessionId(), e);
  305 + return false;
  306 + }
  307 + return true;
  308 + }
  309 +
181 310 private void handleWsAttributesSubscriptionCmd(TelemetryWebSocketSessionRef sessionRef, AttributesSubscriptionCmd cmd) {
182 311 String sessionId = sessionRef.getSessionId();
183 312 log.debug("[{}] Processing: {}", sessionId, cmd);
... ... @@ -220,7 +349,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
220 349 public void onFailure(Throwable e) {
221 350 log.error(FAILED_TO_FETCH_ATTRIBUTES, e);
222 351 SubscriptionUpdate update;
223   - if (UnauthorizedException.class.isInstance(e)) {
  352 + if (e instanceof UnauthorizedException) {
224 353 update = new SubscriptionUpdate(cmd.getCmdId(), SubscriptionErrorCode.UNAUTHORIZED,
225 354 SubscriptionErrorCode.UNAUTHORIZED.getDefaultMsg());
226 355 } else {
... ...
... ... @@ -15,6 +15,8 @@
15 15 */
16 16 package org.thingsboard.server.service.telemetry;
17 17
  18 +import org.springframework.web.socket.CloseStatus;
  19 +
18 20 import java.io.IOException;
19 21
20 22 /**
... ... @@ -26,4 +28,5 @@ public interface TelemetryWebSocketMsgEndpoint {
26 28
27 29 void close(TelemetryWebSocketSessionRef sessionRef) throws IOException;
28 30
  31 + void close(TelemetryWebSocketSessionRef sessionRef, CloseStatus withReason) throws IOException;
29 32 }
... ...
... ... @@ -32,6 +32,17 @@ server:
32 32 # Alias that identifies the key in the key store
33 33 key-alias: "${SSL_KEY_ALIAS:tomcat}"
34 34 log_controller_error_stack_trace: "${HTTP_LOG_CONTROLLER_ERROR_STACK_TRACE:true}"
  35 + ws:
  36 + limits:
  37 + # Limit the amount of sessions and subscriptions available on each server. Put values to zero to disable particular limitation
  38 + max_sessions_per_tenant: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_TENANT:0}"
  39 + max_sessions_per_customer: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_CUSTOMER:0}"
  40 + max_sessions_per_regular_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_REGULAR_USER:0}"
  41 + max_sessions_per_public_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_PUBLIC_USER:0}"
  42 + max_subscriptions_per_tenant: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_TENANT:0}"
  43 + max_subscriptions_per_customer: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_CUSTOMER:0}"
  44 + max_subscriptions_per_regular_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_REGULAR_USER:0}"
  45 + max_subscriptions_per_public_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_PUBLIC_USER:0}"
35 46
36 47 # Zookeeper connection parameters. Used for service discovery.
37 48 zk:
... ...