killbill-memoizeit

Implement auto relbalancing of CBA

11/29/2012 11:44:00 PM

Details

diff --git a/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java b/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java
index c1920f4..e928a24 100644
--- a/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java
+++ b/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java
@@ -19,6 +19,7 @@ package com.ning.billing.invoice.dao;
 import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
 import java.util.UUID;
@@ -55,8 +56,8 @@ import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Objects;
 import com.google.common.base.Predicate;
 import com.google.common.collect.Collections2;
-import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap.Builder;
+import com.google.common.collect.Ordering;
 import com.google.inject.Inject;
 
 public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, InvoiceApiException> implements InvoiceDao {
@@ -199,13 +200,16 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                         transInvoiceItemSqlDao.create(invoiceItemModelDao, context);
                     }
 
+                    // Now we check whether we generated any credit that could be used on some unpaid invoices
+                    useExistingCBAFromTransaction(invoice.getAccountId(), entitySqlDaoWrapperFactory, context);
+
                     notifyOfFutureBillingEvents(entitySqlDaoWrapperFactory, invoice.getAccountId(), callbackDateTimePerSubscriptions);
 
                     // Create associated payments
                     final InvoicePaymentSqlDao invoicePaymentSqlDao = entitySqlDaoWrapperFactory.become(InvoicePaymentSqlDao.class);
                     invoicePaymentSqlDao.batchCreateFromTransaction(invoicePayments, context);
-                }
 
+                }
                 return null;
             }
         });
@@ -259,18 +263,12 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
         return transactionalSqlDao.execute(new EntitySqlDaoTransactionWrapper<List<InvoiceModelDao>>() {
             @Override
             public List<InvoiceModelDao> inTransaction(final EntitySqlDaoWrapperFactory<EntitySqlDao> entitySqlDaoWrapperFactory) throws Exception {
-                final List<InvoiceModelDao> invoices = getAllInvoicesByAccountFromTransaction(accountId, entitySqlDaoWrapperFactory, context);
-                final Collection<InvoiceModelDao> unpaidInvoices = Collections2.filter(invoices, new Predicate<InvoiceModelDao>() {
-                    @Override
-                    public boolean apply(final InvoiceModelDao in) {
-                        return (InvoiceModelDaoHelper.getBalance(in).compareTo(BigDecimal.ZERO) >= 1) && (upToDate == null || !in.getTargetDate().isAfter(upToDate));
-                    }
-                });
-                return new ArrayList<InvoiceModelDao>(unpaidInvoices);
+                return getUnpaidInvoicesByAccountFromTransaction(accountId, entitySqlDaoWrapperFactory, upToDate, context);
             }
         });
     }
 
+
     @Override
     public UUID getInvoiceIdByPaymentId(final UUID paymentId, final InternalTenantContext context) {
         return transactionalSqlDao.execute(new EntitySqlDaoTransactionWrapper<UUID>() {
@@ -292,7 +290,6 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
     }
 
     @Override
-
     public InvoicePaymentModelDao createRefund(final UUID paymentId, final BigDecimal requestedRefundAmount, final boolean isInvoiceAdjusted,
                                                final Map<UUID, BigDecimal> invoiceItemIdsWithNullAmounts, final UUID paymentCookieId,
                                                final InternalCallContext context)
@@ -368,6 +365,9 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                     }
                 }
 
+                // Now we check whether we have any credit that could be used on some unpaid invoices (for which payment was just refunded)
+                useExistingCBAFromTransaction(invoice.getAccountId(), entitySqlDaoWrapperFactory, context);
+
                 // Notify the bus since the balance of the invoice changed
                 notifyBusOfInvoiceAdjustment(entitySqlDaoWrapperFactory, invoice.getId(), invoice.getAccountId(), context.getUserToken(), context);
 
