BruteForceProtector.java

306 lines | 10.693 kB Blame History Raw Download
package org.keycloak.services.managers;


import org.jboss.resteasy.logging.Logger;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UsernameLoginFailureModel;
import org.keycloak.services.ClientConnection;

import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

/**
 * A single thread will log failures.  This is so that we can avoid concurrent writes as we want an accurate failure count
 *
 * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
 * @version $Revision: 1 $
 */
public class BruteForceProtector implements Runnable {
    protected static Logger logger = Logger.getLogger(BruteForceProtector.class);

    protected int maxFailureWaitSeconds = 900;
    protected int minimumQuickLoginWaitSeconds = 60;
    protected int waitIncrementSeconds = 60;
    protected long quickLoginCheckMilliSeconds = 1000;
    protected long maxDeltaTimeMilliSeconds = 60 * 60 * 12 * 1000; // 12 hours
    protected int failureFactor = 30;
    protected volatile boolean run = true;
    protected KeycloakSessionFactory factory;
    protected CountDownLatch shutdownLatch = new CountDownLatch(1);

    protected volatile long failures;
    protected volatile long lastFailure;
    protected volatile long totalTime;

    protected LinkedBlockingQueue<LoginEvent> queue = new LinkedBlockingQueue<LoginEvent>();
    public static final int TRANSACTION_SIZE = 20;


    protected abstract class LoginEvent implements Comparable<LoginEvent> {
        protected final String realmId;
        protected final String username;
        protected final String ip;

        protected LoginEvent(String realmId, String username, String ip) {
            this.realmId = realmId;
            this.username = username;
            this.ip = ip;
        }

        @Override
        public int compareTo(LoginEvent o) {
            return username.compareTo(o.username);
        }
    }

    protected class SuccessfulLogin extends LoginEvent {
        public SuccessfulLogin(String realmId, String userId, String ip) {
            super(realmId, userId, ip);
        }
    }

    protected class ShutdownEvent extends LoginEvent {
        public ShutdownEvent() {
            super(null, null, null);
        }
    }

    protected class FailedLogin extends LoginEvent {
        protected final CountDownLatch latch = new CountDownLatch(1);

        public FailedLogin(String realmId, String username, String ip) {
            super(realmId, username, ip);
        }
    }

    public BruteForceProtector(KeycloakSessionFactory factory) {
        this.factory = factory;
    }

    public void failure(KeycloakSession session, LoginEvent event) {
        UsernameLoginFailureModel user = getUserModel(session, event);
        if (user == null) return;
        user.setLastIPFailure(event.ip);
        long currentTime = System.currentTimeMillis();
        long last = user.getLastFailure();
        long deltaTime = 0;
        if (last > 0) {
            deltaTime = currentTime - last;
        }
        user.setLastFailure(currentTime);
        if (deltaTime > 0) {
            // if last failure was more than MAX_DELTA clear failures
            if (deltaTime > maxDeltaTimeMilliSeconds) {
                user.clearFailures();
            }
        }
        user.incrementFailures();

        int waitSeconds = waitIncrementSeconds * (user.getNumFailures() / failureFactor);
        if (waitSeconds == 0) {
            if (deltaTime > quickLoginCheckMilliSeconds) {
                waitSeconds = minimumQuickLoginWaitSeconds;
            }
        }
        waitSeconds = Math.min(maxFailureWaitSeconds, waitSeconds);
        if (waitSeconds > 0) {
            user.setFailedLoginNotBefore((int) (currentTime / 1000) + waitSeconds);
        }
    }

    protected UsernameLoginFailureModel getUserModel(KeycloakSession session, LoginEvent event) {
        RealmModel realm = getRealmModel(session, event);
        if (realm == null) return null;
        UsernameLoginFailureModel user = realm.getUserLoginFailure(event.username);
        if (user == null) return null;
        return user;
    }

    protected RealmModel getRealmModel(KeycloakSession session, LoginEvent event) {
        RealmModel realm = session.getRealm(event.realmId);
        if (realm == null) return null;
        return realm;
    }

    public void start() {
        new Thread(this, "Brute Force Protector").start();
    }

