diff --git a/payment/src/test/java/org/killbill/billing/payment/provider/MockPaymentProviderPlugin.java b/payment/src/test/java/org/killbill/billing/payment/provider/MockPaymentProviderPlugin.java
index b86f476..32bba1c 100644
--- a/payment/src/test/java/org/killbill/billing/payment/provider/MockPaymentProviderPlugin.java
+++ b/payment/src/test/java/org/killbill/billing/payment/provider/MockPaymentProviderPlugin.java
@@ -48,6 +48,8 @@ import org.killbill.billing.util.callcontext.CallContext;
import org.killbill.billing.util.callcontext.TenantContext;
import org.killbill.billing.util.entity.DefaultPagination;
import org.killbill.billing.util.entity.Pagination;
+import org.killbill.billing.util.entity.dao.DBRouterUntyped;
+import org.killbill.billing.util.entity.dao.DBRouterUntyped.THREAD_STATE;
import org.killbill.clock.Clock;
import com.google.common.base.Preconditions;
@@ -87,6 +89,8 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
private final Clock clock;
+ private THREAD_STATE lastThreadState = null;
+
private class InternalPaymentInfo {
private BigDecimal authAmount;
@@ -272,48 +276,60 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
}
}
+ public THREAD_STATE getLastThreadState() {
+ return lastThreadState;
+ }
+
@Override
public PaymentTransactionInfoPlugin authorizePayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final BigDecimal amount, final Currency currency, final Iterable<PluginProperty> properties, final CallContext context)
throws PaymentPluginApiException {
+ updateLastThreadState();
return getPaymentTransactionInfoPluginResult(kbPaymentId, kbTransactionId, TransactionType.AUTHORIZE, amount, currency, properties);
}
@Override
public PaymentTransactionInfoPlugin capturePayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final BigDecimal amount, final Currency currency, final Iterable<PluginProperty> properties, final CallContext context)
throws PaymentPluginApiException {
+ updateLastThreadState();
return getPaymentTransactionInfoPluginResult(kbPaymentId, kbTransactionId, TransactionType.CAPTURE, amount, currency, properties);
}
@Override
public PaymentTransactionInfoPlugin purchasePayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final BigDecimal amount, final Currency currency, final Iterable<PluginProperty> properties, final CallContext context) throws PaymentPluginApiException {
+ updateLastThreadState();
return getPaymentTransactionInfoPluginResult(kbPaymentId, kbTransactionId, TransactionType.PURCHASE, amount, currency, properties);
}
@Override
public PaymentTransactionInfoPlugin voidPayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final Iterable<PluginProperty> properties, final CallContext context)
throws PaymentPluginApiException {
+ updateLastThreadState();
return getPaymentTransactionInfoPluginResult(kbPaymentId, kbTransactionId, TransactionType.VOID, null, null, properties);
}
@Override
public PaymentTransactionInfoPlugin creditPayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final BigDecimal amount, final Currency currency, final Iterable<PluginProperty> properties, final CallContext context)
throws PaymentPluginApiException {
+ updateLastThreadState();
return getPaymentTransactionInfoPluginResult(kbPaymentId, kbTransactionId, TransactionType.CREDIT, amount, currency, properties);
}
@Override
public List<PaymentTransactionInfoPlugin> getPaymentInfo(final UUID kbAccountId, final UUID kbPaymentId, final Iterable<PluginProperty> properties, final TenantContext context) throws PaymentPluginApiException {
+ updateLastThreadState();
final List<PaymentTransactionInfoPlugin> result = paymentTransactions.get(kbPaymentId.toString());
return result != null ? result : ImmutableList.<PaymentTransactionInfoPlugin>of();
}
@Override
public Pagination<PaymentTransactionInfoPlugin> searchPayments(final String searchKey, final Long offset, final Long limit, final Iterable<PluginProperty> properties, final TenantContext tenantContext) throws PaymentPluginApiException {
+ updateLastThreadState();
throw new IllegalStateException("Not implemented");
}
@Override
public void addPaymentMethod(final UUID kbAccountId, final UUID kbPaymentMethodId, final PaymentMethodPlugin paymentMethodProps, final boolean setDefault, final Iterable<PluginProperty> properties, final CallContext context) throws PaymentPluginApiException {
+ updateLastThreadState();
// externalPaymentMethodId is set to a random value
final PaymentMethodPlugin realWithID = new TestPaymentMethodPlugin(kbPaymentMethodId, paymentMethodProps, UUID.randomUUID().toString());
paymentMethods.put(kbPaymentMethodId.toString(), realWithID);
@@ -324,26 +340,31 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
@Override
public void deletePaymentMethod(final UUID kbAccountId, final UUID kbPaymentMethodId, final Iterable<PluginProperty> properties, final CallContext context) throws PaymentPluginApiException {
+ updateLastThreadState();
paymentMethods.remove(kbPaymentMethodId.toString());
paymentMethodsInfo.remove(kbPaymentMethodId.toString());
}
@Override
public PaymentMethodPlugin getPaymentMethodDetail(final UUID kbAccountId, final UUID kbPaymentMethodId, final Iterable<PluginProperty> properties, final TenantContext context) throws PaymentPluginApiException {
+ updateLastThreadState();
return paymentMethods.get(kbPaymentMethodId.toString());
}
@Override
public void setDefaultPaymentMethod(final UUID kbAccountId, final UUID kbPaymentMethodId, final Iterable<PluginProperty> properties, final CallContext context) throws PaymentPluginApiException {
+ updateLastThreadState();
}
@Override
public List<PaymentMethodInfoPlugin> getPaymentMethods(final UUID kbAccountId, final boolean refreshFromGateway, final Iterable<PluginProperty> properties, final CallContext context) {
+ updateLastThreadState();
return ImmutableList.<PaymentMethodInfoPlugin>copyOf(paymentMethodsInfo.values());
}
@Override
public Pagination<PaymentMethodPlugin> searchPaymentMethods(final String searchKey, final Long offset, final Long limit, final Iterable<PluginProperty> properties, final TenantContext tenantContext) throws PaymentPluginApiException {
+ updateLastThreadState();
final ImmutableList<PaymentMethodPlugin> results = ImmutableList.<PaymentMethodPlugin>copyOf(Iterables.<PaymentMethodPlugin>filter(paymentMethods.values(), new Predicate<PaymentMethodPlugin>() {
@Override
public boolean apply(final PaymentMethodPlugin input) {
@@ -362,6 +383,7 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
@Override
public void resetPaymentMethods(final UUID kbAccountId, final List<PaymentMethodInfoPlugin> input, final Iterable<PluginProperty> properties, final CallContext callContext) {
+ updateLastThreadState();
paymentMethodsInfo.clear();
if (input != null) {
for (final PaymentMethodInfoPlugin cur : input) {
@@ -372,16 +394,19 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
@Override
public HostedPaymentPageFormDescriptor buildFormDescriptor(final UUID kbAccountId, final Iterable<PluginProperty> customFields, final Iterable<PluginProperty> properties, final CallContext callContext) {
+ updateLastThreadState();
return new DefaultNoOpHostedPaymentPageFormDescriptor(kbAccountId);
}
@Override
public GatewayNotification processNotification(final String notification, final Iterable<PluginProperty> properties, final CallContext callContext) throws PaymentPluginApiException {
+ updateLastThreadState();
return new DefaultNoOpGatewayNotification();
}
@Override
public PaymentTransactionInfoPlugin refundPayment(final UUID kbAccountId, final UUID kbPaymentId, final UUID kbTransactionId, final UUID kbPaymentMethodId, final BigDecimal refundAmount, final Currency currency, final Iterable<PluginProperty> properties, final CallContext context) throws PaymentPluginApiException {
+ updateLastThreadState();
final InternalPaymentInfo info = payments.get(kbPaymentId.toString());
if (info == null) {
@@ -479,4 +504,8 @@ public class MockPaymentProviderPlugin implements PaymentPluginApi {
return result;
}
+
+ private void updateLastThreadState() {
+ lastThreadState = DBRouterUntyped.getCurrentState();
+ }
}
diff --git a/payment/src/test/java/org/killbill/billing/payment/TestJanitor.java b/payment/src/test/java/org/killbill/billing/payment/TestJanitor.java
index ecaecf1..1f2e03a 100644
--- a/payment/src/test/java/org/killbill/billing/payment/TestJanitor.java
+++ b/payment/src/test/java/org/killbill/billing/payment/TestJanitor.java
@@ -56,7 +56,10 @@ import org.killbill.billing.payment.provider.DefaultNoOpPaymentInfoPlugin;
import org.killbill.billing.payment.provider.MockPaymentProviderPlugin;
import org.killbill.billing.platform.api.KillbillConfigSource;
import org.killbill.billing.util.callcontext.InternalCallContextFactory;
+import org.killbill.billing.util.entity.dao.DBRouterUntyped;
+import org.killbill.billing.util.entity.dao.DBRouterUntyped.THREAD_STATE;
import org.killbill.bus.api.PersistentBus.EventBusException;
+import org.killbill.commons.profiling.Profiling.WithProfilingCallback;
import org.killbill.notificationq.api.NotificationEvent;
import org.killbill.notificationq.api.NotificationEventWithMetadata;
import org.killbill.notificationq.api.NotificationQueueService;
@@ -75,8 +78,8 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.inject.Inject;
-import static org.awaitility.Awaitility.await;
import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.awaitility.Awaitility.await;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.fail;
@@ -472,6 +475,45 @@ public class TestJanitor extends PaymentTestSuiteWithEmbeddedDB {
});
}
+ @Test(groups = "slow")
+ public void testDBRouterThreadState() throws Throwable {
+ final Payment payment = (Payment) DBRouterUntyped.withRODBIAllowed(true,
+ new WithProfilingCallback<Object, Throwable>() {
+ @Override
+ public Payment execute() throws Throwable {
+ // Shouldn't happen in practice, but it's just to verify the behavior
+ assertEquals(DBRouterUntyped.getCurrentState(), THREAD_STATE.RO_ALLOWED);
+
+ final BigDecimal requestedAmount = BigDecimal.TEN;
+ testListener.pushExpectedEvent(NextEvent.PAYMENT);
+ final Payment payment = paymentApi.createAuthorization(account, account.getPaymentMethodId(), null, requestedAmount, account.getCurrency(), null, UUID.randomUUID().toString(),
+ UUID.randomUUID().toString(), ImmutableList.<PluginProperty>of(), callContext);
+ testListener.assertListenerStatus();
+
+ // Thread switch, RW by default
+ assertEquals(mockPaymentProviderPlugin.getLastThreadState(), THREAD_STATE.RW_ONLY);
+ // Switched to RW, because of RW DAO call
+ assertEquals(DBRouterUntyped.getCurrentState(), THREAD_STATE.RW_ONLY);
+ return payment;
+ }
+ });
+
+ DBRouterUntyped.withRODBIAllowed(true,
+ new WithProfilingCallback<Object, Throwable>() {
+ @Override
+ public Object execute() throws Throwable {
+ assertEquals(DBRouterUntyped.getCurrentState(), THREAD_STATE.RO_ALLOWED);
+
+ final Payment retrievedPayment2 = paymentApi.getPayment(payment.getId(), true, false, ImmutableList.<PluginProperty>of(), callContext);
+ Assert.assertEquals(retrievedPayment2.getTransactions().get(0).getTransactionStatus(), TransactionStatus.SUCCESS);
+
+ // No thread switch, RO as well
+ assertEquals(mockPaymentProviderPlugin.getLastThreadState(), THREAD_STATE.RO_ALLOWED);
+ assertEquals(DBRouterUntyped.getCurrentState(), THREAD_STATE.RO_ALLOWED);
+ return null;
+ }
+ });
+ }
private List<PluginProperty> createPropertiesForInvoice(final Invoice invoice) {
final List<PluginProperty> result = new ArrayList<PluginProperty>();
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java b/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java
index 40eb556..abd2b4b 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java
@@ -24,6 +24,8 @@ import org.skife.jdbi.v2.TransactionCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import com.google.common.annotations.VisibleForTesting;
+
import static org.killbill.billing.util.entity.dao.DBRouterUntyped.THREAD_STATE.RO_ALLOWED;
import static org.killbill.billing.util.entity.dao.DBRouterUntyped.THREAD_STATE.RW_ONLY;
@@ -48,7 +50,7 @@ public class DBRouterUntyped {
public static Object withRODBIAllowed(final boolean allowRODBI,
final WithProfilingCallback<Object, Throwable> callback) throws Throwable {
- final THREAD_STATE currentState = CURRENT_THREAD_STATE.get();
+ final THREAD_STATE currentState = getCurrentState();
CURRENT_THREAD_STATE.set(allowRODBI ? RO_ALLOWED : RW_ONLY);
try {
@@ -58,6 +60,11 @@ public class DBRouterUntyped {
}
}
+ @VisibleForTesting
+ public static THREAD_STATE getCurrentState() {
+ return CURRENT_THREAD_STATE.get();
+ }
+
boolean shouldUseRODBI(final boolean requestedRO) {
if (requestedRO) {
if (isRODBIAllowed()) {
@@ -65,7 +72,7 @@ public class DBRouterUntyped {
return true;
} else {
// Redirect to the rw instance, to work-around any replication delay
- logger.debug("RO DBI requested, but thread state is {}, using RW DBI", CURRENT_THREAD_STATE.get());
+ logger.debug("RO DBI requested, but thread state is {}, using RW DBI", getCurrentState());
return false;
}
} else {
@@ -77,7 +84,7 @@ public class DBRouterUntyped {
}
private boolean isRODBIAllowed() {
- return CURRENT_THREAD_STATE.get() == RO_ALLOWED;
+ return getCurrentState() == RO_ALLOWED;
}
private void disallowRODBI() {