killbill-aplcache

shiro: fix wiring of EhCache to all Shiro Realms Make sure

11/10/2015 9:31:31 PM

Details

diff --git a/profiles/killbill/src/main/java/org/killbill/billing/server/modules/KillbillJdbcTenantRealmProvider.java b/profiles/killbill/src/main/java/org/killbill/billing/server/modules/KillbillJdbcTenantRealmProvider.java
new file mode 100644
index 0000000..2fb4840
--- /dev/null
+++ b/profiles/killbill/src/main/java/org/killbill/billing/server/modules/KillbillJdbcTenantRealmProvider.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2014-2015 Groupon, Inc
+ * Copyright 2014-2015 The Billing Project, LLC
+ *
+ * The Billing Project licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License.  You may obtain a copy of the License at:
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.killbill.billing.server.modules;
+
+import javax.inject.Named;
+import javax.sql.DataSource;
+
+import org.apache.shiro.cache.CacheManager;
+import org.killbill.billing.server.security.KillbillJdbcTenantRealm;
+import org.killbill.billing.util.config.SecurityConfig;
+import org.killbill.billing.util.glue.ShiroEhCacheInstrumentor;
+
+import com.google.inject.Inject;
+import com.google.inject.Provider;
+
+public class KillbillJdbcTenantRealmProvider implements Provider<KillbillJdbcTenantRealm> {
+
+    private final SecurityConfig securityConfig;
+    private final CacheManager cacheManager;
+    private final ShiroEhCacheInstrumentor ehCacheInstrumentor;
+    private final DataSource dataSource;
+
+    @Inject
+    public KillbillJdbcTenantRealmProvider(final SecurityConfig securityConfig, final CacheManager cacheManager, final ShiroEhCacheInstrumentor ehCacheInstrumentor, @Named(KillbillPlatformModule.SHIRO_DATA_SOURCE_ID_NAMED) final DataSource dataSource) {
+        this.securityConfig = securityConfig;
+        this.cacheManager = cacheManager;
+        this.ehCacheInstrumentor = ehCacheInstrumentor;
+        this.dataSource = dataSource;
+    }
+
+    @Override
+    public KillbillJdbcTenantRealm get() {
+        final KillbillJdbcTenantRealm killbillJdbcTenantRealm = new KillbillJdbcTenantRealm(dataSource, securityConfig);
+
+        // Set the cache manager
+        // Note: the DefaultWebSecurityManager used for RBAC will have all of its realms (set in KillBillShiroWebModule)
+        // automatically configured with the EhCache manager (see EhCacheManagerProvider)
+        killbillJdbcTenantRealm.setCacheManager(cacheManager);
+
+        // Instrument the cache
+        ehCacheInstrumentor.instrument(killbillJdbcTenantRealm);
+
+        return killbillJdbcTenantRealm;
+    }
+}
diff --git a/profiles/killbill/src/main/java/org/killbill/billing/server/modules/KillBillShiroWebModule.java b/profiles/killbill/src/main/java/org/killbill/billing/server/modules/KillBillShiroWebModule.java
index 84f15f0..721c7a7 100644
--- a/profiles/killbill/src/main/java/org/killbill/billing/server/modules/KillBillShiroWebModule.java
+++ b/profiles/killbill/src/main/java/org/killbill/billing/server/modules/KillBillShiroWebModule.java
@@ -27,7 +27,9 @@ import org.apache.shiro.authc.pam.ModularRealmAuthenticator;
 import org.apache.shiro.authc.pam.ModularRealmAuthenticatorWith540;
 import org.apache.shiro.cache.CacheManager;
 import org.apache.shiro.guice.web.ShiroWebModuleWith435;
+import org.apache.shiro.realm.Realm;
 import org.apache.shiro.session.mgt.SessionManager;
+import org.apache.shiro.session.mgt.eis.CachingSessionDAO;
 import org.apache.shiro.web.filter.authc.BasicHttpAuthenticationFilter;
 import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
 import org.apache.shiro.web.mgt.WebSecurityManager;
@@ -35,18 +37,22 @@ import org.apache.shiro.web.session.mgt.DefaultWebSessionManager;
 import org.apache.shiro.web.util.WebUtils;
 import org.killbill.billing.jaxrs.resources.JaxrsResource;
 import org.killbill.billing.server.security.FirstSuccessfulStrategyWith540;
