Commit daac250c2e18c026aabde65c97217b45485c4132

Authored by Andrii Shvaika
1 parent c1d8aa13

Correct close and cleanup of the MQTT session context

... ... @@ -81,6 +81,7 @@ import java.util.ArrayList;
81 81 import java.util.List;
82 82 import java.util.Optional;
83 83 import java.util.UUID;
  84 +import java.util.concurrent.Callable;
84 85 import java.util.concurrent.ConcurrentHashMap;
85 86 import java.util.concurrent.ConcurrentMap;
86 87 import java.util.concurrent.TimeUnit;
... ... @@ -153,10 +154,11 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
153 154 if (message.decoderResult().isSuccess()) {
154 155 processMqttMsg(ctx, message);
155 156 } else {
156   - log.error("[{}] Message processing failed: {}", sessionId, message.decoderResult().cause().getMessage());
  157 + log.error("[{}] Message decoding failed: {}", sessionId, message.decoderResult().cause().getMessage());
157 158 ctx.close();
158 159 }
159 160 } else {
  161 + log.debug("[{}] Received non mqtt message: {}", sessionId, msg.getClass().getSimpleName());
160 162 ctx.close();
161 163 }
162 164 } finally {
... ... @@ -168,7 +170,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
168 170 address = getAddress(ctx);
169 171 if (msg.fixedHeader() == null) {
170 172 log.info("[{}:{}] Invalid message received", address.getHostName(), address.getPort());
171   - processDisconnect(ctx);
  173 + ctx.close();
172 174 return;
173 175 }
174 176 deviceSessionCtx.setChannel(ctx);
... ... @@ -208,8 +210,8 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
208 210 }
209 211 }
210 212 } else {
  213 + log.debug("[{}] Unsupported topic for provisioning requests: {}!", sessionId, topicName);
211 214 ctx.close();
212   - throw new RuntimeException("Unsupported topic for provisioning requests!");
213 215 }
214 216 } catch (RuntimeException | AdaptorException e) {
215 217 log.warn("[{}] Failed to process publish msg [{}][{}]", sessionId, topicName, msgId, e);
... ... @@ -220,48 +222,30 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
220 222 ctx.writeAndFlush(new MqttMessage(new MqttFixedHeader(PINGRESP, false, AT_MOST_ONCE, false, 0)));
221 223 break;
222 224 case DISCONNECT:
223   - if (checkConnected(ctx, msg)) {
224   - processDisconnect(ctx);
225   - }
  225 + ctx.close();
226 226 break;
227 227 }
228 228 }
229 229
230 230 void enqueueRegularSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) {
231   - final int queueSize = deviceSessionCtx.getMsgQueueSize().incrementAndGet();
232   - if (queueSize > context.getMessageQueueSizePerDeviceLimit()) {
233   - log.warn("Closing current session because msq queue size for device {} exceed limit {} with msgQueueSize counter {} and actual queue size {}",
234   - deviceSessionCtx.getDeviceId(), context.getMessageQueueSizePerDeviceLimit(), queueSize, deviceSessionCtx.getMsgQueue().size());
  231 + final int queueSize = deviceSessionCtx.getMsgQueueSize();
  232 + if (queueSize >= context.getMessageQueueSizePerDeviceLimit()) {
  233 + log.info("Closing current session because msq queue size for device {} exceed limit {} with msgQueueSize counter {} and actual queue size {}",
  234 + deviceSessionCtx.getDeviceId(), context.getMessageQueueSizePerDeviceLimit(), queueSize, deviceSessionCtx.getMsgQueueSize());
235 235 ctx.close();
236 236 return;
237 237 }
238 238
239   - deviceSessionCtx.getMsgQueue().add(msg);
240   - ReferenceCountUtil.retain(msg);
  239 + deviceSessionCtx.addToQueue(msg);
241 240 processMsgQueue(ctx); //Under the normal conditions the msg queue will contain 0 messages. Many messages will be processed on device connect event in separate thread pool
242 241 }
243 242
244 243 void processMsgQueue(ChannelHandlerContext ctx) {
245 244 if (!deviceSessionCtx.isConnected()) {
246   - log.trace("[{}][{}] Postpone processing msg due to device is not connected. Msg queue size is {}", sessionId, deviceSessionCtx.getDeviceId(), deviceSessionCtx.getMsgQueue().size());
  245 + log.trace("[{}][{}] Postpone processing msg due to device is not connected. Msg queue size is {}", sessionId, deviceSessionCtx.getDeviceId(), deviceSessionCtx.getMsgQueueSize());
247 246 return;
248 247 }
249   - while (!deviceSessionCtx.getMsgQueue().isEmpty()) {
250   - if (deviceSessionCtx.getMsgQueueProcessorLock().tryLock()) {
251   - try {
252   - MqttMessage msg;
253   - while ((msg = deviceSessionCtx.getMsgQueue().poll()) != null) {
254   - deviceSessionCtx.getMsgQueueSize().decrementAndGet();
255   - processRegularSessionMsg(ctx, msg);
256   - ReferenceCountUtil.safeRelease(msg);
257   - }
258   - } finally {
259   - deviceSessionCtx.getMsgQueueProcessorLock().unlock();
260   - }
261   - } else {
262   - return;
263   - }
264   - }
  248 + deviceSessionCtx.tryProcessQueuedMsgs(msg -> processRegularSessionMsg(ctx, msg));
265 249 }
266 250
267 251 void processRegularSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) {
... ... @@ -282,9 +266,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
282 266 }
283 267 break;
284 268 case DISCONNECT:
285   - if (checkConnected(ctx, msg)) {
286   - processDisconnect(ctx);
287   - }
  269 + ctx.close();