@@ -496,18 +496,20 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                 final InvoicePaymentModelDao payment = entitySqlDaoWrapperFactory.become(InvoicePaymentSqlDao.class).getById(invoicePaymentId.toString(), context);
                 if (payment == null) {
                     throw new InvoiceApiException(ErrorCode.INVOICE_PAYMENT_NOT_FOUND, invoicePaymentId.toString());
-                } else {
-                    final InvoicePaymentModelDao chargeBack = new InvoicePaymentModelDao(UUID.randomUUID(), context.getCreatedDate(), InvoicePaymentType.CHARGED_BACK,
-                                                                                         payment.getInvoiceId(), payment.getPaymentId(), context.getCreatedDate(),
-                                                                                         requestedChargedBackAmout.negate(), payment.getCurrency(), null, payment.getId());
-                    transactional.create(chargeBack, context);
+                }
+                final InvoicePaymentModelDao chargeBack = new InvoicePaymentModelDao(UUID.randomUUID(), context.getCreatedDate(), InvoicePaymentType.CHARGED_BACK,
+                                                                                     payment.getInvoiceId(), payment.getPaymentId(), context.getCreatedDate(),
+                                                                                     requestedChargedBackAmout.negate(), payment.getCurrency(), null, payment.getId());
+                transactional.create(chargeBack, context);
 
-                    // Notify the bus since the balance of the invoice changed
-                    final UUID accountId = transactional.getAccountIdFromInvoicePaymentId(chargeBack.getId().toString(), context);
-                    notifyBusOfInvoiceAdjustment(entitySqlDaoWrapperFactory, payment.getInvoiceId(), accountId, context.getUserToken(), context);
+                // Notify the bus since the balance of the invoice changed
+                final UUID accountId = transactional.getAccountIdFromInvoicePaymentId(chargeBack.getId().toString(), context);
+                notifyBusOfInvoiceAdjustment(entitySqlDaoWrapperFactory, payment.getInvoiceId(), accountId, context.getUserToken(), context);
 
-                    return chargeBack;
-                }
+                // Now we check whether we have any credit that could be used on some unpaid invoices (for which payment was just charged back)
+                useExistingCBAFromTransaction(accountId, entitySqlDaoWrapperFactory, context);
+
+                return chargeBack;
             }
         });
     }
@@ -625,18 +627,8 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
                 }
                 populateChildren(invoice, entitySqlDaoWrapperFactory, context);
 
-                final BigDecimal accountCbaAvailable = getAccountCBAFromTransaction(invoice.getAccountId(), entitySqlDaoWrapperFactory, context);
-                final BigDecimal balance = InvoiceModelDaoHelper.getBalance(invoice);
-                if (accountCbaAvailable.compareTo(BigDecimal.ZERO) > 0 && balance.compareTo(BigDecimal.ZERO) > 0) {
-                    final BigDecimal cbaAmountToConsume = accountCbaAvailable.compareTo(balance) > 0 ? balance.negate() : accountCbaAvailable.negate();
-                    final InvoiceItemModelDao cbaAdjItem = new InvoiceItemModelDao(context.getCreatedDate(), InvoiceItemType.CBA_ADJ,
-                                                                                   invoice.getId(), invoice.getAccountId(),
-                                                                                   null, null, null, null,
-                                                                                   context.getCreatedDate().toLocalDate(),
-                                                                                   null, cbaAmountToConsume, null,
-                                                                                   invoice.getCurrency(), null);
-                    transInvoiceItemDao.create(cbaAdjItem, context);
-                }
+                // Now we check whether we have any credit that could be used towards that charge
+                useExistingCBAFromTransaction(accountId, entitySqlDaoWrapperFactory, context);
 
                 // Notify the bus since the balance of the invoice changed
                 notifyBusOfInvoiceAdjustment(entitySqlDaoWrapperFactory, invoiceId, accountId, context.getUserToken(), context);
@@ -799,6 +791,58 @@ public class DefaultInvoiceDao extends EntityDaoBase<InvoiceModelDao, Invoice, I
         });
     }
 
