keycloak-uncached

session transaction

7/11/2014 8:29:11 PM

Details

diff --git a/model/api/src/main/java/org/keycloak/models/KeycloakSession.java b/model/api/src/main/java/org/keycloak/models/KeycloakSession.java
index 409e77e..6eab3fd 100755
--- a/model/api/src/main/java/org/keycloak/models/KeycloakSession.java
+++ b/model/api/src/main/java/org/keycloak/models/KeycloakSession.java
@@ -23,8 +23,22 @@ public interface KeycloakSession {
 
     <T extends Provider> Set<T> getAllProviders(Class<T> clazz);
 
+    /**
+     * Returns a managed provider instance.  Will start a provider transaction.  This transaction is managed by the KeycloakSession
+     * transaction.
+     *
+     * @return
+     * @throws IllegalStateException if transaction is not active
+     */
     ModelProvider model();
 
+    /**
+     * Returns a managed provider instance.  Will start a provider transaction.  This transaction is managed by the KeycloakSession
+     * transaction.
+     *
+     * @return
+     * @throws IllegalStateException if transaction is not active
+     */
     UserSessionProvider sessions();
 
     void close();
diff --git a/model/api/src/main/java/org/keycloak/models/UserProvider.java b/model/api/src/main/java/org/keycloak/models/UserProvider.java
new file mode 100755
index 0000000..f4e2bbd
--- /dev/null
+++ b/model/api/src/main/java/org/keycloak/models/UserProvider.java
@@ -0,0 +1,37 @@
+package org.keycloak.models;
+
+import org.keycloak.provider.Provider;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
+ * @version $Revision: 1 $
+ */
+public interface UserProvider extends Provider {
+    // Note: The reason there are so many query methods here is for layering a cache on top of an persistent KeycloakSession
+
+    KeycloakTransaction getTransaction();
+
+    UserModel addUser(RealmModel realm, String id, String username, boolean addDefaultRoles);
+    UserModel addUser(RealmModel realm, String username);
+    boolean removeUser(RealmModel realm, String name);
+
+    UserModel getUserById(String id, RealmModel realm);
+    UserModel getUserByUsername(String username, RealmModel realm);
+    UserModel getUserByEmail(String email, RealmModel realm);
+    UserModel getUserBySocialLink(SocialLinkModel socialLink, RealmModel realm);
+    List<UserModel> getUsers(RealmModel realm);
+    List<UserModel> searchForUser(String search, RealmModel realm);
+    List<UserModel> searchForUserByAttributes(Map<String, String> attributes, RealmModel realm);
+    Set<SocialLinkModel> getSocialLinks(UserModel user, RealmModel realm);
+    SocialLinkModel getSocialLink(UserModel user, String socialProvider, RealmModel realm);
+
+    void preRemove(RealmModel realm);
+    void preRemove(RoleModel role);
+
+
+    void close();
+}
diff --git a/model/api/src/main/java/org/keycloak/models/UserProviderFactory.java b/model/api/src/main/java/org/keycloak/models/UserProviderFactory.java
new file mode 100755
index 0000000..f052e39
--- /dev/null
+++ b/model/api/src/main/java/org/keycloak/models/UserProviderFactory.java
@@ -0,0 +1,10 @@
+package org.keycloak.models;
+
+import org.keycloak.provider.ProviderFactory;
+
+/**
+ * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
+ * @version $Revision: 1 $
+ */
+public interface UserProviderFactory extends ProviderFactory<UserProvider> {
+}
diff --git a/model/jpa/pom.xml b/model/jpa/pom.xml
index 987abd2..fb1ba9d 100755
--- a/model/jpa/pom.xml
+++ b/model/jpa/pom.xml
@@ -53,6 +53,13 @@
             <scope>provided</scope>
         </dependency>
         <dependency>
+            <groupId>org.keycloak</groupId>
+            <artifactId>keycloak-model-sessions-mem</artifactId>
+            <version>${project.version}</version>
+            <scope>test</scope>
+
+        </dependency>
+        <dependency>
             <groupId>org.jboss.resteasy</groupId>
             <artifactId>resteasy-jaxrs</artifactId>
             <scope>provided</scope>
diff --git a/model/jpa/src/main/java/org/keycloak/models/jpa/JpaUserProvider.java b/model/jpa/src/main/java/org/keycloak/models/jpa/JpaUserProvider.java
new file mode 100755
index 0000000..48b4730
--- /dev/null
+++ b/model/jpa/src/main/java/org/keycloak/models/jpa/JpaUserProvider.java
@@ -0,0 +1,266 @@
+package org.keycloak.models.jpa;
+
+import org.keycloak.models.ApplicationModel;
+import org.keycloak.models.KeycloakSession;
+import org.keycloak.models.KeycloakTransaction;
+import org.keycloak.models.ModelProvider;
+import org.keycloak.models.OAuthClientModel;
+import org.keycloak.models.RealmModel;
+import org.keycloak.models.RoleModel;
+import org.keycloak.models.SocialLinkModel;
+import org.keycloak.models.UserModel;
+import org.keycloak.models.UserProvider;
+import org.keycloak.models.jpa.entities.ApplicationEntity;
+import org.keycloak.models.jpa.entities.OAuthClientEntity;
+import org.keycloak.models.jpa.entities.RealmEntity;
+import org.keycloak.models.jpa.entities.RoleEntity;
+import org.keycloak.models.jpa.entities.SocialLinkEntity;
+import org.keycloak.models.jpa.entities.UserEntity;
+import org.keycloak.models.jpa.entities.UserRoleMappingEntity;
+import org.keycloak.models.utils.KeycloakModelUtils;
+
+import javax.persistence.EntityManager;
+import javax.persistence.TypedQuery;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
+ * @version $Revision: 1 $
+ */
+public class JpaUserProvider implements UserProvider {
+    private final KeycloakSession session;
+    protected EntityManager em;
+
+    public JpaUserProvider(KeycloakSession session, EntityManager em) {
+        this.session = session;
+        this.em = em;
+        this.em = PersistenceExceptionConverter.create(em);
+    }
+
+    @Override
+    public UserModel addUser(RealmModel realm, String id, String username, boolean addDefaultRoles) {
+        if (id == null) {
+            id = KeycloakModelUtils.generateId();
+        }
+
+        UserEntity entity = new UserEntity();
+        entity.setId(id);
+        entity.setUsername(username);
+        RealmEntity realmEntity = em.getReference(RealmEntity.class, realm.getId());
+        entity.setRealm(realmEntity);
+        em.persist(entity);
+        em.flush();
+        UserModel userModel = new UserAdapter(realm, em, entity);
+
+        if (addDefaultRoles) {
+            for (String r : realm.getDefaultRoles()) {
+                userModel.grantRole(realm.getRole(r));
+            }
+
+            for (ApplicationModel application : realm.getApplications()) {
+                for (String r : application.getDefaultRoles()) {
+                    userModel.grantRole(application.getRole(r));
+                }
+            }
+        }
+
+        return userModel;
+    }
+
+    @Override
+    public UserModel addUser(RealmModel realm, String username) {
+        return addUser(realm, KeycloakModelUtils.generateId(), username, true);
+    }
+
+    @Override
+    public boolean removeUser(RealmModel realm, String name) {
+        TypedQuery<UserEntity> query = em.createNamedQuery("getRealmUserByUsername", UserEntity.class);
+        query.setParameter("username", name);
+        RealmEntity realmEntity = em.getReference(RealmEntity.class, realm.getId());
+        query.setParameter("realm", realmEntity);
+        List<UserEntity> results = query.getResultList();
+        if (results.size() == 0) return false;
+        removeUser(results.get(0));
+        return true;
+    }
+
+    private void removeUser(UserEntity user) {
+        em.createQuery("delete from " + UserRoleMappingEntity.class.getSimpleName() + " where user = :user").setParameter("user", user).executeUpdate();
+        em.createQuery("delete from " + SocialLinkEntity.class.getSimpleName() + " where user = :user").setParameter("user", user).executeUpdate();
+        if (user.getAuthenticationLink() != null) {
+            em.remove(user.getAuthenticationLink());
+        }
+        em.remove(user);
+    }
+
+    @Override
+    public void preRemove(RealmModel realm) {
+        TypedQuery<UserEntity> query = em.createQuery("select u from UserEntity u where u.realm = :realm", UserEntity.class);
+        RealmEntity realmEntity = em.getReference(RealmEntity.class, realm.getId());
+        query.setParameter("realm", realmEntity);
+        for (UserEntity u : query.getResultList()) {
+            em.remove(u);
+        }
+    }
+
+    @Override
+    public void preRemove(RoleModel role) {
+        RoleEntity roleEntity = em.getReference(RoleEntity.class, role.getId());
+        em.createQuery("delete from " + UserRoleMappingEntity.class.getSimpleName() + " where role = :role").setParameter("role", roleEntity).executeUpdate();
+    }
+
+
+    @Override
+    public KeycloakTransaction getTransaction() {
+        return new JpaKeycloakTransaction(em);
+    }
+
+    @Override
+    public UserModel getUserById(String id, RealmModel realmModel) {
+        TypedQuery<UserEntity> query = em.createNamedQuery("getRealmUserById", UserEntity.class);
+        query.setParameter("id", id);
+        RealmEntity realm = em.getReference(RealmEntity.class, realmModel.getId());
+        query.setParameter("realm", realm);
+        List<UserEntity> entities = query.getResultList();
+        if (entities.size() == 0) return null;
+        return new UserAdapter(realmModel, em, entities.get(0));
+    }
+
+    @Override
+    public UserModel getUserByUsername(String username, RealmModel realmModel) {
+        TypedQuery<UserEntity> query = em.createNamedQuery("getRealmUserByUsername", UserEntity.class);
+        query.setParameter("username", username);
+        RealmEntity realm = em.getReference(RealmEntity.class, realmModel.getId());
+        query.setParameter("realm", realm);
+        List<UserEntity> results = query.getResultList();
+        if (results.size() == 0) return null;
+        return new UserAdapter(realmModel, em, results.get(0));
+    }
+
+    @Override
+    public UserModel getUserByEmail(String email, RealmModel realmModel) {
+        TypedQuery<UserEntity> query = em.createNamedQuery("getRealmUserByEmail", UserEntity.class);
+        query.setParameter("email", email);
+        RealmEntity realm = em.getReference(RealmEntity.class, realmModel.getId());
+        query.setParameter("realm", realm);
+        List<UserEntity> results = query.getResultList();
+        return results.isEmpty() ? null : new UserAdapter(realmModel, em, results.get(0));
+    }
+
+     @Override
+    public void close() {
+        if (em.getTransaction().isActive()) em.getTransaction().rollback();
+        if (em.isOpen()) em.close();
+    }
+
+    @Override
+    public UserModel getUserBySocialLink(SocialLinkModel socialLink, RealmModel realm) {
+        TypedQuery<UserEntity> query = em.createNamedQuery("findUserByLinkAndRealm", UserEntity.class);
+        RealmEntity realmEntity = em.getReference(RealmEntity.class, realm.getId());
+        query.setParameter("realm", realmEntity);
+        query.setParameter("socialProvider", socialLink.getSocialProvider());
+        query.setParameter("socialUserId", socialLink.getSocialUserId());
+        List<UserEntity> results = query.getResultList();
+        if (results.isEmpty()) {
+            return null;
+        } else if (results.size() > 1) {
+            throw new IllegalStateException("More results found for socialProvider=" + socialLink.getSocialProvider() +
+                    ", socialUserId=" + socialLink.getSocialUserId() + ", results=" + results);
+        } else {
+            UserEntity user = results.get(0);
+            return new UserAdapter(realm, em, user);
+        }
+    }
+
+    @Override
+    public List<UserModel> getUsers(RealmModel realm) {
+        TypedQuery<UserEntity> query = em.createQuery("select u from UserEntity u where u.realm = :realm", UserEntity.class);
+        RealmEntity realmEntity = em.getReference(RealmEntity.class, realm.getId());
+        query.setParameter("realm", realmEntity);
+        List<UserEntity> results = query.getResultList();
+        List<UserModel> users = new ArrayList<UserModel>();
+        for (UserEntity entity : results) users.add(new UserAdapter(realm, em, entity));
+        return users;
+    }
+
+    @Override
+    public List<UserModel> searchForUser(String search, RealmModel realm) {
+        TypedQuery<UserEntity> query = em.createQuery("select u from UserEntity u where u.realm = :realm and ( lower(u.username) like :search or lower(concat(u.firstName, ' ', u.lastName)) like :search or u.email like :search )", UserEntity.class);
+        RealmEntity realmEntity = em.getReference(RealmEntity.class, realm.getId());
+        query.setParameter("realm", realmEntity);
+        query.setParameter("search", "%" + search.toLowerCase() + "%");
+        List<UserEntity> results = query.getResultList();
+        List<UserModel> users = new ArrayList<UserModel>();
+        for (UserEntity entity : results) users.add(new UserAdapter(realm, em, entity));
+        return users;
+    }
+
+    @Override
+    public List<UserModel> searchForUserByAttributes(Map<String, String> attributes, RealmModel realm) {
+        StringBuilder builder = new StringBuilder("select u from UserEntity u");
+        boolean first = true;
+        for (Map.Entry<String, String> entry : attributes.entrySet()) {
+            String attribute = null;
+            if (entry.getKey().equals(UserModel.LOGIN_NAME)) {
+                attribute = "lower(username)";
+            } else if (entry.getKey().equalsIgnoreCase(UserModel.FIRST_NAME)) {
+                attribute = "lower(firstName)";
+            } else if (entry.getKey().equalsIgnoreCase(UserModel.LAST_NAME)) {
+                attribute = "lower(lastName)";
+            } else if (entry.getKey().equalsIgnoreCase(UserModel.EMAIL)) {
+                attribute = "lower(email)";
+            }
+            if (attribute == null) continue;
+            if (first) {
+                first = false;
+                builder.append(" where realm = :realm");
+            } else {
+                builder.append(" and ");
+            }
+            builder.append(attribute).append(" like '%").append(entry.getValue().toLowerCase()).append("%'");
+        }
+        String q = builder.toString();
+        TypedQuery<UserEntity> query = em.createQuery(q, UserEntity.class);
+        RealmEntity realmEntity = em.getReference(RealmEntity.class, realm.getId());
+        query.setParameter("realm", realmEntity);
+        List<UserEntity> results = query.getResultList();
+        List<UserModel> users = new ArrayList<UserModel>();
+        for (UserEntity entity : results) users.add(new UserAdapter(realm, em, entity));
+        return users;
+    }
+
+    private SocialLinkEntity findSocialLink(UserModel user, String socialProvider) {
+        TypedQuery<SocialLinkEntity> query = em.createNamedQuery("findSocialLinkByUserAndProvider", SocialLinkEntity.class);
+        UserEntity userEntity = em.getReference(UserEntity.class, user.getId());
+        query.setParameter("user", userEntity);
+        query.setParameter("socialProvider", socialProvider);
+        List<SocialLinkEntity> results = query.getResultList();
+        return results.size() > 0 ? results.get(0) : null;
+    }
+
+
+    @Override
+    public Set<SocialLinkModel> getSocialLinks(UserModel user, RealmModel realm) {
+        TypedQuery<SocialLinkEntity> query = em.createNamedQuery("findSocialLinkByUser", SocialLinkEntity.class);
+        UserEntity userEntity = em.getReference(UserEntity.class, user.getId());
+        query.setParameter("user", userEntity);
+        List<SocialLinkEntity> results = query.getResultList();
+        Set<SocialLinkModel> set = new HashSet<SocialLinkModel>();
+        for (SocialLinkEntity entity : results) {
+            set.add(new SocialLinkModel(entity.getSocialProvider(), entity.getSocialUserId(), entity.getSocialUsername()));
+        }
+        return set;
+    }
+
+    @Override
+    public SocialLinkModel getSocialLink(UserModel user, String socialProvider, RealmModel realm) {
+        SocialLinkEntity entity = findSocialLink(user, socialProvider);
+        return (entity != null) ? new SocialLinkModel(entity.getSocialProvider(), entity.getSocialUserId(), entity.getSocialUsername()) : null;
+    }
+
+}
diff --git a/model/mongo/pom.xml b/model/mongo/pom.xml
index 31d513d..e9cb48f 100755
--- a/model/mongo/pom.xml
+++ b/model/mongo/pom.xml
@@ -53,7 +53,13 @@
             <artifactId>mongo-java-driver</artifactId>
             <scope>provided</scope>
         </dependency>
+        <dependency>
+            <groupId>org.keycloak</groupId>
+            <artifactId>keycloak-model-sessions-mem</artifactId>
+            <version>${project.version}</version>
+            <scope>test</scope>
 
+        </dependency>
         <dependency>
             <groupId>org.keycloak</groupId>
             <artifactId>keycloak-model-tests</artifactId>
diff --git a/services/src/main/java/org/keycloak/services/DefaultKeycloakSession.java b/services/src/main/java/org/keycloak/services/DefaultKeycloakSession.java
index e26cd74..837f45a 100755
--- a/services/src/main/java/org/keycloak/services/DefaultKeycloakSession.java
+++ b/services/src/main/java/org/keycloak/services/DefaultKeycloakSession.java
@@ -3,13 +3,16 @@ package org.keycloak.services;
 import org.keycloak.models.KeycloakSession;
 import org.keycloak.models.KeycloakTransaction;
 import org.keycloak.models.ModelProvider;
+import org.keycloak.models.UserProvider;
 import org.keycloak.models.UserSessionProvider;
 import org.keycloak.models.cache.CacheModelProvider;
 import org.keycloak.provider.Provider;
 import org.keycloak.provider.ProviderFactory;
 
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
@@ -20,21 +23,82 @@ public class DefaultKeycloakSession implements KeycloakSession {
 
     private final DefaultKeycloakSessionFactory factory;
     private final Map<Integer, Provider> providers = new HashMap<Integer, Provider>();
-    private final ModelProvider model;
+    private ModelProvider model;
+    private UserSessionProvider sessionProvider;
+    private final List<KeycloakTransaction> managedTransactions = new ArrayList<KeycloakTransaction>();
+
+    private final KeycloakTransaction transaction = new KeycloakTransaction() {
+        protected boolean active;
+        protected boolean rollback;
+
+        @Override
+        public void begin() {
+            active = true;
+        }
+
+        @Override
+        public void commit() {
+            if (!active) throw new IllegalStateException("Transaction not active");
+            try {
+                if (rollback) {
+                    rollback();
+                    throw new RuntimeException("Transaction markedfor rollback, so rollback happend");
+                }
+                for (KeycloakTransaction transaction : managedTransactions) {
+                    transaction.commit();
+                }
+            } finally {
+                active = false;
+            }
+
+        }
+
+        @Override
+        public void rollback() {
+            if (!active) throw new IllegalStateException("Transaction not active");
+            try {
+                for (KeycloakTransaction transaction : managedTransactions) {
+                    transaction.rollback();
+                }
+            } finally {
+                active = false;
+            }
+        }
+
+        @Override
+        public void setRollbackOnly() {
+            if (!active) throw new IllegalStateException("Transaction not active");
+            rollback = true;
+        }
+
+        @Override
+        public boolean getRollbackOnly() {
+            if (!active) throw new IllegalStateException("Transaction not active");
+            return rollback;
+        }
+
+        @Override
+        public boolean isActive() {
+            return active;
+        }
+    };
 
     public DefaultKeycloakSession(DefaultKeycloakSessionFactory factory) {
         this.factory = factory;
+    }
 
+    private ModelProvider getModelProvider() {
         if (factory.getDefaultProvider(CacheModelProvider.class) != null) {
-            model = getProvider(CacheModelProvider.class);
+            return getProvider(CacheModelProvider.class);
         } else {
-            model = getProvider(ModelProvider.class);
+            return getProvider(ModelProvider.class);
         }
     }
 
+
     @Override
     public KeycloakTransaction getTransaction() {
-        return model.getTransaction();
+        return transaction;
     }
 
     public <T extends Provider> T getProvider(Class<T> clazz) {
@@ -76,13 +140,26 @@ public class DefaultKeycloakSession implements KeycloakSession {
         return providers;
     }
 
+    @Override
     public ModelProvider model() {
+        if (!transaction.isActive()) throw new IllegalStateException("Transaction is not active");
+        if (model == null) {
+            model = getModelProvider();
+            model.getTransaction().begin();
+            managedTransactions.add(model.getTransaction());
+        }
         return model;
     }
 
     @Override
     public UserSessionProvider sessions() {
-        return getProvider(UserSessionProvider.class);
+        if (!transaction.isActive()) throw new IllegalStateException("Transaction is not active");
+        if (sessionProvider == null) {
+            sessionProvider = getProvider(UserSessionProvider.class);
+            sessionProvider.getTransaction().begin();
+            managedTransactions.add(sessionProvider.getTransaction());
+        }
+        return sessionProvider;
     }
 
     public void close() {
diff --git a/services/src/main/java/org/keycloak/services/resources/TokenService.java b/services/src/main/java/org/keycloak/services/resources/TokenService.java
index f2691a7..2a37360 100755
--- a/services/src/main/java/org/keycloak/services/resources/TokenService.java
+++ b/services/src/main/java/org/keycloak/services/resources/TokenService.java
@@ -99,8 +99,6 @@ public class TokenService {
     @Context
     protected KeycloakSession session;
     @Context
-    protected KeycloakTransaction transaction;
-    @Context
     protected ClientConnection clientConnection;
 
     /*