+import org.killbill.billing.server.security.KillbillJdbcTenantRealm;
 import org.killbill.billing.util.config.RbacConfig;
 import org.killbill.billing.util.glue.EhCacheManagerProvider;
 import org.killbill.billing.util.glue.IniRealmProvider;
 import org.killbill.billing.util.glue.JDBCSessionDaoProvider;
 import org.killbill.billing.util.glue.KillBillShiroModule;
+import org.killbill.billing.util.glue.ShiroEhCacheInstrumentor;
 import org.killbill.billing.util.security.shiro.dao.JDBCSessionDao;
 import org.killbill.billing.util.security.shiro.realm.KillBillJdbcRealm;
 import org.killbill.billing.util.security.shiro.realm.KillBillJndiLdapRealm;
 import org.skife.config.ConfigSource;
 import org.skife.config.ConfigurationObjectFactory;
 
+import com.google.inject.Inject;
 import com.google.inject.Key;
+import com.google.inject.Provider;
 import com.google.inject.TypeLiteral;
 import com.google.inject.binder.AnnotatedBindingBuilder;
 import com.google.inject.matcher.AbstractMatcher;
@@ -67,47 +73,50 @@ public class KillBillShiroWebModule extends ShiroWebModuleWith435 {
     }
 
     @Override
+    public void configure() {
+        super.configure();
+
+        bind(ShiroEhCacheInstrumentor.class).asEagerSingleton();
+    }
+
+    @Override
     protected void configureShiroWeb() {
+        // Magic provider to configure the cache manager
+        bind(CacheManager.class).toProvider(EhCacheManagerProvider.class).asEagerSingleton();
+
+        configureShiroForRBAC();
+
+        configureShiroForTenants();
+    }
+
+    private void configureShiroForRBAC() {
         final RbacConfig config = new ConfigurationObjectFactory(configSource).build(RbacConfig.class);
         bind(RbacConfig.class).toInstance(config);
 
+        // Note: order matters (the first successful match will win, see below)
         bindRealm().toProvider(IniRealmProvider.class).asEagerSingleton();
-
         bindRealm().to(KillBillJdbcRealm.class).asEagerSingleton();
-
         if (KillBillShiroModule.isLDAPEnabled()) {
             bindRealm().to(KillBillJndiLdapRealm.class).asEagerSingleton();
         }
 
-        // Magic provider to configure the cache manager
-        bind(CacheManager.class).toProvider(EhCacheManagerProvider.class).asEagerSingleton();
-
-        if (KillBillShiroModule.isRBACEnabled()) {
-            addFilterChain(JaxrsResource.PREFIX + "/**", Key.get(CorsBasicHttpAuthenticationFilter.class));
-        }
-
         bindListener(new AbstractMatcher<TypeLiteral<?>>() {
                          @Override
                          public boolean matches(final TypeLiteral<?> o) {
                              return Matchers.subclassesOf(WebSecurityManager.class).matches(o.getRawType());
                          }
                      },
-                     new TypeListener() {
-                         @Override
-                         public <I> void hear(final TypeLiteral<I> typeLiteral, final TypeEncounter<I> typeEncounter) {
-                             typeEncounter.register(new InjectionListener<I>() {
-                                 @Override
-                                 public void afterInjection(final Object o) {
-                                     final DefaultWebSecurityManager webSecurityManager = (DefaultWebSecurityManager) o;
-                                     if (webSecurityManager.getAuthenticator() instanceof ModularRealmAuthenticator) {
-                                         final ModularRealmAuthenticator authenticator = (ModularRealmAuthenticator) webSecurityManager.getAuthenticator();
-                                         authenticator.setAuthenticationStrategy(new FirstSuccessfulStrategyWith540());
-                                         webSecurityManager.setAuthenticator(new ModularRealmAuthenticatorWith540(authenticator));
-                                     }
-                                 }
-                             });
-                         }
-                     });
+                     new DefaultWebSecurityManagerTypeListener(getProvider(ShiroEhCacheInstrumentor.class)));
+
+        if (KillBillShiroModule.isRBACEnabled()) {
+            addFilterChain(JaxrsResource.PREFIX + "/**", Key.get(CorsBasicHttpAuthenticationFilter.class));
+        }
+    }
+
+    private void configureShiroForTenants() {
+        // Realm binding for the tenants (see TenantFilter)
+        bind(KillbillJdbcTenantRealm.class).toProvider(KillbillJdbcTenantRealmProvider.class).asEagerSingleton();
+        expose(KillbillJdbcTenantRealm.class);
     }
 
     @Override
