killbill-aplcache

util: limit the number of queries in EntitySqlDaoWrapperInvocationHandler Revisit

2/6/2019 8:13:43 AM

Details

diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
index 1d459ec..3515f02 100644
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -5,6 +5,7 @@
     <inspection_tool class="CheckTagEmptyBody" enabled="false" level="WARNING" enabled_by_default="false" />
     <inspection_tool class="ConfusingOctalEscape" enabled="true" level="WARNING" enabled_by_default="true" />
     <inspection_tool class="ControlFlowStatementWithoutBraces" enabled="true" level="WARNING" enabled_by_default="true" />
+    <inspection_tool class="Convert2Diamond" enabled="false" level="WARNING" enabled_by_default="false" />
     <inspection_tool class="Convert2Lambda" enabled="false" level="WARNING" enabled_by_default="false" />
     <inspection_tool class="CyclicClassDependency" enabled="true" level="WARNING" enabled_by_default="true" />
     <inspection_tool class="FieldMayBeFinal" enabled="true" level="WARNING" enabled_by_default="true" />
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java
index 447b76b..974e23b 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDao.java
@@ -31,6 +31,7 @@ import org.killbill.commons.jdbi.binder.SmartBindBean;
 import org.killbill.commons.jdbi.statement.SmartFetchSize;
 import org.killbill.commons.jdbi.template.KillBillSqlDaoStringTemplate;
 import org.skife.jdbi.v2.sqlobject.Bind;
+import org.skife.jdbi.v2.sqlobject.GetGeneratedKeys;
 import org.skife.jdbi.v2.sqlobject.SqlBatch;
 import org.skife.jdbi.v2.sqlobject.SqlQuery;
 import org.skife.jdbi.v2.sqlobject.SqlUpdate;
@@ -38,16 +39,19 @@ import org.skife.jdbi.v2.sqlobject.customizers.BatchChunkSize;
 import org.skife.jdbi.v2.sqlobject.customizers.Define;
 import org.skife.jdbi.v2.sqlobject.mixins.CloseMe;
 import org.skife.jdbi.v2.sqlobject.mixins.Transactional;
