killbill-uncached

payment: test RO DBI behavior with payment plugins Signed-off-by:

4/30/2018 6:48:42 AM

Details

diff --git a/payment/src/test/java/org/killbill/billing/payment/provider/MockPaymentProviderPlugin.java b/payment/src/test/java/org/killbill/billing/payment/provider/MockPaymentProviderPlugin.java
index b86f476..32bba1c 100644
--- a/payment/src/test/java/org/killbill/billing/payment/provider/MockPaymentProviderPlugin.java
+++ b/payment/src/test/java/org/killbill/billing/payment/provider/MockPaymentProviderPlugin.java
@@ -48,6 +48,8 @@ import org.killbill.billing.util.callcontext.CallContext;
 import org.killbill.billing.util.callcontext.TenantContext;
 import org.killbill.billing.util.entity.DefaultPagination;
 import org.killbill.billing.util.entity.Pagination;
+import org.killbill.billing.util.entity.dao.DBRouterUntyped;
+import org.killbill.billing.util.entity.dao.DBRouterUntyped.THREAD_STATE;
 import org.killbill.clock.Clock;
 
 import com.google.common.base.Preconditions;
@@ -87,6 +89,8 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
 
     private final Clock clock;
 
+    private THREAD_STATE lastThreadState = null;
+
     private class InternalPaymentInfo {
 
         private BigDecimal authAmount;
@@ -272,48 +276,60 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
         }
     }
 
