killbill-aplcache

util: optimize Batch inserts codepath Rely on @GetGeneratedKeys

2/12/2019 4:34:38 AM

Details

diff --git a/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java b/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java
index 2af7799..b0fe792 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java
@@ -363,7 +363,7 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                     existingInvoiceMetadata = existingInvoiceMetadataOrNull;
                 }
 
-                final Collection<InvoiceItemModelDao> invoiceItemsToCreate = new LinkedList<InvoiceItemModelDao>();
+                final List<InvoiceItemModelDao> invoiceItemsToCreate = new LinkedList<InvoiceItemModelDao>();
                 for (final InvoiceModelDao invoiceModelDao : invoices) {
                     invoiceByInvoiceId.put(invoiceModelDao.getId(), invoiceModelDao);
                     final boolean isNotShellInvoice = invoiceIdsReferencedFromItems.remove(invoiceModelDao.getId());
@@ -1160,7 +1160,7 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
     }
 
     private void createInvoiceItemsFromTransaction(final InvoiceItemSqlDao invoiceItemSqlDao,
-                                                   final Iterable<InvoiceItemModelDao> invoiceItemModelDaos,
+                                                   final List<InvoiceItemModelDao> invoiceItemModelDaos,
                                                    final InternalCallContext context) throws EntityPersistenceException, InvoiceApiException {
         for (final InvoiceItemModelDao invoiceItemModelDao : invoiceItemModelDaos) {
             validateInvoiceItemToBeAdjustedIfNeeded(invoiceItemSqlDao, invoiceItemModelDao, context);
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/EntityDaoBase.java b/util/src/main/java/org/killbill/billing/util/entity/dao/EntityDaoBase.java
index 0f4f0ce..14f8902 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/EntityDaoBase.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/EntityDaoBase.java
@@ -43,7 +43,6 @@ import org.killbill.billing.util.entity.dao.DefaultPaginationSqlDaoHelper.Pagina
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterables;
 
 public abstract class EntityDaoBase<M extends EntityModelDao<E>, E extends Entity, U extends BillingExceptionBase> implements EntityDao<M, E, U> {
 
@@ -131,12 +130,14 @@ public abstract class EntityDaoBase<M extends EntityModelDao<E>, E extends Entit
         return (F) transactional.create(entity, context);
     }
 
-    protected <F extends EntityModelDao> void bulkCreate(final EntitySqlDao transactional, final Iterable<F> entities, final InternalCallContext context) {
-        if (Iterables.<F>isEmpty(entities)) {
+    protected <F extends EntityModelDao> void bulkCreate(final EntitySqlDao transactional, final List<F> entities, final InternalCallContext context) {
+        if (entities.isEmpty()) {
             return;
+        } else if (entities.size() == 1) {
+            transactional.create(entities.get(0), context);
+        } else {
+            transactional.create(entities, context);
         }
-
-        transactional.create(entities, context);
     }
 
     protected boolean checkEntityAlreadyExists(final EntitySqlDao<M, E> transactional, final M entity, final InternalCallContext context) {
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java
index 284bb60..6bb5bd6 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java
@@ -53,11 +53,12 @@ public interface EntitySqlDao<M extends EntityModelDao<E>, E extends Entity> ext
                          @SmartBindBean final InternalCallContext context);
 
     @SqlBatch
-    // We don't @GetGeneratedKeys here, as it's unclear if the ordering through the batches is respected by all JDBC drivers
     @BatchChunkSize(1000) // Arbitrary value, just a safety mechanism in case of very large datasets
+    @GetGeneratedKeys(value = LongMapper.class)
     @Audited(ChangeType.INSERT)
-    public void create(@SmartBindBean final Iterable<M> entity,
-                       @SmartBindBean final InternalCallContext context);
+    // Note that you cannot rely on the ordering here
+    public List<Long> create(@SmartBindBean final Iterable<M> entity,
+                             @SmartBindBean final InternalCallContext context);
 
     @SqlQuery
     public M getById(@Bind("id") final String id,
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java
index 99bb03f..0378027 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java
@@ -214,7 +214,7 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, 
         final InternalCallContext contextMaybeWithoutAccountRecordId = retrieveContextFromArguments(args);
         final List<String> entityIds = retrieveEntityIdsFromArguments(method, args);
         // We cannot always infer the TableName from the signature
-        TableName tableName = retrieveTableNameFromArgumentsIfPossible(args);
+        TableName tableName = retrieveTableNameFromArgumentsIfPossible(Arrays.asList(args));
         final ChangeType changeType = auditedAnnotation.value();
         final boolean isBatchQuery = method.getAnnotation(SqlBatch.class) != null;
 
@@ -256,21 +256,24 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, 
                     Preconditions.checkState(tableName == entity.getTableName(), "Entities with different TableName: %s", deletedAndUpdatedEntities);
                 }
             }
-        } else if (changeType == ChangeType.INSERT && !isBatchQuery) {
+        } else if (changeType == ChangeType.INSERT) {
             Preconditions.checkNotNull(tableName, "Insert query should have an EntityModelDao as argument: %s", args);
-            // For non-batch inserts, rely on GetGeneratedKeys
-            Preconditions.checkState(entityIds.size() == 1, "Batch insert not annotated with @SqlBatch?");
-            final long accountRecordId = Long.parseLong(obj.toString());
-            entityRecordIds.add(accountRecordId);
+
+            if (isBatchQuery) {
+                entityRecordIds.addAll((Collection<? extends Long>) obj);
+            } else {
+                entityRecordIds.add((Long) obj);
+            }
 
             // Snowflake
             if (TableName.ACCOUNT.equals(tableName)) {
+                Preconditions.checkState(entityIds.size() == 1, "Bulk insert of accounts isn't supported");
                 // AccountModelDao in practice
                 final TimeZoneAwareEntity accountModelDao = retrieveTimeZoneAwareEntityFromArguments(args);
-                context = internalCallContextFactory.createInternalCallContext(accountModelDao, accountRecordId, contextMaybeWithoutAccountRecordId);
+                context = internalCallContextFactory.createInternalCallContext(accountModelDao, entityRecordIds.get(0), contextMaybeWithoutAccountRecordId);
             }
         } else {
-            // For batch inserts and updates, easiest is to go back to the database
+            // For updates, easiest is to go back to the database
             final List<M> retrievedEntities = sqlDao.getByIds(entityIds, contextMaybeWithoutAccountRecordId);
             printSQLWarnings();
             for (final M entity : retrievedEntities) {
@@ -283,6 +286,7 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, 
                 }
             }
         }
+        Preconditions.checkState(entityIds.size() == entityRecordIds.size(), "SqlDao method has %s as ids but found %s as recordIds", entityIds, entityRecordIds);
 
         // Context validations
         if (context != null) {
@@ -299,8 +303,8 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, 
         if (method.getReturnType().equals(Void.TYPE)) {
             // Return early
             return null;
-        } else if (entityRecordIds.size() > 1) {
-            // Return the raw jdbc response
+        } else if (isBatchQuery) {
+            // Return the raw jdbc response (generated keys)
             return obj;
         } else {
             // PERF: override the return value with the reHydrated entity to avoid an extra 'get' in the transaction,
@@ -395,6 +399,7 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, 
 
                     final Collection<Long> auditTargetRecordIds = insertHistories(reHydratedEntities, changeType, context);
                     // Note: audit entries point to the history record id
+                    Preconditions.checkState(auditTargetRecordIds.size() == entityRecordIds.size(), "Wrong number of auditTargetRecordIds=%s (entityRecordIds=%s)", auditTargetRecordIds, entityRecordIds);
                     insertAudits(auditTargetRecordIds, tableName, changeType, context);
 
                     return reHydratedEntities;
@@ -476,16 +481,20 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, 
         throw new IllegalStateException("No InternalCallContext specified in args: " + Arrays.toString(args));
     }
 
-    private TableName retrieveTableNameFromArgumentsIfPossible(final Object[] args) {
+    private TableName retrieveTableNameFromArgumentsIfPossible(final Iterable args) {
         TableName tableName = null;
         for (final Object arg : args) {
+            TableName argTableName = null;
             if (arg instanceof EntityModelDao) {
-                final TableName argTableName = ((EntityModelDao) arg).getTableName();
-                if (tableName == null) {
-                    tableName = argTableName;
-                } else {
-                    Preconditions.checkState(tableName == argTableName, "SqlDao method with different TableName in args: %s", args);
-                }
+                argTableName = ((EntityModelDao) arg).getTableName();
+            } else if (arg instanceof Iterable) {
+                argTableName = retrieveTableNameFromArgumentsIfPossible((Iterable) arg);
+            }
+
+            if (tableName == null) {
+                tableName = argTableName;
+            } else if (argTableName != null) {
+                Preconditions.checkState(tableName == argTableName, "SqlDao method with different TableName in args: %s", args);
             }
         }
         return tableName;