killbill-uncached

payment: properly unbind Shiro context in plugin threads Shiro

11/5/2015 4:26:42 PM

Details

payment/pom.xml 4(+4 -0)

diff --git a/payment/pom.xml b/payment/pom.xml
index acad52a..f99ab53 100644
--- a/payment/pom.xml
+++ b/payment/pom.xml
@@ -69,6 +69,10 @@
             <artifactId>joda-time</artifactId>
         </dependency>
         <dependency>
+            <groupId>org.apache.shiro</groupId>
+            <artifactId>shiro-core</artifactId>
+        </dependency>
+        <dependency>
             <groupId>org.kill-bill.billing</groupId>
             <artifactId>killbill-account</artifactId>
             <scope>test</scope>
diff --git a/payment/src/main/java/org/killbill/billing/payment/dispatcher/CallableWithRequestData.java b/payment/src/main/java/org/killbill/billing/payment/dispatcher/CallableWithRequestData.java
index e78d8ad..b5be8c5 100644
--- a/payment/src/main/java/org/killbill/billing/payment/dispatcher/CallableWithRequestData.java
+++ b/payment/src/main/java/org/killbill/billing/payment/dispatcher/CallableWithRequestData.java
@@ -19,16 +19,23 @@ package org.killbill.billing.payment.dispatcher;
 
 import java.util.concurrent.Callable;
 
+import org.apache.shiro.mgt.SecurityManager;
+import org.apache.shiro.subject.Subject;
+import org.apache.shiro.util.ThreadContext;
 import org.killbill.commons.request.Request;
 import org.killbill.commons.request.RequestData;
 
 public class CallableWithRequestData<T> implements Callable<T> {
 
     private final RequestData requestData;
+    private final SecurityManager securityManager;
+    private final Subject subject;
     private final Callable<T> delegate;
 
-    public CallableWithRequestData(final RequestData requestData, final Callable<T> delegate) {
+    public CallableWithRequestData(final RequestData requestData, final SecurityManager securityManager, final Subject subject, final Callable<T> delegate) {
         this.requestData = requestData;
+        this.securityManager = securityManager;
+        this.subject = subject;
         this.delegate = delegate;
     }
 
@@ -36,9 +43,13 @@ public class CallableWithRequestData<T> implements Callable<T> {
     public T call() throws Exception {
         try {
             Request.setPerThreadRequestData(requestData);
+            ThreadContext.bind(securityManager);
+            ThreadContext.bind(subject);
             return delegate.call();
         } finally {
             Request.resetPerThreadRequestData();
+            ThreadContext.unbindSecurityManager();
+            ThreadContext.unbindSubject();
         }
     }
 }
diff --git a/payment/src/main/java/org/killbill/billing/payment/dispatcher/PluginDispatcher.java b/payment/src/main/java/org/killbill/billing/payment/dispatcher/PluginDispatcher.java
index d2de676..e023d9f 100644
--- a/payment/src/main/java/org/killbill/billing/payment/dispatcher/PluginDispatcher.java
+++ b/payment/src/main/java/org/killbill/billing/payment/dispatcher/PluginDispatcher.java
@@ -25,6 +25,7 @@ import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
+import org.apache.shiro.util.ThreadContext;
 import org.killbill.billing.payment.core.PaymentExecutors;
 import org.killbill.commons.profiling.Profiling;
 import org.killbill.commons.profiling.ProfilingData;
@@ -53,7 +54,7 @@ public class PluginDispatcher<ReturnType> {
         final ExecutorService pluginExecutor = paymentExecutors.getPluginExecutorService();
 
         // Wrap existing callable to keep the original requestId
-        final Callable<PluginDispatcherReturnType<ReturnType>> callableWithRequestData = new CallableWithRequestData(Request.getPerThreadRequestData(), task);
+        final Callable<PluginDispatcherReturnType<ReturnType>> callableWithRequestData = new CallableWithRequestData(Request.getPerThreadRequestData(), ThreadContext.getSecurityManager(), ThreadContext.getSubject(), task);
 
         final Future<PluginDispatcherReturnType<ReturnType>> future = pluginExecutor.submit(callableWithRequestData);
         final PluginDispatcherReturnType<ReturnType> pluginDispatcherResult = future.get(timeout, unit);
diff --git a/payment/src/test/java/org/killbill/billing/payment/dispatcher/TestPluginDispatcher.java b/payment/src/test/java/org/killbill/billing/payment/dispatcher/TestPluginDispatcher.java
index 835df39..ad324a9 100644
--- a/payment/src/test/java/org/killbill/billing/payment/dispatcher/TestPluginDispatcher.java
+++ b/payment/src/test/java/org/killbill/billing/payment/dispatcher/TestPluginDispatcher.java
@@ -18,7 +18,6 @@ package org.killbill.billing.payment.dispatcher;
 
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
@@ -26,7 +25,6 @@ import org.killbill.billing.ErrorCode;
 import org.killbill.billing.payment.PaymentTestSuiteNoDB;
 import org.killbill.billing.payment.api.PaymentApiException;
 import org.killbill.billing.payment.dispatcher.PluginDispatcher.PluginDispatcherReturnType;
-import org.killbill.commons.profiling.Profiling;
 import org.killbill.commons.request.Request;
 import org.killbill.commons.request.RequestData;
 import org.testng.Assert;
@@ -136,7 +134,9 @@ public class TestPluginDispatcher extends PaymentTestSuiteNoDB {
         };
 
         final CallableWithRequestData<PluginDispatcherReturnType<String>> callable = new CallableWithRequestData<PluginDispatcherReturnType<String>>(new RequestData(requestId),
-                                                                                                                                                        delegate);
+                                                                                                                                                     null,
+                                                                                                                                                     null,
+                                                                                                                                                     delegate);
 
         final String actualRequestId = stringPluginDispatcher.dispatchWithTimeout(callable, 100, TimeUnit.MILLISECONDS);
         Assert.assertEquals(actualRequestId, requestId);