thingsboard-aplcache

WS rate limites

11/13/2018 7:17:10 AM

Details

diff --git a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java
index 34b6833..7e905f2 100644
--- a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java
+++ b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java
@@ -18,14 +18,19 @@ package org.thingsboard.server.controller.plugin;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.beans.factory.BeanCreationNotAllowedException;
 import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.context.annotation.Lazy;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.scheduling.annotation.Scheduled;
 import org.springframework.stereotype.Service;
 import org.springframework.web.socket.CloseStatus;
 import org.springframework.web.socket.TextMessage;
 import org.springframework.web.socket.WebSocketSession;
 import org.springframework.web.socket.handler.TextWebSocketHandler;
+import org.thingsboard.server.common.data.id.CustomerId;
+import org.thingsboard.server.common.data.id.TenantId;
+import org.thingsboard.server.common.data.id.UserId;
 import org.thingsboard.server.config.WebSocketConfiguration;
 import org.thingsboard.server.service.security.model.SecurityUser;
+import org.thingsboard.server.service.security.model.UserPrincipal;
 import org.thingsboard.server.service.telemetry.SessionEvent;
 import org.thingsboard.server.service.telemetry.TelemetryWebSocketMsgEndpoint;
 import org.thingsboard.server.service.telemetry.TelemetryWebSocketService;
@@ -34,6 +39,7 @@ import org.thingsboard.server.service.telemetry.TelemetryWebSocketSessionRef;
 import java.io.IOException;
 import java.net.URI;
 import java.security.InvalidParameterException;
+import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
@@ -48,12 +54,26 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
     @Autowired
     private TelemetryWebSocketService webSocketService;
 