288 270 break;
289 271 case PUBACK:
290 272 int msgId = ((MqttPubAckMessage) msg).variableHeader().messageId();
... ... @@ -438,7 +420,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
438 420 @Override
439 421 public void onError(Throwable e) {
440 422 log.trace("[{}] Failed to publish msg: {}", sessionId, msg, e);
441   - processDisconnect(ctx);
  423 + ctx.close();
442 424 }
443 425 };
444 426 }
... ... @@ -464,7 +446,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
464 446 } else {
465 447 deviceSessionCtx.getContext().getProtoMqttAdaptor().convertToPublish(deviceSessionCtx, provisionResponseMsg).ifPresent(deviceSessionCtx.getChannel()::writeAndFlush);
466 448 }
467   - scheduler.schedule(() -> processDisconnect(ctx), 60, TimeUnit.SECONDS);
  449 + scheduler.schedule((Callable<ChannelFuture>) ctx::close, 60, TimeUnit.SECONDS);
468 450 } catch (Exception e) {
469 451 log.trace("[{}] Failed to convert device attributes response to MQTT msg", sessionId, e);
470 452 }
... ... @@ -473,7 +455,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
473 455 @Override
474 456 public void onError(Throwable e) {
475 457 log.trace("[{}] Failed to publish msg: {}", sessionId, msg, e);
476   - processDisconnect(ctx);
  458 + ctx.close();
477 459 }
478 460 }
479 461
... ... @@ -508,7 +490,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
508 490 @Override
509 491 public void onError(Throwable e) {
510 492 log.trace("[{}] Failed to get firmware: {}", sessionId, msg, e);
511   - processDisconnect(ctx);
  493 + ctx.close();
512 494 }
513 495 }
514 496
... ... @@ -530,7 +512,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
530 512 deviceSessionCtx.getChannel().writeAndFlush(deviceSessionCtx
531 513 .getPayloadAdaptor()
532 514 .createMqttPublishMsg(deviceSessionCtx, MqttTopics.DEVICE_FIRMWARE_ERROR_TOPIC, error.getBytes()));
533   - processDisconnect(ctx);
  515 + ctx.close();