    public void shutdown() {
        run = false;
        try {
            queue.offer(new ShutdownEvent());
            shutdownLatch.await(5, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }


    public void run() {
        final ArrayList<LoginEvent> events = new ArrayList<LoginEvent>(TRANSACTION_SIZE + 1);
        while (run) {
            try {
                LoginEvent take = queue.poll(2, TimeUnit.SECONDS);
                if (take == null) {
                    continue;
                }
                try {
                    events.add(take);
                    queue.drainTo(events, TRANSACTION_SIZE);
                    for (LoginEvent event : events) {
                        if (event instanceof FailedLogin) {
                            logFailure(event);
                        } else if (event instanceof SuccessfulLogin) {
                            logSuccess(event);
                        }
                    }

                    Collections.sort(events); // we sort to avoid deadlock due to ordered updates.  Maybe I'm overthinking this.
                    KeycloakSession session = factory.createSession();
                    try {
                        for (LoginEvent event : events) {
                            if (event instanceof FailedLogin) {
                                failure(session, event);
                            }
                        }
                        session.getTransaction().commit();
                    } catch (Exception e) {
                        session.getTransaction().rollback();
                        throw e;
                    } finally {
                        for (LoginEvent event : events) {
                            if (event instanceof FailedLogin) {
                                ((FailedLogin) event).latch.countDown();
                            }
                        }
                        events.clear();
                        session.close();
                    }
                } catch (Exception e) {
                    logger.error("Failed processing event", e);
                }
            } catch (InterruptedException e) {
                break;
            } finally {
                shutdownLatch.countDown();
            }
        }
    }

    protected void logSuccess(LoginEvent event) {
        logger.warn("login success for user " + event.username + " from ip " + event.ip);
    }

    protected void logFailure(LoginEvent event) {
        logger.warn("login failure for user " + event.username + " from ip " + event.ip);
        failures++;
        long delta = 0;
        if (lastFailure > 0) {
            delta = System.currentTimeMillis() - lastFailure;
            if (delta > maxDeltaTimeMilliSeconds) {
                totalTime = 0;

            } else {
                totalTime += delta;
            }
        }
    }

    public void successfulLogin(RealmModel realm, String username, ClientConnection clientConnection) {
        logger.info("successful login user: " + username + " from ip " + clientConnection.getRemoteAddr());
    }

    public void invalidUser(RealmModel realm, String username, ClientConnection clientConnection) {
        logger.warn("invalid user: " + username + " from ip " + clientConnection.getRemoteAddr());
        // todo more?
    }

    public void failedLogin(RealmModel realm, String username, ClientConnection clientConnection) {
        try {
            FailedLogin event = new FailedLogin(realm.getId(), username, clientConnection.getRemoteAddr());
            queue.offer(event);
            // wait a minimum of seconds for event to process so that a hacker
            // cannot flood with failed logins and overwhelm the queue and not have notBefore updated to block next requests
            // todo failure HTTP responses should be queued via async HTTP
            event.latch.await(5, TimeUnit.SECONDS);

        } catch (InterruptedException e) {
        }
    }

    public boolean isTemporarilyDisabled(RealmModel realm, String username) {
        UsernameLoginFailureModel failure = realm.getUserLoginFailure(username);
        if (failure == null) {
            return false;
        }

        int currTime = (int)(System.currentTimeMillis()/1000);
        if (currTime < failure.getFailedLoginNotBefore()) {
            return true;
        }
        return false;
    }

    public long getFailures() {
        return failures;
    }

    public long getLastFailure() {
        return lastFailure;
    }

    public int getMaxFailureWaitSeconds() {
        return maxFailureWaitSeconds;
    }

    public void setMaxFailureWaitSeconds(int maxFailureWaitSeconds) {
        this.maxFailureWaitSeconds = maxFailureWaitSeconds;
    }

    public int getMinimumQuickLoginWaitSeconds() {
        return minimumQuickLoginWaitSeconds;
    }

    public void setMinimumQuickLoginWaitSeconds(int minimumQuickLoginWaitSeconds) {
        this.minimumQuickLoginWaitSeconds = minimumQuickLoginWaitSeconds;
    }

    public int getWaitIncrementSeconds() {
        return waitIncrementSeconds;
    }

    public void setWaitIncrementSeconds(int waitIncrementSeconds) {
        this.waitIncrementSeconds = waitIncrementSeconds;
    }

    public long getQuickLoginCheckMilliSeconds() {
        return quickLoginCheckMilliSeconds;
    }

    public void setQuickLoginCheckMilliSeconds(long quickLoginCheckMilliSeconds) {
        this.quickLoginCheckMilliSeconds = quickLoginCheckMilliSeconds;
    }

    public long getMaxDeltaTimeMilliSeconds() {
        return maxDeltaTimeMilliSeconds;
    }

    public void setMaxDeltaTimeMilliSeconds(long maxDeltaTimeMilliSeconds) {
        this.maxDeltaTimeMilliSeconds = maxDeltaTimeMilliSeconds;
    }

    public int getFailureFactor() {
        return failureFactor;
    }

    public void setFailureFactor(int failureFactor) {
        this.failureFactor = failureFactor;
    }
}