EntitySqlDaoWrapperInvocationHandler.java

540 lines | 26.094 kB Blame History Raw Download
/*
 * Copyright 2010-2012 Ning, Inc.
 * Copyright 2014-2019 Groupon, Inc
 * Copyright 2014-2019 The Billing Project, LLC
 *
 * The Billing Project licenses this file to you under the Apache License, version 2.0
 * (the "License"); you may not use this file except in compliance with the
 * License.  You may obtain a copy of the License at:
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package org.killbill.billing.util.entity.dao;

import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

import javax.annotation.Nullable;

import org.killbill.billing.callcontext.InternalCallContext;
import org.killbill.billing.util.audit.ChangeType;
import org.killbill.billing.util.cache.Cachable.CacheType;
import org.killbill.billing.util.cache.CacheController;
import org.killbill.billing.util.cache.CacheControllerDispatcher;
import org.killbill.billing.util.callcontext.InternalCallContextFactory;
import org.killbill.billing.util.dao.EntityAudit;
import org.killbill.billing.util.dao.EntityHistoryModelDao;
import org.killbill.billing.util.dao.TableName;
import org.killbill.billing.util.entity.Entity;
import org.killbill.clock.Clock;
import org.killbill.commons.profiling.Profiling;
import org.killbill.commons.profiling.Profiling.WithProfilingCallback;
import org.killbill.commons.profiling.ProfilingFeature.ProfilingFeatureType;
import org.skife.jdbi.v2.Binding;
import org.skife.jdbi.v2.Handle;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.exceptions.DBIException;
import org.skife.jdbi.v2.exceptions.StatementException;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.SqlBatch;
import org.skife.jdbi.v2.unstable.BindIn;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Iterables;

/**
 * Wraps an instance of EntitySqlDao, performing extra work around each method (Sql query)
 *
 * @param <S> EntitySqlDao type of the wrapped instance
 * @param <M> EntityModel associated with S
 * @param <E> Entity associated with M
 */
public class EntitySqlDaoWrapperInvocationHandler<S extends EntitySqlDao<M, E>, M extends EntityModelDao<E>, E extends Entity> implements InvocationHandler {

    private final Logger logger = LoggerFactory.getLogger(EntitySqlDaoWrapperInvocationHandler.class);

    private final Map<String, Annotation[][]> parameterAnnotationsByMethod = new ConcurrentHashMap<String, Annotation[][]>();

    private final Class<S> sqlDaoClass;
    private final S sqlDao;
    private final Handle handle;

    private final CacheControllerDispatcher cacheControllerDispatcher;
    private final Clock clock;
    private final InternalCallContextFactory internalCallContextFactory;
    private final Profiling<Object, Throwable> prof;

    public EntitySqlDaoWrapperInvocationHandler(final Class<S> sqlDaoClass,
                                                final S sqlDao,
                                                final Handle handle,
                                                final Clock clock,
                                                // Special DAO that don't require caching can invoke EntitySqlDaoWrapperInvocationHandler with no caching (e.g NoCachingTenantDao)
                                                @Nullable final CacheControllerDispatcher cacheControllerDispatcher,
                                                final InternalCallContextFactory internalCallContextFactory) {
        this.sqlDaoClass = sqlDaoClass;
        this.sqlDao = sqlDao;
        this.handle = handle;
        this.clock = clock;
        this.cacheControllerDispatcher = cacheControllerDispatcher;
        this.internalCallContextFactory = internalCallContextFactory;
        this.prof = new Profiling<Object, Throwable>();
    }

    @Override
    public Object invoke(final Object proxy, final Method method, final Object[] args) throws Throwable {
        try {
            return prof.executeWithProfiling(ProfilingFeatureType.DAO, getProfilingId(null, method), new WithProfilingCallback<Object, Throwable>() {
                @Override
                public Object execute() throws Throwable {
                    return invokeSafely(method, args);
                }
            });
        } catch (final Throwable t) {
            if (t.getCause() != null && t.getCause().getCause() != null && DBIException.class.isAssignableFrom(t.getCause().getClass())) {
                // Likely a JDBC error, try to extract the SQL statement and JDBI bindings
                if (t.getCause() instanceof StatementException) {
                    final StatementContext statementContext = ((StatementException) t.getCause()).getStatementContext();

                    if (statementContext != null) {
                        // Grumble, we need to rely on the suxxor toString() method as nothing is exposed
                        final Binding binding = statementContext.getBinding();

                        final PreparedStatement statement = statementContext.getStatement();
                        if (statement != null) {
                            // Note: we rely on the JDBC driver to have a sane toString() method...
                            errorDuringTransaction(t.getCause().getCause(), method, statement.toString() + "\n" + binding.toString());
                        } else {
                            errorDuringTransaction(t.getCause().getCause(), method, binding.toString());
                        }

                        // Never reached
                        return null;
                    }
                }

                errorDuringTransaction(t.getCause().getCause(), method);
            } else if (t.getCause() != null) {
                // t is likely not interesting (java.lang.reflect.InvocationTargetException)
                errorDuringTransaction(t.getCause(), method);
            } else {
                errorDuringTransaction(t, method);
            }
        }

        // Never reached
        return null;
    }

