keycloak-aplcache

Merge pull request #2320 from mposolda/master KEYCLOAK-2523

3/3/2016 8:34:29 AM

Details

diff --git a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/StreamCacheRealmProvider.java b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/StreamCacheRealmProvider.java
index e53f510..c10bba4 100755
--- a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/StreamCacheRealmProvider.java
+++ b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/StreamCacheRealmProvider.java
@@ -305,7 +305,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             if (model == null) return null;
             if (invalidations.contains(id)) return model;
             cached = new CachedRealm(loaded, model);
-            cache.addRevisioned(cached);
+            cache.addRevisioned(cached, session);
         } else if (invalidations.contains(id)) {
             return getDelegate().getRealm(id);
         } else if (managedRealms.containsKey(id)) {
@@ -329,7 +329,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             if (model == null) return null;
             if (invalidations.contains(model.getId())) return model;
             query = new RealmListQuery(loaded, cacheKey, model.getId());
-            cache.addRevisioned(query);
+            cache.addRevisioned(query, session);
             return model;
         } else if (invalidations.contains(cacheKey)) {
             return getDelegate().getRealmByName(name);
@@ -435,7 +435,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             for (ClientModel client : model) ids.add(client.getId());
             query = new ClientListQuery(loaded, cacheKey, realm, ids);
             logger.tracev("adding realm clients cache miss: realm {0} key {1}", realm.getName(), cacheKey);
-            cache.addRevisioned(query);
+            cache.addRevisioned(query, session);
             return model;
         }
         List<ClientModel> list = new LinkedList<>();
@@ -508,7 +508,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             for (RoleModel role : model) ids.add(role.getId());
             query = new RoleListQuery(loaded, cacheKey, realm, ids);
             logger.tracev("adding realm roles cache miss: realm {0} key {1}", realm.getName(), cacheKey);
-            cache.addRevisioned(query);
+            cache.addRevisioned(query, session);
             return model;
         }
         Set<RoleModel> list = new HashSet<>();
@@ -544,7 +544,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             for (RoleModel role : model) ids.add(role.getId());
             query = new RoleListQuery(loaded, cacheKey, realm, ids, client.getClientId());
             logger.tracev("adding client roles cache miss: client {0} key {1}", client.getClientId(), cacheKey);
-            cache.addRevisioned(query);
+            cache.addRevisioned(query, session);
             return model;
         }
         Set<RoleModel> list = new HashSet<>();
@@ -593,7 +593,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             if (model == null) return null;
             query = new RoleListQuery(loaded, cacheKey, realm, model.getId());
             logger.tracev("adding realm role cache miss: client {0} key {1}", realm.getName(), cacheKey);
-            cache.addRevisioned(query);
+            cache.addRevisioned(query, session);
             return model;
         }
         RoleModel role = getRoleById(query.getRoles().iterator().next(), realm);
@@ -623,7 +623,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             if (model == null) return null;
             query = new RoleListQuery(loaded, cacheKey, realm, model.getId(), client.getClientId());
             logger.tracev("adding client role cache miss: client {0} key {1}", client.getClientId(), cacheKey);
-            cache.addRevisioned(query);
+            cache.addRevisioned(query, session);
             return model;
         }
         RoleModel role = getRoleById(query.getRoles().iterator().next(), realm);
@@ -660,7 +660,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             } else {
                 cached = new CachedRealmRole(loaded, model, realm);
             }
-            cache.addRevisioned(cached);
+            cache.addRevisioned(cached, session);
 
         } else if (invalidations.contains(id)) {
             return getDelegate().getRoleById(id, realm);
@@ -685,7 +685,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             if (model == null) return null;
             if (invalidations.contains(id)) return model;
             cached = new CachedGroup(loaded, realm, model);
-            cache.addRevisioned(cached);
+            cache.addRevisioned(cached, session);
 
         } else if (invalidations.contains(id)) {
             return getDelegate().getGroupById(id, realm);
@@ -725,7 +725,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             for (GroupModel client : model) ids.add(client.getId());
             query = new GroupListQuery(loaded, cacheKey, realm, ids);
             logger.tracev("adding realm getGroups cache miss: realm {0} key {1}", realm.getName(), cacheKey);
-            cache.addRevisioned(query);
+            cache.addRevisioned(query, session);
             return model;
         }
         List<GroupModel> list = new LinkedList<>();
@@ -761,7 +761,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             for (GroupModel client : model) ids.add(client.getId());
             query = new GroupListQuery(loaded, cacheKey, realm, ids);
             logger.tracev("adding realm getTopLevelGroups cache miss: realm {0} key {1}", realm.getName(), cacheKey);
