keycloak-uncached

KEYCLOAK-4630 Refactor RemoteCacheSessionsLoader to use

8/11/2017 5:34:05 AM

Details

diff --git a/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/InfinispanUserSessionProviderFactory.java b/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/InfinispanUserSessionProviderFactory.java
index 110a812..489dd60 100755
--- a/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/InfinispanUserSessionProviderFactory.java
+++ b/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/InfinispanUserSessionProviderFactory.java
@@ -248,12 +248,12 @@ public class InfinispanUserSessionProviderFactory implements UserSessionProvider
 
     private void loadSessionsFromRemoteCaches(KeycloakSession session) {
         for (String cacheName : remoteCacheInvoker.getRemoteCacheNames()) {
-            loadSessionsFromRemoteCache(session.getKeycloakSessionFactory(), cacheName, getMaxErrors());
+            loadSessionsFromRemoteCache(session.getKeycloakSessionFactory(), cacheName, getSessionsPerSegment(), getMaxErrors());
         }
     }
 
 
-    private void loadSessionsFromRemoteCache(final KeycloakSessionFactory sessionFactory, String cacheName, final int maxErrors) {
+    private void loadSessionsFromRemoteCache(final KeycloakSessionFactory sessionFactory, String cacheName, final int sessionsPerSegment, final int maxErrors) {
         log.debugf("Check pre-loading userSessions from remote cache '%s'", cacheName);
 
         KeycloakModelUtils.runJobInTransaction(sessionFactory, new KeycloakSessionTask() {
@@ -263,8 +263,7 @@ public class InfinispanUserSessionProviderFactory implements UserSessionProvider
                 InfinispanConnectionProvider connections = session.getProvider(InfinispanConnectionProvider.class);
                 Cache<String, Serializable> workCache = connections.getCache(InfinispanConnectionProvider.WORK_CACHE_NAME);
 
-                // Use limit for sessionsPerSegment as RemoteCache bulk load doesn't have support for pagination :/
-                BaseCacheInitializer initializer = new SingleWorkerCacheInitializer(session, workCache, new RemoteCacheSessionsLoader(cacheName), "remoteCacheLoad::" + cacheName);
+                InfinispanCacheInitializer initializer = new InfinispanCacheInitializer(sessionFactory, workCache, new RemoteCacheSessionsLoader(cacheName), "remoteCacheLoad::" + cacheName, sessionsPerSegment, maxErrors);
 
                 initializer.initCache();
                 initializer.loadSessions();
diff --git a/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/initializer/BaseCacheInitializer.java b/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/initializer/BaseCacheInitializer.java
index 43788d0..cca28cc 100644
--- a/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/initializer/BaseCacheInitializer.java
+++ b/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/initializer/BaseCacheInitializer.java
@@ -106,7 +106,7 @@ public abstract class BaseCacheInitializer extends CacheInitializer {
 
 
     private InitializerState getStateFromCache() {
-        // We ignore cacheStore for now, so that in Cross-DC scenario (with RemoteStore enabled) is the remoteStore ignored. This means that every DC needs to load offline sessions separately.
+        // We ignore cacheStore for now, so that in Cross-DC scenario (with RemoteStore enabled) is the remoteStore ignored.
         return (InitializerState) workCache.getAdvancedCache()
                 .withFlags(Flag.SKIP_CACHE_STORE, Flag.SKIP_CACHE_LOAD)
                 .get(stateKey);
@@ -122,7 +122,7 @@ public abstract class BaseCacheInitializer extends CacheInitializer {
             public void run() {
 
                 // Save this synchronously to ensure all nodes read correct state
-                // We ignore cacheStore for now, so that in Cross-DC scenario (with RemoteStore enabled) is the remoteStore ignored. This means that every DC needs to load offline sessions separately.
+                // We ignore cacheStore for now, so that in Cross-DC scenario (with RemoteStore enabled) is the remoteStore ignored.
                 BaseCacheInitializer.this.workCache.getAdvancedCache().
                         withFlags(Flag.IGNORE_RETURN_VALUES, Flag.FORCE_SYNCHRONOUS, Flag.SKIP_CACHE_STORE, Flag.SKIP_CACHE_LOAD)
                         .put(stateKey, state);
diff --git a/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/remotestore/RemoteCacheSessionsLoader.java b/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/remotestore/RemoteCacheSessionsLoader.java
index 65c31bc..ba01b71 100644
--- a/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/remotestore/RemoteCacheSessionsLoader.java
+++ b/model/infinispan/src/main/java/org/keycloak/models/sessions/infinispan/remotestore/RemoteCacheSessionsLoader.java
@@ -18,10 +18,12 @@
 package org.keycloak.models.sessions.infinispan.remotestore;
 
 import java.io.Serializable;
+import java.util.HashMap;
 import java.util.Map;
 
 import org.infinispan.Cache;
 import org.infinispan.client.hotrod.RemoteCache;
+import org.infinispan.commons.marshall.Marshaller;
 import org.infinispan.context.Flag;
 import org.jboss.logging.Logger;
 import org.keycloak.connections.infinispan.InfinispanConnectionProvider;
@@ -40,8 +42,33 @@ public class RemoteCacheSessionsLoader implements SessionLoader {
 
     private static final Logger log = Logger.getLogger(RemoteCacheSessionsLoader.class);
 
-    // Hardcoded limit for now. See if needs to be configurable (or if preloading can be enabled/disabled in configuration)
-    public static final int LIMIT = 100000;
+
+    // Javascript to be executed on remote infinispan server (Flag CACHE_MODE_LOCAL assumes that remoteCache is replicated)
+    private static final String REMOTE_SCRIPT_FOR_LOAD_SESSIONS =
+            "function loadSessions() {" +
+            "  var flagClazz = cache.getClass().getClassLoader().loadClass(\"org.infinispan.context.Flag\"); \n" +
+            "  var localFlag = java.lang.Enum.valueOf(flagClazz, \"CACHE_MODE_LOCAL\"); \n" +
+            "  var cacheStream = cache.getAdvancedCache().withFlags([ localFlag ]).entrySet().stream();\n" +
+            "  var result = cacheStream.skip(first).limit(max).collect(java.util.stream.Collectors.toMap(\n" +
+            "    new java.util.function.Function() {\n" +
+            "      apply: function(entry) {\n" +
+            "        return entry.getKey();\n" +
+            "      }\n" +
+            "    },\n" +
+            "    new java.util.function.Function() {\n" +
+            "      apply: function(entry) {\n" +
+            "        return entry.getValue();\n" +
+            "      }\n" +
+            "    }\n" +
+            "  ));\n" +
+            "\n" +
+            "  cacheStream.close();\n" +
+            "  return result;\n" +
+            "};\n" +
+            "\n" +
+            "loadSessions();";
+
+
 
     private final String cacheName;
 
@@ -51,7 +78,15 @@ public class RemoteCacheSessionsLoader implements SessionLoader {
 
     @Override
     public void init(KeycloakSession session) {
+        RemoteCache remoteCache = InfinispanUtil.getRemoteCache(getCache(session));
+
+        RemoteCache<String, String> scriptCache = remoteCache.getRemoteCacheManager().getCache("___script_cache");
 
+        if (!scriptCache.containsKey("load-sessions.js")) {
+            scriptCache.put("load-sessions.js",
+                    "// mode=local,language=javascript\n" +
+                            REMOTE_SCRIPT_FOR_LOAD_SESSIONS);
+        }
     }
 
     @Override
@@ -67,21 +102,31 @@ public class RemoteCacheSessionsLoader implements SessionLoader {
 
         RemoteCache<?, ?> remoteCache = InfinispanUtil.getRemoteCache(cache);
 
-        int size = remoteCache.size();
+        // TODO:mposolda
+        log.infof("Will do bulk load of sessions from remote cache '%s' . First: %d, max: %d", cache.getName(), first, max);
 
-        if (size > LIMIT) {
-            log.infof("Skip bulk load of '%d' sessions from remote cache '%s'. Sessions will be retrieved lazily", size, cache.getName());
-            return true;
-        } else {
-            log.infof("Will do bulk load of '%d' sessions from remote cache '%s'", size, cache.getName());
-        }
 
+        Map<String, Integer> remoteParams = new HashMap<>();
+        remoteParams.put("first", first);
+        remoteParams.put("max", max);
+        Map<byte[], byte[]> remoteObjects = remoteCache.execute("load-sessions.js", remoteParams);
+
+        // TODO:mposolda
+        log.infof("Finished loading sessions '%s' . First: %d, max: %d", cache.getName(), first, max);
+
+        Marshaller marshaller = remoteCache.getRemoteCacheManager().getMarshaller();
+
+        for (Map.Entry<byte[], byte[]> entry : remoteObjects.entrySet()) {
+            try {
+                String key = (String) marshaller.objectFromByteBuffer(entry.getKey());
+                SessionEntity entity = (SessionEntity) marshaller.objectFromByteBuffer(entry.getValue());
 
-        for (Map.Entry<?, ?> entry : remoteCache.getBulk().entrySet()) {
-            SessionEntity entity = (SessionEntity) entry.getValue();
-            SessionEntityWrapper entityWrapper = new SessionEntityWrapper(entity);
+                SessionEntityWrapper entityWrapper = new SessionEntityWrapper(entity);
 
-            decoratedCache.putAsync(entry.getKey(), entityWrapper);
+                decoratedCache.putAsync(key, entityWrapper);
+            } catch (Exception e) {
+                log.warnf("Error loading session from remote cache", e);
+            }
         }
 
         return true;
diff --git a/testsuite/integration/src/test/java/org/keycloak/testsuite/util/cli/AbstractSessionCacheCommand.java b/testsuite/integration/src/test/java/org/keycloak/testsuite/util/cli/AbstractSessionCacheCommand.java
index f85a8e3..8ea51af 100644
--- a/testsuite/integration/src/test/java/org/keycloak/testsuite/util/cli/AbstractSessionCacheCommand.java
+++ b/testsuite/integration/src/test/java/org/keycloak/testsuite/util/cli/AbstractSessionCacheCommand.java
@@ -17,6 +17,8 @@
 
 package org.keycloak.testsuite.util.cli;
 
+import java.util.function.Function;
+
 import org.infinispan.AdvancedCache;
 import org.infinispan.Cache;
 import org.infinispan.context.Flag;
@@ -25,6 +27,7 @@ import org.keycloak.connections.infinispan.InfinispanConnectionProvider;
 import org.keycloak.models.KeycloakSession;
 import org.keycloak.models.RealmModel;
 import org.keycloak.models.UserModel;
+import org.keycloak.models.sessions.infinispan.changes.SessionEntityWrapper;
 import org.keycloak.models.sessions.infinispan.entities.SessionEntity;
 import org.keycloak.models.sessions.infinispan.entities.UserSessionEntity;
 import org.keycloak.models.utils.KeycloakModelUtils;
@@ -44,8 +47,20 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
             throw new HandledException();
         }
 
-        Cache<String, SessionEntity> ispnCache = provider.getCache(cacheName);
+        Cache<String, SessionEntityWrapper> ispnCache = provider.getCache(cacheName);
         doRunCacheCommand(session, ispnCache);
+
+        ispnCache.entrySet().stream().skip(0).limit(10).collect(java.util.stream.Collectors.toMap(new java.util.function.Function() {
+
+            public Object apply(Object entry) {
+                return ((java.util.Map.Entry) entry).getKey();
+            }
+        }, new java.util.function.Function() {
+
+            public Object apply(Object entry) {
+                return ((java.util.Map.Entry) entry).getValue();
+            }
+        }));
     }
 
     protected void printSession(String id, UserSessionEntity userSession) {
@@ -67,7 +82,7 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         return getName() + " <cache-name>";
     }
 
-    protected abstract void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache);
+    protected abstract void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache);
 
 
     // IMPLS
@@ -80,7 +95,7 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         }
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             UserSessionEntity userSession = new UserSessionEntity();
             String id = getArg(1);
 
@@ -88,7 +103,7 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
             userSession.setRealm(getArg(2));
 
             userSession.setLastSessionRefresh(Time.currentTime());
-            cache.put(id, userSession);
+            cache.put(id, new SessionEntityWrapper(userSession));
         }
 
         @Override
@@ -106,9 +121,9 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         }
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             String id = getArg(1);
-            UserSessionEntity userSession = (UserSessionEntity) cache.get(id);
+            UserSessionEntity userSession = (UserSessionEntity) cache.get(id).getEntity();
             printSession(id, userSession);
         }
 
@@ -127,13 +142,13 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         }
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             String id = getArg(1);
             int count = getIntArg(2);
 
             long start = System.currentTimeMillis();
             for (int i=0 ; i<count ; i++) {
-                UserSessionEntity userSession = (UserSessionEntity) cache.get(id);
+                UserSessionEntity userSession = (UserSessionEntity) cache.get(id).getEntity();
                 //printSession(id, userSession);
             }
             long took = System.currentTimeMillis() - start;
@@ -155,7 +170,7 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         }
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             String id = getArg(1);
             cache.remove(id);
         }
@@ -175,7 +190,7 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         }
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             cache.clear();
         }
     }