    // Nice method name to ease debugging while looking at log files
    private void errorDuringTransaction(final Throwable t, final Method method, final String extraErrorMessage) throws Throwable {
        final StringBuilder errorMessageBuilder = new StringBuilder("Error during transaction for sql entity {} and method {}");
        if (t instanceof SQLException) {
            final SQLException sqlException = (SQLException) t;
            errorMessageBuilder.append(" [SQL DefaultState: ")
                               .append(sqlException.getSQLState())
                               .append(", Vendor Error Code: ")
                               .append(sqlException.getErrorCode())
                               .append("]");
        }
        if (extraErrorMessage != null) {
            // This is usually the SQL statement
            errorMessageBuilder.append("\n").append(extraErrorMessage);
        }
        logger.warn(errorMessageBuilder.toString(), sqlDaoClass, method.getName());

        // This is to avoid throwing an exception wrapped in an UndeclaredThrowableException
        if (!(t instanceof RuntimeException)) {
            throw new RuntimeException(t);
        } else {
            throw t;
        }
    }

    private void errorDuringTransaction(final Throwable t, final Method method) throws Throwable {
        errorDuringTransaction(t, method, null);
    }

    private Object invokeSafely(final Method method, final Object[] args) throws Throwable {
        final Audited auditedAnnotation = method.getAnnotation(Audited.class);

        // This can't be AUDIT'ed and CACHABLE'd at the same time as we only cache 'get'
        if (auditedAnnotation != null) {
            return invokeWithAuditAndHistory(auditedAnnotation, method, args);
        } else {
            return invokeRaw(method, args);
        }
    }

    private Object invokeRaw(final Method method, final Object[] args) throws Throwable {
        return prof.executeWithProfiling(ProfilingFeatureType.DAO_DETAILS, getProfilingId("raw", method), new WithProfilingCallback<Object, Throwable>() {
            @Override
            public Object execute() throws Throwable {
                // Real jdbc call
                final Object result = executeJDBCCall(method, args);
                // This is *almost* the default invocation except that we want to intercept getById calls to populate the caches; the pattern is to always fetch
                // the object after it was created, which means this method is (by pattern) first called right after object creation and contains all the goodies we care
                // about (record_id, account_record_id, object_id, tenant_record_id)
                //
                if (result != null && method.getName().equals("getById")) {
                    populateCacheOnGetByIdInvocation((M) result);
                }
                return result;
            }
        });
    }