+import org.skife.jdbi.v2.util.LongMapper;
 
 @KillBillSqlDaoStringTemplate
 public interface EntitySqlDao<M extends EntityModelDao<E>, E extends Entity> extends AuditSqlDao, HistorySqlDao<M, E>, Transactional<EntitySqlDao<M, E>>, CloseMe {
 
     @SqlUpdate
+    @GetGeneratedKeys(value = LongMapper.class)
     @Audited(ChangeType.INSERT)
     public Object create(@SmartBindBean final M entity,
                          @SmartBindBean final InternalCallContext context);
 
     @SqlBatch
+    // We don't @GetGeneratedKeys here, as it's unclear if the ordering through the batches is respected by all JDBC drivers
     @BatchChunkSize(1000) // Arbitrary value, just a safety mechanism in case of very large datasets
     @Audited(ChangeType.INSERT)
     public void create(@SmartBindBean final Iterable<M> entity,
diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java
index 21c7d38..63a99ff 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/EntitySqlDaoWrapperInvocationHandler.java
@@ -25,10 +25,11 @@ import java.lang.reflect.Method;
 import java.sql.PreparedStatement;
 import java.sql.SQLException;
 import java.sql.SQLWarning;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.Iterator;
-import java.util.LinkedHashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -210,16 +211,20 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, 
     }
 
     private Object invokeWithAuditAndHistory(final Audited auditedAnnotation, final Method method, final Object[] args) throws Throwable {
-        final InternalCallContext context = retrieveContextFromArguments(args);
+        final InternalCallContext contextMaybeWithoutAccountRecordId = retrieveContextFromArguments(args);
         final List<String> entityIds = retrieveEntityIdsFromArguments(method, args);
-
+        // We cannot always infer the TableName from the signature
+        TableName tableName = retrieveTableNameFromArgumentsIfPossible(args);
         final ChangeType changeType = auditedAnnotation.value();
+        final boolean isBatchQuery = method.getAnnotation(SqlBatch.class) != null;
 
         // Get the current state before deletion for the history tables
-        final Map<String, M> deletedEntities = new HashMap<String, M>();
+        final Map<Long, M> deletedEntities = new HashMap<Long, M>();
         if (changeType == ChangeType.DELETE) {
             for (final String entityId : entityIds) {
-                deletedEntities.put(entityId, sqlDao.getById(entityId, context));
+                // TODO Switch to getByIds
+                final M entityToBeDeleted = sqlDao.getById(entityId, contextMaybeWithoutAccountRecordId);
+                deletedEntities.put(entityToBeDeleted.getRecordId(), entityToBeDeleted);
                 printSQLWarnings();
             }
         }
@@ -236,11 +241,79 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, 
             return obj;
         }
 
-        // PERF: override the return value with the reHydrated entity to avoid an extra 'get' in the transaction,
-        // (see EntityDaoBase#createAndRefresh for an example, but it works for updates as well).
-        final List<M> ms = updateHistoryAndAudit(entityIds, deletedEntities, changeType, context);
-        final boolean isBatchQuery = method.getAnnotation(SqlBatch.class) != null;
-        return isBatchQuery ? ms : Iterables.<M>getFirst(ms, null);
+        InternalCallContext context = null;
+        // Retrieve record_id(s) for audit and history tables
+        final List<Long> entityRecordIds = new LinkedList<Long>();
+        if (changeType == ChangeType.DELETE) {
+            for (final Long entityRecordId : deletedEntities.keySet()) {
+                final M entity = deletedEntities.get(entityRecordId);
+                entityRecordIds.add(entityRecordId);
+                if (tableName == null) {
+                    tableName = entity.getTableName();
+                } else {
+                    Preconditions.checkState(tableName == entity.getTableName(), "Entities with different TableName: %s", deletedEntities);
+                }
+            }
+        } else if (changeType == ChangeType.INSERT && !isBatchQuery) {
+            Preconditions.checkNotNull(tableName, "Insert query should have an EntityModelDao as argument: %s", args);
+            // For non-batch inserts, rely on GetGeneratedKeys
+            Preconditions.checkState(entityIds.size() == 1, "Batch insert not annotated with @SqlBatch?");
+            final long accountRecordId = Long.parseLong(obj.toString());
+            entityRecordIds.add(accountRecordId);
+
+            // Snowflake
+            if (TableName.ACCOUNT.equals(tableName)) {
+                // AccountModelDao in practice
+                final TimeZoneAwareEntity accountModelDao = retrieveTimeZoneAwareEntityFromArguments(args);
+                context = internalCallContextFactory.createInternalCallContext(accountModelDao, accountRecordId, contextMaybeWithoutAccountRecordId);
+            }
+        } else {
+            for (final String entityId : entityIds) {
+                // For batch inserts and updates, easiest is to go back to the database
+                // TODO Do we go to the cache here?
+                // TODO Switch to getByIds
+                final M entity = sqlDao.getById(entityId, contextMaybeWithoutAccountRecordId);
+                printSQLWarnings();
+                entityRecordIds.add(entity.getRecordId());
+                if (tableName == null) {
+                    tableName = entity.getTableName();
+                } else {
+                    Preconditions.checkState(tableName == entity.getTableName(), "Entities with different TableName");
+                }
+            }
+        }
+
+        // Context validations
+        if (context != null) {
+            // context was already updated, see above (createAccount code path). Just make sure we don't attempt to bulk create
+            Preconditions.checkState(entityIds.size() == 1, "Bulk insert of accounts isn't supported");
+        } else {
+            context = contextMaybeWithoutAccountRecordId;
+            final boolean tableWithoutAccountRecordId = tableName == TableName.TENANT || tableName == TableName.TENANT_BROADCASTS || tableName == TableName.TENANT_KVS || tableName == TableName.TAG_DEFINITIONS || tableName == TableName.SERVICE_BRODCASTS || tableName == TableName.NODE_INFOS;
+            Preconditions.checkState(context.getAccountRecordId() != null || tableWithoutAccountRecordId,
+                                     "accountRecordId should be set for tableName=%s and changeType=%s", tableName, changeType);
+        }
+
+        final List<M> reHydratedEntitiesOrNull = updateHistoryAndAudit(entityRecordIds, deletedEntities, tableName, changeType, context);
+        if (method.getReturnType().equals(Void.TYPE)) {
+            // Return early
+            return null;
+        } else if (entityRecordIds.size() > 1) {
+            // Return the raw jdbc response
+            return obj;
+        } else {
+            // PERF: override the return value with the reHydrated entity to avoid an extra 'get' in the transaction,
+            // (see EntityDaoBase#createAndRefresh for an example, but it works for updates as well).
+            Preconditions.checkState(entityRecordIds.size() == 1, "Invalid number of entityRecordIds: %s", entityRecordIds);
+
+            if (reHydratedEntitiesOrNull != null) {
+                Preconditions.checkState(reHydratedEntitiesOrNull.size() == 1, "Invalid number of entities: %s", reHydratedEntitiesOrNull);
+                return Iterables.<M>getFirst(reHydratedEntitiesOrNull, null);
+            } else {
+                // Updated entity not retrieved yet, we have to go back to the database
+                return sqlDao.getByRecordId(entityRecordIds.get(0), context);
+            }
+        }
     }
 
     private Object executeJDBCCall(final Method method, final Object[] args) throws IllegalAccessException, InvocationTargetException {
@@ -292,51 +365,40 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, 
                rawKey;
     }
 