+    @Value("${server.ws.limits.max_sessions_per_tenant:0}")
+    private int maxSessionsPerTenant;
+    @Value("${server.ws.limits.max_sessions_per_customer:0}")
+    private int maxSessionsPerCustomer;
+    @Value("${server.ws.limits.max_sessions_per_regular_user:0}")
+    private int maxSessionsPerRegularUser;
+    @Value("${server.ws.limits.max_sessions_per_public_user:0}")
+    private int maxSessionsPerPublicUser;
+
+    private ConcurrentMap<TenantId, Set<String>> tenantSessionsMap = new ConcurrentHashMap<>();
+    private ConcurrentMap<CustomerId, Set<String>> customerSessionsMap = new ConcurrentHashMap<>();
+    private ConcurrentMap<UserId, Set<String>> regularUserSessionsMap = new ConcurrentHashMap<>();
+    private ConcurrentMap<UserId, Set<String>> publicUserSessionsMap = new ConcurrentHashMap<>();
+    
     @Override
     public void handleTextMessage(WebSocketSession session, TextMessage message) {
         try {
             SessionMetaData sessionMd = internalSessionMap.get(session.getId());
             if (sessionMd != null) {
-                log.info("[{}][{}] Processing {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message);
+                log.info("[{}][{}] Processing {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message.getPayload());
                 webSocketService.handleWebSocketMsg(sessionMd.sessionRef, message.getPayload());
             } else {
                 log.warn("[{}] Failed to find session", session.getId());
@@ -71,12 +91,15 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
             String internalSessionId = session.getId();
             TelemetryWebSocketSessionRef sessionRef = toRef(session);
             String externalSessionId = sessionRef.getSessionId();
+            if (!checkLimits(session, sessionRef)) {
+                return;
+            }
             internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef));
             externalSessionMap.put(externalSessionId, internalSessionId);
             processInWebSocketService(sessionRef, SessionEvent.onEstablished());
             log.info("[{}][{}][{}] Session is opened", sessionRef.getSecurityCtx().getTenantId(), externalSessionId, session.getId());
         } catch (InvalidParameterException e) {
-            log.warn("[[{}] Failed to start session", session.getId(), e);
+            log.warn("[{}] Failed to start session", session.getId(), e);
             session.close(CloseStatus.BAD_DATA.withReason(e.getMessage()));
         } catch (Exception e) {
             log.warn("[{}] Failed to start session", session.getId(), e);
@@ -101,6 +124,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
         super.afterConnectionClosed(session, closeStatus);
         SessionMetaData sessionMd = internalSessionMap.remove(session.getId());
         if (sessionMd != null) {
+            cleanupLimits(session, sessionMd.sessionRef);
             externalSessionMap.remove(sessionMd.sessionRef.getSessionId());
             processInWebSocketService(sessionMd.sessionRef, SessionEvent.onClosed());
         }
@@ -136,7 +160,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
         private final WebSocketSession session;
         private final TelemetryWebSocketSessionRef sessionRef;
 
-        public SessionMetaData(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) {
+        SessionMetaData(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) {
             super();
             this.session = session;
             this.sessionRef = sessionRef;
@@ -162,15 +186,21 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
         }
     }
 
+
     @Override
     public void close(TelemetryWebSocketSessionRef sessionRef) throws IOException {
+        close(sessionRef, CloseStatus.NORMAL);
+    }
+
+    @Override
+    public void close(TelemetryWebSocketSessionRef sessionRef, CloseStatus reason) throws IOException {
         String externalId = sessionRef.getSessionId();
         log.debug("[{}] Processing close request", externalId);
         String internalId = externalSessionMap.get(externalId);
         if (internalId != null) {
             SessionMetaData sessionMd = internalSessionMap.get(internalId);
             if (sessionMd != null) {
-                sessionMd.session.close(CloseStatus.NORMAL);
+                sessionMd.session.close(reason);
             } else {
                 log.warn("[{}][{}] Failed to find session by internal id", externalId, internalId);
             }
@@ -179,4 +209,94 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements Telemetr
         }
     }
 
+    private boolean checkLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) throws Exception {
+        String sessionId = session.getId();
+        if (maxSessionsPerTenant > 0) {
+            Set<String> tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet());
+            synchronized (tenantSessions) {
+                if (tenantSessions.size() < maxSessionsPerTenant) {
+                    tenantSessions.add(sessionId);
+                } else {
+                    log.info("[{}][{}][{}] Failed to start session. Max tenant sessions limit reached"
+                            , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
+                    session.close(CloseStatus.POLICY_VIOLATION.withReason("Max tenant sessions limit reached!"));
+                    return false;
+                }
+            }
+        }
+
+        if (sessionRef.getSecurityCtx().isCustomerUser()) {
+            if (maxSessionsPerCustomer > 0) {
+                Set<String> customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet());
+                synchronized (customerSessions) {
+                    if (customerSessions.size() < maxSessionsPerCustomer) {
+                        customerSessions.add(sessionId);
+                    } else {
+                        log.info("[{}][{}][{}] Failed to start session. Max customer sessions limit reached"
+                                , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
+                        session.close(CloseStatus.POLICY_VIOLATION.withReason("Max customer sessions limit reached"));
+                        return false;
+                    }
+                }
+            }
+            if (maxSessionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
+                Set<String> regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
+                synchronized (regularUserSessions) {
+                    if (regularUserSessions.size() < maxSessionsPerRegularUser) {
+                        regularUserSessions.add(sessionId);
+                    } else {
+                        log.info("[{}][{}][{}] Failed to start session. Max user sessions limit reached"
+                                , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
+                        session.close(CloseStatus.POLICY_VIOLATION.withReason("Max regular user sessions limit reached"));
+                        return false;
+                    }
+                }
+            }
+            if (maxSessionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
+                Set<String> publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
+                synchronized (publicUserSessions) {
+                    if (publicUserSessions.size() < maxSessionsPerPublicUser) {
+                        publicUserSessions.add(sessionId);
+                    } else {
+                        log.info("[{}][{}][{}] Failed to start session. Max user sessions limit reached"
+                                , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId);
+                        session.close(CloseStatus.POLICY_VIOLATION.withReason("Max public user sessions limit reached"));
+                        return false;
+                    }
+                }
+            }
+        }
+        return true;
+    }
+
+    private void cleanupLimits(WebSocketSession session, TelemetryWebSocketSessionRef sessionRef) {
+        String sessionId = session.getId();
+        if (maxSessionsPerTenant > 0) {
+            Set<String> tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet());
+            synchronized (tenantSessions) {
+                tenantSessions.remove(sessionId);
+            }
+        }
+        if (sessionRef.getSecurityCtx().isCustomerUser()) {
+            if (maxSessionsPerCustomer > 0) {
+                Set<String> customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet());
+                synchronized (customerSessions) {
+                    customerSessions.remove(sessionId);
+                }
+            }
+            if (maxSessionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
+                Set<String> regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
+                synchronized (regularUserSessions) {
+                    regularUserSessions.remove(sessionId);
+                }
+            }
+            if (maxSessionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
+                Set<String> publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
+                synchronized (publicUserSessions) {
+                    publicUserSessions.remove(sessionId);
+                }
+            }
+        }
+    }
+
 }