+    public THREAD_STATE getLastThreadState() {
+        return lastThreadState;
+    }
+
     @Override
     public PaymentTransactionInfoPlugin authorizePayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final BigDecimal amount, final Currency currency, final Iterable<PluginProperty> properties, final CallContext context)
             throws PaymentPluginApiException {
+        updateLastThreadState();
         return getPaymentTransactionInfoPluginResult(kbPaymentId, kbTransactionId, TransactionType.AUTHORIZE, amount, currency, properties);
     }
 
     @Override
     public PaymentTransactionInfoPlugin capturePayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final BigDecimal amount, final Currency currency, final Iterable<PluginProperty> properties, final CallContext context)
             throws PaymentPluginApiException {
+        updateLastThreadState();
         return getPaymentTransactionInfoPluginResult(kbPaymentId, kbTransactionId, TransactionType.CAPTURE, amount, currency, properties);
     }
 
     @Override
     public PaymentTransactionInfoPlugin purchasePayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final BigDecimal amount, final Currency currency, final Iterable<PluginProperty> properties, final CallContext context) throws PaymentPluginApiException {
+        updateLastThreadState();
         return getPaymentTransactionInfoPluginResult(kbPaymentId, kbTransactionId, TransactionType.PURCHASE, amount, currency, properties);
     }
 
     @Override
     public PaymentTransactionInfoPlugin voidPayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final Iterable<PluginProperty> properties, final CallContext context)
             throws PaymentPluginApiException {
+        updateLastThreadState();
         return getPaymentTransactionInfoPluginResult(kbPaymentId, kbTransactionId, TransactionType.VOID, null, null, properties);
     }
 
     @Override
     public PaymentTransactionInfoPlugin creditPayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final BigDecimal amount, final Currency currency, final Iterable<PluginProperty> properties, final CallContext context)
             throws PaymentPluginApiException {
+        updateLastThreadState();
         return getPaymentTransactionInfoPluginResult(kbPaymentId, kbTransactionId, TransactionType.CREDIT, amount, currency, properties);
     }
 
     @Override
     public List<PaymentTransactionInfoPlugin> getPaymentInfo(final UUID kbAccountId, final UUID kbPaymentId, final Iterable<PluginProperty> properties, final TenantContext context) throws PaymentPluginApiException {
+        updateLastThreadState();
         final List<PaymentTransactionInfoPlugin> result = paymentTransactions.get(kbPaymentId.toString());
         return result != null ? result : ImmutableList.<PaymentTransactionInfoPlugin>of();
     }
 
     @Override
     public Pagination<PaymentTransactionInfoPlugin> searchPayments(final String searchKey, final Long offset, final Long limit, final Iterable<PluginProperty> properties, final TenantContext tenantContext) throws PaymentPluginApiException {
+        updateLastThreadState();
         throw new IllegalStateException("Not implemented");
     }
 
     @Override
     public void addPaymentMethod(final UUID kbAccountId, final UUID kbPaymentMethodId, final PaymentMethodPlugin paymentMethodProps, final boolean setDefault, final Iterable<PluginProperty> properties, final CallContext context) throws PaymentPluginApiException {
+        updateLastThreadState();
         // externalPaymentMethodId is set to a random value
         final PaymentMethodPlugin realWithID = new TestPaymentMethodPlugin(kbPaymentMethodId, paymentMethodProps, UUID.randomUUID().toString());
         paymentMethods.put(kbPaymentMethodId.toString(), realWithID);
@@ -324,26 +340,31 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
 
     @Override
     public void deletePaymentMethod(final UUID kbAccountId, final UUID kbPaymentMethodId, final Iterable<PluginProperty> properties, final CallContext context) throws PaymentPluginApiException {
+        updateLastThreadState();
         paymentMethods.remove(kbPaymentMethodId.toString());
         paymentMethodsInfo.remove(kbPaymentMethodId.toString());
     }
 
     @Override
     public PaymentMethodPlugin getPaymentMethodDetail(final UUID kbAccountId, final UUID kbPaymentMethodId, final Iterable<PluginProperty> properties, final TenantContext context) throws PaymentPluginApiException {
+        updateLastThreadState();
         return paymentMethods.get(kbPaymentMethodId.toString());
     }
 
     @Override
     public void setDefaultPaymentMethod(final UUID kbAccountId, final UUID kbPaymentMethodId, final Iterable<PluginProperty> properties, final CallContext context) throws PaymentPluginApiException {
+        updateLastThreadState();
     }
 
     @Override
     public List<PaymentMethodInfoPlugin> getPaymentMethods(final UUID kbAccountId, final boolean refreshFromGateway, final Iterable<PluginProperty> properties, final CallContext context) {
+        updateLastThreadState();
         return ImmutableList.<PaymentMethodInfoPlugin>copyOf(paymentMethodsInfo.values());
     }
 
     @Override
     public Pagination<PaymentMethodPlugin> searchPaymentMethods(final String searchKey, final Long offset, final Long limit, final Iterable<PluginProperty> properties, final TenantContext tenantContext) throws PaymentPluginApiException {
+        updateLastThreadState();
         final ImmutableList<PaymentMethodPlugin> results = ImmutableList.<PaymentMethodPlugin>copyOf(Iterables.<PaymentMethodPlugin>filter(paymentMethods.values(), new Predicate<PaymentMethodPlugin>() {
             @Override
             public boolean apply(final PaymentMethodPlugin input) {
@@ -362,6 +383,7 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
 
     @Override
     public void resetPaymentMethods(final UUID kbAccountId, final List<PaymentMethodInfoPlugin> input, final Iterable<PluginProperty> properties, final CallContext callContext) {
+        updateLastThreadState();
         paymentMethodsInfo.clear();
         if (input != null) {
             for (final PaymentMethodInfoPlugin cur : input) {
@@ -372,16 +394,19 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
 
     @Override
     public HostedPaymentPageFormDescriptor buildFormDescriptor(final UUID kbAccountId, final Iterable<PluginProperty> customFields, final Iterable<PluginProperty> properties, final CallContext callContext) {
+        updateLastThreadState();
         return new DefaultNoOpHostedPaymentPageFormDescriptor(kbAccountId);
     }
 
     @Override
     public GatewayNotification processNotification(final String notification, final Iterable<PluginProperty> properties, final CallContext callContext) throws PaymentPluginApiException {
+        updateLastThreadState();
         return new DefaultNoOpGatewayNotification();
     }
 
     @Override
     public PaymentTransactionInfoPlugin refundPayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final BigDecimal refundAmount, final Currency currency, final Iterable<PluginProperty> properties, final CallContext context) throws PaymentPluginApiException {
+        updateLastThreadState();
 
         final InternalPaymentInfo info = payments.get(kbPaymentId.toString());
         if (info == null) {
@@ -479,4 +504,8 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
 
         return result;
     }
+
+    private void updateLastThreadState() {
+        lastThreadState = DBRouterUntyped.getCurrentState();
+    }
 }
diff --git a/payment/src/test/java/org/killbill/billing/payment/TestJanitor.java b/payment/src/test/java/org/killbill/billing/payment/TestJanitor.java
index ecaecf1..1f2e03a 100644
--- a/payment/src/test/java/org/killbill/billing/payment/TestJanitor.java
+++ b/payment/src/test/java/org/killbill/billing/payment/TestJanitor.java
@@ -56,7 +56,10 @@ import org.killbill.billing.payment.provider.DefaultNoOpPaymentInfoPlugin;
 import org.killbill.billing.payment.provider.MockPaymentProviderPlugin;
 import org.killbill.billing.platform.api.KillbillConfigSource;
 import org.killbill.billing.util.callcontext.InternalCallContextFactory;
+import org.killbill.billing.util.entity.dao.DBRouterUntyped;
+import org.killbill.billing.util.entity.dao.DBRouterUntyped.THREAD_STATE;
 import org.killbill.bus.api.PersistentBus.EventBusException;
+import org.killbill.commons.profiling.Profiling.WithProfilingCallback;
 import org.killbill.notificationq.api.NotificationEvent;
 import org.killbill.notificationq.api.NotificationEventWithMetadata;
 import org.killbill.notificationq.api.NotificationQueueService;
@@ -75,8 +78,8 @@ import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterables;
 import com.google.inject.Inject;
 
-import static org.awaitility.Awaitility.await;
 import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.awaitility.Awaitility.await;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.fail;
 
@@ -472,6 +475,45 @@ public class TestJanitor extends PaymentTestSuiteWithEmbeddedDB {
         });
     }
 
+    @Test(groups = "slow")
+    public void testDBRouterThreadState() throws Throwable {
+        final Payment payment = (Payment) DBRouterUntyped.withRODBIAllowed(true,
+                                                                           new WithProfilingCallback<Object, Throwable>() {
+                                                                               @Override
+                                                                               public Payment execute() throws Throwable {
+                                                                                   // Shouldn't happen in practice, but it's just to verify the behavior
+                                                                                   assertEquals(DBRouterUntyped.getCurrentState(), THREAD_STATE.RO_ALLOWED);
+
+                                                                                   final BigDecimal requestedAmount = BigDecimal.TEN;
+                                                                                   testListener.pushExpectedEvent(NextEvent.PAYMENT);
+                                                                                   final Payment payment = paymentApi.createAuthorization(account, account.getPaymentMethodId(), null, requestedAmount, account.getCurrency(), null, UUID.randomUUID().toString(),
+                                                                                                                                          UUID.randomUUID().toString(), ImmutableList.<PluginProperty>of(), callContext);
+                                                                                   testListener.assertListenerStatus();
+
+                                                                                   // Thread switch, RW by default
+                                                                                   assertEquals(mockPaymentProviderPlugin.getLastThreadState(), THREAD_STATE.RW_ONLY);
+                                                                                   // Switched to RW, because of RW DAO call
+                                                                                   assertEquals(DBRouterUntyped.getCurrentState(), THREAD_STATE.RW_ONLY);
+                                                                                   return payment;
+                                                                               }
+                                                                           });
+
+        DBRouterUntyped.withRODBIAllowed(true,
+                                         new WithProfilingCallback<Object, Throwable>() {
+                                             @Override
+                                             public Object execute() throws Throwable {
+                                                 assertEquals(DBRouterUntyped.getCurrentState(), THREAD_STATE.RO_ALLOWED);
+
+                                                 final Payment retrievedPayment2 = paymentApi.getPayment(payment.getId(), true, false, ImmutableList.<PluginProperty>of(), callContext);
+                                                 Assert.assertEquals(retrievedPayment2.getTransactions().get(0).getTransactionStatus(), TransactionStatus.SUCCESS);
+
+                                                 // No thread switch, RO as well
+                                                 assertEquals(mockPaymentProviderPlugin.getLastThreadState(), THREAD_STATE.RO_ALLOWED);
+                                                 assertEquals(DBRouterUntyped.getCurrentState(), THREAD_STATE.RO_ALLOWED);
+                                                 return null;
+                                             }
+                                         });
+    }
 
     private List<PluginProperty> createPropertiesForInvoice(final Invoice invoice) {
         final List<PluginProperty> result = new ArrayList<PluginProperty>();
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java b/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java
index 40eb556..abd2b4b 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java
@@ -24,6 +24,8 @@ import org.skife.jdbi.v2.TransactionCallback;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import com.google.common.annotations.VisibleForTesting;
+
 import static org.killbill.billing.util.entity.dao.DBRouterUntyped.THREAD_STATE.RO_ALLOWED;
 import static org.killbill.billing.util.entity.dao.DBRouterUntyped.THREAD_STATE.RW_ONLY;
 
@@ -48,7 +50,7 @@ public class DBRouterUntyped {
 
     public static Object withRODBIAllowed(final boolean allowRODBI,
                                           final WithProfilingCallback<Object, Throwable> callback) throws Throwable {
-        final THREAD_STATE currentState = CURRENT_THREAD_STATE.get();
+        final THREAD_STATE currentState = getCurrentState();
         CURRENT_THREAD_STATE.set(allowRODBI ? RO_ALLOWED : RW_ONLY);
 
         try {
@@ -58,6 +60,11 @@ public class DBRouterUntyped {
         }
     }
 
+    @VisibleForTesting
+    public static THREAD_STATE getCurrentState() {
+        return CURRENT_THREAD_STATE.get();
+    }
+
     boolean shouldUseRODBI(final boolean requestedRO) {
         if (requestedRO) {
             if (isRODBIAllowed()) {
@@ -65,7 +72,7 @@ public class DBRouterUntyped {
                 return true;
             } else {
                 // Redirect to the rw instance, to work-around any replication delay
-                logger.debug("RO DBI requested, but thread state is {}, using RW DBI", CURRENT_THREAD_STATE.get());
+                logger.debug("RO DBI requested, but thread state is {}, using RW DBI", getCurrentState());
                 return false;
             }
         } else {
@@ -77,7 +84,7 @@ public class DBRouterUntyped {
     }
 
     private boolean isRODBIAllowed() {
-        return CURRENT_THREAD_STATE.get() == RO_ALLOWED;
+        return getCurrentState() == RO_ALLOWED;
     }
 
     private void disallowRODBI() {