killbill-memoizeit

Merge pull request #959 from killbill/ro-db-iteration util:

4/25/2018 6:34:32 PM

Details

diff --git a/profiles/killbill/src/test/java/org/killbill/billing/server/security/TestKillbillJdbcTenantRealm.java b/profiles/killbill/src/test/java/org/killbill/billing/server/security/TestKillbillJdbcTenantRealm.java
index 39b1b13..292cd25 100644
--- a/profiles/killbill/src/test/java/org/killbill/billing/server/security/TestKillbillJdbcTenantRealm.java
+++ b/profiles/killbill/src/test/java/org/killbill/billing/server/security/TestKillbillJdbcTenantRealm.java
@@ -51,7 +51,7 @@ public class TestKillbillJdbcTenantRealm extends TestJaxrsBase {
         super.beforeMethod();
 
         // Create the tenant
-        final DefaultTenantDao tenantDao = new DefaultTenantDao(dbi, roDbi, clock, cacheControllerDispatcher, new DefaultNonEntityDao(dbi), Mockito.mock(InternalCallContextFactory.class), securityConfig);
+        final DefaultTenantDao tenantDao = new DefaultTenantDao(dbi, roDbi, clock, cacheControllerDispatcher, new DefaultNonEntityDao(dbi, roDbi), Mockito.mock(InternalCallContextFactory.class), securityConfig);
         tenant = new DefaultTenant(UUID.randomUUID(), null, null, UUID.randomUUID().toString(),
                                    UUID.randomUUID().toString(), UUID.randomUUID().toString());
         tenantDao.create(new TenantModelDao(tenant), internalCallContext);
diff --git a/usage/src/main/java/org/killbill/billing/usage/dao/DefaultRolledUpUsageDao.java b/usage/src/main/java/org/killbill/billing/usage/dao/DefaultRolledUpUsageDao.java
index ea5aabd..c8987b8 100644
--- a/usage/src/main/java/org/killbill/billing/usage/dao/DefaultRolledUpUsageDao.java
+++ b/usage/src/main/java/org/killbill/billing/usage/dao/DefaultRolledUpUsageDao.java
@@ -27,43 +27,42 @@ import javax.inject.Named;
 import org.joda.time.LocalDate;
 import org.killbill.billing.callcontext.InternalCallContext;
 import org.killbill.billing.callcontext.InternalTenantContext;
+import org.killbill.billing.util.entity.dao.DBRouter;
 import org.skife.jdbi.v2.IDBI;
 
 import static org.killbill.billing.util.glue.IDBISetup.MAIN_RO_IDBI_NAMED;
 
 public class DefaultRolledUpUsageDao implements RolledUpUsageDao {
 
-    private final RolledUpUsageSqlDao rolledUpUsageSqlDao;
-    private final RolledUpUsageSqlDao roRolledUpUsageSqlDao;
+    private final DBRouter<RolledUpUsageSqlDao> dbRouter;
 
     @Inject
     public DefaultRolledUpUsageDao(final IDBI dbi, @Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi) {
-        this.rolledUpUsageSqlDao = dbi.onDemand(RolledUpUsageSqlDao.class);
-        this.roRolledUpUsageSqlDao = roDbi.onDemand(RolledUpUsageSqlDao.class);
+        this.dbRouter = new DBRouter<RolledUpUsageSqlDao>(dbi, roDbi, RolledUpUsageSqlDao.class);
     }
 
     @Override
     public void record(final Iterable<RolledUpUsageModelDao> usages, final InternalCallContext context) {
-        rolledUpUsageSqlDao.create(usages, context);
+        dbRouter.onDemand(false).create(usages, context);
     }
 
     @Override
     public Boolean recordsWithTrackingIdExist(final UUID subscriptionId, final String trackingId, final InternalTenantContext context) {
-        return rolledUpUsageSqlDao.recordsWithTrackingIdExist(subscriptionId, trackingId, context) != null;
+        return dbRouter.onDemand(false).recordsWithTrackingIdExist(subscriptionId, trackingId, context) != null;
     }
 
     @Override
     public List<RolledUpUsageModelDao> getUsageForSubscription(final UUID subscriptionId, final LocalDate startDate, final LocalDate endDate, final String unitType, final InternalTenantContext context) {
-        return roRolledUpUsageSqlDao.getUsageForSubscription(subscriptionId, startDate.toDate(), endDate.toDate(), unitType, context);
+        return dbRouter.onDemand(true).getUsageForSubscription(subscriptionId, startDate.toDate(), endDate.toDate(), unitType, context);
     }
 
     @Override
     public List<RolledUpUsageModelDao> getAllUsageForSubscription(final UUID subscriptionId, final LocalDate startDate, final LocalDate endDate, final InternalTenantContext context) {
-        return roRolledUpUsageSqlDao.getAllUsageForSubscription(subscriptionId, startDate.toDate(), endDate.toDate(), context);
+        return dbRouter.onDemand(true).getAllUsageForSubscription(subscriptionId, startDate.toDate(), endDate.toDate(), context);
     }
 
     @Override
     public List<RolledUpUsageModelDao> getRawUsageForAccount(final LocalDate startDate, final LocalDate endDate, final InternalTenantContext context) {
-        return roRolledUpUsageSqlDao.getRawUsageForAccount(startDate.toDate(), endDate.toDate(), context);
+        return dbRouter.onDemand(true).getRawUsageForAccount(startDate.toDate(), endDate.toDate(), context);
     }
 }
diff --git a/util/src/main/java/org/killbill/billing/util/audit/dao/DefaultAuditDao.java b/util/src/main/java/org/killbill/billing/util/audit/dao/DefaultAuditDao.java
index e454359..b5e08a8 100644
--- a/util/src/main/java/org/killbill/billing/util/audit/dao/DefaultAuditDao.java
+++ b/util/src/main/java/org/killbill/billing/util/audit/dao/DefaultAuditDao.java
@@ -45,6 +45,7 @@ import org.killbill.billing.util.dao.NonEntityDao;
 import org.killbill.billing.util.dao.NonEntitySqlDao;
 import org.killbill.billing.util.dao.RecordIdIdMappings;
 import org.killbill.billing.util.dao.TableName;
+import org.killbill.billing.util.entity.dao.DBRouter;
 import org.killbill.billing.util.entity.dao.EntitySqlDao;
 import org.killbill.billing.util.entity.dao.EntitySqlDaoTransactionWrapper;
 import org.killbill.billing.util.entity.dao.EntitySqlDaoTransactionalJdbiWrapper;
@@ -62,18 +63,18 @@ import static org.killbill.billing.util.glue.IDBISetup.MAIN_RO_IDBI_NAMED;
 
 public class DefaultAuditDao implements AuditDao {
 
-    private final NonEntitySqlDao roNonEntitySqlDao;
+    private final DBRouter<NonEntitySqlDao> dbRouter;
     private final EntitySqlDaoTransactionalJdbiWrapper transactionalSqlDao;
 
     @Inject
     public DefaultAuditDao(final IDBI dbi, @Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi, final Clock clock, final CacheControllerDispatcher cacheControllerDispatcher, final NonEntityDao nonEntityDao, final InternalCallContextFactory internalCallContextFactory) {
-        this.roNonEntitySqlDao = roDbi.onDemand(NonEntitySqlDao.class);
+        this.dbRouter = new DBRouter<NonEntitySqlDao>(dbi, roDbi, NonEntitySqlDao.class);
         this.transactionalSqlDao = new EntitySqlDaoTransactionalJdbiWrapper(dbi, roDbi, clock, cacheControllerDispatcher, nonEntityDao, internalCallContextFactory);
     }
 
     @Override
     public DefaultAccountAuditLogs getAuditLogsForAccountRecordId(final AuditLevel auditLevel, final InternalTenantContext context) {
-        final UUID accountId = roNonEntitySqlDao.getIdFromObject(context.getAccountRecordId(), TableName.ACCOUNT.getTableName());
+        final UUID accountId = dbRouter.onDemand(true).getIdFromObject(context.getAccountRecordId(), TableName.ACCOUNT.getTableName());
 
         // Lazy evaluate records to minimize the memory footprint (these can yield a lot of results)
         // We usually always want to wrap our queries in an EntitySqlDaoTransactionWrapper... except here.
@@ -115,6 +116,7 @@ public class DefaultAuditDao implements AuditDao {
                                                                        // For tables without history, e.g. TENANT, originalTableNameForHistoryTableName will be null
                                                                        final TableName originalTableNameForHistoryTableName = findTableNameForHistoryTableName(input.getTableName());
 
+                                                                       final NonEntitySqlDao nonEntitySqlDao = dbRouter.onDemand(true);
                                                                        final ObjectType objectType;
                                                                        final UUID auditedEntityId;
                                                                        if (originalTableNameForHistoryTableName != null) {
@@ -123,19 +125,19 @@ public class DefaultAuditDao implements AuditDao {
 
                                                                            if (historyRecordIdIdsCache.get(originalTableNameForHistoryTableName) == null) {
                                                                                if (TableName.ACCOUNT.equals(originalTableNameForHistoryTableName)) {
-                                                                                   final Iterable<RecordIdIdMappings> mappings = roNonEntitySqlDao.getHistoryRecordIdIdMappingsForAccountsTable(originalTableNameForHistoryTableName.getTableName(),
-                                                                                                                                                                                                input.getTableName().getTableName(),
-                                                                                                                                                                                                tenantContext);
+                                                                                   final Iterable<RecordIdIdMappings> mappings = nonEntitySqlDao.getHistoryRecordIdIdMappingsForAccountsTable(originalTableNameForHistoryTableName.getTableName(),
+                                                                                                                                                                                              input.getTableName().getTableName(),
+                                                                                                                                                                                              tenantContext);
                                                                                    historyRecordIdIdsCache.put(originalTableNameForHistoryTableName, RecordIdIdMappings.toMap(mappings));
                                                                                } else if (TableName.TAG_DEFINITIONS.equals(originalTableNameForHistoryTableName)) {
-                                                                                   final Iterable<RecordIdIdMappings> mappings = roNonEntitySqlDao.getHistoryRecordIdIdMappingsForTablesWithoutAccountRecordId(originalTableNameForHistoryTableName.getTableName(),
-                                                                                                                                                                                                               input.getTableName().getTableName(),
-                                                                                                                                                                                                               tenantContext);
+                                                                                   final Iterable<RecordIdIdMappings> mappings = nonEntitySqlDao.getHistoryRecordIdIdMappingsForTablesWithoutAccountRecordId(originalTableNameForHistoryTableName.getTableName(),
+                                                                                                                                                                                                             input.getTableName().getTableName(),
+                                                                                                                                                                                                             tenantContext);
                                                                                    historyRecordIdIdsCache.put(originalTableNameForHistoryTableName, RecordIdIdMappings.toMap(mappings));
                                                                                } else {
-                                                                                   final Iterable<RecordIdIdMappings> mappings = roNonEntitySqlDao.getHistoryRecordIdIdMappings(originalTableNameForHistoryTableName.getTableName(),
-                                                                                                                                                                                input.getTableName().getTableName(),
-                                                                                                                                                                                tenantContext);
+                                                                                   final Iterable<RecordIdIdMappings> mappings = nonEntitySqlDao.getHistoryRecordIdIdMappings(originalTableNameForHistoryTableName.getTableName(),
+                                                                                                                                                                              input.getTableName().getTableName(),
+                                                                                                                                                                              tenantContext);
                                                                                    historyRecordIdIdsCache.put(originalTableNameForHistoryTableName, RecordIdIdMappings.toMap(mappings));
 
                                                                                }
@@ -146,8 +148,8 @@ public class DefaultAuditDao implements AuditDao {
                                                                            objectType = input.getTableName().getObjectType();
 
                                                                            if (recordIdIdsCache.get(input.getTableName()) == null) {
-                                                                               final Iterable<RecordIdIdMappings> mappings = roNonEntitySqlDao.getRecordIdIdMappings(input.getTableName().getTableName(),
-                                                                                                                                                                     tenantContext);
+                                                                               final Iterable<RecordIdIdMappings> mappings = nonEntitySqlDao.getRecordIdIdMappings(input.getTableName().getTableName(),
+                                                                                                                                                                   tenantContext);
                                                                                recordIdIdsCache.put(input.getTableName(), RecordIdIdMappings.toMap(mappings));
                                                                            }
 
@@ -188,35 +190,35 @@ public class DefaultAuditDao implements AuditDao {
         return transactionalSqlDao.execute(true, new EntitySqlDaoTransactionWrapper<List<AuditLogWithHistory>>() {
             @Override
             public List<AuditLogWithHistory> inTransaction(final EntitySqlDaoWrapperFactory entitySqlDaoWrapperFactory) throws Exception {
-                final Long targetRecordId = roNonEntitySqlDao.getRecordIdFromObject(objectId.toString(), tableName.getTableName());
+                final Long targetRecordId = dbRouter.onDemand(true).getRecordIdFromObject(objectId.toString(), tableName.getTableName());
                 final List<EntityHistoryModelDao> objectHistory = transactional.getHistoryForTargetRecordId(targetRecordId, context);
 
                 return ImmutableList.<AuditLogWithHistory>copyOf(Collections2.transform(entitySqlDaoWrapperFactory.become(EntitySqlDao.class).getAuditLogsViaHistoryForTargetRecordId(historyTableName.name(),
                                                                                                                                                                                       historyTableName.getTableName().toLowerCase(),
                                                                                                                                                                                       targetRecordId,
                                                                                                                                                                                       context),
-                                                                                 new Function<AuditLogModelDao, AuditLogWithHistory>() {
-                                                                                     @Override
-                                                                                     public AuditLogWithHistory apply(final AuditLogModelDao inputAuditLog) {
-                                                                                         EntityHistoryModelDao historyEntity = null;
-                                                                                         if ( objectHistory != null) {
-                                                                                             for (EntityHistoryModelDao history : objectHistory) {
-                                                                                                 if (history.getHistoryRecordId() == inputAuditLog.getTargetRecordId()) {
-                                                                                                     historyEntity = history;
-                                                                                                     break;
-                                                                                                 }
-                                                                                             }
-                                                                                         }
-
-                                                                                         return new DefaultAuditLogWithHistory((historyEntity == null ? null : historyEntity.getEntity()), inputAuditLog, tableName.getObjectType(), objectId);
-                                                                                     }
-                                                                                 }));
+                                                                                        new Function<AuditLogModelDao, AuditLogWithHistory>() {
+                                                                                            @Override
+                                                                                            public AuditLogWithHistory apply(final AuditLogModelDao inputAuditLog) {
+                                                                                                EntityHistoryModelDao historyEntity = null;
+                                                                                                if (objectHistory != null) {
+                                                                                                    for (final EntityHistoryModelDao history : objectHistory) {
+                                                                                                        if (history.getHistoryRecordId().equals(inputAuditLog.getTargetRecordId())) {
+                                                                                                            historyEntity = history;
+                                                                                                            break;
+                                                                                                        }
+                                                                                                    }
+                                                                                                }
+
+                                                                                                return new DefaultAuditLogWithHistory((historyEntity == null ? null : historyEntity.getEntity()), inputAuditLog, tableName.getObjectType(), objectId);
+                                                                                            }
+                                                                                        }));
             }
         });
     }
 
     private List<AuditLog> doGetAuditLogsForId(final TableName tableName, final UUID objectId, final AuditLevel auditLevel, final InternalTenantContext context) {
-        final Long recordId = roNonEntitySqlDao.getRecordIdFromObject(objectId.toString(), tableName.getTableName());
+        final Long recordId = dbRouter.onDemand(true).getRecordIdFromObject(objectId.toString(), tableName.getTableName());
         if (recordId == null) {
             return ImmutableList.<AuditLog>of();
         } else {
@@ -230,7 +232,7 @@ public class DefaultAuditDao implements AuditDao {
             throw new IllegalStateException("History table shouldn't be null for " + tableName);
         }
 
-        final Long targetRecordId = roNonEntitySqlDao.getRecordIdFromObject(objectId.toString(), tableName.getTableName());
+        final Long targetRecordId = dbRouter.onDemand(true).getRecordIdFromObject(objectId.toString(), tableName.getTableName());
         final List<AuditLog> allAuditLogs = transactionalSqlDao.execute(true, new EntitySqlDaoTransactionWrapper<List<AuditLog>>() {
             @Override
             public List<AuditLog> inTransaction(final EntitySqlDaoWrapperFactory entitySqlDaoWrapperFactory) throws Exception {
diff --git a/util/src/main/java/org/killbill/billing/util/broadcast/dao/DefaultBroadcastDao.java b/util/src/main/java/org/killbill/billing/util/broadcast/dao/DefaultBroadcastDao.java
index 834aeee..c2c4e5a 100644
--- a/util/src/main/java/org/killbill/billing/util/broadcast/dao/DefaultBroadcastDao.java
+++ b/util/src/main/java/org/killbill/billing/util/broadcast/dao/DefaultBroadcastDao.java
@@ -22,6 +22,7 @@ import java.util.List;
 import javax.inject.Inject;
 import javax.inject.Named;
 
+import org.killbill.billing.util.entity.dao.DBRouter;
 import org.skife.jdbi.v2.Handle;
 import org.skife.jdbi.v2.IDBI;
 import org.skife.jdbi.v2.TransactionCallback;
@@ -31,18 +32,16 @@ import static org.killbill.billing.util.glue.IDBISetup.MAIN_RO_IDBI_NAMED;
 
 public class DefaultBroadcastDao implements BroadcastDao {
 
-    private final IDBI dbi;
-    private final IDBI roDbi;
+    private final DBRouter<BroadcastSqlDao> dbRouter;
 
     @Inject
     public DefaultBroadcastDao(final IDBI dbi, @Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi) {
-        this.dbi = dbi;
-        this.roDbi = roDbi;
+        this.dbRouter = new DBRouter<BroadcastSqlDao>(dbi, roDbi, BroadcastSqlDao.class);
     }
 
     @Override
     public void create(final BroadcastModelDao broadcastModelDao) {
-        dbi.inTransaction(new TransactionCallback<Void>() {
+        dbRouter.inTransaction(false, new TransactionCallback<Void>() {
             @Override
             public Void inTransaction(final Handle handle, final TransactionStatus status) throws Exception {
                 final BroadcastSqlDao sqlDao = handle.attach(BroadcastSqlDao.class);
@@ -54,7 +53,7 @@ public class DefaultBroadcastDao implements BroadcastDao {
 
     @Override
     public List<BroadcastModelDao> getLatestEntriesFrom(final Long recordId) {
-        return roDbi.inTransaction(new TransactionCallback<List<BroadcastModelDao>>() {
+        return dbRouter.inTransaction(true, new TransactionCallback<List<BroadcastModelDao>>() {
             @Override
             public List<BroadcastModelDao> inTransaction(final Handle handle, final TransactionStatus status) throws Exception {
                 final BroadcastSqlDao sqlDao = handle.attach(BroadcastSqlDao.class);
@@ -65,7 +64,7 @@ public class DefaultBroadcastDao implements BroadcastDao {
 
     @Override
     public BroadcastModelDao getLatestEntry() {
-        return roDbi.inTransaction(new TransactionCallback<BroadcastModelDao>() {
+        return dbRouter.inTransaction(true, new TransactionCallback<BroadcastModelDao>() {
             @Override
             public BroadcastModelDao inTransaction(final Handle handle, final TransactionStatus status) throws Exception {
                 final BroadcastSqlDao sqlDao = handle.attach(BroadcastSqlDao.class);
diff --git a/util/src/main/java/org/killbill/billing/util/cache/AuditLogCacheLoader.java b/util/src/main/java/org/killbill/billing/util/cache/AuditLogCacheLoader.java
index 875920d..7274eab 100644
--- a/util/src/main/java/org/killbill/billing/util/cache/AuditLogCacheLoader.java
+++ b/util/src/main/java/org/killbill/billing/util/cache/AuditLogCacheLoader.java
@@ -28,6 +28,7 @@ import org.killbill.billing.callcontext.InternalTenantContext;
 import org.killbill.billing.util.audit.dao.AuditLogModelDao;
 import org.killbill.billing.util.cache.Cachable.CacheType;
 import org.killbill.billing.util.dao.AuditSqlDao;
+import org.killbill.billing.util.entity.dao.DBRouter;
 import org.skife.jdbi.v2.IDBI;
 
 import static org.killbill.billing.util.glue.IDBISetup.MAIN_RO_IDBI_NAMED;
@@ -35,12 +36,12 @@ import static org.killbill.billing.util.glue.IDBISetup.MAIN_RO_IDBI_NAMED;
 @Singleton
 public class AuditLogCacheLoader extends BaseCacheLoader<String, List<AuditLogModelDao>> {
 
-    private final AuditSqlDao roAuditSqlDao;
+    private final DBRouter<AuditSqlDao> dbRouter;
 
     @Inject
-    public AuditLogCacheLoader(@Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi) {
+    public AuditLogCacheLoader(final IDBI dbi, @Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi) {
         super();
-        this.roAuditSqlDao = roDbi.onDemand(AuditSqlDao.class);
+        this.dbRouter = new DBRouter<AuditSqlDao>(dbi, roDbi, AuditSqlDao.class);
     }
 
     @Override
@@ -55,6 +56,6 @@ public class AuditLogCacheLoader extends BaseCacheLoader<String, List<AuditLogMo
         final Long targetRecordId = (Long) args[1];
         final InternalTenantContext internalTenantContext = (InternalTenantContext) args[2];
 
-        return roAuditSqlDao.getAuditLogsForTargetRecordId(tableName, targetRecordId, internalTenantContext);
+        return dbRouter.onDemand(true).getAuditLogsForTargetRecordId(tableName, targetRecordId, internalTenantContext);
     }
 }
diff --git a/util/src/main/java/org/killbill/billing/util/dao/DefaultNonEntityDao.java b/util/src/main/java/org/killbill/billing/util/dao/DefaultNonEntityDao.java
index 5e8c63d..41aa896 100644
--- a/util/src/main/java/org/killbill/billing/util/dao/DefaultNonEntityDao.java
+++ b/util/src/main/java/org/killbill/billing/util/dao/DefaultNonEntityDao.java
@@ -29,6 +29,7 @@ import org.killbill.billing.util.cache.CacheController;
 import org.killbill.billing.util.cache.CacheControllerDispatcher;
 import org.killbill.billing.util.cache.CacheLoaderArgument;
 import org.killbill.billing.util.callcontext.InternalCallContextFactory;
+import org.killbill.billing.util.entity.dao.DBRouter;
 import org.killbill.commons.profiling.Profiling;
 import org.killbill.commons.profiling.Profiling.WithProfilingCallback;
 import org.killbill.commons.profiling.ProfilingFeature.ProfilingFeatureType;
@@ -42,13 +43,13 @@ import static org.killbill.billing.util.glue.IDBISetup.MAIN_RO_IDBI_NAMED;
 
 public class DefaultNonEntityDao implements NonEntityDao {
 
-    private final NonEntitySqlDao roNonEntitySqlDao;
+    private final DBRouter<NonEntitySqlDao> dbRouter;
     private final WithCaching<String, Long> withCachingObjectId;
     private final WithCaching<String, UUID> withCachingRecordId;
 
     @Inject
-    public DefaultNonEntityDao(@Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi) {
-        this.roNonEntitySqlDao = roDbi.onDemand(NonEntitySqlDao.class);
+    public DefaultNonEntityDao(final IDBI dbi, @Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi) {
+        this.dbRouter = new DBRouter<NonEntitySqlDao>(dbi, roDbi, NonEntitySqlDao.class);
         this.withCachingObjectId = new WithCaching<String, Long>();
         this.withCachingRecordId = new WithCaching<String, UUID>();
     }
@@ -69,7 +70,7 @@ public class DefaultNonEntityDao implements NonEntityDao {
         return withCachingObjectId.withCaching(new OperationRetrieval<Long>() {
             @Override
             public Long doRetrieve(final ObjectType objectType) {
-                final NonEntitySqlDao inTransactionNonEntitySqlDao = handle == null ? roNonEntitySqlDao : SqlObjectBuilder.attach(handle, NonEntitySqlDao.class);
+                final NonEntitySqlDao inTransactionNonEntitySqlDao = handle == null ? dbRouter.onDemand(true) : SqlObjectBuilder.attach(handle, NonEntitySqlDao.class);
                 return inTransactionNonEntitySqlDao.getRecordIdFromObject(objectId.toString(), tableName.getTableName());
             }
         }, objectId.toString(), objectType, tableName, cache);
@@ -89,7 +90,7 @@ public class DefaultNonEntityDao implements NonEntityDao {
         return withCachingObjectId.withCaching(new OperationRetrieval<Long>() {
             @Override
             public Long doRetrieve(final ObjectType objectType) {
-                final NonEntitySqlDao inTransactionNonEntitySqlDao = handle == null ? roNonEntitySqlDao : SqlObjectBuilder.attach(handle, NonEntitySqlDao.class);
+                final NonEntitySqlDao inTransactionNonEntitySqlDao = handle == null ? dbRouter.onDemand(true) : SqlObjectBuilder.attach(handle, NonEntitySqlDao.class);
 
                 switch (tableName) {
                     case TENANT:
@@ -121,7 +122,7 @@ public class DefaultNonEntityDao implements NonEntityDao {
         return withCachingObjectId.withCaching(new OperationRetrieval<Long>() {
             @Override
             public Long doRetrieve(final ObjectType objectType) {
-                final NonEntitySqlDao inTransactionNonEntitySqlDao = handle == null ? roNonEntitySqlDao : SqlObjectBuilder.attach(handle, NonEntitySqlDao.class);
+                final NonEntitySqlDao inTransactionNonEntitySqlDao = handle == null ? dbRouter.onDemand(true) : SqlObjectBuilder.attach(handle, NonEntitySqlDao.class);
 
                 switch (tableName) {
                     case TENANT:
@@ -153,7 +154,7 @@ public class DefaultNonEntityDao implements NonEntityDao {
         return withCachingRecordId.withCaching(new OperationRetrieval<UUID>() {
             @Override
             public UUID doRetrieve(final ObjectType objectType) {
-                final NonEntitySqlDao inTransactionNonEntitySqlDao = handle == null ? roNonEntitySqlDao : SqlObjectBuilder.attach(handle, NonEntitySqlDao.class);
+                final NonEntitySqlDao inTransactionNonEntitySqlDao = handle == null ? dbRouter.onDemand(true) : SqlObjectBuilder.attach(handle, NonEntitySqlDao.class);
                 return inTransactionNonEntitySqlDao.getIdFromObject(recordId, tableName.getTableName());
             }
         }, String.valueOf(recordId), objectType, tableName, cache);
@@ -167,7 +168,7 @@ public class DefaultNonEntityDao implements NonEntityDao {
 
     @Override
     public Long retrieveHistoryTargetRecordId(@Nullable final Long recordId, final TableName tableName) {
-        return roNonEntitySqlDao.getHistoryTargetRecordId(recordId, tableName.getTableName());
+        return dbRouter.onDemand(true).getHistoryTargetRecordId(recordId, tableName.getTableName());
     }
 
     private interface OperationRetrieval<TypeOut> {
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouter.java b/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouter.java
new file mode 100644
index 0000000..8fa8006
--- /dev/null
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouter.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2014-2018 Groupon, Inc
+ * Copyright 2014-2018 The Billing Project, LLC
+ *
+ * The Billing Project licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License.  You may obtain a copy of the License at:
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.killbill.billing.util.entity.dao;
+
+import org.skife.jdbi.v2.IDBI;
+
+public class DBRouter<C> extends DBRouterUntyped {
+
+    private final C onDemand;
+    private final C roOnDemand;
+
+    public DBRouter(final IDBI dbi, final IDBI roDbi, final Class<C> sqlObjectType) {
+        super(dbi, roDbi);
+        this.onDemand = dbi.onDemand(sqlObjectType);
+        this.roOnDemand = roDbi.onDemand(sqlObjectType);
+    }
+
+    public C onDemand(final boolean requestedRO) {
+        if (shouldUseRODBI(requestedRO)) {
+            return roOnDemand;
+        } else {
+            return onDemand;
+        }
+    }
+}
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java b/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java
new file mode 100644
index 0000000..f369b84
--- /dev/null
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/DBRouterUntyped.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2014-2018 Groupon, Inc
+ * Copyright 2014-2018 The Billing Project, LLC
+ *
+ * The Billing Project licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License.  You may obtain a copy of the License at:
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.killbill.billing.util.entity.dao;
+
+import org.killbill.billing.util.glue.KillbillApiAopModule;
+import org.skife.jdbi.v2.Handle;
+import org.skife.jdbi.v2.IDBI;
+import org.skife.jdbi.v2.TransactionCallback;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class DBRouterUntyped {
+
+    private static final Logger logger = LoggerFactory.getLogger(DBRouterUntyped.class);
+
+    protected final IDBI dbi;
+    protected final IDBI roDbi;
+
+    public DBRouterUntyped(final IDBI dbi, final IDBI roDbi) {
+        this.dbi = dbi;
+        this.roDbi = roDbi;
+    }
+
+    public Handle getHandle(final boolean requestedRO) {
+        if (shouldUseRODBI(requestedRO)) {
+            return roDbi.open();
+        } else {
+            return dbi.open();
+        }
+    }
+
+    public <T> T onDemand(final boolean requestedRO, final Class<T> sqlObjectType) {
+        if (shouldUseRODBI(requestedRO)) {
+            return roDbi.onDemand(sqlObjectType);
+        } else {
+            return dbi.onDemand(sqlObjectType);
+        }
+    }
+
+    public <T> T inTransaction(final boolean requestedRO, final TransactionCallback<T> callback) {
+        if (shouldUseRODBI(requestedRO)) {
+            return roDbi.inTransaction(callback);
+        } else {
+            return dbi.inTransaction(callback);
+        }
+    }
+
+    boolean shouldUseRODBI(final boolean requestedRO) {
+        if (!requestedRO) {
+            KillbillApiAopModule.setDirtyDBFlag();
+            logger.debug("Dirty flag set, using RW DBI");
+            return false;
+        } else {
+            if (KillbillApiAopModule.getDirtyDBFlag()) {
+                // Redirect to the rw instance, to work-around any replication delay
+                logger.debug("RO DBI handle requested, but dirty flag set, using RW DBI");
+                return false;
+            } else {
+                logger.debug("Using RO DBI");
+                return true;
+            }
+        }
+    }
+}
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoTransactionalJdbiWrapper.java b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoTransactionalJdbiWrapper.java
index a6f6287..cb6c90e 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoTransactionalJdbiWrapper.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoTransactionalJdbiWrapper.java
@@ -24,7 +24,6 @@ import org.killbill.billing.util.cache.CacheControllerDispatcher;
 import org.killbill.billing.util.callcontext.InternalCallContextFactory;
 import org.killbill.billing.util.dao.NonEntityDao;
 import org.killbill.billing.util.entity.Entity;
-import org.killbill.billing.util.glue.KillbillApiAopModule;
 import org.killbill.clock.Clock;
 import org.skife.jdbi.v2.Handle;
 import org.skife.jdbi.v2.IDBI;
@@ -40,8 +39,7 @@ public class EntitySqlDaoTransactionalJdbiWrapper {
 
     private final Logger logger = LoggerFactory.getLogger(EntitySqlDaoTransactionalJdbiWrapper.class);
 
-    private final IDBI dbi;
-    private final IDBI roDbi;
+    private final DBRouterUntyped dbRouter;
     private final Clock clock;
     private final CacheControllerDispatcher cacheControllerDispatcher;
     private final NonEntityDao nonEntityDao;
@@ -49,12 +47,11 @@ public class EntitySqlDaoTransactionalJdbiWrapper {
 
     public EntitySqlDaoTransactionalJdbiWrapper(final IDBI dbi, final IDBI roDbi, final Clock clock, final CacheControllerDispatcher cacheControllerDispatcher,
                                                 final NonEntityDao nonEntityDao, final InternalCallContextFactory internalCallContextFactory) {
-        this.dbi = dbi;
-        this.roDbi = roDbi;
         this.clock = clock;
         this.cacheControllerDispatcher = cacheControllerDispatcher;
         this.nonEntityDao = nonEntityDao;
         this.internalCallContextFactory = internalCallContextFactory;
+        this.dbRouter = new DBRouterUntyped(dbi, roDbi);
     }
 
     public <M extends EntityModelDao> void populateCaches(final M refreshedEntity) {
@@ -89,39 +86,21 @@ public class EntitySqlDaoTransactionalJdbiWrapper {
      */
     public <ReturnType> ReturnType execute(final boolean requestedRO, final EntitySqlDaoTransactionWrapper<ReturnType> entitySqlDaoTransactionWrapper) {
         final String debugInfo = logger.isDebugEnabled() ? getDebugInfo() : null;
-        final boolean ro = shouldUseRODBI(requestedRO, debugInfo);
-        final String debugPrefix = ro ? "RO" : "RW";
 
-        final Handle handle = ro ? roDbi.open() : dbi.open();
-        logger.debug("[{}] DBI handle created, transaction: {}", debugPrefix, debugInfo);
+        final Handle handle = dbRouter.getHandle(requestedRO);
+        logger.debug("DBI handle created, transaction: {}", debugInfo);
         try {
             final EntitySqlDao<EntityModelDao<Entity>, Entity> entitySqlDao = handle.attach(InitialEntitySqlDao.class);
             // The transaction isolation level is now set at the pool level: this avoids 3 roundtrips for each transaction
             // Note that if the pool isn't used (tests or PostgreSQL), the transaction level will depend on the DB configuration
             //return entitySqlDao.inTransaction(TransactionIsolationLevel.READ_COMMITTED, new JdbiTransaction<ReturnType, EntityModelDao<Entity>, Entity>(handle, entitySqlDaoTransactionWrapper));
-            logger.debug("[{}] Starting transaction {}", debugPrefix, debugInfo);
+            logger.debug("Starting transaction {}", debugInfo);
             final ReturnType returnType = entitySqlDao.inTransaction(new JdbiTransaction<ReturnType, EntityModelDao<Entity>, Entity>(handle, entitySqlDaoTransactionWrapper));
-            logger.debug("[{}] Exiting  transaction {}, returning {}", debugPrefix, debugInfo, returnType);
+            logger.debug("Exiting  transaction {}, returning {}", debugInfo, returnType);
             return returnType;
         } finally {
             handle.close();
-            logger.debug("[{}] DBI handle closed,  transaction: {}", debugPrefix, debugInfo);
-        }
-    }
-
-    private boolean shouldUseRODBI(final boolean requestedRO, final String debugInfo) {
-        if (!requestedRO) {
-            KillbillApiAopModule.setDirtyDBFlag();
-            logger.debug("[RW] Dirty flag set, transaction: {}", debugInfo);
-            return false;
-        } else {
-            if (KillbillApiAopModule.getDirtyDBFlag()) {
-                // Redirect to the rw instance, to work-around any replication delay
-                logger.debug("[RW] RO DBI handle requested, but dirty flag set, transaction: {}", debugInfo);
-                return false;
-            } else {
-                return true;
-            }
+            logger.debug("DBI handle closed,  transaction: {}", debugInfo);
         }
     }
 
@@ -130,12 +109,7 @@ public class EntitySqlDaoTransactionalJdbiWrapper {
     // to send bus events, record notifications where we need to keep the Connection through the jDBI Handle.
     //
     public <M extends EntityModelDao<E>, E extends Entity, T extends EntitySqlDao<M, E>> T onDemandForStreamingResults(final Class<T> sqlObjectType) {
-        final String debugInfo = logger.isDebugEnabled() ? getDebugInfo() : null;
-        if (shouldUseRODBI(true, debugInfo)) {
-            return roDbi.onDemand(sqlObjectType);
-        } else {
-            return dbi.onDemand(sqlObjectType);
-        }
+        return dbRouter.onDemand(true, sqlObjectType);
     }
 
     /**
diff --git a/util/src/main/java/org/killbill/billing/util/nodes/dao/DefaultNodeInfoDao.java b/util/src/main/java/org/killbill/billing/util/nodes/dao/DefaultNodeInfoDao.java
index 892c1db..82d1fd4 100644
--- a/util/src/main/java/org/killbill/billing/util/nodes/dao/DefaultNodeInfoDao.java
+++ b/util/src/main/java/org/killbill/billing/util/nodes/dao/DefaultNodeInfoDao.java
@@ -22,6 +22,7 @@ import java.util.List;
 
 import javax.inject.Named;
 
+import org.killbill.billing.util.entity.dao.DBRouter;
 import org.killbill.clock.Clock;
 import org.skife.jdbi.v2.Handle;
 import org.skife.jdbi.v2.IDBI;
@@ -34,20 +35,18 @@ import static org.killbill.billing.util.glue.IDBISetup.MAIN_RO_IDBI_NAMED;
 
 public class DefaultNodeInfoDao implements NodeInfoDao {
 
-    private final IDBI dbi;
-    private final IDBI roDbi;
+    private final DBRouter<NodeInfoSqlDao> dbRouter;
     private final Clock clock;
 
     @Inject
     public DefaultNodeInfoDao(final IDBI dbi, @Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi, final Clock clock) {
-        this.dbi = dbi;
-        this.roDbi = roDbi;
+        this.dbRouter = new DBRouter<NodeInfoSqlDao>(dbi, roDbi, NodeInfoSqlDao.class);
         this.clock = clock;
     }
 
     @Override
     public void create(final NodeInfoModelDao nodeInfoModelDao) {
-        dbi.inTransaction(new TransactionCallback<Void>() {
+        dbRouter.inTransaction(false, new TransactionCallback<Void>() {
             @Override
             public Void inTransaction(final Handle handle, final TransactionStatus status) throws Exception {
                 final NodeInfoSqlDao sqlDao = handle.attach(NodeInfoSqlDao.class);
@@ -62,7 +61,7 @@ public class DefaultNodeInfoDao implements NodeInfoDao {
 
     @Override
     public void updateNodeInfo(final String nodeName, final String nodeInfo) {
-        dbi.inTransaction(new TransactionCallback<Void>() {
+        dbRouter.inTransaction(false, new TransactionCallback<Void>() {
             @Override
             public Void inTransaction(final Handle handle, final TransactionStatus status) throws Exception {
                 final NodeInfoSqlDao sqlDao = handle.attach(NodeInfoSqlDao.class);
@@ -75,7 +74,7 @@ public class DefaultNodeInfoDao implements NodeInfoDao {
 
     @Override
     public void delete(final String nodeName) {
-        dbi.inTransaction(new TransactionCallback<Void>() {
+        dbRouter.inTransaction(false, new TransactionCallback<Void>() {
             @Override
             public Void inTransaction(final Handle handle, final TransactionStatus status) throws Exception {
                 final NodeInfoSqlDao sqlDao = handle.attach(NodeInfoSqlDao.class);
@@ -87,7 +86,7 @@ public class DefaultNodeInfoDao implements NodeInfoDao {
 
     @Override
     public List<NodeInfoModelDao> getAll() {
-        return roDbi.inTransaction(new TransactionCallback<List<NodeInfoModelDao>>() {
+        return dbRouter.inTransaction(true, new TransactionCallback<List<NodeInfoModelDao>>() {
             @Override
             public List<NodeInfoModelDao> inTransaction(final Handle handle, final TransactionStatus status) throws Exception {
                 final NodeInfoSqlDao sqlDao = handle.attach(NodeInfoSqlDao.class);
@@ -98,7 +97,7 @@ public class DefaultNodeInfoDao implements NodeInfoDao {
 
     @Override
     public NodeInfoModelDao getByNodeName(final String nodeName) {
-        return roDbi.inTransaction(new TransactionCallback<NodeInfoModelDao>() {
+        return dbRouter.inTransaction(true, new TransactionCallback<NodeInfoModelDao>() {
             @Override
             public NodeInfoModelDao inTransaction(final Handle handle, final TransactionStatus status) throws Exception {
                 final NodeInfoSqlDao sqlDao = handle.attach(NodeInfoSqlDao.class);
diff --git a/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java b/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java
index bf5c760..8bd91cc 100644
--- a/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java
+++ b/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java
@@ -29,6 +29,7 @@ import javax.inject.Named;
 import org.apache.shiro.session.Session;
 import org.apache.shiro.session.mgt.eis.CachingSessionDAO;
 import org.killbill.billing.util.UUIDs;
+import org.killbill.billing.util.entity.dao.DBRouter;
 import org.skife.jdbi.v2.IDBI;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -42,27 +43,25 @@ public class JDBCSessionDao extends CachingSessionDAO {
 
     private static final Logger log = LoggerFactory.getLogger(JDBCSessionDao.class);
 
-    private final JDBCSessionSqlDao jdbcSessionSqlDao;
-    private final JDBCSessionSqlDao roJdbcSessionSqlDao;
+    private final DBRouter<JDBCSessionSqlDao> dbRouter;
 
     private final Cache<Serializable, Boolean> noUpdateSessionsCache = CacheBuilder.newBuilder().expireAfterWrite(5, TimeUnit.SECONDS).build();
 
     @Inject
     public JDBCSessionDao(final IDBI dbi, @Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi) {
-        this.jdbcSessionSqlDao = dbi.onDemand(JDBCSessionSqlDao.class);
-        this.roJdbcSessionSqlDao = roDbi.onDemand(JDBCSessionSqlDao.class);
+        this.dbRouter = new DBRouter<JDBCSessionSqlDao>(dbi, roDbi, JDBCSessionSqlDao.class);
     }
 
     @Override
     protected void doUpdate(final Session session) {
         if (shouldUpdateSession(session)) {
-            jdbcSessionSqlDao.update(new SessionModelDao(session));
+            dbRouter.onDemand(false).update(new SessionModelDao(session));
         }
     }
 
     @Override
     protected void doDelete(final Session session) {
-        jdbcSessionSqlDao.delete(new SessionModelDao(session));
+        dbRouter.onDemand(false).delete(new SessionModelDao(session));
     }
 
     @Override
@@ -71,7 +70,7 @@ public class JDBCSessionDao extends CachingSessionDAO {
         // See SessionModelDao#toSimpleSession for why we use toString()
         final String sessionIdAsString = sessionId.toString();
         assignSessionId(session, sessionIdAsString);
-        jdbcSessionSqlDao.create(new SessionModelDao(session));
+        dbRouter.onDemand(false).create(new SessionModelDao(session));
         // Make sure to return a String here as well, or Shiro will cache the Session with a UUID key
         // while it is expecting String
         return sessionIdAsString;
@@ -85,7 +84,7 @@ public class JDBCSessionDao extends CachingSessionDAO {
         }
 
         final String sessionIdString = sessionId.toString();
-        final SessionModelDao sessionModelDao = roJdbcSessionSqlDao.read(sessionIdString);
+        final SessionModelDao sessionModelDao = dbRouter.onDemand(true).read(sessionIdString);
 
         if (sessionModelDao == null) {
             return null;
diff --git a/util/src/main/java/org/killbill/billing/util/validation/dao/DatabaseSchemaDao.java b/util/src/main/java/org/killbill/billing/util/validation/dao/DatabaseSchemaDao.java
index 2922ce2..210feba 100644
--- a/util/src/main/java/org/killbill/billing/util/validation/dao/DatabaseSchemaDao.java
+++ b/util/src/main/java/org/killbill/billing/util/validation/dao/DatabaseSchemaDao.java
@@ -24,6 +24,7 @@ import javax.annotation.Nullable;
 import javax.inject.Named;
 import javax.inject.Singleton;
 
+import org.killbill.billing.util.entity.dao.DBRouter;
 import org.killbill.billing.util.validation.DefaultColumnInfo;
 import org.skife.jdbi.v2.IDBI;
 
@@ -34,11 +35,11 @@ import static org.killbill.billing.util.glue.IDBISetup.MAIN_RO_IDBI_NAMED;
 @Singleton
 public class DatabaseSchemaDao {
 
-    private final DatabaseSchemaSqlDao roDatabaseSchemaSqlDao;
+    private final DBRouter<DatabaseSchemaSqlDao> dbRouter;
 
     @Inject
-    public DatabaseSchemaDao(@Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi) {
-        this.roDatabaseSchemaSqlDao = roDbi.onDemand(DatabaseSchemaSqlDao.class);
+    public DatabaseSchemaDao(final IDBI dbi, @Named(MAIN_RO_IDBI_NAMED) final IDBI roDbi) {
+        this.dbRouter = new DBRouter<DatabaseSchemaSqlDao>(dbi, roDbi, DatabaseSchemaSqlDao.class);
     }
 
     public List<DefaultColumnInfo> getColumnInfoList() {
@@ -46,6 +47,6 @@ public class DatabaseSchemaDao {
     }
 
     public List<DefaultColumnInfo> getColumnInfoList(@Nullable final String schemaName) {
-        return roDatabaseSchemaSqlDao.getSchemaInfo(schemaName);
+        return dbRouter.onDemand(true).getSchemaInfo(schemaName);
     }
 }
diff --git a/util/src/test/java/org/killbill/billing/util/validation/TestValidationManager.java b/util/src/test/java/org/killbill/billing/util/validation/TestValidationManager.java
index 68f8773..1a9761a 100644
--- a/util/src/test/java/org/killbill/billing/util/validation/TestValidationManager.java
+++ b/util/src/test/java/org/killbill/billing/util/validation/TestValidationManager.java
@@ -41,7 +41,7 @@ public class TestValidationManager extends UtilTestSuiteWithEmbeddedDB {
     @BeforeClass(groups = "slow")
     public void beforeClass() throws Exception {
         super.beforeClass();
-        final DatabaseSchemaDao dao = new DatabaseSchemaDao(dbi);
+        final DatabaseSchemaDao dao = new DatabaseSchemaDao(dbi, roDbi);
         vm = new ValidationManager(dao);
         vm.loadSchemaInformation(helper.getDatabaseName());
     }