    private Object invokeWithAuditAndHistory(final Audited auditedAnnotation, final Method method, final Object[] args) throws Throwable {
        final InternalCallContext contextMaybeWithoutAccountRecordId = retrieveContextFromArguments(args);
        final List<String> entityIds = retrieveEntityIdsFromArguments(method, args);
        // We cannot always infer the TableName from the signature
        TableName tableName = retrieveTableNameFromArgumentsIfPossible(args);
        final ChangeType changeType = auditedAnnotation.value();
        final boolean isBatchQuery = method.getAnnotation(SqlBatch.class) != null;

        // Get the current state before deletion for the history tables
        final Map<Long, M> deletedEntities = new HashMap<Long, M>();
        if (changeType == ChangeType.DELETE) {
            for (final String entityId : entityIds) {
                // TODO Switch to getByIds
                final M entityToBeDeleted = sqlDao.getById(entityId, contextMaybeWithoutAccountRecordId);
                deletedEntities.put(entityToBeDeleted.getRecordId(), entityToBeDeleted);
                printSQLWarnings();
            }
        }

        // Real jdbc call
        final Object obj = prof.executeWithProfiling(ProfilingFeatureType.DAO_DETAILS, getProfilingId("raw", method), new WithProfilingCallback<Object, Throwable>() {
            @Override
            public Object execute() throws Throwable {
                return executeJDBCCall(method, args);
            }
        });

        if (entityIds.isEmpty() ) {
            return obj;
        }

        InternalCallContext context = null;
        // Retrieve record_id(s) for audit and history tables
        final List<Long> entityRecordIds = new LinkedList<Long>();
        if (changeType == ChangeType.DELETE) {
            for (final Long entityRecordId : deletedEntities.keySet()) {
                final M entity = deletedEntities.get(entityRecordId);
                entityRecordIds.add(entityRecordId);
                if (tableName == null) {
                    tableName = entity.getTableName();
                } else {
                    Preconditions.checkState(tableName == entity.getTableName(), "Entities with different TableName: %s", deletedEntities);
                }
            }
        } else if (changeType == ChangeType.INSERT && !isBatchQuery) {
            Preconditions.checkNotNull(tableName, "Insert query should have an EntityModelDao as argument: %s", args);
            // For non-batch inserts, rely on GetGeneratedKeys
            Preconditions.checkState(entityIds.size() == 1, "Batch insert not annotated with @SqlBatch?");
            final long accountRecordId = Long.parseLong(obj.toString());
            entityRecordIds.add(accountRecordId);

            // Snowflake
            if (TableName.ACCOUNT.equals(tableName)) {
                // AccountModelDao in practice
                final TimeZoneAwareEntity accountModelDao = retrieveTimeZoneAwareEntityFromArguments(args);
                context = internalCallContextFactory.createInternalCallContext(accountModelDao, accountRecordId, contextMaybeWithoutAccountRecordId);
            }
        } else {
            for (final String entityId : entityIds) {
                // For batch inserts and updates, easiest is to go back to the database
                // TODO Do we go to the cache here?
                // TODO Switch to getByIds
                final M entity = sqlDao.getById(entityId, contextMaybeWithoutAccountRecordId);
                printSQLWarnings();
                entityRecordIds.add(entity.getRecordId());
                if (tableName == null) {
                    tableName = entity.getTableName();
                } else {
                    Preconditions.checkState(tableName == entity.getTableName(), "Entities with different TableName");
                }
            }
        }

        // Context validations
        if (context != null) {
            // context was already updated, see above (createAccount code path). Just make sure we don't attempt to bulk create
            Preconditions.checkState(entityIds.size() == 1, "Bulk insert of accounts isn't supported");
        } else {
            context = contextMaybeWithoutAccountRecordId;
            final boolean tableWithoutAccountRecordId = tableName == TableName.TENANT || tableName == TableName.TENANT_BROADCASTS || tableName == TableName.TENANT_KVS || tableName == TableName.TAG_DEFINITIONS || tableName == TableName.SERVICE_BRODCASTS || tableName == TableName.NODE_INFOS;
            Preconditions.checkState(context.getAccountRecordId() != null || tableWithoutAccountRecordId,
                                     "accountRecordId should be set for tableName=%s and changeType=%s", tableName, changeType);
        }

        final List<M> reHydratedEntitiesOrNull = updateHistoryAndAudit(entityRecordIds, deletedEntities, tableName, changeType, context);
        if (method.getReturnType().equals(Void.TYPE)) {
            // Return early
            return null;
        } else if (entityRecordIds.size() > 1) {
            // Return the raw jdbc response
            return obj;
        } else {
            // PERF: override the return value with the reHydrated entity to avoid an extra 'get' in the transaction,
            // (see EntityDaoBase#createAndRefresh for an example, but it works for updates as well).
            Preconditions.checkState(entityRecordIds.size() == 1, "Invalid number of entityRecordIds: %s", entityRecordIds);

            if (reHydratedEntitiesOrNull != null) {
                Preconditions.checkState(reHydratedEntitiesOrNull.size() == 1, "Invalid number of entities: %s", reHydratedEntitiesOrNull);
                return Iterables.<M>getFirst(reHydratedEntitiesOrNull, null);
            } else {
                // Updated entity not retrieved yet, we have to go back to the database
                return sqlDao.getByRecordId(entityRecordIds.get(0), context);
            }
        }
    }

    private Object executeJDBCCall(final Method method, final Object[] args) throws IllegalAccessException, InvocationTargetException {
        final Object invoke = method.invoke(sqlDao, args);
        printSQLWarnings();
        return invoke;
    }

    private void printSQLWarnings() {
        if (logger.isDebugEnabled()) {
            try {
                SQLWarning warning = handle.getConnection().getWarnings();
                while (warning != null) {
                    logger.debug("[SQL WARNING] {}", warning);
                    warning = warning.getNextWarning();
                }
                handle.getConnection().clearWarnings();
            } catch (final SQLException e) {
                logger.debug("Error whilst retrieving SQL warnings", e);
            }
        }
    }