+    private void useExistingCBAFromTransaction(final UUID accountId, final EntitySqlDaoWrapperFactory<EntitySqlDao> entitySqlDaoWrapperFactory, final InternalCallContext context) throws InvoiceApiException, EntityPersistenceException {
+
+        final BigDecimal accountCBA = getAccountCBAFromTransaction(accountId, entitySqlDaoWrapperFactory, context);
+        if (accountCBA.compareTo(BigDecimal.ZERO) <= 0) {
+            return;
+        }
+
+        final List<InvoiceModelDao> unpaidInvoices = getUnpaidInvoicesByAccountFromTransaction(accountId, entitySqlDaoWrapperFactory, null, context);
+        // We order the same os BillingStateCalculator-- should really share the comparator
+        final List<InvoiceModelDao> orderedUnpaidInvoices = Ordering.from(new Comparator<InvoiceModelDao>() {
+            @Override
+            public int compare(final InvoiceModelDao i1, final InvoiceModelDao i2) {
+                return i1.getInvoiceDate().compareTo(i2.getInvoiceDate());
+            }
+        }).immutableSortedCopy(unpaidInvoices);
+
+        BigDecimal remainingAccountCBA = accountCBA;
+        for (InvoiceModelDao cur : orderedUnpaidInvoices) {
+            final BigDecimal curInvoiceBalance = InvoiceModelDaoHelper.getBalance(cur);
+            final BigDecimal cbaToApplyOnInvoice = remainingAccountCBA.compareTo(curInvoiceBalance) <= 0 ? remainingAccountCBA : curInvoiceBalance;
+            remainingAccountCBA = remainingAccountCBA.subtract(cbaToApplyOnInvoice);
+
+
+            final InvoiceItemModelDao cbaAdjItem = new InvoiceItemModelDao(context.getCreatedDate(), InvoiceItemType.CBA_ADJ,
+                                                                           cur.getId(), cur.getAccountId(),
+                                                                           null, null, null, null,
+                                                                           context.getCreatedDate().toLocalDate(),
+                                                                           null, cbaToApplyOnInvoice.negate(), null,
+                                                                           cur.getCurrency(), null);
+
+            final InvoiceItemSqlDao transInvoiceItemDao = entitySqlDaoWrapperFactory.become(InvoiceItemSqlDao.class);
+            transInvoiceItemDao.create(cbaAdjItem, context);
+
+            if (remainingAccountCBA.compareTo(BigDecimal.ZERO) <= 0) {
+                break;
+            }
+        }
+    }
+
+
+    private List<InvoiceModelDao> getUnpaidInvoicesByAccountFromTransaction(final UUID accountId, final EntitySqlDaoWrapperFactory<EntitySqlDao> entitySqlDaoWrapperFactory, final LocalDate upToDate, final InternalTenantContext context) {
+        final List<InvoiceModelDao> invoices = getAllInvoicesByAccountFromTransaction(accountId, entitySqlDaoWrapperFactory, context);
+        final Collection<InvoiceModelDao> unpaidInvoices = Collections2.filter(invoices, new Predicate<InvoiceModelDao>() {
+            @Override
+            public boolean apply(final InvoiceModelDao in) {
+                return (InvoiceModelDaoHelper.getBalance(in).compareTo(BigDecimal.ZERO) >= 1) && (upToDate == null || !in.getTargetDate().isAfter(upToDate));
+            }
+        });
+        return new ArrayList<InvoiceModelDao>(unpaidInvoices);
+    }
+
+
     /**
      * Create an adjustment for a given invoice item. This just creates the object in memory, it doesn't write it to disk.
      *
diff --git a/invoice/src/test/java/com/ning/billing/invoice/dao/TestInvoiceDao.java b/invoice/src/test/java/com/ning/billing/invoice/dao/TestInvoiceDao.java
index 43269b5..736bfe3 100644
--- a/invoice/src/test/java/com/ning/billing/invoice/dao/TestInvoiceDao.java
+++ b/invoice/src/test/java/com/ning/billing/invoice/dao/TestInvoiceDao.java
@@ -631,14 +631,16 @@ public class TestInvoiceDao extends InvoiceDaoTestBase {
         final boolean partialRefund = refundAmount.compareTo(amount) < 0;
         final BigDecimal cba = invoiceDao.getAccountCBA(accountId, internalCallContext);
         final InvoiceModelDao savedInvoice = invoiceDao.getById(invoice.getId(), internalCallContext);
-        assertEquals(cba.compareTo(new BigDecimal("20.0")), 0);
+
+        final BigDecimal expectedCba = balance.compareTo(BigDecimal.ZERO) < 0 ? balance.negate() : BigDecimal.ZERO;
+        assertEquals(cba.compareTo(expectedCba), 0);
         if (partialRefund) {
             // IB = 20 (rec) - 20 (repair) + 20 (cba) - (20 -7) = 7;  AB = IB - CBA = 7 - 20 = -13
             assertEquals(balance.compareTo(new BigDecimal("-13.0")), 0);
-            assertEquals(savedInvoice.getInvoiceItems().size(), 3);
+            assertEquals(savedInvoice.getInvoiceItems().size(), 4);
         } else {
             assertEquals(balance.compareTo(new BigDecimal("0.0")), 0);
-            assertEquals(savedInvoice.getInvoiceItems().size(), 3);
+            assertEquals(savedInvoice.getInvoiceItems().size(), 4);
         }
     }
 
@@ -730,7 +732,8 @@ public class TestInvoiceDao extends InvoiceDaoTestBase {
         balance = invoiceDao.getAccountBalance(accountId, internalCallContext);
         assertEquals(balance.compareTo(expectedFinalBalance), 0);
         cba = invoiceDao.getAccountCBA(accountId, internalCallContext);
-        assertEquals(cba.compareTo(new BigDecimal("10.00")), 0);
+        final BigDecimal expectedCba = balance.compareTo(BigDecimal.ZERO) < 0 ? balance.negate() : BigDecimal.ZERO;
+        assertEquals(cba.compareTo(expectedCba), 0);
     }
 
     @Test(groups = "slow")
@@ -738,13 +741,8 @@ public class TestInvoiceDao extends InvoiceDaoTestBase {
 
         final UUID accountId = UUID.randomUUID();
         final UUID bundleId = UUID.randomUUID();
-        final LocalDate targetDate1 = new LocalDate(2011, 10, 6);
-        final Invoice invoice1 = new DefaultInvoice(accountId, clock.getUTCToday(), targetDate1, Currency.USD);
-        createInvoice(invoice1, true, internalCallContext);
 
-        // CREATE INVOICE WITH A (just) CBA. Should not happen, but that does not matter for that test
-        final CreditBalanceAdjInvoiceItem cbaItem = new CreditBalanceAdjInvoiceItem(invoice1.getId(), accountId, new LocalDate(), new BigDecimal("20.0"), Currency.USD);
-        createInvoiceItem(cbaItem, internalCallContext);
+        invoiceDao.insertCredit(accountId, null,  new BigDecimal("20.0"), new LocalDate(), Currency.USD, internalCallContext);
 
         final InvoiceItemModelDao charge = invoiceDao.insertExternalCharge(accountId, null, bundleId, "bla", new BigDecimal("15.0"), clock.getUTCNow().toLocalDate(), Currency.USD, internalCallContext);
 
@@ -1358,7 +1356,7 @@ public class TestInvoiceDao extends InvoiceDaoTestBase {
     }
 
     @Test(groups = "slow")
-    public void testDeleteCBAPartiallyConsumed() throws Exception {
+    public void testRefundWithCBAPartiallyConsumed() throws Exception {
         final UUID accountId = UUID.randomUUID();
 
         // Create invoice 1
@@ -1377,6 +1375,12 @@ public class TestInvoiceDao extends InvoiceDaoTestBase {
         final CreditBalanceAdjInvoiceItem creditBalanceAdjInvoiceItem1 = new CreditBalanceAdjInvoiceItem(fixedItem1.getInvoiceId(), fixedItem1.getAccountId(),
                                                                                                          fixedItem1.getStartDate(), fixedItem1.getAmount(),
                                                                                                          fixedItem1.getCurrency());
+
+        final UUID paymentId = UUID.randomUUID();
+        final DefaultInvoicePayment defaultInvoicePayment = new DefaultInvoicePayment(InvoicePaymentType.ATTEMPT, paymentId, invoice1.getId(), clock.getUTCNow().plusDays(12), new BigDecimal("10.0"), Currency.USD);
+
+        invoiceDao.notifyOfPayment(new InvoicePaymentModelDao(defaultInvoicePayment), internalCallContext);
+
         createInvoice(invoice1, true, internalCallContext);
         createInvoiceItem(fixedItem1, internalCallContext);
         createInvoiceItem(repairAdjInvoiceItem, internalCallContext);
@@ -1398,20 +1402,20 @@ public class TestInvoiceDao extends InvoiceDaoTestBase {
 
         // Verify scenario - half of the CBA should have been used
         Assert.assertEquals(invoiceDao.getAccountCBA(accountId, internalCallContext).doubleValue(), 5.00);
-        verifyInvoice(invoice1.getId(), 10.00, 10.00);
+        verifyInvoice(invoice1.getId(), 0.00, 10.00);
         verifyInvoice(invoice2.getId(), 0.00, -5.00);
 
-        // Delete the CBA on invoice 1
-        invoiceDao.deleteCBA(accountId, invoice1.getId(), creditBalanceAdjInvoiceItem1.getId(), internalCallContext);
+        // Refund Payment before we can deleted CBA
+        invoiceDao.createRefund(paymentId, new BigDecimal("10.0"), false, ImmutableMap.<UUID,BigDecimal>of(), UUID.randomUUID(), internalCallContext);
 
         // Verify all three invoices were affected
         Assert.assertEquals(invoiceDao.getAccountCBA(accountId, internalCallContext).doubleValue(), 0.00);
-        verifyInvoice(invoice1.getId(), 0.00, 0.00);
-        verifyInvoice(invoice2.getId(), 5.00, 0.00);
+        verifyInvoice(invoice1.getId(), 5.00, 5.00);
+        verifyInvoice(invoice2.getId(), 0.00, -5.00);
     }
 
     @Test(groups = "slow")
-    public void testDeleteCBAFullyConsumedTwice() throws Exception {
+    public void testRefundCBAFullyConsumedTwice() throws Exception {
         final UUID accountId = UUID.randomUUID();
 
         // Create invoice 1
@@ -1435,6 +1439,13 @@ public class TestInvoiceDao extends InvoiceDaoTestBase {
         createInvoiceItem(repairAdjInvoiceItem, internalCallContext);
         createInvoiceItem(creditBalanceAdjInvoiceItem1, internalCallContext);
 
+
+        final BigDecimal paymentAmount = new BigDecimal("10.00");
+        final UUID paymentId = UUID.randomUUID();
+
+        final DefaultInvoicePayment defaultInvoicePayment = new DefaultInvoicePayment(InvoicePaymentType.ATTEMPT, paymentId, invoice1.getId(), clock.getUTCNow().plusDays(12), paymentAmount, Currency.USD);
+        invoiceDao.notifyOfPayment(new InvoicePaymentModelDao(defaultInvoicePayment), internalCallContext);
+
         // Create invoice 2
         // Scenario: single item
         // * $5 item
@@ -1465,18 +1476,17 @@ public class TestInvoiceDao extends InvoiceDaoTestBase {
 
         // Verify scenario - all CBA should have been used
         Assert.assertEquals(invoiceDao.getAccountCBA(accountId, internalCallContext).doubleValue(), 0.00);
-        verifyInvoice(invoice1.getId(), 10.00, 10.00);
+        verifyInvoice(invoice1.getId(), 0.00, 10.00);
         verifyInvoice(invoice2.getId(), 0.00, -5.00);
         verifyInvoice(invoice3.getId(), 0.00, -5.00);
 
-        // Delete the CBA on invoice 1
-        invoiceDao.deleteCBA(accountId, invoice1.getId(), creditBalanceAdjInvoiceItem1.getId(), internalCallContext);
+        invoiceDao.createRefund(paymentId, paymentAmount, false, ImmutableMap.<UUID, BigDecimal>of(), UUID.randomUUID(), internalCallContext);
 
         // Verify all three invoices were affected
         Assert.assertEquals(invoiceDao.getAccountCBA(accountId, internalCallContext).doubleValue(), 0.00);
-        verifyInvoice(invoice1.getId(), 0.00, 0.00);
-        verifyInvoice(invoice2.getId(), 5.00, 0.00);
-        verifyInvoice(invoice3.getId(), 5.00, 0.00);
+        verifyInvoice(invoice1.getId(), 10.00, 10.00);
+        verifyInvoice(invoice2.getId(), 0.00, -5.00);
+        verifyInvoice(invoice3.getId(), 0.00, -5.00);
     }
 
     @Test(groups = "slow")