diff --git a/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java b/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java
index 2af7799..b0fe792 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java
@@ -363,7 +363,7 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
existingInvoiceMetadata = existingInvoiceMetadataOrNull;
}
- final Collection<InvoiceItemModelDao> invoiceItemsToCreate = new LinkedList<InvoiceItemModelDao>();
+ final List<InvoiceItemModelDao> invoiceItemsToCreate = new LinkedList<InvoiceItemModelDao>();
for (final InvoiceModelDao invoiceModelDao : invoices) {
invoiceByInvoiceId.put(invoiceModelDao.getId(), invoiceModelDao);
final boolean isNotShellInvoice = invoiceIdsReferencedFromItems.remove(invoiceModelDao.getId());
@@ -1160,7 +1160,7 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
}
private void createInvoiceItemsFromTransaction(final InvoiceItemSqlDao invoiceItemSqlDao,
- final Iterable<InvoiceItemModelDao> invoiceItemModelDaos,
+ final List<InvoiceItemModelDao> invoiceItemModelDaos,
final InternalCallContext context) throws EntityPersistenceException, InvoiceApiException {
for (final InvoiceItemModelDao invoiceItemModelDao : invoiceItemModelDaos) {
validateInvoiceItemToBeAdjustedIfNeeded(invoiceItemSqlDao, invoiceItemModelDao, context);
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/EntityDaoBase.java b/util/src/main/java/org/killbill/billing/util/entity/dao/EntityDaoBase.java
index 0f4f0ce..14f8902 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/EntityDaoBase.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/EntityDaoBase.java
@@ -43,7 +43,6 @@ import org.killbill.billing.util.entity.dao.DefaultPaginationSqlDaoHelper.Pagina
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterables;
public abstract class EntityDaoBase<M extends EntityModelDao<E>, E extends Entity, U extends BillingExceptionBase> implements EntityDao<M, E, U> {
@@ -131,12 +130,14 @@ public abstract class EntityDaoBase<M extends EntityModelDao<E>, E extends Entit
return (F) transactional.create(entity, context);
}
- protected <F extends EntityModelDao> void bulkCreate(final EntitySqlDao transactional, final Iterable<F> entities, final InternalCallContext context) {
- if (Iterables.<F>isEmpty(entities)) {
+ protected <F extends EntityModelDao> void bulkCreate(final EntitySqlDao transactional, final List<F> entities, final InternalCallContext context) {
+ if (entities.isEmpty()) {
return;
+ } else if (entities.size() == 1) {
+ transactional.create(entities.get(0), context);
+ } else {
+ transactional.create(entities, context);
}
-
- transactional.create(entities, context);
}
protected boolean checkEntityAlreadyExists(final EntitySqlDao<M, E> transactional, final M entity, final InternalCallContext context) {
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java
index 284bb60..6bb5bd6 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java
@@ -53,11 +53,12 @@ public interface EntitySqlDao<M extends EntityModelDao<E>, E extends Entity> ext
@SmartBindBean final InternalCallContext context);
@SqlBatch
- // We don't @GetGeneratedKeys here, as it's unclear if the ordering through the batches is respected by all JDBC drivers
@BatchChunkSize(1000) // Arbitrary value, just a safety mechanism in case of very large datasets
+ @GetGeneratedKeys(value = LongMapper.class)
@Audited(ChangeType.INSERT)
- public void create(@SmartBindBean final Iterable<M> entity,
- @SmartBindBean final InternalCallContext context);
+ // Note that you cannot rely on the ordering here
+ public List<Long> create(@SmartBindBean final Iterable<M> entity,
+ @SmartBindBean final InternalCallContext context);
@SqlQuery
public M getById(@Bind("id") final String id,
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java
index 99bb03f..0378027 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java
@@ -214,7 +214,7 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>,
final InternalCallContext contextMaybeWithoutAccountRecordId = retrieveContextFromArguments(args);
final List<String> entityIds = retrieveEntityIdsFromArguments(method, args);
// We cannot always infer the TableName from the signature
- TableName tableName = retrieveTableNameFromArgumentsIfPossible(args);
+ TableName tableName = retrieveTableNameFromArgumentsIfPossible(Arrays.asList(args));
final ChangeType changeType = auditedAnnotation.value();
final boolean isBatchQuery = method.getAnnotation(SqlBatch.class) != null;
@@ -256,21 +256,24 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>,
Preconditions.checkState(tableName == entity.getTableName(), "Entities with different TableName: %s", deletedAndUpdatedEntities);
}
}
- } else if (changeType == ChangeType.INSERT && !isBatchQuery) {
+ } else if (changeType == ChangeType.INSERT) {
Preconditions.checkNotNull(tableName, "Insert query should have an EntityModelDao as argument: %s", args);
- // For non-batch inserts, rely on GetGeneratedKeys
- Preconditions.checkState(entityIds.size() == 1, "Batch insert not annotated with @SqlBatch?");
- final long accountRecordId = Long.parseLong(obj.toString());
- entityRecordIds.add(accountRecordId);
+
+ if (isBatchQuery) {
+ entityRecordIds.addAll((Collection<? extends Long>) obj);
+ } else {
+ entityRecordIds.add((Long) obj);
+ }
// Snowflake
if (TableName.ACCOUNT.equals(tableName)) {
+ Preconditions.checkState(entityIds.size() == 1, "Bulk insert of accounts isn't supported");
// AccountModelDao in practice
final TimeZoneAwareEntity accountModelDao = retrieveTimeZoneAwareEntityFromArguments(args);
- context = internalCallContextFactory.createInternalCallContext(accountModelDao, accountRecordId, contextMaybeWithoutAccountRecordId);
+ context = internalCallContextFactory.createInternalCallContext(accountModelDao, entityRecordIds.get(0), contextMaybeWithoutAccountRecordId);
}
} else {
- // For batch inserts and updates, easiest is to go back to the database
+ // For updates, easiest is to go back to the database
final List<M> retrievedEntities = sqlDao.getByIds(entityIds, contextMaybeWithoutAccountRecordId);
printSQLWarnings();
for (final M entity : retrievedEntities) {
@@ -283,6 +286,7 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>,
}
}
}
+ Preconditions.checkState(entityIds.size() == entityRecordIds.size(), "SqlDao method has %s as ids but found %s as recordIds", entityIds, entityRecordIds);
// Context validations
if (context != null) {
@@ -299,8 +303,8 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>,
if (method.getReturnType().equals(Void.TYPE)) {
// Return early
return null;
- } else if (entityRecordIds.size() > 1) {
- // Return the raw jdbc response
+ } else if (isBatchQuery) {
+ // Return the raw jdbc response (generated keys)
return obj;
} else {
// PERF: override the return value with the reHydrated entity to avoid an extra 'get' in the transaction,
@@ -395,6 +399,7 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>,
final Collection<Long> auditTargetRecordIds = insertHistories(reHydratedEntities, changeType, context);
// Note: audit entries point to the history record id
+ Preconditions.checkState(auditTargetRecordIds.size() == entityRecordIds.size(), "Wrong number of auditTargetRecordIds=%s (entityRecordIds=%s)", auditTargetRecordIds, entityRecordIds);
insertAudits(auditTargetRecordIds, tableName, changeType, context);
return reHydratedEntities;
@@ -476,16 +481,20 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>,
throw new IllegalStateException("No InternalCallContext specified in args: " + Arrays.toString(args));
}
- private TableName retrieveTableNameFromArgumentsIfPossible(final Object[] args) {
+ private TableName retrieveTableNameFromArgumentsIfPossible(final Iterable args) {
TableName tableName = null;
for (final Object arg : args) {
+ TableName argTableName = null;
if (arg instanceof EntityModelDao) {
- final TableName argTableName = ((EntityModelDao) arg).getTableName();
- if (tableName == null) {
- tableName = argTableName;
- } else {
- Preconditions.checkState(tableName == argTableName, "SqlDao method with different TableName in args: %s", args);
- }
+ argTableName = ((EntityModelDao) arg).getTableName();
+ } else if (arg instanceof Iterable) {
+ argTableName = retrieveTableNameFromArgumentsIfPossible((Iterable) arg);
+ }
+
+ if (tableName == null) {
+ tableName = argTableName;
+ } else if (argTableName != null) {
+ Preconditions.checkState(tableName == argTableName, "SqlDao method with different TableName in args: %s", args);
}
}
return tableName;