killbill-uncached

invoice: implement item adjustments in DefaultInvoiceDao Signed-off-by:

8/2/2012 7:41:25 PM

Details

diff --git a/api/src/main/java/com/ning/billing/ErrorCode.java b/api/src/main/java/com/ning/billing/ErrorCode.java
index 8e000da..189d61e 100644
--- a/api/src/main/java/com/ning/billing/ErrorCode.java
+++ b/api/src/main/java/com/ning/billing/ErrorCode.java
@@ -203,6 +203,7 @@ public enum ErrorCode {
     CHARGE_BACK_DOES_NOT_EXIST(4004, "Could not find chargeback for id %s."),
     INVOICE_PAYMENT_BY_ATTEMPT_NOT_FOUND(4905, "No invoice payment could be found for paymentAttempt id %s."),
     REFUND_AMOUNT_TOO_HIGH(4906, "Tried to refund %s of a %s payment."),
+    REFUND_AMOUNT_DONT_MATCH_ITEMS_TO_ADJUST(4907, "You can't specify a refund amount of %s that doesn't match the invoice items amount of %s."),
 
     /*
      *
diff --git a/api/src/main/java/com/ning/billing/invoice/api/InvoiceItemType.java b/api/src/main/java/com/ning/billing/invoice/api/InvoiceItemType.java
index 548bfd5..63cd82a 100644
--- a/api/src/main/java/com/ning/billing/invoice/api/InvoiceItemType.java
+++ b/api/src/main/java/com/ning/billing/invoice/api/InvoiceItemType.java
@@ -28,8 +28,8 @@ public enum InvoiceItemType {
     // Credit adjustment, either at the account level (on its own invoice) or against an existing invoice
     // (invoice level adjustment)
     CREDIT_ADJ,
-    // Invoice item adjustment
+    // Invoice item adjustment (by itself or triggered by a refund)
     ITEM_ADJ,
-    // Refund adjustment (against a posted payment)
+    // Refund adjustment (against a posted payment), used when adjusting invoices
     REFUND_ADJ
 }
diff --git a/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java b/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java
index a2f9a9b..4fe738e 100644
--- a/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java
+++ b/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java
@@ -58,9 +58,11 @@ import com.ning.billing.util.dao.ObjectType;
 import com.ning.billing.util.dao.TableName;
 import com.ning.billing.util.tag.ControlTagType;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Objects;
 import com.google.common.base.Predicate;
 import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableMap.Builder;
 import com.google.inject.Inject;
 
 public class DefaultInvoiceDao implements InvoiceDao {
@@ -307,25 +309,28 @@ public class DefaultInvoiceDao implements InvoiceDao {
     }
 
     @Override
-    public InvoicePayment createRefund(final UUID paymentId, final BigDecimal amount, final boolean isInvoiceAdjusted,
-                                       final Map<UUID, BigDecimal> invoiceItemIdsWithAmounts, final UUID paymentCookieId,
+    public InvoicePayment createRefund(final UUID paymentId, final BigDecimal requestedRefundAmount, final boolean isInvoiceAdjusted,
+                                       final Map<UUID, BigDecimal> invoiceItemIdsWithNullAmounts, final UUID paymentCookieId,
                                        final CallContext context)
             throws InvoiceApiException {
         return invoicePaymentSqlDao.inTransaction(new Transaction<InvoicePayment, InvoicePaymentSqlDao>() {
             @Override
             public InvoicePayment inTransaction(final InvoicePaymentSqlDao transactional, final TransactionStatus status) throws Exception {
 
+                final InvoiceSqlDao transInvoiceDao = transactional.become(InvoiceSqlDao.class);
+
                 final InvoicePayment payment = transactional.getByPaymentId(paymentId.toString());
                 if (payment == null) {
                     throw new InvoiceApiException(ErrorCode.INVOICE_PAYMENT_BY_ATTEMPT_NOT_FOUND, paymentId);
                 }
-                final BigDecimal maxRefundAmount = payment.getAmount() == null ? BigDecimal.ZERO : payment.getAmount();
-                final BigDecimal requestedPositiveAmount = amount == null ? maxRefundAmount : amount;
-                // This check is good but not enough, we need to also take into account previous refunds
-                // (But that should have been checked in the payment call already)
-                if (requestedPositiveAmount.compareTo(maxRefundAmount) > 0) {
-                    throw new InvoiceApiException(ErrorCode.REFUND_AMOUNT_TOO_HIGH, requestedPositiveAmount, maxRefundAmount);
-                }
+
+                // Retrieve the amounts to adjust, if needed
+                final Map<UUID, BigDecimal> invoiceItemIdsWithAmounts = computeItemAdjustments(payment.getInvoiceId().toString(),
+                                                                                               transInvoiceDao,
+                                                                                               invoiceItemIdsWithNullAmounts);
+
+                // Compute the actual amount to refund
+                final BigDecimal requestedPositiveAmount = computePositiveRefundAmount(payment, requestedRefundAmount, invoiceItemIdsWithAmounts);
 
                 // Before we go further, check if that refund already got inserted -- the payment system keeps a state machine
                 // and so this call may be called several time for the same  paymentCookieId (which is really the refundId)
@@ -340,7 +345,6 @@ public class DefaultInvoiceDao implements InvoiceDao {
                 transactional.create(refund, context);
 
                 // Retrieve invoice after the Refund
-                final InvoiceSqlDao transInvoiceDao = transactional.become(InvoiceSqlDao.class);
                 final Invoice invoice = transInvoiceDao.getById(payment.getInvoiceId().toString());
                 if (invoice != null) {
                     populateChildren(invoice, transInvoiceDao);
@@ -351,8 +355,7 @@ public class DefaultInvoiceDao implements InvoiceDao {
                 final BigDecimal invoiceBalanceAfterRefund = invoice.getBalance();
                 final InvoiceItemSqlDao transInvoiceItemDao = transInvoiceDao.become(InvoiceItemSqlDao.class);
 
-                // If we have an existing CBA > 0, we need to adjust it
-                //final BigDecimal cbaAmountAfterRefund = invoice.getCBAAmount();
+                // If we have an existing CBA > 0 at the account level, we need to use it
                 final BigDecimal accountCbaAvailable = getAccountCBAFromTransaction(invoice.getAccountId(), transInvoiceDao);
                 BigDecimal cbaAdjAmount = BigDecimal.ZERO;
                 if (accountCbaAvailable.compareTo(BigDecimal.ZERO) > 0) {
@@ -362,19 +365,102 @@ public class DefaultInvoiceDao implements InvoiceDao {
                 }
                 final BigDecimal requestedPositiveAmountAfterCbaAdj = requestedPositiveAmount.add(cbaAdjAmount);
 
-                if (isInvoiceAdjusted) {
+                // At this point, we created the refund which made the invoice balance positive and applied any existing
+                // available CBA to that invoice.
+                // We now need to adjust the invoice and/or invoice items if needed and specified.
+                if (isInvoiceAdjusted && invoiceItemIdsWithAmounts.size() == 0) {
+                    // Invoice adjustment
                     final BigDecimal maxBalanceToAdjust = (invoiceBalanceAfterRefund.compareTo(BigDecimal.ZERO) <= 0) ? BigDecimal.ZERO : invoiceBalanceAfterRefund;
                     final BigDecimal requestedPositiveAmountToAdjust = requestedPositiveAmountAfterCbaAdj.compareTo(maxBalanceToAdjust) > 0 ? maxBalanceToAdjust : requestedPositiveAmountAfterCbaAdj;
                     if (requestedPositiveAmountToAdjust.compareTo(BigDecimal.ZERO) > 0) {
                         final InvoiceItem adjItem = new RefundAdjInvoiceItem(invoice.getId(), invoice.getAccountId(), context.getCreatedDate().toLocalDate(), requestedPositiveAmountToAdjust.negate(), invoice.getCurrency());
                         transInvoiceItemDao.create(adjItem, context);
                     }
+                } else if (isInvoiceAdjusted) {
+                    // Invoice item adjustment
+                    for (final UUID invoiceItemId : invoiceItemIdsWithAmounts.keySet()) {
+                        final BigDecimal adjAmount = invoiceItemIdsWithAmounts.get(invoiceItemId);
+                        final InvoiceItem item = createAdjustmentItem(transInvoiceDao, invoice.getId(), invoiceItemId, adjAmount,
+                                                                      invoice.getCurrency(), context.getCreatedDate().toLocalDate());
+                        transInvoiceItemDao.create(item, context);
+                    }
                 }
+
                 return refund;
             }
         });
     }
 
+    /**
+     * Find amounts to adjust for individual items, if not specified.
+     * The user gives us a list of items to adjust associated with a given amount (how much to refund per invoice item).
+     * In case of full adjustments, the amount can be null: in this case, we retrieve the original amount for the invoice
+     * item.
+     *
+     * @param invoiceId                     original invoice id
+     * @param transInvoiceDao               the transactional InvoiceSqlDao
+     * @param invoiceItemIdsWithNullAmounts the original mapping between invoice item ids and amount to refund (contains null)
+     * @return the final mapping between invoice item ids and amount to refund
+     * @throws InvoiceApiException
+     */
+    private Map<UUID, BigDecimal> computeItemAdjustments(final String invoiceId, final InvoiceSqlDao transInvoiceDao,
+                                                         final Map<UUID, BigDecimal> invoiceItemIdsWithNullAmounts) throws InvoiceApiException {
+        // Populate the missing amounts for individual items, if needed
+        final Builder<UUID, BigDecimal> invoiceItemIdsWithAmountsBuilder = new Builder<UUID, BigDecimal>();
+        if (invoiceItemIdsWithNullAmounts.size() == 0) {
+            return invoiceItemIdsWithAmountsBuilder.build();
+        }
+
+        // Retrieve invoice before the Refund
+        final Invoice invoice = transInvoiceDao.getById(invoiceId);
+        if (invoice != null) {
+            populateChildren(invoice, transInvoiceDao);
+        } else {
+            throw new IllegalStateException("Invoice shouldn't be null for id " + invoiceId);
+        }
+
+        for (final UUID invoiceItemId : invoiceItemIdsWithNullAmounts.keySet()) {
+            final BigDecimal adjAmount = Objects.firstNonNull(invoiceItemIdsWithNullAmounts.get(invoiceItemId),
+                                                              getInvoiceItemAmountForId(invoice, invoiceItemId));
+            invoiceItemIdsWithAmountsBuilder.put(invoiceItemId, adjAmount);
+        }
+
+        return invoiceItemIdsWithAmountsBuilder.build();
+    }
+
+    private BigDecimal getInvoiceItemAmountForId(final Invoice invoice, final UUID invoiceItemId) throws InvoiceApiException {
+        for (final InvoiceItem invoiceItem : invoice.getInvoiceItems()) {
+            if (invoiceItem.getId().equals(invoiceItemId)) {
+                return invoiceItem.getAmount();
+            }
+        }
+
+        throw new InvoiceApiException(ErrorCode.INVOICE_ITEM_NOT_FOUND, invoiceItemId);
+    }
+
+    @VisibleForTesting
+    BigDecimal computePositiveRefundAmount(final InvoicePayment payment, final BigDecimal requestedAmount, final Map<UUID, BigDecimal> invoiceItemIdsWithAmounts) throws InvoiceApiException {
+        final BigDecimal maxRefundAmount = payment.getAmount() == null ? BigDecimal.ZERO : payment.getAmount();
+        final BigDecimal requestedPositiveAmount = requestedAmount == null ? maxRefundAmount : requestedAmount;
+        // This check is good but not enough, we need to also take into account previous refunds
+        // (But that should have been checked in the payment call already)
+        if (requestedPositiveAmount.compareTo(maxRefundAmount) > 0) {
+            throw new InvoiceApiException(ErrorCode.REFUND_AMOUNT_TOO_HIGH, requestedPositiveAmount, maxRefundAmount);
+        }
+
+        // Verify if the requested amount matches the invoice items to adjust, if specified
+        BigDecimal amountFromItems = BigDecimal.ZERO;
+        for (final BigDecimal itemAmount : invoiceItemIdsWithAmounts.values()) {
+            amountFromItems = amountFromItems.add(itemAmount);
+        }
+
+        // Sanity check: if some items were specified, then the sum should be equal to specified refund amount, if specified
+        if (amountFromItems.compareTo(BigDecimal.ZERO) != 0 && requestedPositiveAmount.compareTo(amountFromItems) != 0) {
+            throw new InvoiceApiException(ErrorCode.REFUND_AMOUNT_DONT_MATCH_ITEMS_TO_ADJUST, requestedPositiveAmount, amountFromItems);
+        }
+        return requestedPositiveAmount;
+    }
+
     @Override
     public InvoicePayment postChargeback(final UUID invoicePaymentId, final BigDecimal amount, final CallContext context) throws InvoiceApiException {
 
diff --git a/invoice/src/test/java/com/ning/billing/invoice/api/invoice/TestDefaultInvoicePaymentApi.java b/invoice/src/test/java/com/ning/billing/invoice/api/invoice/TestDefaultInvoicePaymentApi.java
new file mode 100644
index 0000000..9ccbf84
--- /dev/null
+++ b/invoice/src/test/java/com/ning/billing/invoice/api/invoice/TestDefaultInvoicePaymentApi.java
@@ -0,0 +1,181 @@
+/*
+ * Copyright 2010-2012 Ning, Inc.
+ *
+ * Ning licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License.  You may obtain a copy of the License at:
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.ning.billing.invoice.api.invoice;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+
+import org.skife.jdbi.v2.IDBI;
+import org.testng.Assert;
+import org.testng.annotations.BeforeSuite;
+import org.testng.annotations.Test;
+
+import com.ning.billing.KillbillTestSuiteWithEmbeddedDB;
+import com.ning.billing.catalog.api.Currency;
+import com.ning.billing.dbi.MysqlTestingHelper;
+import com.ning.billing.invoice.InvoiceTestSuiteWithEmbeddedDB;
+import com.ning.billing.invoice.api.Invoice;
+import com.ning.billing.invoice.api.InvoiceApiException;
+import com.ning.billing.invoice.api.InvoicePayment;
+import com.ning.billing.invoice.api.InvoicePayment.InvoicePaymentType;
+import com.ning.billing.invoice.api.InvoicePaymentApi;
+import com.ning.billing.invoice.dao.DefaultInvoiceDao;
+import com.ning.billing.invoice.dao.InvoiceDao;
+import com.ning.billing.invoice.dao.InvoiceItemSqlDao;
+import com.ning.billing.invoice.dao.InvoiceSqlDao;
+import com.ning.billing.invoice.notification.MockNextBillingDatePoster;
+import com.ning.billing.invoice.notification.NextBillingDatePoster;
+import com.ning.billing.util.api.TagUserApi;
+import com.ning.billing.util.callcontext.CallContext;
+import com.ning.billing.util.callcontext.TestCallContext;
+import com.ning.billing.util.clock.Clock;
+import com.ning.billing.util.clock.ClockMock;
+import com.ning.billing.util.tag.api.DefaultTagUserApi;
+import com.ning.billing.util.tag.dao.MockTagDao;
+import com.ning.billing.util.tag.dao.MockTagDefinitionDao;
+import com.ning.billing.util.tag.dao.TagDao;
+import com.ning.billing.util.tag.dao.TagDefinitionDao;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+import static com.ning.billing.invoice.tests.InvoiceTestUtils.createAndPersistInvoice;
+import static com.ning.billing.invoice.tests.InvoiceTestUtils.createAndPersistPayment;
+
+public class TestDefaultInvoicePaymentApi extends InvoiceTestSuiteWithEmbeddedDB {
+
+    private static final BigDecimal THIRTY = new BigDecimal("30.00");
+    private static final Currency CURRENCY = Currency.EUR;
+
+    private final Clock clock = new ClockMock();
+
+    private InvoiceSqlDao invoiceSqlDao;
+    private InvoiceItemSqlDao invoiceItemSqlDao;
+    private InvoicePaymentApi invoicePaymentApi;
+    private CallContext context;
+
+    @BeforeSuite(groups = "slow")
+    public void setup() throws IOException {
+        final MysqlTestingHelper helper = KillbillTestSuiteWithEmbeddedDB.getMysqlTestingHelper();
+        final IDBI dbi = helper.getDBI();
+
+        invoiceSqlDao = dbi.onDemand(InvoiceSqlDao.class);
+        invoiceSqlDao.test();
+
+        invoiceItemSqlDao = dbi.onDemand(InvoiceItemSqlDao.class);
+        invoiceItemSqlDao.test();
+
+        final NextBillingDatePoster nextBillingDatePoster = new MockNextBillingDatePoster();
+        final TagDefinitionDao tagDefinitionDao = new MockTagDefinitionDao();
+        final TagDao tagDao = new MockTagDao();
+        final TagUserApi tagUserApi = new DefaultTagUserApi(tagDefinitionDao, tagDao);
+        final InvoiceDao invoiceDao = new DefaultInvoiceDao(dbi, nextBillingDatePoster, tagUserApi, clock);
+        invoicePaymentApi = new DefaultInvoicePaymentApi(invoiceDao);
+
+        context = new TestCallContext("Invoice payment tests");
+    }
+
+    @Test(groups = "slow")
+    public void testFullRefundWithNoAdjustment() throws Exception {
+        verifyRefund(THIRTY, THIRTY, THIRTY, false, ImmutableMap.<UUID, BigDecimal>of());
+    }
+
+    @Test(groups = "slow")
+    public void testPartialRefundWithNoAdjustment() throws Exception {
+        verifyRefund(THIRTY, BigDecimal.TEN, BigDecimal.TEN, false, ImmutableMap.<UUID, BigDecimal>of());
+    }
+
+    @Test(groups = "slow")
+    public void testFullRefundWithInvoiceAdjustment() throws Exception {
+        verifyRefund(THIRTY, THIRTY, BigDecimal.ZERO, true, ImmutableMap.<UUID, BigDecimal>of());
+    }
+
+    @Test(groups = "slow")
+    public void testPartialRefundWithInvoiceAdjustment() throws Exception {
+        verifyRefund(THIRTY, BigDecimal.TEN, BigDecimal.ZERO, true, ImmutableMap.<UUID, BigDecimal>of());
+    }
+
+    @Test(groups = "slow")
+    public void testFullRefundWithBothInvoiceItemAdjustments() throws Exception {
+        // Create an invoice with two items (30 \u20ac and 10 \u20ac)
+        final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock,
+                                                        ImmutableList.<BigDecimal>of(THIRTY, BigDecimal.TEN), CURRENCY, context);
+
+        // Fully adjust both items
+        final Map<UUID, BigDecimal> adjustments = new HashMap<UUID, BigDecimal>();
+        adjustments.put(invoice.getInvoiceItems().get(0).getId(), null);
+        adjustments.put(invoice.getInvoiceItems().get(1).getId(), null);
+
+        verifyRefund(invoice, new BigDecimal("40"), new BigDecimal("40"), BigDecimal.ZERO, true, adjustments);
+    }
+
+    @Test(groups = "slow")
+    public void testPartialRefundWithSingleInvoiceItemAdjustment() throws Exception {
+        // Create an invoice with two items (30 \u20ac and 10 \u20ac)
+        final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock,
+                                                        ImmutableList.<BigDecimal>of(THIRTY, BigDecimal.TEN), CURRENCY, context);
+
+        // Fully adjust both items
+        final Map<UUID, BigDecimal> adjustments = new HashMap<UUID, BigDecimal>();
+        adjustments.put(invoice.getInvoiceItems().get(0).getId(), null);
+
+        verifyRefund(invoice, new BigDecimal("40"), new BigDecimal("30"), BigDecimal.ZERO, true, adjustments);
+    }
+
+    @Test(groups = "slow")
+    public void testPartialRefundWithTwoInvoiceItemAdjustment() throws Exception {
+        // Create an invoice with two items (30 \u20ac and 10 \u20ac)
+        final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock,
+                                                        ImmutableList.<BigDecimal>of(THIRTY, BigDecimal.TEN), CURRENCY, context);
+        // Adjust partially both items: the invoice posted was 40 \u20ac, but we should really just have charged you 2 \u20ac
+        final ImmutableMap<UUID, BigDecimal> adjustments = ImmutableMap.<UUID, BigDecimal>of(invoice.getInvoiceItems().get(0).getId(), new BigDecimal("29"),
+                                                                                             invoice.getInvoiceItems().get(1).getId(), new BigDecimal("9"));
+        verifyRefund(invoice, new BigDecimal("40"), new BigDecimal("38"), BigDecimal.ZERO, true, adjustments);
+    }
+
+    private void verifyRefund(final BigDecimal invoiceAmount, final BigDecimal refundAmount, final BigDecimal finalInvoiceAmount,
+                              final boolean adjusted, final Map<UUID, BigDecimal> invoiceItemIdsWithAmounts) throws InvoiceApiException {
+        final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock, invoiceAmount, CURRENCY, context);
+        verifyRefund(invoice, invoiceAmount, refundAmount, finalInvoiceAmount, adjusted, invoiceItemIdsWithAmounts);
+    }
+
+    private void verifyRefund(final Invoice invoice, final BigDecimal invoiceAmount, final BigDecimal refundAmount, final BigDecimal finalInvoiceAmount,
+                              final boolean adjusted, final Map<UUID, BigDecimal> invoiceItemIdsWithAmounts) throws InvoiceApiException {
+        final InvoicePayment payment = createAndPersistPayment(invoicePaymentApi, clock, invoice.getId(), invoiceAmount, CURRENCY, context);
+
+        // Verify the initial invoice balance
+        final BigDecimal initialInvoiceBalance = invoicePaymentApi.getInvoice(invoice.getId()).getBalance();
+        Assert.assertEquals(initialInvoiceBalance.compareTo(BigDecimal.ZERO), 0);
+
+        // Create a full refund with no adjustment
+        final InvoicePayment refund = invoicePaymentApi.createRefund(payment.getPaymentId(), refundAmount, adjusted, invoiceItemIdsWithAmounts,
+                                                                     UUID.randomUUID(), context);
+        Assert.assertEquals(refund.getAmount().compareTo(refundAmount.negate()), 0);
+        Assert.assertEquals(refund.getCurrency(), CURRENCY);
+        Assert.assertEquals(refund.getInvoiceId(), invoice.getId());
+        Assert.assertEquals(refund.getPaymentId(), payment.getPaymentId());
+        Assert.assertEquals(refund.getType(), InvoicePaymentType.REFUND);
+
+        // Verify the current invoice balance
+        final BigDecimal newInvoiceBalance = invoicePaymentApi.getInvoice(invoice.getId()).getBalance().setScale(2, RoundingMode.HALF_UP);
+        Assert.assertEquals(newInvoiceBalance.compareTo(finalInvoiceAmount.setScale(2, RoundingMode.HALF_UP)), 0);
+    }
+}
diff --git a/invoice/src/test/java/com/ning/billing/invoice/dao/TestDefaultInvoiceDao.java b/invoice/src/test/java/com/ning/billing/invoice/dao/TestDefaultInvoiceDao.java
index 2bccd7b..9c8a92a 100644
--- a/invoice/src/test/java/com/ning/billing/invoice/dao/TestDefaultInvoiceDao.java
+++ b/invoice/src/test/java/com/ning/billing/invoice/dao/TestDefaultInvoiceDao.java
@@ -16,6 +16,7 @@
 
 package com.ning.billing.invoice.dao;
 
+import java.math.BigDecimal;
 import java.util.Map;
 import java.util.UUID;
 
@@ -25,8 +26,11 @@ import org.testng.Assert;
 import org.testng.annotations.BeforeMethod;
 import org.testng.annotations.Test;
 
+import com.ning.billing.ErrorCode;
 import com.ning.billing.invoice.InvoiceTestSuite;
 import com.ning.billing.invoice.api.Invoice;
+import com.ning.billing.invoice.api.InvoiceApiException;
+import com.ning.billing.invoice.api.InvoicePayment;
 import com.ning.billing.invoice.notification.NextBillingDatePoster;
 import com.ning.billing.util.api.TagUserApi;
 import com.ning.billing.util.callcontext.CallContext;
@@ -40,7 +44,10 @@ import com.ning.billing.util.tag.dao.MockTagDefinitionDao;
 import com.ning.billing.util.tag.dao.TagDao;
 import com.ning.billing.util.tag.dao.TagDefinitionDao;
 
+import com.google.common.collect.ImmutableMap;
+
 public class TestDefaultInvoiceDao extends InvoiceTestSuite {
+
     private InvoiceSqlDao invoiceSqlDao;
     private TagUserApi tagUserApi;
     private DefaultInvoiceDao dao;
@@ -59,6 +66,43 @@ public class TestDefaultInvoiceDao extends InvoiceTestSuite {
     }
 
     @Test(groups = "fast")
+    public void testComputePositiveRefundAmount() throws Exception {
+        // Verify the cases with no adjustment first
+        final Map<UUID, BigDecimal> noItemAdjustment = ImmutableMap.<UUID, BigDecimal>of();
+        verifyComputedRefundAmount(null, null, noItemAdjustment, BigDecimal.ZERO);
+        verifyComputedRefundAmount(null, BigDecimal.ZERO, noItemAdjustment, BigDecimal.ZERO);
+        verifyComputedRefundAmount(BigDecimal.TEN, null, noItemAdjustment, BigDecimal.TEN);
+        verifyComputedRefundAmount(BigDecimal.TEN, BigDecimal.ONE, noItemAdjustment, BigDecimal.ONE);
+        try {
+            verifyComputedRefundAmount(BigDecimal.ONE, BigDecimal.TEN, noItemAdjustment, BigDecimal.TEN);
+            Assert.fail("Shouldn't have been able to compute a refund amount");
+        } catch (InvoiceApiException e) {
+            Assert.assertEquals(e.getCode(), ErrorCode.REFUND_AMOUNT_TOO_HIGH.getCode());
+        }
+
+        // Try with adjustments now
+        final Map<UUID, BigDecimal> itemAdjustments = ImmutableMap.<UUID, BigDecimal>of(UUID.randomUUID(), BigDecimal.ONE,
+                                                                                        UUID.randomUUID(), BigDecimal.TEN,
+                                                                                        UUID.randomUUID(), BigDecimal.ZERO);
+        verifyComputedRefundAmount(new BigDecimal("100"), new BigDecimal("11"), itemAdjustments, new BigDecimal("11"));
+        try {
+            verifyComputedRefundAmount(new BigDecimal("100"), BigDecimal.TEN, itemAdjustments, BigDecimal.TEN);
+            Assert.fail("Shouldn't have been able to compute a refund amount");
+        } catch (InvoiceApiException e) {
+            Assert.assertEquals(e.getCode(), ErrorCode.REFUND_AMOUNT_DONT_MATCH_ITEMS_TO_ADJUST.getCode());
+        }
+    }
+
+    private void verifyComputedRefundAmount(final BigDecimal paymentAmount, final BigDecimal requestedAmount,
+                                            final Map<UUID, BigDecimal> invoiceItemIdsWithAmounts, final BigDecimal expectedRefundAmount) throws InvoiceApiException {
+        final InvoicePayment invoicePayment = Mockito.mock(InvoicePayment.class);
+        Mockito.when(invoicePayment.getAmount()).thenReturn(paymentAmount);
+
+        final BigDecimal actualRefundAmount = dao.computePositiveRefundAmount(invoicePayment, requestedAmount, invoiceItemIdsWithAmounts);
+        Assert.assertEquals(actualRefundAmount, expectedRefundAmount);
+    }
+
+    @Test(groups = "fast")
     public void testFindByNumber() throws Exception {
         final Integer number = Integer.MAX_VALUE;
         final Invoice invoice = Mockito.mock(Invoice.class);
diff --git a/invoice/src/test/java/com/ning/billing/invoice/tests/InvoiceTestUtils.java b/invoice/src/test/java/com/ning/billing/invoice/tests/InvoiceTestUtils.java
new file mode 100644
index 0000000..8826134
--- /dev/null
+++ b/invoice/src/test/java/com/ning/billing/invoice/tests/InvoiceTestUtils.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2010-2012 Ning, Inc.
+ *
+ * Ning licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License.  You may obtain a copy of the License at:
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.ning.billing.invoice.tests;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+import org.mockito.Mockito;
+
+import com.ning.billing.catalog.api.Currency;
+import com.ning.billing.invoice.api.Invoice;
+import com.ning.billing.invoice.api.InvoiceItem;
+import com.ning.billing.invoice.api.InvoicePayment;
+import com.ning.billing.invoice.api.InvoicePayment.InvoicePaymentType;
+import com.ning.billing.invoice.api.InvoicePaymentApi;
+import com.ning.billing.invoice.dao.InvoiceItemSqlDao;
+import com.ning.billing.invoice.dao.InvoiceSqlDao;
+import com.ning.billing.invoice.model.FixedPriceInvoiceItem;
+import com.ning.billing.util.callcontext.CallContext;
+import com.ning.billing.util.clock.Clock;
+
+import com.google.common.collect.ImmutableList;
+
+public class InvoiceTestUtils {
+
+    private InvoiceTestUtils() {}
+
+    public static Invoice createAndPersistInvoice(final InvoiceSqlDao invoiceSqlDao,
+                                                  final InvoiceItemSqlDao invoiceItemSqlDao,
+                                                  final Clock clock,
+                                                  final BigDecimal amount,
+                                                  final Currency currency,
+                                                  final CallContext callContext) {
+        return createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock, ImmutableList.<BigDecimal>of(amount),
+                                       currency, callContext);
+    }
+
+    public static Invoice createAndPersistInvoice(final InvoiceSqlDao invoiceSqlDao,
+                                                  final InvoiceItemSqlDao invoiceItemSqlDao,
+                                                  final Clock clock,
+                                                  final List<BigDecimal> amounts,
+                                                  final Currency currency,
+                                                  final CallContext callContext) {
+        final Invoice invoice = Mockito.mock(Invoice.class);
+        final UUID invoiceId = UUID.randomUUID();
+        final UUID accountId = UUID.randomUUID();
+
+        Mockito.when(invoice.getId()).thenReturn(invoiceId);
+        Mockito.when(invoice.getAccountId()).thenReturn(accountId);
+        Mockito.when(invoice.getInvoiceDate()).thenReturn(clock.getUTCToday());
+        Mockito.when(invoice.getTargetDate()).thenReturn(clock.getUTCToday());
+        Mockito.when(invoice.getCurrency()).thenReturn(currency);
+        Mockito.when(invoice.isMigrationInvoice()).thenReturn(false);
+
+        final List<InvoiceItem> invoiceItems = new ArrayList<InvoiceItem>();
+        for (final BigDecimal amount : amounts) {
+            final InvoiceItem invoiceItem = createInvoiceItem(clock, invoiceId, accountId, amount, currency);
+            invoiceItemSqlDao.create(invoiceItem, callContext);
+            invoiceItems.add(invoiceItem);
+        }
+        Mockito.when(invoice.getInvoiceItems()).thenReturn(invoiceItems);
+
+        invoiceSqlDao.create(invoice, callContext);
+
+        return invoice;
+    }
+
+    public static InvoiceItem createInvoiceItem(final Clock clock, final UUID invoiceId, final UUID accountId, final BigDecimal amount, final Currency currency) {
+        return new FixedPriceInvoiceItem(invoiceId, accountId, UUID.randomUUID(), UUID.randomUUID(),
+                                         "charge back test", "charge back phase", clock.getUTCToday(), amount, currency);
+    }
+
+    public static InvoicePayment createAndPersistPayment(final InvoicePaymentApi invoicePaymentApi,
+                                                         final Clock clock,
+                                                         final UUID invoiceId,
+                                                         final BigDecimal amount,
+                                                         final Currency currency,
+                                                         final CallContext callContext) {
+        final InvoicePayment payment = Mockito.mock(InvoicePayment.class);
+        Mockito.when(payment.getId()).thenReturn(UUID.randomUUID());
+        Mockito.when(payment.getType()).thenReturn(InvoicePaymentType.ATTEMPT);
+        Mockito.when(payment.getInvoiceId()).thenReturn(invoiceId);
+        Mockito.when(payment.getPaymentId()).thenReturn(UUID.randomUUID());
+        Mockito.when(payment.getPaymentCookieId()).thenReturn(UUID.randomUUID());
+        Mockito.when(payment.getPaymentDate()).thenReturn(clock.getUTCNow());
+        Mockito.when(payment.getAmount()).thenReturn(amount);
+        Mockito.when(payment.getCurrency()).thenReturn(currency);
+
+        invoicePaymentApi.notifyOfPayment(payment, callContext);
+
+        return payment;
+    }
+}
diff --git a/invoice/src/test/java/com/ning/billing/invoice/tests/TestChargeBacks.java b/invoice/src/test/java/com/ning/billing/invoice/tests/TestChargeBacks.java
index ef43230..689b11a 100644
--- a/invoice/src/test/java/com/ning/billing/invoice/tests/TestChargeBacks.java
+++ b/invoice/src/test/java/com/ning/billing/invoice/tests/TestChargeBacks.java
@@ -19,11 +19,9 @@ package com.ning.billing.invoice.tests;
 import java.io.IOException;
 import java.math.BigDecimal;
 import java.net.URL;
-import java.util.ArrayList;
 import java.util.List;
 import java.util.UUID;
 
-import org.mockito.Mockito;
 import org.skife.jdbi.v2.IDBI;
 import org.skife.jdbi.v2.exceptions.TransactionFailedException;
 import org.testng.annotations.BeforeSuite;
@@ -35,16 +33,14 @@ import com.ning.billing.dbi.MysqlTestingHelper;
 import com.ning.billing.invoice.InvoiceTestSuiteWithEmbeddedDB;
 import com.ning.billing.invoice.api.Invoice;
 import com.ning.billing.invoice.api.InvoiceApiException;
-import com.ning.billing.invoice.api.InvoiceItem;
 import com.ning.billing.invoice.api.InvoicePayment;
-import com.ning.billing.invoice.api.InvoicePayment.InvoicePaymentType;
 import com.ning.billing.invoice.api.InvoicePaymentApi;
 import com.ning.billing.invoice.api.invoice.DefaultInvoicePaymentApi;
 import com.ning.billing.invoice.dao.DefaultInvoiceDao;
 import com.ning.billing.invoice.dao.InvoiceDao;
+import com.ning.billing.invoice.dao.InvoiceItemSqlDao;
 import com.ning.billing.invoice.dao.InvoiceSqlDao;
 import com.ning.billing.invoice.glue.InvoiceModuleWithEmbeddedDb;
-import com.ning.billing.invoice.model.FixedPriceInvoiceItem;
 import com.ning.billing.invoice.notification.MockNextBillingDatePoster;
 import com.ning.billing.invoice.notification.NextBillingDatePoster;
 import com.ning.billing.util.api.TagUserApi;
@@ -58,16 +54,20 @@ import com.ning.billing.util.tag.dao.MockTagDefinitionDao;
 import com.ning.billing.util.tag.dao.TagDao;
 import com.ning.billing.util.tag.dao.TagDefinitionDao;
 
+import static com.ning.billing.invoice.tests.InvoiceTestUtils.createAndPersistInvoice;
+import static com.ning.billing.invoice.tests.InvoiceTestUtils.createAndPersistPayment;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertNotNull;
 import static org.testng.Assert.assertTrue;
 import static org.testng.Assert.fail;
 
 public class TestChargeBacks extends InvoiceTestSuiteWithEmbeddedDB {
+
     private static final BigDecimal FIFTEEN = new BigDecimal("15.00");
     private static final BigDecimal THIRTY = new BigDecimal("30.00");
     private static final BigDecimal ONE_MILLION = new BigDecimal("1000000.00");
     private InvoiceSqlDao invoiceSqlDao;
+    private InvoiceItemSqlDao invoiceItemSqlDao;
     private InvoicePaymentApi invoicePaymentApi;
     private CallContext context;
     private final Clock clock = new ClockMock();
@@ -83,6 +83,8 @@ public class TestChargeBacks extends InvoiceTestSuiteWithEmbeddedDB {
         invoiceSqlDao = dbi.onDemand(InvoiceSqlDao.class);
         invoiceSqlDao.test();
 
+        invoiceItemSqlDao = dbi.onDemand(InvoiceItemSqlDao.class);
+        invoiceItemSqlDao.test();
         final NextBillingDatePoster nextBillingDatePoster = new MockNextBillingDatePoster();
         final TagDefinitionDao tagDefinitionDao = new MockTagDefinitionDao();
         final TagDao tagDao = new MockTagDao();
@@ -105,8 +107,8 @@ public class TestChargeBacks extends InvoiceTestSuiteWithEmbeddedDB {
 
     @Test(groups = "slow")
     public void testCompleteChargeBack() throws InvoiceApiException {
-        final Invoice invoice = createAndPersistInvoice(THIRTY);
-        final InvoicePayment payment = createAndPersistPayment(invoice.getId(), THIRTY);
+        final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock, THIRTY, CURRENCY, context);
+        final InvoicePayment payment = createAndPersistPayment(invoicePaymentApi, clock, invoice.getId(), THIRTY, CURRENCY, context);
 
         // create a full charge back
         invoicePaymentApi.createChargeback(payment.getId(), THIRTY, context);
@@ -118,8 +120,8 @@ public class TestChargeBacks extends InvoiceTestSuiteWithEmbeddedDB {
 
     @Test(groups = "slow")
     public void testPartialChargeBack() throws InvoiceApiException {
-        final Invoice invoice = createAndPersistInvoice(THIRTY);
-        final InvoicePayment payment = createAndPersistPayment(invoice.getId(), THIRTY);
+        final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock, THIRTY, CURRENCY, context);
+        final InvoicePayment payment = createAndPersistPayment(invoicePaymentApi, clock, invoice.getId(), THIRTY, CURRENCY, context);
 
         // create a partial charge back
         invoicePaymentApi.createChargeback(payment.getId(), FIFTEEN, context);
@@ -132,8 +134,8 @@ public class TestChargeBacks extends InvoiceTestSuiteWithEmbeddedDB {
     @Test(groups = "slow", expectedExceptions = InvoiceApiException.class)
     public void testChargeBackLargerThanPaymentAmount() throws InvoiceApiException {
         try {
-            final Invoice invoice = createAndPersistInvoice(THIRTY);
-            final InvoicePayment payment = createAndPersistPayment(invoice.getId(), THIRTY);
+            final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock, THIRTY, CURRENCY, context);
+            final InvoicePayment payment = createAndPersistPayment(invoicePaymentApi, clock, invoice.getId(), THIRTY, CURRENCY, context);
 
             // create a large charge back
             invoicePaymentApi.createChargeback(payment.getId(), ONE_MILLION, context);
@@ -146,8 +148,8 @@ public class TestChargeBacks extends InvoiceTestSuiteWithEmbeddedDB {
     @Test(groups = "slow", expectedExceptions = InvoiceApiException.class)
     public void testNegativeChargeBackAmount() throws InvoiceApiException {
         try {
-            final Invoice invoice = createAndPersistInvoice(THIRTY);
-            final InvoicePayment payment = createAndPersistPayment(invoice.getId(), THIRTY);
+            final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock, THIRTY, CURRENCY, context);
+            final InvoicePayment payment = createAndPersistPayment(invoicePaymentApi, clock, invoice.getId(), THIRTY, CURRENCY, context);
 
             // create a partial charge back
             invoicePaymentApi.createChargeback(payment.getId(), BigDecimal.ONE.negate(), context);
@@ -158,8 +160,8 @@ public class TestChargeBacks extends InvoiceTestSuiteWithEmbeddedDB {
 
     @Test(groups = "slow")
     public void testGetAccountIdFromPaymentIdHappyPath() throws InvoiceApiException {
-        final Invoice invoice = createAndPersistInvoice(THIRTY);
-        final InvoicePayment payment = createAndPersistPayment(invoice.getId(), THIRTY);
+        final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock, THIRTY, CURRENCY, context);
+        final InvoicePayment payment = createAndPersistPayment(invoicePaymentApi, clock, invoice.getId(), THIRTY, CURRENCY, context);
         final UUID accountId = invoicePaymentApi.getAccountIdFromInvoicePaymentId(payment.getId());
         assertEquals(accountId, invoice.getAccountId());
     }
@@ -178,8 +180,8 @@ public class TestChargeBacks extends InvoiceTestSuiteWithEmbeddedDB {
 
     @Test(groups = "slow")
     public void testGetChargeBacksByAccountIdHappyPath() throws InvoiceApiException {
-        final Invoice invoice = createAndPersistInvoice(THIRTY);
-        final InvoicePayment payment = createAndPersistPayment(invoice.getId(), THIRTY);
+        final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock, THIRTY, CURRENCY, context);
+        final InvoicePayment payment = createAndPersistPayment(invoicePaymentApi, clock, invoice.getId(), THIRTY, CURRENCY, context);
 
         // create a partial charge back
         invoicePaymentApi.createChargeback(payment.getId(), FIFTEEN, context);
@@ -199,8 +201,8 @@ public class TestChargeBacks extends InvoiceTestSuiteWithEmbeddedDB {
 
     @Test(groups = "slow")
     public void testGetChargeBacksByInvoicePaymentIdHappyPath() throws InvoiceApiException {
-        final Invoice invoice = createAndPersistInvoice(THIRTY);
-        final InvoicePayment payment = createAndPersistPayment(invoice.getId(), THIRTY);
+        final Invoice invoice = createAndPersistInvoice(invoiceSqlDao, invoiceItemSqlDao, clock, THIRTY, CURRENCY, context);
+        final InvoicePayment payment = createAndPersistPayment(invoicePaymentApi, clock, invoice.getId(), THIRTY, CURRENCY, context);
 
         // create a partial charge back
         invoicePaymentApi.createChargeback(payment.getId(), FIFTEEN, context);
@@ -210,45 +212,4 @@ public class TestChargeBacks extends InvoiceTestSuiteWithEmbeddedDB {
         assertEquals(chargebacks.size(), 1);
         assertEquals(chargebacks.get(0).getLinkedInvoicePaymentId(), payment.getId());
     }
-
-    private Invoice createAndPersistInvoice(final BigDecimal amount) {
-        final Invoice invoice = Mockito.mock(Invoice.class);
-        final UUID invoiceId = UUID.randomUUID();
-        final UUID accountId = UUID.randomUUID();
-
-        Mockito.when(invoice.getId()).thenReturn(invoiceId);
-        Mockito.when(invoice.getAccountId()).thenReturn(accountId);
-        Mockito.when(invoice.getInvoiceDate()).thenReturn(clock.getUTCToday());
-        Mockito.when(invoice.getTargetDate()).thenReturn(clock.getUTCToday());
-        Mockito.when(invoice.getCurrency()).thenReturn(CURRENCY);
-        Mockito.when(invoice.isMigrationInvoice()).thenReturn(false);
-
-        final List<InvoiceItem> items = new ArrayList<InvoiceItem>();
-        items.add(createInvoiceItem(invoiceId, accountId, amount));
-        Mockito.when(invoice.getInvoiceItems()).thenReturn(items);
-
-        invoiceSqlDao.create(invoice, context);
-
-        return invoice;
-    }
-
-    private InvoiceItem createInvoiceItem(final UUID invoiceId, final UUID accountId, final BigDecimal amount) {
-        return new FixedPriceInvoiceItem(invoiceId, accountId, UUID.randomUUID(), UUID.randomUUID(),
-                                         "charge back test", "charge back phase", clock.getUTCToday(), amount, CURRENCY);
-    }
-
-    private InvoicePayment createAndPersistPayment(final UUID invoiceId, final BigDecimal amount) {
-        final InvoicePayment payment = Mockito.mock(InvoicePayment.class);
-        Mockito.when(payment.getId()).thenReturn(UUID.randomUUID());
-        Mockito.when(payment.getType()).thenReturn(InvoicePaymentType.ATTEMPT);
-        Mockito.when(payment.getInvoiceId()).thenReturn(invoiceId);
-        Mockito.when(payment.getPaymentId()).thenReturn(UUID.randomUUID());
-        Mockito.when(payment.getPaymentDate()).thenReturn(clock.getUTCNow());
-        Mockito.when(payment.getAmount()).thenReturn(amount);
-        Mockito.when(payment.getCurrency()).thenReturn(CURRENCY);
-
-        invoicePaymentApi.notifyOfPayment(payment, context);
-
-        return payment;
-    }
 }