package org.keycloak.models.sessions.infinispan;

import org.infinispan.Cache;
import org.infinispan.distexec.mapreduce.MapReduceTask;
import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientSessionModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakTransaction;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.UsernameLoginFailureModel;
import org.keycloak.models.sessions.infinispan.entities.ClientSessionEntity;
import org.keycloak.models.sessions.infinispan.entities.LoginFailureEntity;
import org.keycloak.models.sessions.infinispan.entities.LoginFailureKey;
import org.keycloak.models.sessions.infinispan.entities.SessionEntity;
import org.keycloak.models.sessions.infinispan.entities.UserSessionEntity;
import org.keycloak.models.sessions.infinispan.mapreduce.ClientSessionMapper;
import org.keycloak.models.sessions.infinispan.mapreduce.FirstResultReducer;
import org.keycloak.models.sessions.infinispan.mapreduce.LargestResultReducer;
import org.keycloak.models.sessions.infinispan.mapreduce.SessionMapper;
import org.keycloak.models.sessions.infinispan.mapreduce.UserSessionMapper;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.util.Time;

import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/**
 * @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
 */
public class InfinispanUserSessionProvider implements UserSessionProvider {

    private final KeycloakSession session;
    private final Cache<String, SessionEntity> sessionCache;
    private final Cache<LoginFailureKey, LoginFailureEntity> loginFailureCache;
    private final InfinispanKeycloakTransaction tx;

    public InfinispanUserSessionProvider(KeycloakSession session, Cache<String, SessionEntity> sessionCache, Cache<LoginFailureKey, LoginFailureEntity> loginFailureCache) {
        this.session = session;
        this.sessionCache = sessionCache;
        this.loginFailureCache = loginFailureCache;
        this.tx = new InfinispanKeycloakTransaction();

        session.getTransaction().enlistAfterCompletion(tx);
    }

    @Override
    public ClientSessionModel createClientSession(RealmModel realm, ClientModel client) {
        String id = KeycloakModelUtils.generateId();

        ClientSessionEntity entity = new ClientSessionEntity();
        entity.setId(id);
        entity.setRealm(realm.getId());
        entity.setTimestamp(Time.currentTime());
        entity.setClient(client.getId());

        tx.put(sessionCache, id, entity);

        return wrap(realm, entity);
    }

    @Override
    public UserSessionModel createUserSession(RealmModel realm, UserModel user, String loginUsername, String ipAddress, String authMethod, boolean rememberMe) {
        String id = KeycloakModelUtils.generateId();

        UserSessionEntity entity = new UserSessionEntity();
        entity.setId(id);
        entity.setRealm(realm.getId());
        entity.setUser(user.getId());
        entity.setLoginUsername(loginUsername);
        entity.setIpAddress(ipAddress);
        entity.setAuthMethod(authMethod);
        entity.setRememberMe(rememberMe);

        int currentTime = Time.currentTime();

        entity.setStarted(currentTime);
        entity.setLastSessionRefresh(currentTime);

        tx.put(sessionCache, id, entity);

        return wrap(realm, entity);
    }

    @Override
    public ClientSessionModel getClientSession(RealmModel realm, String id) {
        ClientSessionEntity entity = (ClientSessionEntity) sessionCache.get(id);
        return wrap(realm, entity);
    }

    @Override
    public ClientSessionModel getClientSession(String id) {
        ClientSessionEntity entity = (ClientSessionEntity) sessionCache.get(id);
        if (entity != null) {
            RealmModel realm = session.realms().getRealm(entity.getRealm());
            return wrap(realm, entity);
        }
        return null;
    }

    @Override
    public UserSessionModel getUserSession(RealmModel realm, String id) {
        UserSessionEntity entity = (UserSessionEntity) sessionCache.get(id);
        return wrap(realm, entity);
    }

    @Override
    public List<UserSessionModel> getUserSessions(RealmModel realm, UserModel user) {
        Map<String, UserSessionEntity> sessions = new MapReduceTask(sessionCache)
                .mappedWith(UserSessionMapper.create(realm.getId()).user(user.getId()))
                .reducedWith(new FirstResultReducer())
                .execute();

        return wrapUserSessions(realm, sessions.values());
    }