diff --git a/application/src/main/java/org/thingsboard/server/service/telemetry/DefaultTelemetryWebSocketService.java b/application/src/main/java/org/thingsboard/server/service/telemetry/DefaultTelemetryWebSocketService.java
index ef4f819..3712e2c 100644
--- a/application/src/main/java/org/thingsboard/server/service/telemetry/DefaultTelemetryWebSocketService.java
+++ b/application/src/main/java/org/thingsboard/server/service/telemetry/DefaultTelemetryWebSocketService.java
@@ -23,12 +23,17 @@ import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
 import org.springframework.stereotype.Service;
 import org.springframework.util.StringUtils;
+import org.springframework.web.socket.CloseStatus;
+import org.springframework.web.socket.WebSocketSession;
 import org.thingsboard.server.common.data.DataConstants;
+import org.thingsboard.server.common.data.id.CustomerId;
 import org.thingsboard.server.common.data.id.EntityId;
 import org.thingsboard.server.common.data.id.EntityIdFactory;
 import org.thingsboard.server.common.data.id.TenantId;
+import org.thingsboard.server.common.data.id.UserId;
 import org.thingsboard.server.common.data.kv.Aggregation;
 import org.thingsboard.server.common.data.kv.AttributeKvEntry;
 import org.thingsboard.server.common.data.kv.BaseReadTsKvQuery;
@@ -42,6 +47,7 @@ import org.thingsboard.server.service.security.AccessValidator;
 import org.thingsboard.server.service.security.ValidationCallback;
 import org.thingsboard.server.service.security.ValidationResult;
 import org.thingsboard.server.service.security.ValidationResultCode;
+import org.thingsboard.server.service.security.model.UserPrincipal;
 import org.thingsboard.server.service.telemetry.cmd.AttributesSubscriptionCmd;
 import org.thingsboard.server.service.telemetry.cmd.GetHistoryCmd;
 import org.thingsboard.server.service.telemetry.cmd.SubscriptionCmd;
@@ -64,6 +70,7 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -112,11 +119,25 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
     @Autowired
     private TimeseriesService tsService;
 
+    @Value("${server.ws.limits.max_subscriptions_per_tenant:0}")
+    private int maxSubscriptionsPerTenant;
+    @Value("${server.ws.limits.max_subscriptions_per_customer:0}")
+    private int maxSubscriptionsPerCustomer;
+    @Value("${server.ws.limits.max_subscriptions_per_regular_user:0}")
+    private int maxSubscriptionsPerRegularUser;
+    @Value("${server.ws.limits.max_subscriptions_per_public_user:0}")
+    private int maxSubscriptionsPerPublicUser;
+
+    private ConcurrentMap<TenantId, Set<String>> tenantSubscriptionsMap = new ConcurrentHashMap<>();
+    private ConcurrentMap<CustomerId, Set<String>> customerSubscriptionsMap = new ConcurrentHashMap<>();
+    private ConcurrentMap<UserId, Set<String>> regularUserSubscriptionsMap = new ConcurrentHashMap<>();
+    private ConcurrentMap<UserId, Set<String>> publicUserSubscriptionsMap = new ConcurrentHashMap<>();
+
     private ExecutorService executor;
 
     @PostConstruct
     public void initExecutor() {
-        executor = new ThreadPoolExecutor(0, 50, 60L, TimeUnit.SECONDS,  new LinkedBlockingQueue<>());
+        executor = new ThreadPoolExecutor(0, 50, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue<>());
     }
 
     @PreDestroy
@@ -140,6 +161,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
             case CLOSED:
                 wsSessionsMap.remove(sessionId);
                 subscriptionManager.cleanupLocalWsSessionSubscriptions(sessionRef, sessionId);
+                processSessionClose(sessionRef);
                 break;
         }
     }