-            cache.addRevisioned(query);
+            cache.addRevisioned(query, session);
             return model;
         }
         List<GroupModel> list = new LinkedList<>();
@@ -837,7 +837,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             if (invalidations.contains(id)) return model;
             cached = new CachedClient(loaded, realm, model);
             logger.tracev("adding client by id cache miss: {0}", cached.getClientId());
-            cache.addRevisioned(cached);
+            cache.addRevisioned(cached, session);
         } else if (invalidations.contains(id)) {
             return getDelegate().getClientById(id, realm);
         } else if (managedApplications.containsKey(id)) {
@@ -866,7 +866,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             id = model.getId();
             query = new ClientListQuery(loaded, cacheKey, realm, id);
             logger.tracev("adding client by name cache miss: {0}", clientId);
-            cache.addRevisioned(query);
+            cache.addRevisioned(query, session);
         } else if (invalidations.contains(cacheKey)) {
             return getDelegate().getClientByClientId(clientId, realm);
         } else {
@@ -895,7 +895,7 @@ public class StreamCacheRealmProvider implements CacheRealmProvider {
             if (model == null) return null;
             if (invalidations.contains(id)) return model;
             cached = new CachedClientTemplate(loaded, realm, model);
-            cache.addRevisioned(cached);
+            cache.addRevisioned(cached, session);
         } else if (invalidations.contains(id)) {
             return getDelegate().getClientTemplateById(id, realm);
         } else if (managedClientTemplates.containsKey(id)) {
diff --git a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/StreamRealmCache.java b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/StreamRealmCache.java
index 7ca6b7a..4815fb0 100755
--- a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/StreamRealmCache.java
+++ b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/StreamRealmCache.java
@@ -24,6 +24,7 @@ import org.infinispan.notifications.cachelistener.annotation.CacheEntryInvalidat
 import org.infinispan.notifications.cachelistener.event.CacheEntriesEvictedEvent;
 import org.infinispan.notifications.cachelistener.event.CacheEntryInvalidatedEvent;
 import org.jboss.logging.Logger;
+import org.keycloak.models.KeycloakSession;
 import org.keycloak.models.cache.infinispan.entities.AbstractRevisioned;
 import org.keycloak.models.cache.infinispan.entities.CachedClient;
 import org.keycloak.models.cache.infinispan.entities.CachedClientTemplate;
@@ -38,7 +39,7 @@ import org.keycloak.models.cache.infinispan.stream.HasRolePredicate;
 import org.keycloak.models.cache.infinispan.stream.InClientPredicate;
 import org.keycloak.models.cache.infinispan.stream.InRealmPredicate;
 import org.keycloak.models.cache.infinispan.stream.RealmQueryPredicate;
-import org.keycloak.models.cache.infinispan.stream.RoleQueryPredicate;
+import org.keycloak.models.utils.UpdateCounter;
 
 import java.util.HashSet;
 import java.util.Iterator;
@@ -73,7 +74,9 @@ public class StreamRealmCache {
 
     public Long getCurrentRevision(String id) {
         Long revision = revisions.get(id);
-        if (revision == null) revision = UpdateCounter.current();
+        if (revision == null) {
+            revision = UpdateCounter.current();
+        }
         // if you do cache.remove() on node 1 and the entry doesn't exist on node 2, node 2 never receives a invalidation event
         // so, we do this to force this.
         String invalidationKey = "invalidation.key" + id;
@@ -121,7 +124,7 @@ public class StreamRealmCache {
         Object rev = revisions.put(id, next);
     }
 
-    public void addRevisioned(Revisioned object) {
+    public void addRevisioned(Revisioned object, KeycloakSession session) {
         //startRevisionBatch();
         String id = object.getId();
         try {
@@ -135,12 +138,19 @@ public class StreamRealmCache {
             revisions.startBatch();
             if (!revisions.getAdvancedCache().lock(id)) {
                 logger.trace("Could not obtain version lock");
+                return;
             }
             rev = revisions.get(id);
             if (rev == null) {
                 if (id.endsWith("realm.clients")) logger.trace("addRevisioned rev2 == null realm.clients");
                 return;
             }
+            if (rev > session.getTransaction().getStartupRevision()) { // revision is ahead transaction start. Other transaction updated in the meantime. Don't cache
+                if (logger.isTraceEnabled()) {
+                    logger.tracev("Skipped cache. Current revision {0}, Transaction start revision {1}", object.getRevision(), session.getTransaction().getStartupRevision());
+                }
+                return;
+            }
             if (rev.equals(object.getRevision())) {
                 if (id.endsWith("realm.clients")) logger.tracev("adding Object.revision {0} rev {1}", object.getRevision(), rev);
                 cache.putForExternalRead(id, object);
diff --git a/server-spi/src/main/java/org/keycloak/models/KeycloakTransactionManager.java b/server-spi/src/main/java/org/keycloak/models/KeycloakTransactionManager.java
index 0e2dcbe..456de8a 100755
--- a/server-spi/src/main/java/org/keycloak/models/KeycloakTransactionManager.java
+++ b/server-spi/src/main/java/org/keycloak/models/KeycloakTransactionManager.java
@@ -23,6 +23,8 @@ package org.keycloak.models;
  */
 public interface KeycloakTransactionManager extends KeycloakTransaction {
 
+    long getStartupRevision();
+
     void enlist(KeycloakTransaction transaction);
     void enlistAfterCompletion(KeycloakTransaction transaction);
 
diff --git a/services/src/main/java/org/keycloak/services/DefaultKeycloakTransactionManager.java b/services/src/main/java/org/keycloak/services/DefaultKeycloakTransactionManager.java
index fca6a9e..0a208d6 100755
--- a/services/src/main/java/org/keycloak/services/DefaultKeycloakTransactionManager.java
+++ b/services/src/main/java/org/keycloak/services/DefaultKeycloakTransactionManager.java
@@ -18,6 +18,7 @@ package org.keycloak.services;
 
 import org.keycloak.models.KeycloakTransaction;
 import org.keycloak.models.KeycloakTransactionManager;
+import org.keycloak.models.utils.UpdateCounter;
 import org.keycloak.services.ServicesLogger;
 
 import java.util.LinkedList;
@@ -35,6 +36,12 @@ public class DefaultKeycloakTransactionManager implements KeycloakTransactionMan
     private List<KeycloakTransaction> afterCompletion = new LinkedList<KeycloakTransaction>();
     private boolean active;
     private boolean rollback;
+    private long startupRevision;
+
+    @Override
+    public long getStartupRevision() {
+        return startupRevision;
+    }
 
     @Override
     public void enlist(KeycloakTransaction transaction) {
@@ -69,6 +76,8 @@ public class DefaultKeycloakTransactionManager implements KeycloakTransactionMan
              throw new IllegalStateException("Transaction already active");
         }
 
+        startupRevision = UpdateCounter.current();
+
         for (KeycloakTransaction tx : transactions) {
             tx.begin();
         }
diff --git a/testsuite/integration/src/test/java/org/keycloak/testsuite/broker/AbstractKeycloakIdentityProviderTest.java b/testsuite/integration/src/test/java/org/keycloak/testsuite/broker/AbstractKeycloakIdentityProviderTest.java
index b5f21a0..bfe3298 100755
--- a/testsuite/integration/src/test/java/org/keycloak/testsuite/broker/AbstractKeycloakIdentityProviderTest.java
+++ b/testsuite/integration/src/test/java/org/keycloak/testsuite/broker/AbstractKeycloakIdentityProviderTest.java
@@ -464,72 +464,83 @@ public abstract class AbstractKeycloakIdentityProviderTest extends AbstractIdent
         setUpdateProfileFirstLogin(IdentityProviderRepresentation.UPFLM_ON);
         IdentityProviderModel identityProviderModel = getIdentityProviderModel();
 
-        identityProviderModel.setStoreToken(true);
+        setStoreToken(identityProviderModel, true);
+        try {
+            authenticateWithIdentityProvider(identityProviderModel, "test-user", true);
 
-        authenticateWithIdentityProvider(identityProviderModel, "test-user", true);
+            brokerServerRule.stopSession(session, true);
+            session = brokerServerRule.startSession();
 
-        brokerServerRule.stopSession(session, true);
-        session = brokerServerRule.startSession();
+            UserModel federatedUser = getFederatedUser();
+            RealmModel realm = getRealm();
+            Set<FederatedIdentityModel> federatedIdentities = this.session.users().getFederatedIdentities(federatedUser, realm);
 
-        UserModel federatedUser = getFederatedUser();
-        RealmModel realm = getRealm();
-        Set<FederatedIdentityModel> federatedIdentities = this.session.users().getFederatedIdentities(federatedUser, realm);
+            assertFalse(federatedIdentities.isEmpty());
+            assertEquals(1, federatedIdentities.size());
 
-        assertFalse(federatedIdentities.isEmpty());
-        assertEquals(1, federatedIdentities.size());
+            FederatedIdentityModel identityModel = federatedIdentities.iterator().next();
 
-        FederatedIdentityModel identityModel = federatedIdentities.iterator().next();
-
-        assertNotNull(identityModel.getToken());
-
-        UserSessionStatusServlet.UserSessionStatus userSessionStatus = retrieveSessionStatus();
-        String accessToken = userSessionStatus.getAccessTokenString();
-        URI tokenEndpointUrl = Urls.identityProviderRetrieveToken(BASE_URI, getProviderId(), realm.getName());
-        final String authHeader = "Bearer " + accessToken;
-        ClientRequestFilter authFilter = new ClientRequestFilter() {
-            @Override
-            public void filter(ClientRequestContext requestContext) throws IOException {
-                requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader);
-            }
-        };
-        Client client = ClientBuilder.newBuilder().register(authFilter).build();
-        WebTarget tokenEndpoint = client.target(tokenEndpointUrl);
-        Response response = tokenEndpoint.request().get();
-        assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
-        assertNotNull(response.readEntity(String.class));
-        revokeGrant();
+            assertNotNull(identityModel.getToken());
 
+            UserSessionStatusServlet.UserSessionStatus userSessionStatus = retrieveSessionStatus();
+            String accessToken = userSessionStatus.getAccessTokenString();
+            URI tokenEndpointUrl = Urls.identityProviderRetrieveToken(BASE_URI, getProviderId(), realm.getName());
+            final String authHeader = "Bearer " + accessToken;
+            ClientRequestFilter authFilter = new ClientRequestFilter() {
+                @Override
+                public void filter(ClientRequestContext requestContext) throws IOException {
+                    requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader);
+                }
+            };
+            Client client = ClientBuilder.newBuilder().register(authFilter).build();
+            WebTarget tokenEndpoint = client.target(tokenEndpointUrl);
+            Response response = tokenEndpoint.request().get();
+            assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
+            assertNotNull(response.readEntity(String.class));
+            revokeGrant();
 
-        driver.navigate().to("http://localhost:8081/test-app/logout");
-        String currentUrl = this.driver.getCurrentUrl();
-        System.out.println("after logout currentUrl: " + currentUrl);
-        assertTrue(currentUrl.startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
-
-        unconfigureUserRetrieveToken("test-user");
-        loginIDP("test-user");
-        //authenticateWithIdentityProvider(identityProviderModel, "test-user");
-        assertEquals("http://localhost:8081/test-app", driver.getCurrentUrl());
-
-        userSessionStatus = retrieveSessionStatus();
-        accessToken = userSessionStatus.getAccessTokenString();
-        final String authHeader2 = "Bearer " + accessToken;
-        ClientRequestFilter authFilter2 = new ClientRequestFilter() {
-            @Override
-            public void filter(ClientRequestContext requestContext) throws IOException {
-                requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader2);
-            }
-        };
-        client = ClientBuilder.newBuilder().register(authFilter2).build();
-        tokenEndpoint = client.target(tokenEndpointUrl);
-        response = tokenEndpoint.request().get();
-
-        assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus());
 
-        revokeGrant();
-        driver.navigate().to("http://localhost:8081/test-app/logout");
-        driver.navigate().to("http://localhost:8081/test-app");
+            driver.navigate().to("http://localhost:8081/test-app/logout");
+            String currentUrl = this.driver.getCurrentUrl();
+            System.out.println("after logout currentUrl: " + currentUrl);
+            assertTrue(currentUrl.startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
+
+            unconfigureUserRetrieveToken("test-user");
+            loginIDP("test-user");
+            //authenticateWithIdentityProvider(identityProviderModel, "test-user");
+            assertEquals("http://localhost:8081/test-app", driver.getCurrentUrl());
+
+            userSessionStatus = retrieveSessionStatus();
+            accessToken = userSessionStatus.getAccessTokenString();
+            final String authHeader2 = "Bearer " + accessToken;
+            ClientRequestFilter authFilter2 = new ClientRequestFilter() {
+                @Override
+                public void filter(ClientRequestContext requestContext) throws IOException {
+                    requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader2);
+                }
+            };
+            client = ClientBuilder.newBuilder().register(authFilter2).build();
+            tokenEndpoint = client.target(tokenEndpointUrl);
+            response = tokenEndpoint.request().get();
+
+            assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus());
+
+            revokeGrant();
+            driver.navigate().to("http://localhost:8081/test-app/logout");
+            driver.navigate().to("http://localhost:8081/test-app");
 
-        assertTrue(this.driver.getCurrentUrl().startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
+            assertTrue(this.driver.getCurrentUrl().startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
+        } finally {
+            setStoreToken(identityProviderModel, false);
+        }
+    }
+
+    private void setStoreToken(IdentityProviderModel identityProviderModel, boolean storeToken) {
+        identityProviderModel.setStoreToken(storeToken);
+        getRealm().updateIdentityProvider(identityProviderModel);
+
+        brokerServerRule.stopSession(session, storeToken);
+        session = brokerServerRule.startSession();
     }
 
     protected abstract void doAssertTokenRetrieval(String pageSource);