    private void populateCacheOnGetByIdInvocation(final M model) {
        populateCaches(cacheControllerDispatcher, model);
    }

    public static void populateCaches(final CacheControllerDispatcher cacheControllerDispatcher, final EntityModelDao model) {
        final CacheController<String, Long> cacheRecordId = cacheControllerDispatcher.getCacheController(CacheType.RECORD_ID);
        cacheRecordId.putIfAbsent(getKey(model.getId().toString(), CacheType.RECORD_ID, model.getTableName()), model.getRecordId());

        final CacheController<String, UUID> cacheObjectId = cacheControllerDispatcher.getCacheController(CacheType.OBJECT_ID);
        cacheObjectId.putIfAbsent(getKey(model.getRecordId().toString(), CacheType.OBJECT_ID, model.getTableName()), model.getId());

        if (model.getTenantRecordId() != null) {
            final CacheController<String, Long> cacheTenantRecordId = cacheControllerDispatcher.getCacheController(CacheType.TENANT_RECORD_ID);
            cacheTenantRecordId.putIfAbsent(getKey(model.getId().toString(), CacheType.TENANT_RECORD_ID, model.getTableName()), model.getTenantRecordId());
        }

        if (model.getAccountRecordId() != null) {
            final CacheController<String, Long> cacheAccountRecordId = cacheControllerDispatcher.getCacheController(CacheType.ACCOUNT_RECORD_ID);
            cacheAccountRecordId.putIfAbsent(getKey(model.getId().toString(), CacheType.ACCOUNT_RECORD_ID, model.getTableName()), model.getAccountRecordId());
        }
    }

    private static String getKey(final String rawKey, final CacheType cacheType, final TableName tableName) {
        return cacheType.isKeyPrefixedWithTableName() ?
               tableName + CacheControllerDispatcher.CACHE_KEY_SEPARATOR + rawKey :
               rawKey;
    }

    // Update history and audit tables.
    // PERF: if the latest entities had to be fetched from the database, return them. Otherwise, return null.
    private List<M> updateHistoryAndAudit(final Collection<Long> entityRecordIds,
                                          final Map<Long, M> deletedEntities,
                                          final TableName tableName,
                                          final ChangeType changeType,
                                          final InternalCallContext context) throws Throwable {
        final Object reHydratedEntitiesOrNull = prof.executeWithProfiling(ProfilingFeatureType.DAO_DETAILS, getProfilingId("history/audit", null), new WithProfilingCallback<Object, Throwable>() {
            @Override
            public List<M> execute() {
                if (tableName.getHistoryTableName() == null) {
                    insertAudits(entityRecordIds, tableName, changeType, context);
                    return null;
                } else {
                    // We'll keep the ordering
                    final Collection<Long> auditTargetRecordIds = new ArrayList<>(entityRecordIds.size());
                    final List<M> reHydratedEntities = new ArrayList<>(entityRecordIds.size());
                    for (final Long entityRecordId : entityRecordIds) {
                        // Make sure to re-hydrate the objects first (especially needed for create calls)
                        // TODO Could we do this in bulk too?
                        final M reHydratedEntityModelDao = MoreObjects.firstNonNull(deletedEntities.get(entityRecordId), sqlDao.getByRecordId(entityRecordId, context));
                        final Long auditTargetRecordId = insertHistory(reHydratedEntityModelDao, changeType, context);
                        auditTargetRecordIds.add(auditTargetRecordId);
                        reHydratedEntities.add(reHydratedEntityModelDao);
                    }
                    // Note: audit entries point to the history record id
                    insertAudits(auditTargetRecordIds, tableName, changeType, context);

                    return reHydratedEntities;
                }
            }
        });
        //noinspection unchecked
        return (List<M>) reHydratedEntitiesOrNull;
    }

    private List<String> retrieveEntityIdsFromArguments(final Method method, final Object[] args) {
        final Annotation[][] parameterAnnotations = getAnnotations(method);

        int i = -1;
        for (final Object arg : args) {
            i++;

            // Assume the first argument of type Entity is our type of Entity (type U here)
            // This is true for e.g. create calls
            if (arg instanceof Entity) {
                return ImmutableList.<String>of(((Entity) arg).getId().toString());
            }

            // For Batch calls, the first argument will be of type List<Entity>
            if (arg instanceof Iterable) {
                final Builder<String> entityIds = extractEntityIdsFromBatchArgument((Iterable) arg);
                if (entityIds != null) {
                    return entityIds.build();
                }
            }

            for (final Annotation annotation : parameterAnnotations[i]) {
                if (arg instanceof String && Bind.class.equals(annotation.annotationType()) && ("id").equals(((Bind) annotation).value())) {
                    return ImmutableList.<String>of((String) arg);
                } else if (arg instanceof Collection && BindIn.class.equals(annotation.annotationType()) && ("ids").equals(((BindIn) annotation).value())) {
                    return ImmutableList.<String>copyOf((Collection) arg);
                }
            }
        }
        return ImmutableList.<String>of();
    }

    private Annotation[][] getAnnotations(final Method method) {
        // Expensive to compute
        final String methodString = method.toString();

        // Method.getParameterAnnotations() generates lots of garbage objects
        Annotation[][] parameterAnnotations = parameterAnnotationsByMethod.get(methodString);
        if (parameterAnnotations == null) {
            parameterAnnotations = method.getParameterAnnotations();
            parameterAnnotationsByMethod.put(methodString, parameterAnnotations);
        }
        return parameterAnnotations;
    }

    private Builder<String> extractEntityIdsFromBatchArgument(final Iterable arg) {
        final Iterator iterator = arg.iterator();
        final Builder<String> entityIds = new Builder<String>();
        while (iterator.hasNext()) {
            final Object object = iterator.next();
            if (!(object instanceof Entity)) {
                // No good - ignore
                return null;
            } else {
                entityIds.add(((Entity) object).getId().toString());
            }
        }

        return entityIds;
    }

    private InternalCallContext retrieveContextFromArguments(final Object[] args) {
        for (final Object arg : args) {
            if (!(arg instanceof InternalCallContext)) {
                continue;
            }
            return (InternalCallContext) arg;
        }
        throw new IllegalStateException("No InternalCallContext specified in args: " + Arrays.toString(args));
    }

    private TableName retrieveTableNameFromArgumentsIfPossible(final Object[] args) {
        TableName tableName = null;
        for (final Object arg : args) {
            if (arg instanceof EntityModelDao) {
                final TableName argTableName = ((EntityModelDao) arg).getTableName();
                if (tableName == null) {
                    tableName = argTableName;
                } else {
                    Preconditions.checkState(tableName == argTableName, "SqlDao method with different TableName in args: %s", args);
                }
            }
        }
        return tableName;
    }

    private TimeZoneAwareEntity retrieveTimeZoneAwareEntityFromArguments(final Object[] args) {
        for (final Object arg : args) {
            if (!(arg instanceof TimeZoneAwareEntity)) {
                continue;
            }
            return (TimeZoneAwareEntity) arg;
        }
        throw new IllegalStateException("TimeZoneAwareEntity should have been found among " + args);
    }

    private Long insertHistory(final M reHydratedEntityModelDao, final ChangeType changeType, final InternalCallContext context) {
        final EntityHistoryModelDao<M, E> history = new EntityHistoryModelDao<M, E>(reHydratedEntityModelDao, reHydratedEntityModelDao.getRecordId(), changeType, null, context.getCreatedDate());
        final Long recordId = sqlDao.addHistoryFromTransaction(history, context);
        printSQLWarnings();
        return recordId;
    }

    // Bulk insert all audit logs for this operation
    private void insertAudits(final Iterable<Long> auditTargetRecordIds,
                              final TableName tableName,
                              final ChangeType changeType,
                              final InternalCallContext context) {
        final TableName destinationTableName = MoreObjects.firstNonNull(tableName.getHistoryTableName(), tableName);

        final Collection<EntityAudit> audits = new LinkedList<EntityAudit>();
        for (final Long auditTargetRecordId : auditTargetRecordIds) {
            final EntityAudit audit = new EntityAudit(destinationTableName, auditTargetRecordId, changeType, context.getCreatedDate());
            audits.add(audit);
        }

        sqlDao.insertAuditsFromTransaction(audits, context);
        printSQLWarnings();
    }

    private String getProfilingId(@Nullable final String prefix, @Nullable final Method method) {
        final StringBuilder stringBuilder = new StringBuilder().append(sqlDaoClass.getSimpleName());

        if (prefix != null) {
            stringBuilder.append(" (")
                         .append(prefix)
                         .append(")");
        }

        if (method != null) {
            stringBuilder.append(": ").append(method.getName());
        }

        return stringBuilder.toString();
    }
}