    @Override
    public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client) {
        return getUserSessions(realm, client, -1, -1);
    }

    @Override
    public List<UserSessionModel> getUserSessions(RealmModel realm, ClientModel client, int firstResult, int maxResults) {
        Map<String, Integer> map = new MapReduceTask(sessionCache)
                .mappedWith(ClientSessionMapper.create(realm.getId()).client(client.getId()).emitUserSessionAndTimestamp())
                .reducedWith(new LargestResultReducer())
                .execute();

        List<Map.Entry<String, Integer>> sessionTimestamps = new LinkedList<Map.Entry<String, Integer>>(map.entrySet());

        Collections.sort(sessionTimestamps, new Comparator<Map.Entry<String, Integer>>() {
            @Override
            public int compare(Map.Entry<String, Integer> e1, Map.Entry<String, Integer> e2) {
                return e1.getValue().compareTo(e2.getValue());
            }
        });

        if (firstResult != -1 || maxResults == -1) {
            if (firstResult == -1) {
                firstResult = 0;
            }

            if (maxResults == -1) {
                maxResults = Integer.MAX_VALUE;
            }

            if (firstResult > sessionTimestamps.size()) {
                return Collections.emptyList();
            }

            int toIndex = (firstResult + maxResults) < sessionTimestamps.size() ? firstResult + maxResults : sessionTimestamps.size();
            sessionTimestamps = sessionTimestamps.subList(firstResult, toIndex);
        }

        List<UserSessionModel> userSessions = new LinkedList<UserSessionModel>();
        for (Map.Entry<String, Integer> e : sessionTimestamps) {
            UserSessionEntity userSessionEntity = (UserSessionEntity) sessionCache.get(e.getKey());
            if (userSessionEntity != null) {
                userSessions.add(wrap(realm, userSessionEntity));
            }
        }

        return userSessions;
    }

    @Override
    public int getActiveUserSessions(RealmModel realm, ClientModel client) {
        Map map = new MapReduceTask(sessionCache)
                .mappedWith(ClientSessionMapper.create(realm.getId()).client(client.getId()).emitUserSessionAndTimestamp())
                .reducedWith(new LargestResultReducer()).execute();

        return map.size();
    }

    @Override
    public void removeUserSession(RealmModel realm, UserSessionModel session) {
        removeUserSession(realm, session.getId());
    }

    @Override
    public void removeUserSessions(RealmModel realm, UserModel user) {
        Map<String, String> sessions = new MapReduceTask(sessionCache)
                .mappedWith(UserSessionMapper.create(realm.getId()).user(user.getId()).emitKey())
                .reducedWith(new FirstResultReducer())
                .execute();

        for (String id : sessions.keySet()) {
            removeUserSession(realm, id);
        }
    }

    @Override
    public void removeExpiredUserSessions(RealmModel realm) {
        int expired = Time.currentTime() - realm.getSsoSessionMaxLifespan();
        int expiredRefresh = Time.currentTime() - realm.getSsoSessionIdleTimeout();

        Map<String, String> map = new MapReduceTask(sessionCache)
                .mappedWith(UserSessionMapper.create(realm.getId()).expired(expired, expiredRefresh).emitKey())
                .reducedWith(new FirstResultReducer())
                .execute();

        for (String id : map.keySet()) {
            removeUserSession(realm, id);
        }
    }

    @Override
    public void removeUserSessions(RealmModel realm) {
        Map<String, String> ids = new MapReduceTask(sessionCache)
                .mappedWith(SessionMapper.create(realm.getId()).emitKey())
                .reducedWith(new FirstResultReducer())
                .execute();

        for (String id : ids.keySet()) {
            sessionCache.remove(id);
        }
    }

    @Override
    public UsernameLoginFailureModel getUserLoginFailure(RealmModel realm, String username) {
        LoginFailureKey key = new LoginFailureKey(realm.getId(), username);
        return wrap(key, loginFailureCache.get(key));
    }

    @Override
    public UsernameLoginFailureModel addUserLoginFailure(RealmModel realm, String username) {
        LoginFailureKey key = new LoginFailureKey(realm.getId(), username);
        LoginFailureEntity entity = new LoginFailureEntity();
        entity.setRealm(realm.getId());
        entity.setUsername(username);
        tx.put(loginFailureCache, key, entity);
        return wrap(key, entity);
    }

    @Override
    public void onRealmRemoved(RealmModel realm) {
        removeUserSessions(realm);
    }

    @Override
    public void onClientRemoved(RealmModel realm, ClientModel client) {
        Map<String, String> map = new MapReduceTask(sessionCache)
                .mappedWith(ClientSessionMapper.create(realm.getId()).client(client.getId()).emitKey())
                .reducedWith(new FirstResultReducer())
                .execute();

        for (String id : map.keySet()) {
            tx.remove(sessionCache, id);
        }
    }

    @Override
    public void onUserRemoved(RealmModel realm, UserModel user) {
        removeUserSessions(realm, user);

        loginFailureCache.remove(new LoginFailureKey(realm.getId(), user.getUsername()));
        loginFailureCache.remove(new LoginFailureKey(realm.getId(), user.getEmail()));
    }

    @Override
    public void close() {
    }

    void attachSession(UserSessionModel userSession, ClientSessionModel clientSession) {
        UserSessionEntity entity = ((UserSessionAdapter) userSession).getEntity();
        String clientSessionId = clientSession.getId();
        if (entity.getClientSessions() == null) {
            entity.setClientSessions(new HashSet<String>());
        }
        if (!entity.getClientSessions().contains(clientSessionId)) {
            entity.getClientSessions().add(clientSessionId);
            tx.replace(sessionCache, entity.getId(), entity);
        }
    }

    void dettachSession(UserSessionModel userSession, ClientSessionModel clientSession) {
        UserSessionEntity entity = ((UserSessionAdapter) userSession).getEntity();
        String clientSessionId = clientSession.getId();
        if (entity.getClientSessions() != null && entity.getClientSessions().contains(clientSessionId)) {
            entity.getClientSessions().remove(clientSessionId);
            if (entity.getClientSessions().isEmpty()) {
                entity.setClientSessions(null);
            }
            tx.replace(sessionCache, entity.getId(), entity);
        }
    }

    void removeUserSession(RealmModel realm, String userSessionId) {
        tx.remove(sessionCache, userSessionId);

        Map<String, String> map = new MapReduceTask(sessionCache)
                .mappedWith(ClientSessionMapper.create(realm.getId()).userSession(userSessionId).emitKey())
                .reducedWith(new FirstResultReducer())
                .execute();

        for (String id : map.keySet()) {
            tx.remove(sessionCache, id);
        }
    }

    InfinispanKeycloakTransaction getTx() {
        return tx;
    }

    UserSessionModel wrap(RealmModel realm, UserSessionEntity entity) {
        return entity != null ? new UserSessionAdapter(session, this, sessionCache, realm, entity) : null;
    }

    List<UserSessionModel> wrapUserSessions(RealmModel realm, Collection<UserSessionEntity> entities) {
        List<UserSessionModel> models = new LinkedList<UserSessionModel>();
        for (UserSessionEntity e : entities) {
            models.add(wrap(realm, e));
        }
        return models;
    }

    ClientSessionModel wrap(RealmModel realm, ClientSessionEntity entity) {
        return entity != null ? new ClientSessionAdapter(session, this, sessionCache, realm, entity) : null;
    }


    UsernameLoginFailureModel wrap(LoginFailureKey key, LoginFailureEntity entity) {
        return entity != null ? new UsernameLoginFailureAdapter(this, loginFailureCache, key, entity) : null;
    }

    List<ClientSessionModel> wrapClientSessions(RealmModel realm, Collection<ClientSessionEntity> entities) {
        List<ClientSessionModel> models = new LinkedList<ClientSessionModel>();
        for (ClientSessionEntity e : entities) {
            models.add(wrap(realm, e));
        }
        return models;
    }

    class InfinispanKeycloakTransaction implements KeycloakTransaction {

        private boolean active;
        private boolean rollback;
        private Map<Object, CacheTask> tasks = new HashMap<Object, CacheTask>();

        @Override
        public void begin() {
            active = true;
        }

        @Override
        public void commit() {
            if (rollback) {
                throw new RuntimeException("Rollback only!");
            }

            for (CacheTask task : tasks.values()) {
                task.execute();
            }
        }

        @Override
        public void rollback() {
            tasks.clear();
        }

        @Override
        public void setRollbackOnly() {
            rollback = true;
        }

        @Override
        public boolean getRollbackOnly() {
            return rollback;
        }

        @Override
        public boolean isActive() {
            return active;
        }

        public void put(Cache cache, Object key, Object value) {
            if (tasks.containsKey(key)) {
                throw new IllegalStateException("Can't add session: task in progress for session");
            } else {
                tasks.put(key, new CacheTask(cache, CacheOperation.ADD, key, value));
            }
        }

        public void replace(Cache cache, Object key, Object value) {
            CacheTask current = tasks.get(key);
            if (current != null) {
                switch (current.operation) {
                    case ADD:
                    case REPLACE:
                        current.value = value;
                        return;
                    case REMOVE:
                        throw new IllegalStateException("Can't remove session: task in progress for session");
                }
            } else {
                tasks.put(key, new CacheTask(cache, CacheOperation.REPLACE, key, value));
            }
        }

        public void remove(Cache cache, String key) {
            tasks.put(key, new CacheTask(cache, CacheOperation.REMOVE, key, null));
        }

        public class CacheTask {
            private Cache cache;
            private CacheOperation operation;
            private Object key;
            private Object value;

            public CacheTask(Cache cache, CacheOperation operation, Object key, Object value) {
                this.cache = cache;
                this.operation = operation;
                this.key = key;
                this.value = value;
            }

            public void execute() {
                switch (operation) {
                    case ADD:
                        cache.put(key, value);
                        break;
                    case REMOVE:
                        cache.remove(key);
                        break;
                    case REPLACE:
                        cache.replace(key, value);
                        break;
                }
            }
        }

    }

    public enum CacheOperation {
        ADD, REMOVE, REPLACE
    }

}
