Commit 3ffd62f8443aed8e02938ba360aa50e57cdd583e

Authored by Andrii Shvaika
1 parent 055331a8

Fix corner case when access token matches user name in credentials

  1 +/**
  2 + * Copyright © 2016-2021 The Thingsboard Authors
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +package org.thingsboard.server.service.transport;
  17 +
  18 +enum BasicCredentialsValidationResult {HASH_MISMATCH, PASSWORD_MISMATCH, VALID}
... ...
... ... @@ -25,7 +25,6 @@ import com.google.protobuf.ByteString;
25 25 import lombok.RequiredArgsConstructor;
26 26 import lombok.extern.slf4j.Slf4j;
27 27 import org.springframework.stereotype.Service;
28   -import org.springframework.util.StringUtils;
29 28 import org.thingsboard.common.util.JacksonUtil;
30 29 import org.thingsboard.server.cache.ota.OtaPackageDataCache;
31 30 import org.thingsboard.server.common.data.ApiUsageState;
... ... @@ -37,6 +36,7 @@ import org.thingsboard.server.common.data.EntityType;
37 36 import org.thingsboard.server.common.data.OtaPackage;
38 37 import org.thingsboard.server.common.data.OtaPackageInfo;
39 38 import org.thingsboard.server.common.data.ResourceType;
  39 +import org.thingsboard.server.common.data.StringUtils;
40 40 import org.thingsboard.server.common.data.TbResource;
41 41 import org.thingsboard.server.common.data.TenantProfile;
42 42 import org.thingsboard.server.common.data.device.credentials.BasicMqttCredentials;
... ... @@ -107,6 +107,9 @@ import java.util.concurrent.locks.Lock;
107 107 import java.util.concurrent.locks.ReentrantLock;
108 108 import java.util.stream.Collectors;
109 109
  110 +import static org.thingsboard.server.service.transport.BasicCredentialsValidationResult.PASSWORD_MISMATCH;
  111 +import static org.thingsboard.server.service.transport.BasicCredentialsValidationResult.VALID;
  112 +
110 113 /**
111 114 * Created by ashvayka on 05.10.18.
112 115 */
... ... @@ -181,71 +184,89 @@ public class DefaultTransportApiService implements TransportApiService {
181 184 //TODO: Make async and enable caching
182 185 DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(credentialsId);
183 186 if (credentials != null && credentials.getCredentialsType() == credentialsType) {
184   - return getDeviceInfo(credentials.getDeviceId(), credentials);
  187 + return getDeviceInfo(credentials);
185 188 } else {
186 189 return getEmptyTransportApiResponseFuture();
187 190 }
188 191 }
189 192
190 193 private ListenableFuture<TransportApiResponseMsg> validateCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) {
191   - DeviceCredentials credentials = null;
192   - if (!StringUtils.isEmpty(mqtt.getUserName())) {
193   - credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(mqtt.getUserName());
  194 + DeviceCredentials credentials;
  195 + if (StringUtils.isEmpty(mqtt.getUserName())) {
  196 + credentials = checkMqttCredentials(mqtt, EncryptionUtil.getSha3Hash(mqtt.getClientId()));
194 197 if (credentials != null) {
195   - if (credentials.getCredentialsType() == DeviceCredentialsType.ACCESS_TOKEN) {
196   - return getDeviceInfo(credentials.getDeviceId(), credentials);
197   - } else if (credentials.getCredentialsType() == DeviceCredentialsType.MQTT_BASIC) {
198   - if (!checkMqttCredentials(mqtt, credentials)) {
199   - credentials = null;
200   - }
201   - } else {
  198 + return getDeviceInfo(credentials);
  199 + } else {
  200 + return getEmptyTransportApiResponseFuture();
  201 + }
  202 + } else {
  203 + credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(
  204 + EncryptionUtil.getSha3Hash("|", mqtt.getClientId(), mqtt.getUserName()));
  205 + if (checkIsMqttCredentials(credentials)) {
  206 + var validationResult = validateMqttCredentials(mqtt, credentials);
  207 + if (VALID.equals(validationResult)) {
  208 + return getDeviceInfo(credentials);
  209 + } else if (PASSWORD_MISMATCH.equals(validationResult)) {
202 210 return getEmptyTransportApiResponseFuture();
  211 + } else {
  212 + return validateUserNameCredentials(mqtt);
203 213 }
  214 + } else {
  215 + return validateUserNameCredentials(mqtt);
204 216 }
205   - if (credentials == null) {
206   - credentials = checkMqttCredentials(mqtt, EncryptionUtil.getSha3Hash("|", mqtt.getClientId(), mqtt.getUserName()));
207   - }
208   - }
209   - if (credentials == null) {
210   - credentials = checkMqttCredentials(mqtt, EncryptionUtil.getSha3Hash(mqtt.getClientId()));
211 217 }
  218 + }
  219 +
  220 + private ListenableFuture<TransportApiResponseMsg> validateUserNameCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg mqtt) {
  221 + DeviceCredentials credentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(mqtt.getUserName());
212 222 if (credentials != null) {
213   - return getDeviceInfo(credentials.getDeviceId(), credentials);
214   - } else {
215   - return getEmptyTransportApiResponseFuture();
  223 + switch (credentials.getCredentialsType()) {
  224 + case ACCESS_TOKEN:
  225 + return getDeviceInfo(credentials);
  226 + case MQTT_BASIC:
  227 + if (VALID.equals(validateMqttCredentials(mqtt, credentials))) {
  228 + return getDeviceInfo(credentials);
  229 + } else {
  230 + return getEmptyTransportApiResponseFuture();
  231 + }
  232 + }
216 233 }
  234 + return getEmptyTransportApiResponseFuture();
  235 + }
  236 +
  237 + private static boolean checkIsMqttCredentials(DeviceCredentials credentials) {
  238 + return credentials != null && DeviceCredentialsType.MQTT_BASIC.equals(credentials.getCredentialsType());
217 239 }
218 240
219 241 private DeviceCredentials checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, String credId) {
220   - DeviceCredentials deviceCredentials = deviceCredentialsService.findDeviceCredentialsByCredentialsId(credId);
  242 + return checkMqttCredentials(clientCred, deviceCredentialsService.findDeviceCredentialsByCredentialsId(credId));
  243 + }
  244 +
  245 + private DeviceCredentials checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, DeviceCredentials deviceCredentials) {
221 246 if (deviceCredentials != null && deviceCredentials.getCredentialsType() == DeviceCredentialsType.MQTT_BASIC) {
222   - if (!checkMqttCredentials(clientCred, deviceCredentials)) {
223   - return null;
224   - } else {
  247 + if (VALID.equals(validateMqttCredentials(clientCred, deviceCredentials))) {
225 248 return deviceCredentials;
226 249 }
227 250 }
228 251 return null;
229 252 }
230 253
231   - private boolean checkMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, DeviceCredentials deviceCredentials) {
  254 + private BasicCredentialsValidationResult validateMqttCredentials(TransportProtos.ValidateBasicMqttCredRequestMsg clientCred, DeviceCredentials deviceCredentials) {
232 255 BasicMqttCredentials dbCred = JacksonUtil.fromString(deviceCredentials.getCredentialsValue(), BasicMqttCredentials.class);
233 256 if (!StringUtils.isEmpty(dbCred.getClientId()) && !dbCred.getClientId().equals(clientCred.getClientId())) {
234   - return false;
  257 + return BasicCredentialsValidationResult.HASH_MISMATCH;
235 258 }
236 259 if (!StringUtils.isEmpty(dbCred.getUserName()) && !dbCred.getUserName().equals(clientCred.getUserName())) {
237   - return false;
  260 + return BasicCredentialsValidationResult.HASH_MISMATCH;
238 261 }
239 262 if (!StringUtils.isEmpty(dbCred.getPassword())) {
240 263 if (StringUtils.isEmpty(clientCred.getPassword())) {
241   - return false;
  264 + return BasicCredentialsValidationResult.PASSWORD_MISMATCH;
242 265 } else {
243   - if (!dbCred.getPassword().equals(clientCred.getPassword())) {
244   - return false;
245   - }
  266 + return dbCred.getPassword().equals(clientCred.getPassword()) ? VALID : BasicCredentialsValidationResult.PASSWORD_MISMATCH;
246 267 }
247 268 }
248   - return true;
  269 + return VALID;
249 270 }
250 271
251 272 private ListenableFuture<TransportApiResponseMsg> handle(GetOrCreateDeviceFromGatewayRequestMsg requestMsg) {
... ... @@ -437,10 +458,10 @@ public class DefaultTransportApiService implements TransportApiService {
437 458 .build());
438 459 }
439 460
440   - private ListenableFuture<TransportApiResponseMsg> getDeviceInfo(DeviceId deviceId, DeviceCredentials credentials) {
441   - return Futures.transform(deviceService.findDeviceByIdAsync(TenantId.SYS_TENANT_ID, deviceId), device -> {
  461 + private ListenableFuture<TransportApiResponseMsg> getDeviceInfo(DeviceCredentials credentials) {
  462 + return Futures.transform(deviceService.findDeviceByIdAsync(TenantId.SYS_TENANT_ID, credentials.getDeviceId()), device -> {
442 463 if (device == null) {
443   - log.trace("[{}] Failed to lookup device by id", deviceId);
  464 + log.trace("[{}] Failed to lookup device by id", credentials.getDeviceId());
444 465 return getEmptyTransportApiResponse();
445 466 }
446 467 try {
... ... @@ -458,7 +479,7 @@ public class DefaultTransportApiService implements TransportApiService {
458 479 return TransportApiResponseMsg.newBuilder()
459 480 .setValidateCredResponseMsg(builder.build()).build();
460 481 } catch (JsonProcessingException e) {
461   - log.warn("[{}] Failed to lookup device by id", deviceId, e);
  482 + log.warn("[{}] Failed to lookup device by id", credentials.getDeviceId(), e);
462 483 return getEmptyTransportApiResponse();
463 484 }
464 485 }, MoreExecutors.directExecutor());
... ...
... ... @@ -33,6 +33,7 @@ import java.util.Arrays;
33 33 "org.thingsboard.server.transport.*.attributes.request.sql.*Test",
34 34 "org.thingsboard.server.transport.*.claim.sql.*Test",
35 35 "org.thingsboard.server.transport.*.provision.sql.*Test",
  36 + "org.thingsboard.server.transport.*.credentials.*Test",
36 37 "org.thingsboard.server.transport.lwm2m.sql.*Test"
37 38 })
38 39 public class TransportSqlTestSuite {
... ...
  1 +/**
  2 + * Copyright © 2016-2021 The Thingsboard Authors
  3 + *
  4 + * Licensed under the Apache License, Version 2.0 (the "License");
  5 + * you may not use this file except in compliance with the License.
  6 + * You may obtain a copy of the License at
  7 + *
  8 + * http://www.apache.org/licenses/LICENSE-2.0
  9 + *
  10 + * Unless required by applicable law or agreed to in writing, software
  11 + * distributed under the License is distributed on an "AS IS" BASIS,
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 + * See the License for the specific language governing permissions and
  14 + * limitations under the License.
  15 + */
  16 +package org.thingsboard.server.transport.mqtt.credentials;
  17 +
  18 +import com.fasterxml.jackson.core.type.TypeReference;
  19 +import org.apache.commons.lang3.RandomStringUtils;
  20 +import org.eclipse.paho.client.mqttv3.MqttAsyncClient;
  21 +import org.eclipse.paho.client.mqttv3.MqttConnectOptions;
  22 +import org.eclipse.paho.client.mqttv3.MqttException;
  23 +import org.eclipse.paho.client.mqttv3.MqttSecurityException;
  24 +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence;
  25 +import org.junit.After;
  26 +import org.junit.Assert;
  27 +import org.junit.Before;
  28 +import org.junit.Test;
  29 +import org.thingsboard.common.util.JacksonUtil;
  30 +import org.thingsboard.server.common.data.Device;
  31 +import org.thingsboard.server.common.data.StringUtils;
  32 +import org.thingsboard.server.common.data.Tenant;
  33 +import org.thingsboard.server.common.data.User;
  34 +import org.thingsboard.server.common.data.device.credentials.BasicMqttCredentials;
  35 +import org.thingsboard.server.common.data.device.profile.MqttTopics;
  36 +import org.thingsboard.server.common.data.security.Authority;
  37 +import org.thingsboard.server.common.data.security.DeviceCredentials;
  38 +import org.thingsboard.server.common.data.security.DeviceCredentialsType;
  39 +import org.thingsboard.server.dao.service.DaoSqlTest;
  40 +import org.thingsboard.server.transport.mqtt.AbstractMqttIntegrationTest;
  41 +
  42 +import java.util.Arrays;
  43 +import java.util.HashSet;
  44 +import java.util.List;
  45 +import java.util.Set;
  46 +
  47 +import static org.junit.Assert.assertEquals;
  48 +import static org.junit.Assert.assertNotNull;
  49 +import static org.junit.Assert.assertNull;
  50 +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
  51 +
  52 +@DaoSqlTest
  53 +public class BasicMqttCredentialsTest extends AbstractMqttIntegrationTest {
  54 +
  55 + public static final String CLIENT_ID = "ClientId";
  56 + public static final String USER_NAME1 = "UserName1";
  57 + public static final String USER_NAME2 = "UserName2";
  58 + public static final String USER_NAME3 = "UserName3";
  59 + public static final String PASSWORD = "secret";
  60 +
  61 + private Device clientIdDevice;
  62 + private Device clientIdAndUserNameDevice1;
  63 + private Device clientIdAndUserNameAndPasswordDevice2;
  64 + private Device clientIdAndUserNameAndPasswordDevice3;
  65 + private Device accessTokenDevice;
  66 + private Device accessToken2Device;
  67 +
  68 +
  69 + @Before
  70 + public void before() throws Exception {
  71 + loginSysAdmin();
  72 +
  73 + Tenant tenant = new Tenant();
  74 + tenant.setTitle("My tenant");
  75 + savedTenant = doPost("/api/tenant", tenant, Tenant.class);
  76 + Assert.assertNotNull(savedTenant);
  77 +
  78 + tenantAdmin = new User();
  79 + tenantAdmin.setAuthority(Authority.TENANT_ADMIN);
  80 + tenantAdmin.setTenantId(savedTenant.getId());
  81 + tenantAdmin.setEmail("tenant" + atomicInteger.getAndIncrement() + "@thingsboard.org");
  82 + tenantAdmin.setFirstName("Joe");
  83 + tenantAdmin.setLastName("Downs");
  84 +
  85 + tenantAdmin = createUserAndLogin(tenantAdmin, "testPassword1");
  86 +
  87 + BasicMqttCredentials credValue = new BasicMqttCredentials();
  88 + credValue.setClientId(CLIENT_ID);
  89 + clientIdDevice = createDevice("clientIdDevice", credValue);
  90 +
  91 + credValue = new BasicMqttCredentials();
  92 + credValue.setClientId(CLIENT_ID);
  93 + credValue.setUserName(USER_NAME1);
  94 + clientIdAndUserNameDevice1 = createDevice("clientIdAndUserNameDevice", credValue);
  95 +
  96 + credValue = new BasicMqttCredentials();
  97 + credValue.setClientId(CLIENT_ID);
  98 + credValue.setUserName(USER_NAME2);
  99 + credValue.setPassword(PASSWORD);
  100 + clientIdAndUserNameAndPasswordDevice2 = createDevice("clientIdAndUserNameAndPasswordDevice", credValue);
  101 +
  102 + credValue = new BasicMqttCredentials();
  103 + credValue.setClientId(CLIENT_ID);
  104 + credValue.setUserName(USER_NAME3);
  105 + credValue.setPassword(PASSWORD);
  106 + clientIdAndUserNameAndPasswordDevice3 = createDevice("clientIdAndUserNameAndPasswordDevice2", credValue);
  107 +
  108 + accessTokenDevice = createDevice("accessTokenDevice", USER_NAME1);
  109 + accessToken2Device = createDevice("accessToken2Device", USER_NAME2);
  110 + }
  111 +
  112 + @Test
  113 + public void testCorrectCredentials() throws Exception {
  114 + // Check that correct devices receive telemetry
  115 + testTelemetryIsDelivered(accessTokenDevice, getMqttAsyncClient(null, USER_NAME1, null));
  116 + testTelemetryIsDelivered(clientIdDevice, getMqttAsyncClient(CLIENT_ID, null, null));
  117 + testTelemetryIsDelivered(clientIdAndUserNameDevice1, getMqttAsyncClient(CLIENT_ID, USER_NAME1, null));
  118 + testTelemetryIsDelivered(clientIdAndUserNameAndPasswordDevice2, getMqttAsyncClient(CLIENT_ID, USER_NAME2, PASSWORD));
  119 +
  120 + // Also correct. Random clientId and password, but matches access token
  121 + testTelemetryIsDelivered(accessToken2Device, getMqttAsyncClient(RandomStringUtils.randomAlphanumeric(10), USER_NAME2, RandomStringUtils.randomAlphanumeric(10)));
  122 + }
  123 +
  124 + @Test(expected = MqttSecurityException.class)
  125 + public void testCorrectClientIdAndUserNameButWrongPassword() throws Exception {
  126 + // Not correct. Correct clientId and username, but wrong password
  127 + testTelemetryIsNotDelivered(clientIdAndUserNameAndPasswordDevice3, getMqttAsyncClient(CLIENT_ID, USER_NAME3, "WRONG PASSWORD"));
  128 + }
  129 +
  130 + private void testTelemetryIsDelivered(Device device, MqttAsyncClient client) throws Exception {
  131 + testTelemetryIsDelivered(device, client, true);
  132 + }
  133 +
  134 + private void testTelemetryIsNotDelivered(Device device, MqttAsyncClient client) throws Exception {
  135 + testTelemetryIsDelivered(device, client, false);
  136 + }
  137 +
  138 + private void testTelemetryIsDelivered(Device device, MqttAsyncClient client, boolean ok) throws Exception {
  139 + String randomKey = RandomStringUtils.randomAlphanumeric(10);
  140 + List<String> expectedKeys = Arrays.asList(randomKey);
  141 + publishMqttMsg(client, JacksonUtil.toString(JacksonUtil.newObjectNode().put(randomKey, true)).getBytes(), MqttTopics.DEVICE_TELEMETRY_TOPIC);
  142 +
  143 + String deviceId = device.getId().getId().toString();
  144 +
  145 + long start = System.currentTimeMillis();
  146 + long end = System.currentTimeMillis() + 5000;
  147 +
  148 + List<String> actualKeys = null;
  149 + while (start <= end) {
  150 + actualKeys = doGetAsyncTyped("/api/plugins/telemetry/DEVICE/" + deviceId + "/keys/timeseries", new TypeReference<>() {
  151 + });
  152 + if (actualKeys.size() == expectedKeys.size()) {
  153 + break;
  154 + }
  155 + Thread.sleep(100);
  156 + start += 100;
  157 + }
  158 + if (ok) {
  159 + assertNotNull(actualKeys);
  160 +
  161 + Set<String> actualKeySet = new HashSet<>(actualKeys);
  162 + Set<String> expectedKeySet = new HashSet<>(expectedKeys);
  163 +
  164 + assertEquals(expectedKeySet, actualKeySet);
  165 + } else {
  166 + assertNull(actualKeys);
  167 + }
  168 + client.disconnect().waitForCompletion();
  169 + }
  170 +
  171 + @After
  172 + public void after() throws Exception {
  173 + processAfterTest();
  174 + }
  175 +
  176 + protected MqttAsyncClient getMqttAsyncClient(String clientId, String username, String password) throws MqttException {
  177 + if (StringUtils.isEmpty(clientId)) {
  178 + clientId = MqttAsyncClient.generateClientId();
  179 + }
  180 + MqttAsyncClient client = new MqttAsyncClient(MQTT_URL, clientId, new MemoryPersistence());
  181 +
  182 + MqttConnectOptions options = new MqttConnectOptions();
  183 + if (StringUtils.isNotEmpty(username)) {
  184 + options.setUserName(username);
  185 + }
  186 + if (StringUtils.isNotEmpty(password)) {
  187 + options.setPassword(password.toCharArray());
  188 + }
  189 + client.connect(options).waitForCompletion();
  190 + return client;
  191 + }
  192 +
  193 + private Device createDevice(String deviceName, BasicMqttCredentials clientIdCredValue) throws Exception {
  194 + Device device = new Device();
  195 + device.setName(deviceName);
  196 + device.setType("default");
  197 +
  198 + device = doPost("/api/device", device, Device.class);
  199 +
  200 + DeviceCredentials clientIdCred =
  201 + doGet("/api/device/" + device.getId().getId().toString() + "/credentials", DeviceCredentials.class);
  202 +
  203 + clientIdCred.setCredentialsType(DeviceCredentialsType.MQTT_BASIC);
  204 +
  205 +
  206 + clientIdCred.setCredentialsValue(JacksonUtil.toString(clientIdCredValue));
  207 + doPost("/api/device/credentials", clientIdCred).andExpect(status().isOk());
  208 + return device;
  209 + }
  210 +
  211 + private Device createDevice(String deviceName, String accessToken) throws Exception {
  212 + Device device = new Device();
  213 + device.setName(deviceName);
  214 + device.setType("default");
  215 +
  216 + device = doPost("/api/device", device, Device.class);
  217 +
  218 + DeviceCredentials clientIdCred =
  219 + doGet("/api/device/" + device.getId().getId().toString() + "/credentials", DeviceCredentials.class);
  220 +
  221 + clientIdCred.setCredentialsType(DeviceCredentialsType.ACCESS_TOKEN);
  222 + clientIdCred.setCredentialsId(accessToken);
  223 + doPost("/api/device/credentials", clientIdCred).andExpect(status().isOk());
  224 + return device;
  225 + }
  226 +}
... ...
... ... @@ -19,6 +19,8 @@ import org.thingsboard.server.common.data.id.DeviceId;
19 19 import org.thingsboard.server.common.data.id.TenantId;
20 20 import org.thingsboard.server.common.data.security.DeviceCredentials;
21 21
  22 +import java.util.List;
  23 +
22 24 public interface DeviceCredentialsService {
23 25
24 26 DeviceCredentials findDeviceCredentialsByDeviceId(TenantId tenantId, DeviceId deviceId);
... ... @@ -32,4 +34,5 @@ public interface DeviceCredentialsService {
32 34 void formatCredentials(DeviceCredentials deviceCredentials);
33 35
34 36 void deleteDeviceCredentials(TenantId tenantId, DeviceCredentials deviceCredentials);
  37 +
35 38 }
... ...