@@ -189,7 +204,7 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         }
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             log.info("Size: " + cache.size());
         }
     }
@@ -203,13 +218,13 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         }
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             for (String id : cache.keySet()) {
-                SessionEntity entity = cache.get(id);
+                SessionEntity entity = cache.get(id).getEntity();
                 if (!(entity instanceof UserSessionEntity)) {
                     continue;
                 }
-                UserSessionEntity userSession = (UserSessionEntity) cache.get(id);
+                UserSessionEntity userSession = (UserSessionEntity) cache.get(id).getEntity();
                 log.info("list: key=" + id + ", value=" + toString(userSession));
             }
         }
@@ -225,10 +240,10 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
 
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             String id = getArg(1);
             cache = ((AdvancedCache) cache).withFlags(Flag.CACHE_MODE_LOCAL);
-            UserSessionEntity userSession = (UserSessionEntity) cache.get(id);
+            UserSessionEntity userSession = (UserSessionEntity) cache.get(id).getEntity();
             printSession(id, userSession);
         }
 
@@ -247,7 +262,7 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         }
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             log.info("Size local: " + cache.getAdvancedCache().withFlags(Flag.CACHE_MODE_LOCAL).size());
         }
     }
@@ -261,7 +276,7 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         }
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             String realmName = getArg(1);
             int count = getIntArg(2);
             int batchCount = getIntArg(3);
@@ -275,7 +290,7 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
                     userSession.setRealm(realmName);
 
                     userSession.setLastSessionRefresh(Time.currentTime());
-                    cache.put(id, userSession);
+                    cache.put(id, new SessionEntityWrapper(userSession));
                 }
 
                 log.infof("Created '%d' sessions started from offset '%d'", countInIteration, firstInIteration);
@@ -301,7 +316,7 @@ public abstract class AbstractSessionCacheCommand extends AbstractCommand {
         }
 
         @Override
-        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntity> cache) {
+        protected void doRunCacheCommand(KeycloakSession session, Cache<String, SessionEntityWrapper> cache) {
             String realmName = getArg(1);
             String username = getArg(2);
             int count = getIntArg(3);