-    private List<M> updateHistoryAndAudit(final Collection<String> entityIds,
-                                          final Map<String, M> deletedEntities,
+    // Update history and audit tables.
+    // PERF: if the latest entities had to be fetched from the database, return them. Otherwise, return null.
+    private List<M> updateHistoryAndAudit(final Collection<Long> entityRecordIds,
+                                          final Map<Long, M> deletedEntities,
+                                          final TableName tableName,
                                           final ChangeType changeType,
                                           final InternalCallContext context) throws Throwable {
-        final Object reHydratedEntities = prof.executeWithProfiling(ProfilingFeatureType.DAO_DETAILS, getProfilingId("history/audit", null), new WithProfilingCallback<Object, Throwable>() {
+        final Object reHydratedEntitiesOrNull = prof.executeWithProfiling(ProfilingFeatureType.DAO_DETAILS, getProfilingId("history/audit", null), new WithProfilingCallback<Object, Throwable>() {
             @Override
             public List<M> execute() {
-                TableName tableName = null;
-                // We'll keep the ordering
-                final Map<M, Long> reHydratedEntityModelDaoAndHistoryRecordIds = new LinkedHashMap<M, Long>(entityIds.size());
-                for (final String entityId : entityIds) {
-                    final M reHydratedEntity;
-                    if (changeType == ChangeType.DELETE) {
-                        reHydratedEntity = deletedEntities.get(entityId);
-                    } else {
-                        // TODO Could we avoid this query?
-                        reHydratedEntity = sqlDao.getById(entityId, context);
-                        printSQLWarnings();
-                    }
-                    Preconditions.checkNotNull(reHydratedEntity, "reHydratedEntity cannot be null");
-                    final Long entityRecordId = reHydratedEntity.getRecordId();
-                    if (tableName == null) {
-                        tableName = reHydratedEntity.getTableName();
-                    }
-
-                    // Note: audit entries point to the history record id
-                    final Long historyRecordId;
-                    if (tableName.getHistoryTableName() != null) {
+                if (tableName.getHistoryTableName() == null) {
+                    insertAudits(entityRecordIds, tableName, changeType, context);
+                    return null;
+                } else {
+                    // We'll keep the ordering
+                    final Collection<Long> auditTargetRecordIds = new ArrayList<>(entityRecordIds.size());
+                    final List<M> reHydratedEntities = new ArrayList<>(entityRecordIds.size());
+                    for (final Long entityRecordId : entityRecordIds) {
+                        // Make sure to re-hydrate the objects first (especially needed for create calls)
                         // TODO Could we do this in bulk too?
-                        historyRecordId = insertHistory(entityRecordId, reHydratedEntity, changeType, context);
-                    } else {
-                        historyRecordId = entityRecordId;
+                        final M reHydratedEntityModelDao = MoreObjects.firstNonNull(deletedEntities.get(entityRecordId), sqlDao.getByRecordId(entityRecordId, context));
+                        final Long auditTargetRecordId = insertHistory(reHydratedEntityModelDao, changeType, context);
+                        auditTargetRecordIds.add(auditTargetRecordId);
+                        reHydratedEntities.add(reHydratedEntityModelDao);
                     }
+                    // Note: audit entries point to the history record id
+                    insertAudits(auditTargetRecordIds, tableName, changeType, context);
 
-                    reHydratedEntityModelDaoAndHistoryRecordIds.put(reHydratedEntity, historyRecordId);
+                    return reHydratedEntities;
                 }
-
-                // Make sure to re-hydrate the objects first (especially needed for create calls)
-                insertAudits(tableName, reHydratedEntityModelDaoAndHistoryRecordIds, changeType, context);
-
-                return ImmutableList.<M>copyOf(reHydratedEntityModelDaoAndHistoryRecordIds.keySet());
             }
         });
         //noinspection unchecked
-        return (List<M>) reHydratedEntities;
+        return (List<M>) reHydratedEntitiesOrNull;
     }
 
     private List<String> retrieveEntityIdsFromArguments(final Method method, final Object[] args) {
@@ -407,42 +469,51 @@ public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, 
             }
             return (InternalCallContext) arg;
         }
-        return null;
+        throw new IllegalStateException("No InternalCallContext specified in args: " + Arrays.toString(args));
+    }
+
+    private TableName retrieveTableNameFromArgumentsIfPossible(final Object[] args) {
+        TableName tableName = null;
+        for (final Object arg : args) {
+            if (arg instanceof EntityModelDao) {
+                final TableName argTableName = ((EntityModelDao) arg).getTableName();
+                if (tableName == null) {
+                    tableName = argTableName;
+                } else {
+                    Preconditions.checkState(tableName == argTableName, "SqlDao method with different TableName in args: %s", args);
+                }
+            }
+        }
+        return tableName;
     }
 
-    private Long insertHistory(final Long entityRecordId, final M entityModelDao, final ChangeType changeType, final InternalCallContext context) {
-        final EntityHistoryModelDao<M, E> history = new EntityHistoryModelDao<M, E>(entityModelDao, entityRecordId, changeType, null, context.getCreatedDate());
+    private TimeZoneAwareEntity retrieveTimeZoneAwareEntityFromArguments(final Object[] args) {
+        for (final Object arg : args) {
+            if (!(arg instanceof TimeZoneAwareEntity)) {
+                continue;
+            }
+            return (TimeZoneAwareEntity) arg;
+        }
+        throw new IllegalStateException("TimeZoneAwareEntity should have been found among " + args);
+    }
+
+    private Long insertHistory(final M reHydratedEntityModelDao, final ChangeType changeType, final InternalCallContext context) {
+        final EntityHistoryModelDao<M, E> history = new EntityHistoryModelDao<M, E>(reHydratedEntityModelDao, reHydratedEntityModelDao.getRecordId(), changeType, null, context.getCreatedDate());
         final Long recordId = sqlDao.addHistoryFromTransaction(history, context);
         printSQLWarnings();
         return recordId;
     }
 
-    private void insertAudits(final TableName tableName,
-                              final Map<M, Long> entityModelDaoAndHistoryRecordIds,
+    // Bulk insert all audit logs for this operation
+    private void insertAudits(final Iterable<Long> auditTargetRecordIds,
+                              final TableName tableName,
                               final ChangeType changeType,
-                              final InternalCallContext contextMaybeWithoutAccountRecordId) {
+                              final InternalCallContext context) {
         final TableName destinationTableName = MoreObjects.firstNonNull(tableName.getHistoryTableName(), tableName);
 
-        final InternalCallContext context;
-        if (TableName.ACCOUNT.equals(tableName) && ChangeType.INSERT.equals(changeType)) {
-            Preconditions.checkState(entityModelDaoAndHistoryRecordIds.size() == 1, "Bulk insert of accounts isn't supported");
-            final M entityModelDao = Iterables.<M>getFirst(entityModelDaoAndHistoryRecordIds.keySet(), null);
-            // AccountModelDao in practice
-            final TimeZoneAwareEntity accountModelDao = (TimeZoneAwareEntity) entityModelDao;
-            context = internalCallContextFactory.createInternalCallContext(accountModelDao, entityModelDao.getRecordId(), contextMaybeWithoutAccountRecordId);
-        } else if (contextMaybeWithoutAccountRecordId.getAccountRecordId() == null) {
-            Preconditions.checkState(tableName == TableName.TENANT || tableName == TableName.TENANT_BROADCASTS || tableName == TableName.TENANT_KVS || tableName == TableName.TAG_DEFINITIONS || tableName == TableName.SERVICE_BRODCASTS || tableName == TableName.NODE_INFOS,
-                                     "accountRecordId should be set for tableName=%s and changeType=%s", tableName, changeType);
-            context = contextMaybeWithoutAccountRecordId;
-
-        } else {
-            context = contextMaybeWithoutAccountRecordId;
-        }
-
         final Collection<EntityAudit> audits = new LinkedList<EntityAudit>();
-        for (final M entityModelDao : entityModelDaoAndHistoryRecordIds.keySet()) {
-            final Long targetRecordId = entityModelDaoAndHistoryRecordIds.get(entityModelDao);
-            final EntityAudit audit = new EntityAudit(destinationTableName, targetRecordId, changeType, context.getCreatedDate());
+        for (final Long auditTargetRecordId : auditTargetRecordIds) {
+            final EntityAudit audit = new EntityAudit(destinationTableName, auditTargetRecordId, changeType, context.getCreatedDate());
             audits.add(audit);
         }