@@ -131,4 +140,36 @@ public class KillBillShiroWebModule extends ShiroWebModuleWith435 {
             return "OPTIONS".equalsIgnoreCase(httpMethod) || super.isAccessAllowed(request, response, mappedValue);
         }
     }
+
+    private static final class DefaultWebSecurityManagerTypeListener implements TypeListener {
+
+        private final Provider<ShiroEhCacheInstrumentor> instrumentorProvider;
+
+        @Inject
+        public DefaultWebSecurityManagerTypeListener(final Provider<ShiroEhCacheInstrumentor> instrumentorProvider) {
+            this.instrumentorProvider = instrumentorProvider;
+        }
+
+        @Override
+        public <I> void hear(final TypeLiteral<I> typeLiteral, final TypeEncounter<I> typeEncounter) {
+            typeEncounter.register(new InjectionListener<I>() {
+                @Override
+                public void afterInjection(final Object o) {
+                    final ShiroEhCacheInstrumentor ehCacheInstrumentor = instrumentorProvider.get();
+                    ehCacheInstrumentor.instrument(CachingSessionDAO.ACTIVE_SESSION_CACHE_NAME);
+
+                    final DefaultWebSecurityManager webSecurityManager = (DefaultWebSecurityManager) o;
+                    if (webSecurityManager.getAuthenticator() instanceof ModularRealmAuthenticator) {
+                        final ModularRealmAuthenticator authenticator = (ModularRealmAuthenticator) webSecurityManager.getAuthenticator();
+                        authenticator.setAuthenticationStrategy(new FirstSuccessfulStrategyWith540());
+                        webSecurityManager.setAuthenticator(new ModularRealmAuthenticatorWith540(authenticator));
+
+                        for (final Realm realm : webSecurityManager.getRealms()) {
+                            ehCacheInstrumentor.instrument(realm);
+                        }
+                    }
+                }
+            });
+        }
+    }
 }
diff --git a/profiles/killbill/src/main/java/org/killbill/billing/server/security/TenantFilter.java b/profiles/killbill/src/main/java/org/killbill/billing/server/security/TenantFilter.java
index 82f63d1..a0af60a 100644
--- a/profiles/killbill/src/main/java/org/killbill/billing/server/security/TenantFilter.java
+++ b/profiles/killbill/src/main/java/org/killbill/billing/server/security/TenantFilter.java
@@ -21,7 +21,6 @@ package org.killbill.billing.server.security;
 import java.io.IOException;
 
 import javax.inject.Inject;
-import javax.inject.Named;
 import javax.inject.Singleton;
 import javax.servlet.Filter;
 import javax.servlet.FilterChain;
@@ -31,7 +30,6 @@ import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
-import javax.sql.DataSource;
 
 import org.apache.shiro.authc.AuthenticationException;
 import org.apache.shiro.authc.AuthenticationToken;
@@ -40,11 +38,9 @@ import org.apache.shiro.authc.pam.ModularRealmAuthenticator;
 import org.apache.shiro.realm.Realm;
 import org.killbill.billing.jaxrs.resources.JaxrsResource;
 import org.killbill.billing.server.listeners.KillbillGuiceListener;
-import org.killbill.billing.server.modules.KillbillPlatformModule;
 import org.killbill.billing.tenant.api.Tenant;
 import org.killbill.billing.tenant.api.TenantApiException;
 import org.killbill.billing.tenant.api.TenantUserApi;