@@ -154,10 +176,18 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
             TelemetryPluginCmdsWrapper cmdsWrapper = jsonMapper.readValue(msg, TelemetryPluginCmdsWrapper.class);
             if (cmdsWrapper != null) {
                 if (cmdsWrapper.getAttrSubCmds() != null) {
-                    cmdsWrapper.getAttrSubCmds().forEach(cmd -> handleWsAttributesSubscriptionCmd(sessionRef, cmd));
+                    cmdsWrapper.getAttrSubCmds().forEach(cmd -> {
+                        if (processSubscription(sessionRef, cmd)) {
+                            handleWsAttributesSubscriptionCmd(sessionRef, cmd);
+                        }
+                    });
                 }
                 if (cmdsWrapper.getTsSubCmds() != null) {
-                    cmdsWrapper.getTsSubCmds().forEach(cmd -> handleWsTimeseriesSubscriptionCmd(sessionRef, cmd));
+                    cmdsWrapper.getTsSubCmds().forEach(cmd -> {
+                        if (processSubscription(sessionRef, cmd)) {
+                            handleWsTimeseriesSubscriptionCmd(sessionRef, cmd);
+                        }
+                    });
                 }
                 if (cmdsWrapper.getHistoryCmds() != null) {
                     cmdsWrapper.getHistoryCmds().forEach(cmd -> handleWsHistoryCmd(sessionRef, cmd));
@@ -178,6 +208,105 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
         }
     }
 
