Commit b2d694f7ee2ee9a8ba0d0a6d5b9b69406d3f513b

Authored by Igor Kulikov
1 parent 9f499d91

Add separate SSL channel for mqtt transport

@@ -595,6 +595,10 @@ transport: @@ -595,6 +595,10 @@ transport:
595 ssl: 595 ssl:
596 # Enable/disable SSL support 596 # Enable/disable SSL support
597 enabled: "${MQTT_SSL_ENABLED:false}" 597 enabled: "${MQTT_SSL_ENABLED:false}"
  598 + # MQTT SSL bind address
  599 + bind_address: "${MQTT_SSL_BIND_ADDRESS:0.0.0.0}"
  600 + # MQTT SSL bind port
  601 + bind_port: "${MQTT_SSL_BIND_PORT:8883}"
598 # SSL protocol: See http://docs.oracle.com/javase/8/docs/technotes/guides/security/StandardNames.html#SSLContext 602 # SSL protocol: See http://docs.oracle.com/javase/8/docs/technotes/guides/security/StandardNames.html#SSLContext
599 protocol: "${MQTT_SSL_PROTOCOL:TLSv1.2}" 603 protocol: "${MQTT_SSL_PROTOCOL:TLSv1.2}"
600 # Path to the key store that holds the SSL certificate 604 # Path to the key store that holds the SSL certificate
@@ -73,7 +73,16 @@ public class MqttSslHandlerProvider { @@ -73,7 +73,16 @@ public class MqttSslHandlerProvider {
73 @Autowired 73 @Autowired
74 private TransportService transportService; 74 private TransportService transportService;
75 75
  76 + private SslHandler sslHandler;
  77 +
76 public SslHandler getSslHandler() { 78 public SslHandler getSslHandler() {
  79 + if (sslHandler == null) {
  80 + sslHandler = createSslHandler();
  81 + }
  82 + return sslHandler;
  83 + }
  84 +
  85 + private SslHandler createSslHandler() {
77 try { 86 try {
78 TrustManagerFactory tmFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); 87 TrustManagerFactory tmFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
79 KeyStore trustStore = KeyStore.getInstance(keyStoreType); 88 KeyStore trustStore = KeyStore.getInstance(keyStoreType);
@@ -28,16 +28,18 @@ import io.netty.handler.ssl.SslHandler; @@ -28,16 +28,18 @@ import io.netty.handler.ssl.SslHandler;
28 public class MqttTransportServerInitializer extends ChannelInitializer<SocketChannel> { 28 public class MqttTransportServerInitializer extends ChannelInitializer<SocketChannel> {
29 29
30 private final MqttTransportContext context; 30 private final MqttTransportContext context;
  31 + private final boolean sslEnabled;
31 32
32 - public MqttTransportServerInitializer(MqttTransportContext context) { 33 + public MqttTransportServerInitializer(MqttTransportContext context, boolean sslEnabled) {
33 this.context = context; 34 this.context = context;
  35 + this.sslEnabled = sslEnabled;
34 } 36 }
35 37
36 @Override 38 @Override
37 public void initChannel(SocketChannel ch) { 39 public void initChannel(SocketChannel ch) {
38 ChannelPipeline pipeline = ch.pipeline(); 40 ChannelPipeline pipeline = ch.pipeline();
39 SslHandler sslHandler = null; 41 SslHandler sslHandler = null;
40 - if (context.getSslHandlerProvider() != null) { 42 + if (sslEnabled && context.getSslHandlerProvider() != null) {
41 sslHandler = context.getSslHandlerProvider().getSslHandler(); 43 sslHandler = context.getSslHandlerProvider().getSslHandler();
42 pipeline.addLast(sslHandler); 44 pipeline.addLast(sslHandler);
43 } 45 }
@@ -46,6 +46,14 @@ public class MqttTransportService implements TbTransportService { @@ -46,6 +46,14 @@ public class MqttTransportService implements TbTransportService {
46 @Value("${transport.mqtt.bind_port}") 46 @Value("${transport.mqtt.bind_port}")
47 private Integer port; 47 private Integer port;
48 48
  49 + @Value("${transport.mqtt.ssl.enabled}")
  50 + private boolean sslEnabled;
  51 +
  52 + @Value("${transport.mqtt.ssl.bind_address}")
  53 + private String sslHost;
  54 + @Value("${transport.mqtt.ssl.bind_port}")
  55 + private Integer sslPort;
  56 +
49 @Value("${transport.mqtt.netty.leak_detector_level}") 57 @Value("${transport.mqtt.netty.leak_detector_level}")
50 private String leakDetectorLevel; 58 private String leakDetectorLevel;
51 @Value("${transport.mqtt.netty.boss_group_thread_count}") 59 @Value("${transport.mqtt.netty.boss_group_thread_count}")
@@ -59,6 +67,7 @@ public class MqttTransportService implements TbTransportService { @@ -59,6 +67,7 @@ public class MqttTransportService implements TbTransportService {
59 private MqttTransportContext context; 67 private MqttTransportContext context;
60 68
61 private Channel serverChannel; 69 private Channel serverChannel;
  70 + private Channel sslServerChannel;
62 private EventLoopGroup bossGroup; 71 private EventLoopGroup bossGroup;
63 private EventLoopGroup workerGroup; 72 private EventLoopGroup workerGroup;
64 73
@@ -73,10 +82,18 @@ public class MqttTransportService implements TbTransportService { @@ -73,10 +82,18 @@ public class MqttTransportService implements TbTransportService {
73 ServerBootstrap b = new ServerBootstrap(); 82 ServerBootstrap b = new ServerBootstrap();
74 b.group(bossGroup, workerGroup) 83 b.group(bossGroup, workerGroup)
75 .channel(NioServerSocketChannel.class) 84 .channel(NioServerSocketChannel.class)
76 - .childHandler(new MqttTransportServerInitializer(context)) 85 + .childHandler(new MqttTransportServerInitializer(context, false))
77 .childOption(ChannelOption.SO_KEEPALIVE, keepAlive); 86 .childOption(ChannelOption.SO_KEEPALIVE, keepAlive);
78 87
79 serverChannel = b.bind(host, port).sync().channel(); 88 serverChannel = b.bind(host, port).sync().channel();
  89 + if (sslEnabled) {
  90 + b = new ServerBootstrap();
  91 + b.group(bossGroup, workerGroup)
  92 + .channel(NioServerSocketChannel.class)
  93 + .childHandler(new MqttTransportServerInitializer(context, true))
  94 + .childOption(ChannelOption.SO_KEEPALIVE, keepAlive);
  95 + sslServerChannel = b.bind(sslHost, sslPort).sync().channel();
  96 + }
80 log.info("Mqtt transport started!"); 97 log.info("Mqtt transport started!");
81 } 98 }
82 99
@@ -85,6 +102,9 @@ public class MqttTransportService implements TbTransportService { @@ -85,6 +102,9 @@ public class MqttTransportService implements TbTransportService {
85 log.info("Stopping MQTT transport!"); 102 log.info("Stopping MQTT transport!");
86 try { 103 try {
87 serverChannel.close().sync(); 104 serverChannel.close().sync();
  105 + if (sslEnabled) {
  106 + sslServerChannel.close().sync();
  107 + }
88 } finally { 108 } finally {
89 workerGroup.shutdownGracefully(); 109 workerGroup.shutdownGracefully();
90 bossGroup.shutdownGracefully(); 110 bossGroup.shutdownGracefully();
@@ -99,6 +99,10 @@ transport: @@ -99,6 +99,10 @@ transport:
99 ssl: 99 ssl:
100 # Enable/disable SSL support 100 # Enable/disable SSL support
101 enabled: "${MQTT_SSL_ENABLED:false}" 101 enabled: "${MQTT_SSL_ENABLED:false}"
  102 + # MQTT SSL bind address
  103 + bind_address: "${MQTT_SSL_BIND_ADDRESS:0.0.0.0}"
  104 + # MQTT SSL bind port
  105 + bind_port: "${MQTT_SSL_BIND_PORT:8883}"
102 # SSL protocol: See http://docs.oracle.com/javase/8/docs/technotes/guides/security/StandardNames.html#SSLContext 106 # SSL protocol: See http://docs.oracle.com/javase/8/docs/technotes/guides/security/StandardNames.html#SSLContext
103 protocol: "${MQTT_SSL_PROTOCOL:TLSv1.2}" 107 protocol: "${MQTT_SSL_PROTOCOL:TLSv1.2}"
104 # Path to the key store that holds the SSL certificate 108 # Path to the key store that holds the SSL certificate
@@ -298,4 +302,4 @@ management: @@ -298,4 +302,4 @@ management:
298 web: 302 web:
299 exposure: 303 exposure:
300 # Expose metrics endpoint (use value 'prometheus' to enable prometheus metrics). 304 # Expose metrics endpoint (use value 'prometheus' to enable prometheus metrics).
301 - include: '${METRICS_ENDPOINTS_EXPOSE:info}'  
  305 + include: '${METRICS_ENDPOINTS_EXPOSE:info}'