-import org.killbill.billing.util.config.SecurityConfig;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -61,17 +57,12 @@ public class TenantFilter implements Filter {
     @Inject
     protected TenantUserApi tenantUserApi;
     @Inject
-    protected SecurityConfig securityConfig;
-
-    @Inject
-    @Named(KillbillPlatformModule.SHIRO_DATA_SOURCE_ID_NAMED)
-    protected DataSource dataSource;
+    protected KillbillJdbcTenantRealm killbillJdbcTenantRealm;
 
     private ModularRealmAuthenticator modularRealmAuthenticator;
 
     @Override
     public void init(final FilterConfig filterConfig) throws ServletException {
-        final Realm killbillJdbcTenantRealm = new KillbillJdbcTenantRealm(dataSource, securityConfig);
         // We use Shiro to verify the api credentials - but the Shiro Subject is only used for RBAC
         modularRealmAuthenticator = new ModularRealmAuthenticator();
         modularRealmAuthenticator.setRealms(ImmutableList.<Realm>of(killbillJdbcTenantRealm));
diff --git a/profiles/killbill/src/test/java/org/killbill/billing/jaxrs/TestJaxrsBase.java b/profiles/killbill/src/test/java/org/killbill/billing/jaxrs/TestJaxrsBase.java
index fa48420..5044598 100644
--- a/profiles/killbill/src/test/java/org/killbill/billing/jaxrs/TestJaxrsBase.java
+++ b/profiles/killbill/src/test/java/org/killbill/billing/jaxrs/TestJaxrsBase.java
@@ -22,6 +22,7 @@ import java.util.EventListener;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 
 import javax.inject.Inject;
 import javax.inject.Named;
@@ -196,6 +197,9 @@ public class TestJaxrsBase extends KillbillClient {
         clock.resetDeltaFromReality();
         clock.setDay(new LocalDate(2012, 8, 25));
 
+        // Make sure to re-generate the api key and secret (could be cached by Shiro)
+        DEFAULT_API_KEY = UUID.randomUUID().toString();
+        DEFAULT_API_SECRET = UUID.randomUUID().toString();
         loginTenant(DEFAULT_API_KEY, DEFAULT_API_SECRET);
 
         // Recreate the tenant (tables have been cleaned-up)
diff --git a/util/src/main/java/org/killbill/billing/util/glue/CacheModule.java b/util/src/main/java/org/killbill/billing/util/glue/CacheModule.java
index 5b94fb4..f01e061 100644
--- a/util/src/main/java/org/killbill/billing/util/glue/CacheModule.java
+++ b/util/src/main/java/org/killbill/billing/util/glue/CacheModule.java
@@ -1,7 +1,7 @@
 /*
  * Copyright 2010-2013 Ning, Inc.
- * Copyright 2014 Groupon, Inc
- * Copyright 2014 The Billing Project, LLC
+ * Copyright 2014-2015 Groupon, Inc
+ * Copyright 2014-2015 The Billing Project, LLC
  *
  * The Billing Project licenses this file to you under the Apache License, version 2.0
  * (the "License"); you may not use this file except in compliance with the
diff --git a/util/src/main/java/org/killbill/billing/util/glue/EhCacheManagerProvider.java b/util/src/main/java/org/killbill/billing/util/glue/EhCacheManagerProvider.java
index d8b92a7..f6c6e21 100644
--- a/util/src/main/java/org/killbill/billing/util/glue/EhCacheManagerProvider.java
+++ b/util/src/main/java/org/killbill/billing/util/glue/EhCacheManagerProvider.java
@@ -24,27 +24,16 @@ import javax.inject.Provider;
 import org.apache.shiro.cache.ehcache.EhCacheManager;
 import org.apache.shiro.mgt.DefaultSecurityManager;
 import org.apache.shiro.mgt.SecurityManager;
-import org.apache.shiro.session.mgt.eis.CachingSessionDAO;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
-import com.codahale.metrics.MetricRegistry;
-import com.codahale.metrics.ehcache.InstrumentedEhcache;
-import net.sf.ehcache.CacheException;
 import net.sf.ehcache.CacheManager;
-import net.sf.ehcache.Ehcache;
 
 public class EhCacheManagerProvider implements Provider<EhCacheManager> {
 
-    private static final Logger logger = LoggerFactory.getLogger(EhCacheManagerProvider.class);
-
-    private final MetricRegistry metricRegistry;
     private final SecurityManager securityManager;
     private final CacheManager ehCacheCacheManager;
 
     @Inject
-    public EhCacheManagerProvider(final MetricRegistry metricRegistry, final SecurityManager securityManager, final CacheManager ehCacheCacheManager) {
-        this.metricRegistry = metricRegistry;
+    public EhCacheManagerProvider(final SecurityManager securityManager, final CacheManager ehCacheCacheManager) {
         this.securityManager = securityManager;
         this.ehCacheCacheManager = ehCacheCacheManager;
     }
@@ -55,21 +44,8 @@ public class EhCacheManagerProvider implements Provider<EhCacheManager> {
         // Same EhCache manager instance as the rest of the system
         shiroEhCacheManager.setCacheManager(ehCacheCacheManager);
 
-        // It looks like Shiro's cache manager is not thread safe. Concurrent requests on startup
-        // can throw org.apache.shiro.cache.CacheException: net.sf.ehcache.ObjectExistsException: Cache shiro-activeSessionCache already exists
-        // As a workaround, create the cache manually here
-        shiroEhCacheManager.getCache(CachingSessionDAO.ACTIVE_SESSION_CACHE_NAME);
-
-        // Instrument the cache
-        final Ehcache shiroActiveSessionEhcache = ehCacheCacheManager.getEhcache(CachingSessionDAO.ACTIVE_SESSION_CACHE_NAME);
-        final Ehcache decoratedCache = InstrumentedEhcache.instrument(metricRegistry, shiroActiveSessionEhcache);
-        try {
-            ehCacheCacheManager.replaceCacheWithDecoratedCache(shiroActiveSessionEhcache, decoratedCache);
-        } catch (final CacheException e) {
-            logger.warn("Unable to instrument cache {}: {}", shiroActiveSessionEhcache.getName(), e.getMessage());
-        }
-
         if (securityManager instanceof DefaultSecurityManager) {
+            // For RBAC only (see also KillbillJdbcTenantRealm)
             ((DefaultSecurityManager) securityManager).setCacheManager(shiroEhCacheManager);
         }
 
diff --git a/util/src/main/java/org/killbill/billing/util/glue/ShiroEhCacheInstrumentor.java b/util/src/main/java/org/killbill/billing/util/glue/ShiroEhCacheInstrumentor.java
new file mode 100644
index 0000000..6a0bb91
--- /dev/null
+++ b/util/src/main/java/org/killbill/billing/util/glue/ShiroEhCacheInstrumentor.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright 2014-2015 Groupon, Inc
+ * Copyright 2014-2015 The Billing Project, LLC
+ *
+ * The Billing Project licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License.  You may obtain a copy of the License at:
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.killbill.billing.util.glue;
+
+import javax.inject.Inject;
+
+import org.apache.shiro.cache.CacheManager;
+import org.apache.shiro.realm.AuthenticatingRealm;
+import org.apache.shiro.realm.AuthorizingRealm;
+import org.apache.shiro.realm.Realm;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.codahale.metrics.MetricRegistry;
+import com.codahale.metrics.ehcache.InstrumentedEhcache;
+import net.sf.ehcache.CacheException;
+import net.sf.ehcache.Ehcache;
+
+public class ShiroEhCacheInstrumentor {
+
+    private static final Logger logger = LoggerFactory.getLogger(ShiroEhCacheInstrumentor.class);
+
+    private final MetricRegistry metricRegistry;
+    private final CacheManager shiroEhCacheManager;
+    private final net.sf.ehcache.CacheManager ehCacheCacheManager;
+
+    @Inject
+    public ShiroEhCacheInstrumentor(final MetricRegistry metricRegistry, final CacheManager shiroEhCacheManager, final net.sf.ehcache.CacheManager ehCacheCacheManager) {
+        this.metricRegistry = metricRegistry;
+        this.shiroEhCacheManager = shiroEhCacheManager;
+        this.ehCacheCacheManager = ehCacheCacheManager;
+    }
+
+    public void instrument(final Realm realm) {
+        if (realm instanceof AuthorizingRealm) {
+            instrument((AuthorizingRealm) realm);
+        } else if (realm instanceof AuthenticatingRealm) {
+            instrument((AuthenticatingRealm) realm);
+        }
+    }
+
+    public void instrument(final AuthorizingRealm realm) {
+        instrument(realm.getAuthenticationCacheName());
+        instrument(realm.getAuthorizationCacheName());
+    }
+
+    public void instrument(final AuthenticatingRealm realm) {
+        instrument(realm.getAuthenticationCacheName());
+    }
+
+    public void instrument(final String cacheName) {
+        // Initialize the cache, if it doesn't exist yet
+        // Note: Shiro's cache manager is not thread safe. Concurrent requests on startup
+        // can throw org.apache.shiro.cache.CacheException: net.sf.ehcache.ObjectExistsException: Cache shiro-activeSessionCache already exists
+        shiroEhCacheManager.getCache(cacheName);
+
+        final Ehcache shiroEhcache = ehCacheCacheManager.getEhcache(cacheName);
+        final Ehcache decoratedCache = InstrumentedEhcache.instrument(metricRegistry, shiroEhcache);
+        try {
+            ehCacheCacheManager.replaceCacheWithDecoratedCache(shiroEhcache, decoratedCache);
+        } catch (final CacheException e) {
+            logger.warn("Unable to instrument cache {}: {}", shiroEhcache.getName(), e.getMessage());
+        }
+    }
+}