+    private void processSessionClose(TelemetryWebSocketSessionRef sessionRef) {
+        String sessionId = "[" + sessionRef.getSessionId() + "]";
+        if (maxSubscriptionsPerTenant > 0) {
+            Set<String> tenantSubscriptions = tenantSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet());
+            synchronized (tenantSubscriptions) {
+                tenantSubscriptions.removeIf(subId -> subId.startsWith(sessionId));
+            }
+        }
+        if (sessionRef.getSecurityCtx().isCustomerUser()) {
+            if (maxSubscriptionsPerCustomer > 0) {
+                Set<String> customerSessions = customerSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet());
+                synchronized (customerSessions) {
+                    customerSessions.removeIf(subId -> subId.startsWith(sessionId));
+                }
+            }
+            if (maxSubscriptionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
+                Set<String> regularUserSessions = regularUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
+                synchronized (regularUserSessions) {
+                    regularUserSessions.removeIf(subId -> subId.startsWith(sessionId));
+                }
+            }
+            if (maxSubscriptionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
+                Set<String> publicUserSessions = publicUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
+                synchronized (publicUserSessions) {
+                    publicUserSessions.removeIf(subId -> subId.startsWith(sessionId));
+                }
+            }
+        }
+    }
+
+    private boolean processSubscription(TelemetryWebSocketSessionRef sessionRef, SubscriptionCmd cmd) {
+        String subId = "[" + sessionRef.getSessionId() + "]:[" + cmd.getCmdId() + "]";
+        try {
+            if (maxSubscriptionsPerTenant > 0) {
+                Set<String> tenantSubscriptions = tenantSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet());
+                synchronized (tenantSubscriptions) {
+                    if (cmd.isUnsubscribe()) {
+                        tenantSubscriptions.remove(subId);
+                    } else if (tenantSubscriptions.size() < maxSubscriptionsPerTenant) {
+                        tenantSubscriptions.add(subId);
+                    } else {
+                        log.info("[{}][{}][{}] Failed to start subscription. Max tenant subscriptions limit reached"
+                                , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId);
+                        msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max tenant subscriptions limit reached!"));
+                        return false;
+                    }
+                }
+            }
+
+            if (sessionRef.getSecurityCtx().isCustomerUser()) {
+                if (maxSubscriptionsPerCustomer > 0) {
+                    Set<String> customerSessions = customerSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet());
+                    synchronized (customerSessions) {
+                        if (cmd.isUnsubscribe()) {
+                            customerSessions.remove(subId);
+                        } else if (customerSessions.size() < maxSubscriptionsPerCustomer) {
+                            customerSessions.add(subId);
+                        } else {
+                            log.info("[{}][{}][{}] Failed to start subscription. Max customer sessions limit reached"
+                                    , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId);
+                            msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max customer subscriptions limit reached"));
+                            return false;
+                        }
+                    }
+                }
+                if (maxSubscriptionsPerRegularUser > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
+                    Set<String> regularUserSessions = regularUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
+                    synchronized (regularUserSessions) {
+                        if (regularUserSessions.size() < maxSubscriptionsPerRegularUser) {
+                            regularUserSessions.add(subId);
+                        } else {
+                            log.info("[{}][{}][{}] Failed to start subscription. Max user sessions limit reached"
+                                    , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId);
+                            msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max regular user subscriptions limit reached"));
+                            return false;
+                        }
+                    }
+                }
+                if (maxSubscriptionsPerPublicUser > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) {
+                    Set<String> publicUserSessions = publicUserSubscriptionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet());
+                    synchronized (publicUserSessions) {
+                        if (publicUserSessions.size() < maxSubscriptionsPerPublicUser) {
+                            publicUserSessions.add(subId);
+                        } else {
+                            log.info("[{}][{}][{}] Failed to start subscription. Max user sessions limit reached"
+                                    , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), subId);
+                            msgEndpoint.close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Max public user subscriptions limit reached"));
+                            return false;
+                        }
+                    }
+                }
+            }
+        } catch (IOException e) {
+            log.warn("[{}] Failed to send session close: {}", sessionRef.getSessionId(), e);
+            return false;
+        }
+        return true;
+    }
+
     private void handleWsAttributesSubscriptionCmd(TelemetryWebSocketSessionRef sessionRef, AttributesSubscriptionCmd cmd) {
         String sessionId = sessionRef.getSessionId();
         log.debug("[{}] Processing: {}", sessionId, cmd);
@@ -220,7 +349,7 @@ public class DefaultTelemetryWebSocketService implements TelemetryWebSocketServi
             public void onFailure(Throwable e) {
                 log.error(FAILED_TO_FETCH_ATTRIBUTES, e);
                 SubscriptionUpdate update;
-                if (UnauthorizedException.class.isInstance(e)) {
+                if (e instanceof UnauthorizedException) {
                     update = new SubscriptionUpdate(cmd.getCmdId(), SubscriptionErrorCode.UNAUTHORIZED,
                             SubscriptionErrorCode.UNAUTHORIZED.getDefaultMsg());
                 } else {
diff --git a/application/src/main/java/org/thingsboard/server/service/telemetry/TelemetryWebSocketMsgEndpoint.java b/application/src/main/java/org/thingsboard/server/service/telemetry/TelemetryWebSocketMsgEndpoint.java
index 00fb80a..c21d6fd 100644
--- a/application/src/main/java/org/thingsboard/server/service/telemetry/TelemetryWebSocketMsgEndpoint.java
+++ b/application/src/main/java/org/thingsboard/server/service/telemetry/TelemetryWebSocketMsgEndpoint.java
@@ -15,6 +15,8 @@
  */
 package org.thingsboard.server.service.telemetry;
 
+import org.springframework.web.socket.CloseStatus;
+
 import java.io.IOException;
 
 /**
@@ -26,4 +28,5 @@ public interface TelemetryWebSocketMsgEndpoint {
 
     void close(TelemetryWebSocketSessionRef sessionRef) throws IOException;
 
+    void close(TelemetryWebSocketSessionRef sessionRef, CloseStatus withReason) throws IOException;
 }
diff --git a/application/src/main/resources/thingsboard.yml b/application/src/main/resources/thingsboard.yml
index ed18a0a..abb7079 100644
--- a/application/src/main/resources/thingsboard.yml
+++ b/application/src/main/resources/thingsboard.yml
@@ -32,6 +32,17 @@ server:
     # Alias that identifies the key in the key store
     key-alias: "${SSL_KEY_ALIAS:tomcat}"
   log_controller_error_stack_trace: "${HTTP_LOG_CONTROLLER_ERROR_STACK_TRACE:true}"
+  ws:
+    limits:
+      # Limit the amount of sessions and subscriptions available on each server. Put values to zero to disable particular limitation
+      max_sessions_per_tenant: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_TENANT:0}"
+      max_sessions_per_customer: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_CUSTOMER:0}"
+      max_sessions_per_regular_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_REGULAR_USER:0}"
+      max_sessions_per_public_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SESSIONS_PER_PUBLIC_USER:0}"
+      max_subscriptions_per_tenant: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_TENANT:0}"
+      max_subscriptions_per_customer: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_CUSTOMER:0}"
+      max_subscriptions_per_regular_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_REGULAR_USER:0}"
+      max_subscriptions_per_public_user: "${TB_SERVER_WS_TENANT_RATE_LIMITS_MAX_SUBSCRIPTIONS_PER_PUBLIC_USER:0}"
 
 # Zookeeper connection parameters. Used for service discovery.
 zk: