killbill-memoizeit

invoice: add more sanity checks * Allow only EXTERNAL_CHARGE,

2/5/2015 4:09:52 PM

Details

diff --git a/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java b/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java
index 23e49be..5add0ff 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/dao/DefaultInvoiceDao.java
@@ -19,6 +19,7 @@
 package org.killbill.billing.invoice.dao;
 
 import java.math.BigDecimal;
+import java.util.Collection;
 import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
@@ -33,6 +34,7 @@ import org.killbill.billing.ErrorCode;
 import org.killbill.billing.callcontext.InternalCallContext;
 import org.killbill.billing.callcontext.InternalTenantContext;
 import org.killbill.billing.catalog.api.Currency;
+import org.killbill.billing.entity.EntityPersistenceException;
 import org.killbill.billing.invoice.api.Invoice;
 import org.killbill.billing.invoice.api.InvoiceApiException;
 import org.killbill.billing.invoice.api.InvoiceItemType;
@@ -76,6 +78,12 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                                                                                             }
                                                                                         });
 
+    private static final Collection<InvoiceItemType> INVOICE_ITEM_TYPES_ADJUSTABLE = ImmutableList.<InvoiceItemType>of(InvoiceItemType.EXTERNAL_CHARGE,
+                                                                                                                       InvoiceItemType.FIXED,
+                                                                                                                       InvoiceItemType.RECURRING,
+                                                                                                                       InvoiceItemType.TAX,
+                                                                                                                       InvoiceItemType.USAGE);
+
     private final NextBillingDatePoster nextBillingDatePoster;
     private final PersistentBus eventBus;
     private final InternalCallContextFactory internalCallContextFactory;
@@ -221,7 +229,7 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                     // Create the invoice items
                     final InvoiceItemSqlDao transInvoiceItemSqlDao = entitySqlDaoWrapperFactory.become(InvoiceItemSqlDao.class);
                     for (final InvoiceItemModelDao invoiceItemModelDao : invoiceItems) {
-                        transInvoiceItemSqlDao.create(invoiceItemModelDao, context);
+                        createInvoiceItemFromTransaction(transInvoiceItemSqlDao, invoiceItemModelDao, context);
                     }
 
                     cbaDao.addCBAComplexityFromTransaction(invoice, entitySqlDaoWrapperFactory, context);
@@ -254,7 +262,7 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                     // Create the invoice items if needed
                     for (final InvoiceItemModelDao invoiceItemModelDao : invoiceModelDao.getInvoiceItems()) {
                         if (transInvoiceItemSqlDao.getById(invoiceItemModelDao.getId().toString(), context) == null) {
-                            transInvoiceItemSqlDao.create(invoiceItemModelDao, context);
+                            createInvoiceItemFromTransaction(transInvoiceItemSqlDao, invoiceItemModelDao, context);
                             createdInvoiceItems.add(transInvoiceItemSqlDao.getById(invoiceItemModelDao.getId().toString(), context));
                             madeChanges = true;
                         }
@@ -455,7 +463,7 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                         final InvoiceItemModelDao adjItem = new InvoiceItemModelDao(context.getCreatedDate(), InvoiceItemType.REFUND_ADJ, invoice.getId(), invoice.getAccountId(),
                                                                                     null, null, null, null, null, null, context.getCreatedDate().toLocalDate(), null,
                                                                                     requestedPositiveAmountToAdjust.negate(), null, invoice.getCurrency(), null);
-                        transInvoiceItemDao.create(adjItem, context);
+                        createInvoiceItemFromTransaction(transInvoiceItemDao, adjItem, context);
                         invoice.addInvoiceItem(adjItem);
                     }
                 } else if (isInvoiceAdjusted) {
@@ -465,7 +473,8 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                         final InvoiceItemModelDao item = invoiceDaoHelper.createAdjustmentItem(entitySqlDaoWrapperFactory, invoice.getId(), invoiceItemId, adjAmount,
                                                                                                invoice.getCurrency(), context.getCreatedDate().toLocalDate(),
                                                                                                context);
-                        transInvoiceItemDao.create(item, context);
+
+                        createInvoiceItemFromTransaction(transInvoiceItemDao, item, context);
                         invoice.addInvoiceItem(item);
                     }
                 }
@@ -679,7 +688,7 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                 final InvoiceItemModelDao cbaAdjItem = new InvoiceItemModelDao(context.getCreatedDate(), InvoiceItemType.CBA_ADJ, invoice.getId(), invoice.getAccountId(),
                                                                                null, null, null, null, null, null, context.getCreatedDate().toLocalDate(),
                                                                                null, cbaItem.getAmount().negate(), null, cbaItem.getCurrency(), cbaItem.getId());
-                invoiceItemSqlDao.create(cbaAdjItem, context);
+                createInvoiceItemFromTransaction(invoiceItemSqlDao, cbaAdjItem, context);
 
                 // Verify the final invoice balance is not negative
                 invoiceDaoHelper.populateChildren(invoice, entitySqlDaoWrapperFactory, context);
@@ -736,7 +745,7 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                                                                                            invoice.getAccountId(), null, null, null, null, null, null,
                                                                                            context.getCreatedDate().toLocalDate(), null,
                                                                                            positiveCBAAdjItemAmount, null, cbaItem.getCurrency(), cbaItem.getId());
-                        invoiceItemSqlDao.create(nextCBAAdjItem, context);
+                        createInvoiceItemFromTransaction(invoiceItemSqlDao, nextCBAAdjItem, context);
                         if (positiveRemainderToAdjust.compareTo(BigDecimal.ZERO) == 0) {
                             break;
                         }
@@ -778,4 +787,25 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
             log.warn("Failed to post adjustment event for invoice " + invoiceId, e);
         }
     }
+
+    private void createInvoiceItemFromTransaction(final InvoiceItemSqlDao invoiceItemSqlDao, final InvoiceItemModelDao invoiceItemModelDao, final InternalCallContext context) throws EntityPersistenceException, InvoiceApiException {
+        // There is no efficient way to retrieve an invoice item given an ID today (and invoice plugins can put item adjustments
+        // on a different invoice than the original item), so it's easier to do the check in the DAO rather than in the API layer
+        // See also https://github.com/killbill/killbill/issues/7
+        if (InvoiceItemType.ITEM_ADJ.equals(invoiceItemModelDao.getType())) {
+            validateInvoiceItemToBeAdjusted(invoiceItemSqlDao, invoiceItemModelDao, context);
+        }
+
+        invoiceItemSqlDao.create(invoiceItemModelDao, context);
+    }
+
+    private void validateInvoiceItemToBeAdjusted(final InvoiceItemSqlDao invoiceItemSqlDao, final InvoiceItemModelDao invoiceItemModelDao, final InternalCallContext context) throws InvoiceApiException {
+        Preconditions.checkNotNull(invoiceItemModelDao.getLinkedItemId(), "LinkedItemId cannot be null for ITEM_ADJ item: " + invoiceItemModelDao);
+        // Note: this assumes the linked item has already been created in or prior to the transaction, which should almost always be the case
+        // (unless some whacky plugin creates an out-of-order item adjustment on a subsequent external charge)
+        final InvoiceItemModelDao invoiceItemToBeAdjusted = invoiceItemSqlDao.getById(invoiceItemModelDao.getLinkedItemId().toString(), context);
+        if (!INVOICE_ITEM_TYPES_ADJUSTABLE.contains(invoiceItemToBeAdjusted.getType())) {
+            throw new InvoiceApiException(ErrorCode.INVOICE_ITEM_ADJUSTMENT_ITEM_INVALID, invoiceItemToBeAdjusted.getId());
+        }
+    }
 }
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/InvoicePluginDispatcher.java b/invoice/src/main/java/org/killbill/billing/invoice/InvoicePluginDispatcher.java
index 8af7af0..396942e 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/InvoicePluginDispatcher.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/InvoicePluginDispatcher.java
@@ -18,12 +18,15 @@
 package org.killbill.billing.invoice;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.LinkedList;
 import java.util.List;
 
 import javax.inject.Inject;
 
+import org.killbill.billing.ErrorCode;
 import org.killbill.billing.invoice.api.Invoice;
+import org.killbill.billing.invoice.api.InvoiceApiException;
 import org.killbill.billing.invoice.api.InvoiceItem;
 import org.killbill.billing.invoice.api.InvoiceItemType;
 import org.killbill.billing.invoice.model.DefaultInvoice;
@@ -40,6 +43,10 @@ public class InvoicePluginDispatcher {
 
     private static final Logger log = LoggerFactory.getLogger(InvoicePluginDispatcher.class);
 
+    private static final Collection<InvoiceItemType> ALLOWED_INVOICE_ITEM_TYPES = ImmutableList.<InvoiceItemType>of(InvoiceItemType.EXTERNAL_CHARGE,
+                                                                                                                    InvoiceItemType.ITEM_ADJ,
+                                                                                                                    InvoiceItemType.TAX);
+
     private final OSGIServiceRegistration<InvoicePluginApi> pluginRegistry;
 
     @Inject
@@ -51,8 +58,7 @@ public class InvoicePluginDispatcher {
     // If we have multiple plugins there is a question of plugin ordering and also a 'product' questions to decide whether
     // subsequent plugins should have access to items added by previous plugins
     //
-    public List<InvoiceItem> getAdditionalInvoiceItems(final Invoice originalInvoice, final CallContext callContext) {
-
+    public List<InvoiceItem> getAdditionalInvoiceItems(final Invoice originalInvoice, final CallContext callContext) throws InvoiceApiException {
         // We clone the original invoice so plugins don't remove/add items
         final Invoice clonedInvoice = (Invoice) ((DefaultInvoice) originalInvoice).clone();
         final List<InvoiceItem> additionalInvoiceItems = new LinkedList<InvoiceItem>();
@@ -61,23 +67,21 @@ public class InvoicePluginDispatcher {
             final List<InvoiceItem> items = invoicePlugin.getAdditionalInvoiceItems(clonedInvoice, ImmutableList.<PluginProperty>of(), callContext);
             if (items != null) {
                 for (final InvoiceItem item : items) {
-                    if (item.getInvoiceItemType() != InvoiceItemType.FIXED &&
-                        item.getInvoiceItemType() != InvoiceItemType.RECURRING &&
-                        item.getInvoiceItemType() != InvoiceItemType.REPAIR_ADJ &&
-                        item.getInvoiceItemType() != InvoiceItemType.CBA_ADJ &&
-                        item.getInvoiceItemType() != InvoiceItemType.CREDIT_ADJ &&
-                        item.getInvoiceItemType() != InvoiceItemType.REFUND_ADJ &&
-                        item.getInvoiceItemType() != InvoiceItemType.USAGE) {
-                        additionalInvoiceItems.add(item);
-                    } else {
-                        log.warn("Ignoring invoice item of type {} from InvoicePlugin {}: {}", item.getInvoiceItemType(), invoicePlugin, item);
-                    }
+                    validateInvoiceItemFromPlugin(item, invoicePlugin);
+                    additionalInvoiceItems.add(item);
                 }
             }
         }
         return additionalInvoiceItems;
     }
 
+    private void validateInvoiceItemFromPlugin(final InvoiceItem invoiceItem, final InvoicePluginApi invoicePlugin) throws InvoiceApiException {
+        if (!ALLOWED_INVOICE_ITEM_TYPES.contains(invoiceItem.getInvoiceItemType())) {
+            log.warn("Ignoring invoice item of type {} from InvoicePlugin {}: {}", invoiceItem.getInvoiceItemType(), invoicePlugin, invoiceItem);
+            throw new InvoiceApiException(ErrorCode.INVOICE_ITEM_TYPE_INVALID, invoiceItem.getInvoiceItemType());
+        }
+    }
+
     private List<InvoicePluginApi> getInvoicePlugins() {
         final List<InvoicePluginApi> invoicePlugins = new ArrayList<InvoicePluginApi>();
         for (final String name : pluginRegistry.getAllServices()) {
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/tree/AccountItemTree.java b/invoice/src/main/java/org/killbill/billing/invoice/tree/AccountItemTree.java
index 2210894..1f6ec8d 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/tree/AccountItemTree.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/tree/AccountItemTree.java
@@ -125,6 +125,7 @@ public class AccountItemTree {
                 linkedInvoiceItem.getInvoiceItemType() != InvoiceItemType.RECURRING &&
                 linkedInvoiceItem.getInvoiceItemType() != InvoiceItemType.FIXED) {
                 // We only care about adjustments for recurring and fixed items when building the tree
+                // (we assume that REPAIR_ADJ and ITEM_ADJ items cannot be adjusted)
                 return;
             }
         }
diff --git a/jaxrs/src/main/java/org/killbill/billing/jaxrs/mappers/InvoiceApiExceptionMapper.java b/jaxrs/src/main/java/org/killbill/billing/jaxrs/mappers/InvoiceApiExceptionMapper.java
index ef4e0e9..bb5d255 100644
--- a/jaxrs/src/main/java/org/killbill/billing/jaxrs/mappers/InvoiceApiExceptionMapper.java
+++ b/jaxrs/src/main/java/org/killbill/billing/jaxrs/mappers/InvoiceApiExceptionMapper.java
@@ -1,7 +1,9 @@
 /*
  * Copyright 2010-2013 Ning, Inc.
+ * Copyright 2014-2015 Groupon, Inc
+ * Copyright 2014-2015 The Billing Project, LLC
  *
- * Ning licenses this file to you under the Apache License, version 2.0
+ * The Billing Project licenses this file to you under the Apache License, version 2.0
  * (the "License"); you may not use this file except in compliance with the
  * License.  You may obtain a copy of the License at:
  *
@@ -70,6 +72,10 @@ public class InvoiceApiExceptionMapper extends ExceptionMapperBase implements Ex
             return buildBadRequestResponse(exception, uriInfo);
         } else if (exception.getCode() == ErrorCode.CURRENCY_INVALID.getCode()) {
             return buildBadRequestResponse(exception, uriInfo);
+        } else if (exception.getCode() == ErrorCode.INVOICE_ITEM_ADJUSTMENT_ITEM_INVALID.getCode()) {
+            return buildBadRequestResponse(exception, uriInfo);
+        } else if (exception.getCode() == ErrorCode.INVOICE_ITEM_TYPE_INVALID.getCode()) {
+            return buildBadRequestResponse(exception, uriInfo);
         } else {
             return fallback(exception, uriInfo);
         }