534 516 }
535 517
536 518 private void processSubscribe(ChannelHandlerContext ctx, MqttSubscribeMessage mqttMsg) {
... ... @@ -699,6 +681,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
699 681 });
700 682 } catch (Exception e) {
701 683 ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage));
  684 + log.trace("[{}] X509 auth failure: {}", sessionId, address, e);
702 685 ctx.close();
703 686 }
704 687 }
... ... @@ -716,12 +699,6 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
716 699 return null;
717 700 }
718 701
719   - void processDisconnect(ChannelHandlerContext ctx) {
720   - ctx.close();
721   - log.info("[{}] Client disconnected!", sessionId);
722   - doDisconnect();
723   - }
724   -
725 702 private MqttConnAckMessage createMqttConnAckMsg(MqttConnectReturnCode returnCode, MqttConnectMessage msg) {
726 703 MqttFixedHeader mqttFixedHeader =
727 704 new MqttFixedHeader(CONNACK, false, AT_MOST_ONCE, false, 0);
... ... @@ -766,7 +743,6 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
766 743 return true;
767 744 } else {
768 745 log.info("[{}] Closing current session due to invalid msg order: {}", sessionId, msg);
769   - ctx.close();
770 746 return false;
771 747 }
772 748 }
... ... @@ -791,11 +767,13 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
791 767
792 768 @Override
793 769 public void operationComplete(Future<? super Void> future) throws Exception {
  770 + log.trace("[{}] Channel closed!", sessionId);
794 771 doDisconnect();
795 772 }
796 773
797   - private void doDisconnect() {
  774 + public void doDisconnect() {
798 775 if (deviceSessionCtx.isConnected()) {
  776 + log.info("[{}] Client disconnected!", sessionId);
799 777 transportService.process(deviceSessionCtx.getSessionInfo(), DefaultTransportService.getSessionEventMsg(SessionEvent.CLOSED), null);
800 778 transportService.deregisterSession(deviceSessionCtx.getSessionInfo());
801 779 if (gatewaySessionHandler != null) {
... ... @@ -803,11 +781,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
803 781 }
804 782 deviceSessionCtx.setDisconnected();
805 783 }
806   -
807   - if (!deviceSessionCtx.getMsgQueue().isEmpty()) {
808   - log.warn("doDisconnect for device {} but unprocessed messages {} left in the msg queue", deviceSessionCtx.getDeviceId(), deviceSessionCtx.getMsgQueue().size());
809   - deviceSessionCtx.getMsgQueue().clear();
810   - }
  784 + deviceSessionCtx.release();
811 785 }
812 786
813 787
... ... @@ -866,7 +840,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
866 840 @Override
867 841 public void onRemoteSessionCloseCommand(UUID sessionId, TransportProtos.SessionCloseNotificationProto sessionCloseNotification) {
868 842 log.trace("[{}] Received the remote command to close the session: {}", sessionId, sessionCloseNotification.getMessage());
869   - processDisconnect(deviceSessionCtx.getChannel());
  843 + deviceSessionCtx.getChannel().close();
870 844 }
871 845
872 846 @Override
... ...
... ... @@ -19,6 +19,7 @@ import com.google.protobuf.Descriptors;
19 19 import com.google.protobuf.DynamicMessage;
20 20 import io.netty.channel.ChannelHandlerContext;
21 21 import io.netty.handler.codec.mqtt.MqttMessage;
  22 +import io.netty.util.ReferenceCountUtil;
22 23 import lombok.Getter;
23 24 import lombok.Setter;
24 25 import lombok.extern.slf4j.Slf4j;
... ... @@ -35,12 +36,16 @@ import org.thingsboard.server.transport.mqtt.adaptors.MqttTransportAdaptor;
35 36 import org.thingsboard.server.transport.mqtt.util.MqttTopicFilter;
36 37 import org.thingsboard.server.transport.mqtt.util.MqttTopicFilterFactory;
37 38
  39 +import java.util.Collection;
  40 +import java.util.Collections;
  41 +import java.util.Queue;
38 42 import java.util.UUID;
39 43 import java.util.concurrent.ConcurrentLinkedQueue;
40 44 import java.util.concurrent.ConcurrentMap;
41 45 import java.util.concurrent.atomic.AtomicInteger;
42 46 import java.util.concurrent.locks.Lock;
43 47 import java.util.concurrent.locks.ReentrantLock;
  48 +import java.util.function.Consumer;
44 49
45 50 /**
46 51 * @author Andrew Shvayka
... ... @@ -57,13 +62,11 @@ public class DeviceSessionCtx extends MqttDeviceAwareSessionContext {
57 62
58 63 private final AtomicInteger msgIdSeq = new AtomicInteger(0);
59 64
60   - @Getter
61 65 private final ConcurrentLinkedQueue<MqttMessage> msgQueue = new ConcurrentLinkedQueue<>();
62 66
63 67 @Getter
64 68 private final Lock msgQueueProcessorLock = new ReentrantLock();
65 69
66   - @Getter
67 70 private final AtomicInteger msgQueueSize = new AtomicInteger(0);
68 71
69 72 @Getter
... ... @@ -160,4 +163,49 @@ public class DeviceSessionCtx extends MqttDeviceAwareSessionContext {
160 163 rpcResponseDynamicMessageDescriptor = protoTransportPayloadConfig.getRpcResponseDynamicMessageDescriptor(protoTransportPayloadConfig.getDeviceRpcResponseProtoSchema());
161 164 rpcRequestDynamicMessageBuilder = protoTransportPayloadConfig.getRpcRequestDynamicMessageBuilder(protoTransportPayloadConfig.getDeviceRpcRequestProtoSchema());
162 165 }
  166 +
  167 + public void addToQueue(MqttMessage msg) {
  168 + msgQueueSize.incrementAndGet();
  169 + ReferenceCountUtil.retain(msg);
  170 + msgQueue.add(msg);
  171 + }
  172 +
  173 + public void tryProcessQueuedMsgs(Consumer<MqttMessage> msgProcessor) {
  174 + while (!msgQueue.isEmpty()) {
  175 + if (msgQueueProcessorLock.tryLock()) {
  176 + try {
  177 + MqttMessage msg;
  178 + while ((msg = msgQueue.poll()) != null) {
  179 + try {
  180 + msgQueueSize.decrementAndGet();
  181 + msgProcessor.accept(msg);
  182 + } finally {
  183 + ReferenceCountUtil.safeRelease(msg);
  184 + }
  185 + }
  186 + } finally {
  187 + msgQueueProcessorLock.unlock();
  188 + }
  189 + } else {
  190 + return;
  191 + }
  192 + }
  193 + }
  194 +
  195 + public int getMsgQueueSize() {
  196 + return msgQueueSize.get();
  197 + }
  198 +
  199 + public void release() {
  200 + if (!msgQueue.isEmpty()) {
  201 + log.warn("doDisconnect for device {} but unprocessed messages {} left in the msg queue", getDeviceId(), msgQueue.size());
  202 + msgQueue.forEach(ReferenceCountUtil::safeRelease);
  203 + msgQueue.clear();
  204 + }
  205 + }
  206 +
  207 + public Collection<MqttMessage> getMsgQueueSnapshot(){
  208 + return Collections.unmodifiableCollection(msgQueue);
  209 + }
  210 +
163 211 }
... ...
... ... @@ -112,18 +112,6 @@ public class MqttTransportHandlerTest {
112 112 }
113 113
114 114 @Test
115   - public void givenMessageWithoutFixedHeader_whenProcessMqttMsg_thenProcessDisconnect() {
116   - MqttFixedHeader mqttFixedHeader = null;
117   - MqttMessage msg = new MqttMessage(mqttFixedHeader);
118   - willDoNothing().given(handler).processDisconnect(ctx);
119   -
120   - handler.processMqttMsg(ctx, msg);
121   -
122   - assertThat(handler.address, is(IP_ADDR));
123   - verify(handler, times(1)).processDisconnect(ctx);
124   - }
125   -
126   - @Test
127 115 public void givenMqttConnectMessage_whenProcessMqttMsg_thenProcessConnect() {
128 116 MqttConnectMessage msg = getMqttConnectMessage();
129 117 willDoNothing().given(handler).processConnect(ctx, msg);
... ... @@ -132,7 +120,7 @@ public class MqttTransportHandlerTest {
132 120
133 121 assertThat(handler.address, is(IP_ADDR));
134 122 assertThat(handler.deviceSessionCtx.getChannel(), is(ctx));
135   - verify(handler, never()).processDisconnect(any());
  123 + verify(handler, never()).doDisconnect();
136 124 verify(handler, times(1)).processConnect(ctx, msg);
137 125 }
138 126
... ... @@ -140,8 +128,8 @@ public class MqttTransportHandlerTest {
140 128 public void givenQueueLimit_whenEnqueueRegularSessionMsgOverLimit_thenOK() {
141 129 List<MqttPublishMessage> messages = Stream.generate(this::getMqttPublishMessage).limit(MSG_QUEUE_LIMIT).collect(Collectors.toList());
142 130 messages.forEach(msg -> handler.enqueueRegularSessionMsg(ctx, msg));
143   - assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(MSG_QUEUE_LIMIT));
144   - assertThat(handler.deviceSessionCtx.getMsgQueue(), contains(messages.toArray()));
  131 + assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(MSG_QUEUE_LIMIT));
  132 + assertThat(handler.deviceSessionCtx.getMsgQueueSnapshot(), contains(messages.toArray()));
145 133 }
146 134
147 135 @Test
... ... @@ -152,7 +140,7 @@ public class MqttTransportHandlerTest {
152 140
153 141 messages.forEach((msg) -> handler.enqueueRegularSessionMsg(ctx, msg));
154 142
155   - assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(limit));
  143 + assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(MSG_QUEUE_LIMIT));
156 144 verify(handler, times(limit)).enqueueRegularSessionMsg(any(), any());
157 145 verify(handler, times(MSG_QUEUE_LIMIT)).processMsgQueue(any());
158 146 verify(ctx, times(1)).close();
... ... @@ -169,9 +157,9 @@ public class MqttTransportHandlerTest {
169 157 assertThat(handler.address, is(IP_ADDR));
170 158 assertThat(handler.deviceSessionCtx.getChannel(), is(ctx));
171 159 assertThat(handler.deviceSessionCtx.isConnected(), is(false));
172   - assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(MSG_QUEUE_LIMIT));
173   - assertThat(handler.deviceSessionCtx.getMsgQueue(), contains(messages.toArray()));
174   - verify(handler, never()).processDisconnect(any());
  160 + assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(MSG_QUEUE_LIMIT));
  161 + assertThat(handler.deviceSessionCtx.getMsgQueueSnapshot(), contains(messages.toArray()));
  162 + verify(handler, never()).doDisconnect();
175 163 verify(handler, times(1)).processConnect(any(), any());
176 164 verify(handler, times(MSG_QUEUE_LIMIT)).enqueueRegularSessionMsg(any(), any());
177 165 verify(handler, never()).processRegularSessionMsg(any(), any());
... ... @@ -212,8 +200,8 @@ public class MqttTransportHandlerTest {
212 200 assertThat(finishLatch.await(TIMEOUT, TimeUnit.SECONDS), is(true));
213 201
214 202 //then
215   - assertThat(handler.deviceSessionCtx.getMsgQueueSize().get(), is(0));
216   - assertThat(handler.deviceSessionCtx.getMsgQueue(), empty());
  203 + assertThat(handler.deviceSessionCtx.getMsgQueueSize(), is(0));
  204 + assertThat(handler.deviceSessionCtx.getMsgQueueSnapshot(), empty());
217 205 verify(handler, times(MSG_QUEUE_LIMIT)).processRegularSessionMsg(any(), any());
218 206 messages.forEach((msg) -> verify(handler, times(1)).processRegularSessionMsg(ctx, msg));
219 207 }
... ...