Showing
4 changed files
with
272 additions
and
9 deletions
@@ -18,14 +18,19 @@ package org.thingsboard.server.controller.plugin; | @@ -18,14 +18,19 @@ package org.thingsboard.server.controller.plugin; | ||
18 | import lombok.extern.slf4j.Slf4j; | 18 | import lombok.extern.slf4j.Slf4j; |
19 | import org.springframework.beans.factory.BeanCreationNotAllowedException; | 19 | import org.springframework.beans.factory.BeanCreationNotAllowedException; |
20 | import org.springframework.beans.factory.annotation.Autowired; | 20 | import org.springframework.beans.factory.annotation.Autowired; |
21 | -import org.springframework.context.annotation.Lazy; | 21 | +import org.springframework.beans.factory.annotation.Value; |
22 | +import org.springframework.scheduling.annotation.Scheduled; | ||
22 | import org.springframework.stereotype.Service; | 23 | import org.springframework.stereotype.Service; |
23 | import org.springframework.web.socket.CloseStatus; | 24 | import org.springframework.web.socket.CloseStatus; |
24 | import org.springframework.web.socket.TextMessage; | 25 | import org.springframework.web.socket.TextMessage; |
25 | import org.springframework.web.socket.WebSocketSession; | 26 | import org.springframework.web.socket.WebSocketSession; |
26 | import org.springframework.web.socket.handler.TextWebSocketHandler; | 27 | import org.springframework.web.socket.handler.TextWebSocketHandler; |
28 | +import org.thingsboard.server.common.data.id.CustomerId; | ||
29 | +import org.thingsboard.server.common.data.id.TenantId; | ||
30 | +import org.thingsboard.server.common.data.id.UserId; | ||
27 | import org.thingsboard.server.config.WebSocketConfiguration; | 31 | import org.thingsboard.server.config.WebSocketConfiguration; |
28 | import org.thingsboard.server.service.security.model.SecurityUser; | 32 | import org.thingsboard.server.service.security.model.SecurityUser; |
33 | +import org.thingsboard.server.service.security.model.UserPrincipal; | ||
29 | import org.thingsboard.server.service.telemetry.SessionEvent; | 34 | import org.thingsboard.server.service.telemetry.SessionEvent; |
30 | import org.thingsboard.server.service.telemetry.TelemetryWebSocketMsgEndpoint; | 35 | import org.thingsboard.server.service.telemetry.TelemetryWebSocketMsgEndpoint; |
31 | import org.thingsboard.server.service.telemetry.TelemetryWebSocketService; | 36 | import org.thingsboard.server.service.telemetry.TelemetryWebSocketService; |
@@ -34,6 +39,7 @@ import org.thingsboard.server.service.telemetry.TelemetryWebSocketSessionRef; | @@ -34,6 +39,7 @@ import org.thingsboard.server.service.telemetry.TelemetryWebSocketSessionRef; | ||
34 | import java.io.IOException; | 39 | import java.io.IOException; |
35 | import java.net.URI; | 40 | import java.net.URI; |
36 | import java.security.InvalidParameterException; | 41 | import java.security.InvalidParameterException; |
42 | +import java.util.Set; | ||
37 | import java.util.UUID; | 43 | import java.util.UUID; |
38 | import java.util.concurrent.ConcurrentHashMap; | 44 | import java.util.concurrent.ConcurrentHashMap; |
39 | import java.util.concurrent.ConcurrentMap; | 45 | import java.util.concurrent.ConcurrentMap; |
@@ -48,12 +54,26 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | @@ -48,12 +54,26 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | ||
48 | @Autowired | 54 | @Autowired |
49 | private TelemetryWebSocketService webSocketService; | 55 | private TelemetryWebSocketService webSocketService; |
50 | 56 | ||
57 | + @Value("${server.ws.limits.max_sessions_per_tenant:0}") | ||
58 | + private int maxSessionsPerTenant; | ||
59 | + @Value("${server.ws.limits.max_sessions_per_customer:0}") | ||
60 | + private int maxSessionsPerCustomer; | ||
61 | + @Value("${server.ws.limits.max_sessions_per_regular_user:0}") | ||
62 | + private int maxSessionsPerRegularUser; | ||
63 | + @Value("${server.ws.limits.max_sessions_per_public_user:0}") | ||
64 | + private int maxSessionsPerPublicUser; | ||
65 | + | ||
66 | + private ConcurrentMap<TenantId, Set<String>> tenantSessionsMap = new ConcurrentHashMap<>(); | ||
67 | + private ConcurrentMap<CustomerId, Set<String>> customerSessionsMap = new ConcurrentHashMap<>(); | ||
68 | + private ConcurrentMap<UserId, Set<String>> regularUserSessionsMap = new ConcurrentHashMap<>(); | ||
69 | + private ConcurrentMap<UserId, Set<String>> publicUserSessionsMap = new ConcurrentHashMap<>(); | ||
70 | + | ||
51 | @Override | 71 | @Override |
52 | public void handleTextMessage(WebSocketSession session, TextMessage message) { | 72 | public void handleTextMessage(WebSocketSession session, TextMessage message) { |
53 | try { | 73 | try { |
54 | SessionMetaData sessionMd = internalSessionMap.get(session.getId()); | 74 | SessionMetaData sessionMd = internalSessionMap.get(session.getId()); |
55 | if (sessionMd != null) { | 75 | if (sessionMd != null) { |
56 | - log.info("[{}][{}] Processing {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message); | 76 | + log.info("[{}][{}] Processing {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message.getPayload()); |
57 | webSocketService.handleWebSocketMsg(sessionMd.sessionRef, message.getPayload()); | 77 | webSocketService.handleWebSocketMsg(sessionMd.sessionRef, message.getPayload()); |
58 | } else { | 78 | } else { |
59 | log.warn("[{}] Failed to find session", session.getId()); | 79 | log.warn("[{}] Failed to find session", session.getId()); |
@@ -71,12 +91,15 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | @@ -71,12 +91,15 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | ||
71 | String internalSessionId = session.getId(); | 91 | String internalSessionId = session.getId(); |
72 | TelemetryWebSocketSessionRef sessionRef = toRef(session); | 92 | TelemetryWebSocketSessionRef sessionRef = toRef(session); |
73 | String externalSessionId = sessionRef.getSessionId(); | 93 | String externalSessionId = sessionRef.getSessionId(); |
94 | + if (!checkLimits(session, sessionRef)) { | ||
95 | + return; | ||
96 | + } | ||
74 | internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef)); | 97 | internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef)); |
75 | externalSessionMap.put(externalSessionId, internalSessionId); | 98 | externalSessionMap.put(externalSessionId, internalSessionId); |
76 | processInWebSocketService(sessionRef, SessionEvent.onEstablished()); | 99 | processInWebSocketService(sessionRef, SessionEvent.onEstablished()); |
77 | log.info("[{}][{}][{}] Session is opened", sessionRef.getSecurityCtx().getTenantId(), externalSessionId, session.getId()); | 100 | log.info("[{}][{}][{}] Session is opened", sessionRef.getSecurityCtx().getTenantId(), externalSessionId, session.getId()); |
78 | } catch (InvalidParameterException e) { | 101 | } catch (InvalidParameterException e) { |
79 | - log.warn("[[{}] Failed to start session", session.getId(), e); | 102 | + log.warn("[{}] Failed to start session", session.getId(), e); |
80 | session.close(CloseStatus.BAD_DATA.withReason(e.getMessage())); | 103 | session.close(CloseStatus.BAD_DATA.withReason(e.getMessage())); |
81 | } catch (Exception e) { | 104 | } catch (Exception e) { |
82 | log.warn("[{}] Failed to start session", session.getId(), e); | 105 | log.warn("[{}] Failed to start session", session.getId(), e); |
@@ -101,6 +124,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | @@ -101,6 +124,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | ||
101 | super.afterConnectionClosed(session, closeStatus); | 124 | super.afterConnectionClosed(session, closeStatus); |
102 | SessionMetaData sessionMd = internalSessionMap.remove(session.getId()); | 125 | SessionMetaData sessionMd = internalSessionMap.remove(session.getId()); |
103 | if (sessionMd != null) { | 126 | if (sessionMd != null) { |
127 | + cleanupLimits(session, sessionMd.sessionRef); | ||
104 | externalSessionMap.remove(sessionMd.sessionRef.getSessionId()); | 128 | externalSessionMap.remove(sessionMd.sessionRef.getSessionId()); |
105 | processInWebSocketService(sessionMd.sessionRef, SessionEvent.onClosed()); | 129 | processInWebSocketService(sessionMd.sessionRef, SessionEvent.onClosed()); |
106 | } | 130 | } |
@@ -136,7 +160,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | @@ -136,7 +160,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | ||
136 | private final WebSocketSession session; | 160 | private final WebSocketSession session; |
137 | private final TelemetryWebSocketSessionRef sessionRef; | 161 | private final TelemetryWebSocketSessionRef sessionRef; |
138 | 162 | ||
139 | - public SessionMetaData(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) { | 163 | + SessionMetaData(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) { |
140 | super(); | 164 | super(); |
141 | this.session = session; | 165 | this.session = session; |
142 | this.sessionRef = sessionRef; | 166 | this.sessionRef = sessionRef; |
@@ -162,15 +186,21 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | @@ -162,15 +186,21 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | ||
162 | } | 186 | } |
163 | } | 187 | } |
164 | 188 | ||
189 | + | ||
165 | @Override | 190 | @Override |
166 | public void close(TelemetryWebSocketSessionRef sessionRef) throws IOException { | 191 | public void close(TelemetryWebSocketSessionRef sessionRef) throws IOException { |
192 | + close(sessionRef, CloseStatus.NORMAL); | ||
193 | + } | ||
194 | + | ||
195 | + @Override | ||
196 | + public void close(TelemetryWebSocketSessionRef sessionRef, CloseStatus reason) throws IOException { | ||
167 | String externalId = sessionRef.getSessionId(); | 197 | String externalId = sessionRef.getSessionId(); |
168 | log.debug("[{}] Processing close request", externalId); | 198 | log.debug("[{}] Processing close request", externalId); |
169 | String internalId = externalSessionMap.get(externalId); | 199 | String internalId = externalSessionMap.get(externalId); |
170 | if (internalId != null) { | 200 | if (internalId != null) { |
171 | SessionMetaData sessionMd = internalSessionMap.get(internalId); | 201 | SessionMetaData sessionMd = internalSessionMap.get(internalId); |
172 | if (sessionMd != null) { | 202 | if (sessionMd != null) { |
173 | - sessionMd.session.close(CloseStatus.NORMAL); | 203 | + sessionMd.session.close(reason); |
174 | } else { | 204 | } else { |
175 | log.warn("[{}][{}] Failed to find session by internal id", externalId, internalId); | 205 | log.warn("[{}][{}] Failed to find session by internal id", externalId, internalId); |
176 | } | 206 | } |
@@ -179,4 +209,94 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | @@ -179,4 +209,94 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr | ||
179 | } | 209 | } |
180 | } | 210 | } |
181 | 211 | ||
212 | + private boolean checkLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) throws Exception { | ||
213 | + String sessionId = session.getId(); | ||
214 | + if (maxSessionsPerTenant > 0) { | ||
215 | + Set<String> tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet()); | ||
216 | + synchronized (tenantSessions) { | ||
217 | + if (tenantSessions.size() < maxSessionsPerTenant) { | ||
218 | + tenantSessions.add(sessionId); | ||
219 | + } else { | ||
220 | + log.info("[{}][{}][{}] Failed to start session. Max tenant sessions limit reached" | ||
221 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); | ||
222 | + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max tenant sessions limit reached!")); | ||
223 | + return false; | ||
224 | + } | ||
225 | + } | ||
226 | + } | ||
227 | + | ||
228 | + if (sessionRef.getSecurityCtx().isCustomerUser()) { | ||
229 | + if (maxSessionsPerCustomer > 0) { | ||
230 | + Set<String> customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet()); | ||
231 | + synchronized (customerSessions) { | ||
232 | + if (customerSessions.size() < maxSessionsPerCustomer) { | ||
233 | + customerSessions.add(sessionId); | ||
234 | + } else { | ||
235 | + log.info("[{}][{}][{}] Failed to start session. Max customer sessions limit reached" | ||
236 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); | ||
237 | + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max customer sessions limit reached")); | ||
238 | + return false; | ||
239 | + } | ||
240 | + } | ||
241 | + } | ||
242 | + if (maxSessionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | ||
243 | + Set<String> regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | ||
244 | + synchronized (regularUserSessions) { | ||
245 | + if (regularUserSessions.size() < maxSessionsPerRegularUser) { | ||
246 | + regularUserSessions.add(sessionId); | ||
247 | + } else { | ||
248 | + log.info("[{}][{}][{}] Failed to start session. Max user sessions limit reached" | ||
249 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); | ||
250 | + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max regular user sessions limit reached")); | ||
251 | + return false; | ||
252 | + } | ||
253 | + } | ||
254 | + } | ||
255 | + if (maxSessionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | ||
256 | + Set<String> publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | ||
257 | + synchronized (publicUserSessions) { | ||
258 | + if (publicUserSessions.size() < maxSessionsPerPublicUser) { | ||
259 | + publicUserSessions.add(sessionId); | ||
260 | + } else { | ||
261 | + log.info("[{}][{}][{}] Failed to start session. Max user sessions limit reached" | ||
262 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); | ||
263 | + session.close(CloseStatus.POLICY_VIOLATION.withReason("Max public user sessions limit reached")); | ||
264 | + return false; | ||
265 | + } | ||
266 | + } | ||
267 | + } | ||
268 | + } | ||
269 | + return true; | ||
270 | + } | ||
271 | + | ||
272 | + private void cleanupLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) { | ||
273 | + String sessionId = session.getId(); | ||
274 | + if (maxSessionsPerTenant > 0) { | ||
275 | + Set<String> tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet()); | ||
276 | + synchronized (tenantSessions) { | ||
277 | + tenantSessions.remove(sessionId); | ||
278 | + } | ||
279 | + } | ||
280 | + if (sessionRef.getSecurityCtx().isCustomerUser()) { | ||
281 | + if (maxSessionsPerCustomer > 0) { | ||
282 | + Set<String> customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet()); | ||
283 | + synchronized (customerSessions) { | ||
284 | + customerSessions.remove(sessionId); | ||
285 | + } | ||
286 | + } | ||
287 | + if (maxSessionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | ||
288 | + Set<String> regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | ||
289 | + synchronized (regularUserSessions) { | ||
290 | + regularUserSessions.remove(sessionId); | ||
291 | + } | ||
292 | + } | ||
293 | + if (maxSessionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | ||
294 | + Set<String> publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | ||
295 | + synchronized (publicUserSessions) { | ||
296 | + publicUserSessions.remove(sessionId); | ||
297 | + } | ||
298 | + } | ||
299 | + } | ||
300 | + } | ||
301 | + | ||
182 | } | 302 | } |
@@ -23,12 +23,17 @@ import com.google.common.util.concurrent.Futures; | @@ -23,12 +23,17 @@ import com.google.common.util.concurrent.Futures; | ||
23 | import com.google.common.util.concurrent.ListenableFuture; | 23 | import com.google.common.util.concurrent.ListenableFuture; |
24 | import lombok.extern.slf4j.Slf4j; | 24 | import lombok.extern.slf4j.Slf4j; |
25 | import org.springframework.beans.factory.annotation.Autowired; | 25 | import org.springframework.beans.factory.annotation.Autowired; |
26 | +import org.springframework.beans.factory.annotation.Value; | ||
26 | import org.springframework.stereotype.Service; | 27 | import org.springframework.stereotype.Service; |
27 | import org.springframework.util.StringUtils; | 28 | import org.springframework.util.StringUtils; |
29 | +import org.springframework.web.socket.CloseStatus; | ||
30 | +import org.springframework.web.socket.WebSocketSession; | ||
28 | import org.thingsboard.server.common.data.DataConstants; | 31 | import org.thingsboard.server.common.data.DataConstants; |
32 | +import org.thingsboard.server.common.data.id.CustomerId; | ||
29 | import org.thingsboard.server.common.data.id.EntityId; | 33 | import org.thingsboard.server.common.data.id.EntityId; |
30 | import org.thingsboard.server.common.data.id.EntityIdFactory; | 34 | import org.thingsboard.server.common.data.id.EntityIdFactory; |
31 | import org.thingsboard.server.common.data.id.TenantId; | 35 | import org.thingsboard.server.common.data.id.TenantId; |
36 | +import org.thingsboard.server.common.data.id.UserId; | ||
32 | import org.thingsboard.server.common.data.kv.Aggregation; | 37 | import org.thingsboard.server.common.data.kv.Aggregation; |
33 | import org.thingsboard.server.common.data.kv.AttributeKvEntry; | 38 | import org.thingsboard.server.common.data.kv.AttributeKvEntry; |
34 | import org.thingsboard.server.common.data.kv.BaseReadTsKvQuery; | 39 | import org.thingsboard.server.common.data.kv.BaseReadTsKvQuery; |
@@ -42,6 +47,7 @@ import org.thingsboard.server.service.security.AccessValidator; | @@ -42,6 +47,7 @@ import org.thingsboard.server.service.security.AccessValidator; | ||
42 | import org.thingsboard.server.service.security.ValidationCallback; | 47 | import org.thingsboard.server.service.security.ValidationCallback; |
43 | import org.thingsboard.server.service.security.ValidationResult; | 48 | import org.thingsboard.server.service.security.ValidationResult; |
44 | import org.thingsboard.server.service.security.ValidationResultCode; | 49 | import org.thingsboard.server.service.security.ValidationResultCode; |
50 | +import org.thingsboard.server.service.security.model.UserPrincipal; | ||
45 | import org.thingsboard.server.service.telemetry.cmd.AttributesSubscriptionCmd; | 51 | import org.thingsboard.server.service.telemetry.cmd.AttributesSubscriptionCmd; |
46 | import org.thingsboard.server.service.telemetry.cmd.GetHistoryCmd; | 52 | import org.thingsboard.server.service.telemetry.cmd.GetHistoryCmd; |
47 | import org.thingsboard.server.service.telemetry.cmd.SubscriptionCmd; | 53 | import org.thingsboard.server.service.telemetry.cmd.SubscriptionCmd; |
@@ -64,6 +70,7 @@ import java.util.ArrayList; | @@ -64,6 +70,7 @@ import java.util.ArrayList; | ||
64 | import java.util.Collections; | 70 | import java.util.Collections; |
65 | import java.util.HashMap; | 71 | import java.util.HashMap; |
66 | import java.util.HashSet; | 72 | import java.util.HashSet; |
73 | +import java.util.Iterator; | ||
67 | import java.util.List; | 74 | import java.util.List; |
68 | import java.util.Map; | 75 | import java.util.Map; |
69 | import java.util.Optional; | 76 | import java.util.Optional; |
@@ -112,11 +119,25 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi | @@ -112,11 +119,25 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi | ||
112 | @Autowired | 119 | @Autowired |
113 | private TimeseriesService tsService; | 120 | private TimeseriesService tsService; |
114 | 121 | ||
122 | + @Value("${server.ws.limits.max_subscriptions_per_tenant:0}") | ||
123 | + private int maxSubscriptionsPerTenant; | ||
124 | + @Value("${server.ws.limits.max_subscriptions_per_customer:0}") | ||
125 | + private int maxSubscriptionsPerCustomer; | ||
126 | + @Value("${server.ws.limits.max_subscriptions_per_regular_user:0}") | ||
127 | + private int maxSubscriptionsPerRegularUser; | ||
128 | + @Value("${server.ws.limits.max_subscriptions_per_public_user:0}") | ||
129 | + private int maxSubscriptionsPerPublicUser; | ||
130 | + | ||
131 | + private ConcurrentMap<TenantId, Set<String>> tenantSubscriptionsMap = new ConcurrentHashMap<>(); | ||
132 | + private ConcurrentMap<CustomerId, Set<String>> customerSubscriptionsMap = new ConcurrentHashMap<>(); | ||
133 | + private ConcurrentMap<UserId, Set<String>> regularUserSubscriptionsMap = new ConcurrentHashMap<>(); | ||
134 | + private ConcurrentMap<UserId, Set<String>> publicUserSubscriptionsMap = new ConcurrentHashMap<>(); | ||
135 | + | ||
115 | private ExecutorService executor; | 136 | private ExecutorService executor; |
116 | 137 | ||
117 | @PostConstruct | 138 | @PostConstruct |
118 | public void initExecutor() { | 139 | public void initExecutor() { |
119 | - executor = new ThreadPoolExecutor(0, 50, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue<>()); | 140 | + executor = new ThreadPoolExecutor(0, 50, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue<>()); |
120 | } | 141 | } |
121 | 142 | ||
122 | @PreDestroy | 143 | @PreDestroy |
@@ -140,6 +161,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi | @@ -140,6 +161,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi | ||
140 | case CLOSED: | 161 | case CLOSED: |
141 | wsSessionsMap.remove(sessionId); | 162 | wsSessionsMap.remove(sessionId); |
142 | subscriptionManager.cleanupLocalWsSessionSubscriptions(sessionRef, sessionId); | 163 | subscriptionManager.cleanupLocalWsSessionSubscriptions(sessionRef, sessionId); |
164 | + processSessionClose(sessionRef); | ||
143 | break; | 165 | break; |
144 | } | 166 | } |
145 | } | 167 | } |
@@ -154,10 +176,18 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi | @@ -154,10 +176,18 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi | ||
154 | TelemetryPluginCmdsWrapper cmdsWrapper = jsonMapper.readValue(msg, TelemetryPluginCmdsWrapper.class); | 176 | TelemetryPluginCmdsWrapper cmdsWrapper = jsonMapper.readValue(msg, TelemetryPluginCmdsWrapper.class); |
155 | if (cmdsWrapper != null) { | 177 | if (cmdsWrapper != null) { |
156 | if (cmdsWrapper.getAttrSubCmds() != null) { | 178 | if (cmdsWrapper.getAttrSubCmds() != null) { |
157 | - cmdsWrapper.getAttrSubCmds().forEach(cmd -> handleWsAttributesSubscriptionCmd(sessionRef, cmd)); | 179 | + cmdsWrapper.getAttrSubCmds().forEach(cmd -> { |
180 | + if (processSubscription(sessionRef, cmd)) { | ||
181 | + handleWsAttributesSubscriptionCmd(sessionRef, cmd); | ||
182 | + } | ||
183 | + }); | ||
158 | } | 184 | } |
159 | if (cmdsWrapper.getTsSubCmds() != null) { | 185 | if (cmdsWrapper.getTsSubCmds() != null) { |
160 | - cmdsWrapper.getTsSubCmds().forEach(cmd -> handleWsTimeseriesSubscriptionCmd(sessionRef, cmd)); | 186 | + cmdsWrapper.getTsSubCmds().forEach(cmd -> { |
187 | + if (processSubscription(sessionRef, cmd)) { | ||
188 | + handleWsTimeseriesSubscriptionCmd(sessionRef, cmd); | ||
189 | + } | ||
190 | + }); | ||
161 | } | 191 | } |
162 | if (cmdsWrapper.getHistoryCmds() != null) { | 192 | if (cmdsWrapper.getHistoryCmds() != null) { |
163 | cmdsWrapper.getHistoryCmds().forEach(cmd -> handleWsHistoryCmd(sessionRef, cmd)); | 193 | cmdsWrapper.getHistoryCmds().forEach(cmd -> handleWsHistoryCmd(sessionRef, cmd)); |
@@ -178,6 +208,105 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi | @@ -178,6 +208,105 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi | ||
178 | } | 208 | } |
179 | } | 209 | } |
180 | 210 | ||
211 | + private void processSessionClose(TelemetryWebSocketSessionRef sessionRef) { | ||
212 | + String sessionId = "[" + sessionRef.getSessionId() + "]"; | ||
213 | + if (maxSubscriptionsPerTenant > 0) { | ||
214 | + Set<String> tenantSubscriptions = tenantSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet()); | ||
215 | + synchronized (tenantSubscriptions) { | ||
216 | + tenantSubscriptions.removeIf(subId -> subId.startsWith(sessionId)); | ||
217 | + } | ||
218 | + } | ||
219 | + if (sessionRef.getSecurityCtx().isCustomerUser()) { | ||
220 | + if (maxSubscriptionsPerCustomer > 0) { | ||
221 | + Set<String> customerSessions = customerSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet()); | ||
222 | + synchronized (customerSessions) { | ||
223 | + customerSessions.removeIf(subId -> subId.startsWith(sessionId)); | ||
224 | + } | ||
225 | + } | ||
226 | + if (maxSubscriptionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | ||
227 | + Set<String> regularUserSessions = regularUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | ||
228 | + synchronized (regularUserSessions) { | ||
229 | + regularUserSessions.removeIf(subId -> subId.startsWith(sessionId)); | ||
230 | + } | ||
231 | + } | ||
232 | + if (maxSubscriptionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | ||
233 | + Set<String> publicUserSessions = publicUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | ||
234 | + synchronized (publicUserSessions) { | ||
235 | + publicUserSessions.removeIf(subId -> subId.startsWith(sessionId)); | ||
236 | + } | ||
237 | + } | ||
238 | + } | ||
239 | + } | ||
240 | + | ||
241 | + private boolean processSubscription(TelemetryWebSocketSessionRef sessionRef, SubscriptionCmd cmd) { | ||
242 | + String subId = "[" + sessionRef.getSessionId() + "]:[" + cmd.getCmdId() + "]"; | ||
243 | + try { | ||
244 | + if (maxSubscriptionsPerTenant > 0) { | ||
245 | + Set<String> tenantSubscriptions = tenantSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet()); | ||
246 | + synchronized (tenantSubscriptions) { | ||
247 | + if (cmd.isUnsubscribe()) { | ||
248 | + tenantSubscriptions.remove(subId); | ||
249 | + } else if (tenantSubscriptions.size() < maxSubscriptionsPerTenant) { | ||
250 | + tenantSubscriptions.add(subId); | ||
251 | + } else { | ||
252 | + log.info("[{}][{}][{}] Failed to start subscription. Max tenant subscriptions limit reached" | ||
253 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId); | ||
254 | + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max tenant subscriptions limit reached!")); | ||
255 | + return false; | ||
256 | + } | ||
257 | + } | ||
258 | + } | ||
259 | + | ||
260 | + if (sessionRef.getSecurityCtx().isCustomerUser()) { | ||
261 | + if (maxSubscriptionsPerCustomer > 0) { | ||
262 | + Set<String> customerSessions = customerSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet()); | ||
263 | + synchronized (customerSessions) { | ||
264 | + if (cmd.isUnsubscribe()) { | ||
265 | + customerSessions.remove(subId); | ||
266 | + } else if (customerSessions.size() < maxSubscriptionsPerCustomer) { | ||
267 | + customerSessions.add(subId); | ||
268 | + } else { | ||
269 | + log.info("[{}][{}][{}] Failed to start subscription. Max customer sessions limit reached" | ||
270 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId); | ||
271 | + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max customer subscriptions limit reached")); | ||
272 | + return false; | ||
273 | + } | ||
274 | + } | ||
275 | + } | ||
276 | + if (maxSubscriptionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | ||
277 | + Set<String> regularUserSessions = regularUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | ||
278 | + synchronized (regularUserSessions) { | ||
279 | + if (regularUserSessions.size() < maxSubscriptionsPerRegularUser) { | ||
280 | + regularUserSessions.add(subId); | ||
281 | + } else { | ||
282 | + log.info("[{}][{}][{}] Failed to start subscription. Max user sessions limit reached" | ||
283 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId); | ||
284 | + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max regular user subscriptions limit reached")); | ||
285 | + return false; | ||
286 | + } | ||
287 | + } | ||
288 | + } | ||
289 | + if (maxSubscriptionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { | ||
290 | + Set<String> publicUserSessions = publicUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); | ||
291 | + synchronized (publicUserSessions) { | ||
292 | + if (publicUserSessions.size() < maxSubscriptionsPerPublicUser) { | ||
293 | + publicUserSessions.add(subId); | ||
294 | + } else { | ||
295 | + log.info("[{}][{}][{}] Failed to start subscription. Max user sessions limit reached" | ||
296 | + , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId); | ||
297 | + msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max public user subscriptions limit reached")); | ||
298 | + return false; | ||
299 | + } | ||
300 | + } | ||
301 | + } | ||
302 | + } | ||
303 | + } catch (IOException e) { | ||
304 | + log.warn("[{}] Failed to send session close: {}", sessionRef.getSessionId(), e); | ||
305 | + return false; | ||
306 | + } | ||
307 | + return true; | ||
308 | + } | ||
309 | + | ||
181 | private void handleWsAttributesSubscriptionCmd(TelemetryWebSocketSessionRef sessionRef, AttributesSubscriptionCmd cmd) { | 310 | private void handleWsAttributesSubscriptionCmd(TelemetryWebSocketSessionRef sessionRef, AttributesSubscriptionCmd cmd) { |
182 | String sessionId = sessionRef.getSessionId(); | 311 | String sessionId = sessionRef.getSessionId(); |
183 | log.debug("[{}] Processing: {}", sessionId, cmd); | 312 | log.debug("[{}] Processing: {}", sessionId, cmd); |
@@ -220,7 +349,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi | @@ -220,7 +349,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi | ||
220 | public void onFailure(Throwable e) { | 349 | public void onFailure(Throwable e) { |
221 | log.error(FAILED_TO_FETCH_ATTRIBUTES, e); | 350 | log.error(FAILED_TO_FETCH_ATTRIBUTES, e); |
222 | SubscriptionUpdate update; | 351 | SubscriptionUpdate update; |
223 | - if (UnauthorizedException.class.isInstance(e)) { | 352 | + if (e instanceof UnauthorizedException) { |
224 | update = new SubscriptionUpdate(cmd.getCmdId(), SubscriptionErrorCode.UNAUTHORIZED, | 353 | update = new SubscriptionUpdate(cmd.getCmdId(), SubscriptionErrorCode.UNAUTHORIZED, |
225 | SubscriptionErrorCode.UNAUTHORIZED.getDefaultMsg()); | 354 | SubscriptionErrorCode.UNAUTHORIZED.getDefaultMsg()); |
226 | } else { | 355 | } else { |
@@ -15,6 +15,8 @@ | @@ -15,6 +15,8 @@ | ||
15 | */ | 15 | */ |
16 | package org.thingsboard.server.service.telemetry; | 16 | package org.thingsboard.server.service.telemetry; |
17 | 17 | ||
18 | +import org.springframework.web.socket.CloseStatus; | ||
19 | + | ||
18 | import java.io.IOException; | 20 | import java.io.IOException; |
19 | 21 | ||
20 | /** | 22 | /** |
@@ -26,4 +28,5 @@ public interface TelemetryWebSocketMsgEndpoint { | @@ -26,4 +28,5 @@ public interface TelemetryWebSocketMsgEndpoint { | ||
26 | 28 | ||
27 | void close(TelemetryWebSocketSessionRef sessionRef) throws IOException; | 29 | void close(TelemetryWebSocketSessionRef sessionRef) throws IOException; |
28 | 30 | ||
31 | + void close(TelemetryWebSocketSessionRef sessionRef, CloseStatus withReason) throws IOException; | ||
29 | } | 32 | } |
@@ -32,6 +32,17 @@ server: | @@ -32,6 +32,17 @@ server: | ||
32 | # Alias that identifies the key in the key store | 32 | # Alias that identifies the key in the key store |
33 | key-alias: "${SSL_KEY_ALIAS:tomcat}" | 33 | key-alias: "${SSL_KEY_ALIAS:tomcat}" |
34 | log_controller_error_stack_trace: "${HTTP_LOG_CONTROLLER_ERROR_STACK_TRACE:true}" | 34 | log_controller_error_stack_trace: "${HTTP_LOG_CONTROLLER_ERROR_STACK_TRACE:true}" |
35 | + ws: | ||
36 | + limits: | ||
37 | + # Limit the amount of sessions and subscriptions available on each server. Put values to zero to disable particular limitation | ||
38 | + max_sessions_per_tenant: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_TENANT:0}" | ||
39 | + max_sessions_per_customer: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_CUSTOMER:0}" | ||
40 | + max_sessions_per_regular_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_REGULAR_USER:0}" | ||
41 | + max_sessions_per_public_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_PUBLIC_USER:0}" | ||
42 | + max_subscriptions_per_tenant: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_TENANT:0}" | ||
43 | + max_subscriptions_per_customer: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_CUSTOMER:0}" | ||
44 | + max_subscriptions_per_regular_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_REGULAR_USER:0}" | ||
45 | + max_subscriptions_per_public_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_PUBLIC_USER:0}" | ||
35 | 46 | ||
36 | # Zookeeper connection parameters. Used for service discovery. | 47 | # Zookeeper connection parameters. Used for service discovery. |
37 | zk: | 48 | zk: |