killbill-uncached

invoice: treat full item adjustments like repairs This fixes

11/6/2016 11:01:29 PM

Details

diff --git a/.idea/compiler.xml b/.idea/compiler.xml
index 9b1b3d8..c4af3f8 100644
--- a/.idea/compiler.xml
+++ b/.idea/compiler.xml
@@ -181,4 +181,4 @@
       <module name="killbill-util" target="1.6" />
     </bytecodeTargetLevel>
   </component>
-</project>
\ No newline at end of file
+</project>
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/tree/Item.java b/invoice/src/main/java/org/killbill/billing/invoice/tree/Item.java
index c5dd0d3..1b4ff98 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/tree/Item.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/tree/Item.java
@@ -93,6 +93,10 @@ public class Item {
     }
 
     public Item(final InvoiceItem item, final UUID targetInvoiceId, final ItemAction action) {
+        this(item, item.getStartDate(), item.getEndDate(), targetInvoiceId, action);
+    }
+
+    public Item(final InvoiceItem item, final LocalDate startDate, final LocalDate endDate, final UUID targetInvoiceId, final ItemAction action) {
         this.id = item.getId();
         this.accountId = item.getAccountId();
         this.bundleId = item.getBundleId();
@@ -101,8 +105,8 @@ public class Item {
         this.invoiceId = item.getInvoiceId();
         this.planName = item.getPlanName();
         this.phaseName = item.getPhaseName();
-        this.startDate = item.getStartDate();
-        this.endDate = item.getEndDate();
+        this.startDate = startDate;
+        this.endDate = endDate;
         this.amount = item.getAmount().abs();
         this.rate = item.getRate();
         this.currency = item.getCurrency();
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsInterval.java b/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsInterval.java
index eeeeba7..c217011 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsInterval.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsInterval.java
@@ -1,7 +1,7 @@
 /*
  * Copyright 2010-2014 Ning, Inc.
- * Copyright 2014-2015 Groupon, Inc
- * Copyright 2014-2015 The Billing Project, LLC
+ * Copyright 2014-2016 Groupon, Inc
+ * Copyright 2014-2016 The Billing Project, LLC
  *
  * The Billing Project licenses this file to you under the Apache License, version 2.0
  * (the "License"); you may not use this file except in compliance with the
@@ -18,14 +18,11 @@
 
 package org.killbill.billing.invoice.tree;
 
-import java.math.BigDecimal;
-import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
-import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
-import java.util.Map;
 import java.util.UUID;
 
 import org.joda.time.LocalDate;
@@ -35,7 +32,9 @@ import org.killbill.billing.invoice.tree.Item.ItemAction;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Predicate;
 import com.google.common.collect.Iterables;
+import com.google.common.collect.LinkedListMultimap;
 import com.google.common.collect.Lists;
+import com.google.common.collect.Multimap;
 
 /**
  * Keeps track of all the items existing on a specified interval.
@@ -59,23 +58,13 @@ public class ItemsInterval {
         }
     }
 
-    public boolean containsItem(final UUID targetId) {
+    public Item findItem(final UUID targetId) {
         return Iterables.tryFind(items, new Predicate<Item>() {
             @Override
             public boolean apply(final Item input) {
                 return input.getId().equals(targetId);
             }
-        }).orNull() != null;
-    }
-
-    public void setAdjustment(final BigDecimal amount, final UUID targetId) {
-        final Item item = Iterables.tryFind(items, new Predicate<Item>() {
-            @Override
-            public boolean apply(final Item input) {
-                return input.getId().equals(targetId);
-            }
-        }).get();
-        item.incrementAdjustedAmount(amount);
+        }).orNull();
     }
 
     public List<Item> getItems() {
@@ -109,25 +98,23 @@ public class ItemsInterval {
      * @return true if there is no more items
      */
     public boolean mergeCancellingPairs() {
-
-        final Map<UUID, List<Item>> tmp = new HashMap<UUID, List<Item>>();
-        for (Item cur : items) {
-            final UUID idToConsider = (cur.getAction() == ItemAction.ADD) ? cur.getId() : cur.getLinkedId();
-            List<Item> listForItem = tmp.get(idToConsider);
-            if (listForItem == null) {
-                listForItem = new ArrayList<Item>(2);
-                tmp.put(idToConsider, listForItem);
-            }
-            listForItem.add(cur);
+        final Multimap<UUID, Item> cancellingPairPerInvoiceItemId = LinkedListMultimap.<UUID, Item>create();
+        for (final Item item : items) {
+            final UUID invoiceItemId = (item.getAction() == ItemAction.ADD) ? item.getId() : item.getLinkedId();
+            cancellingPairPerInvoiceItemId.put(invoiceItemId, item);
         }
 
-        for (List<Item> listForIds : tmp.values()) {
-            if (listForIds.size() == 2) {
-                items.remove(listForIds.get(0));
-                items.remove(listForIds.get(1));
+        for (final UUID invoiceItemId : cancellingPairPerInvoiceItemId.keySet()) {
+            final Collection<Item> itemsToRemove = cancellingPairPerInvoiceItemId.get(invoiceItemId);
+            Preconditions.checkArgument(itemsToRemove.size() <= 2, "Too many repairs for invoiceItemId='%s': %s", invoiceItemId, itemsToRemove);
+            if (itemsToRemove.size() == 2) {
+                for (final Item itemToRemove : itemsToRemove) {
+                    items.remove(itemToRemove);
+                }
             }
         }
-        return items.size() == 0;
+
+        return items.isEmpty();
     }
 
     public Iterable<Item> get_ADD_items() {
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsNodeInterval.java b/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsNodeInterval.java
index 12bad8b..bc27585 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsNodeInterval.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsNodeInterval.java
@@ -24,11 +24,14 @@ import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.Iterator;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.UUID;
 
 import org.joda.time.LocalDate;
+import org.killbill.billing.invoice.api.InvoiceItem;
+import org.killbill.billing.invoice.tree.Item.ItemAction;
 import org.killbill.billing.util.jackson.ObjectMapper;
 
 import com.fasterxml.jackson.annotation.JsonIgnore;
@@ -221,22 +224,31 @@ public class ItemsNodeInterval extends NodeInterval {
 
     /**
      * Add the adjustment amount on the item specified by the targetId.
-     *
-     * @param adjustementDate date of the adjustment
-     * @param amount          amount of the adjustment
-     * @param targetId        item that has been adjusted
      */
-    public void addAdjustment(final LocalDate adjustementDate, final BigDecimal amount, final UUID targetId) {
-        // TODO we should really be using findNode(adjustementDate, new SearchCallback() instead but wrong dates in test
-        // creates test panic.
+    public void addAdjustment(final InvoiceItem item) {
+        final UUID targetId = item.getLinkedItemId();
+
+        // TODO we should really be using findNode(adjustmentDate, callback) instead but wrong dates in test creates panic.
         final NodeInterval node = findNode(new SearchCallback() {
             @Override
             public boolean isMatch(final NodeInterval curNode) {
-                return ((ItemsNodeInterval) curNode).getItemsInterval().containsItem(targetId);
+                return ((ItemsNodeInterval) curNode).getItemsInterval().findItem(targetId) != null;
             }
         });
-        Preconditions.checkNotNull(node, "Cannot add adjustment for item = " + targetId + ", date = " + adjustementDate);
-        ((ItemsNodeInterval) node).setAdjustment(amount.negate(), targetId);
+        Preconditions.checkNotNull(item, "Unable to find item interval for id='%s', tree=%s", targetId, this);
+
+        final ItemsInterval targetItemsInterval = ((ItemsNodeInterval) node).getItemsInterval();
+        final List<Item> targetItems = targetItemsInterval.getItems();
+        final Item targetItem = targetItemsInterval.findItem(targetId);
+        Preconditions.checkNotNull(item, "Unable to find item with id='%s', items=%s", targetId, targetItems);
+
+        final BigDecimal adjustmentAmount = item.getAmount().negate();
+        if (targetItem.getAmount().compareTo(adjustmentAmount) == 0) {
+            // Full item adjustment - treat it like a repair
+            addExistingItem(new ItemsNodeInterval(this, targetInvoiceId, new Item(item, targetItem.getStartDate(), targetItem.getEndDate(), targetInvoiceId, ItemAction.CANCEL)));
+        } else {
+            targetItem.incrementAdjustedAmount(adjustmentAmount);
+        }
     }
 
     public void jsonSerializeTree(final ObjectMapper mapper, final OutputStream output) throws IOException {
@@ -272,10 +284,6 @@ public class ItemsNodeInterval extends NodeInterval {
         generator.close();
     }
 
-    protected void setAdjustment(final BigDecimal amount, final UUID linkedId) {
-        items.setAdjustment(amount, linkedId);
-    }
-
     //
     // Before we build the tree, we make a first pass at removing full repaired items; those can come in two shapes:
     // Case A - The first one, is the mergeCancellingPairs logics which simply look for one CANCEL pointing to one ADD item in the same
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/tree/NodeInterval.java b/invoice/src/main/java/org/killbill/billing/invoice/tree/NodeInterval.java
index 0577d8b..0642778 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/tree/NodeInterval.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/tree/NodeInterval.java
@@ -343,6 +343,45 @@ public class NodeInterval {
         return result;
     }
 
+    @Override
+    public String toString() {
+        final StringBuilder sb = new StringBuilder("NodeInterval{");
+        sb.append("this=[")
+          .append(start)
+          .append(",")
+          .append(end)
+          .append("]");
+        if (parent == null) {
+            sb.append(", parent=").append(parent);
+        } else {
+            sb.append(", parent=[")
+              .append(parent.getStart())
+              .append(",")
+              .append(parent.getEnd())
+              .append("]");
+        }
+        if (leftChild == null) {
+            sb.append(", leftChild=").append(leftChild);
+        } else {
+            sb.append(", leftChild=[")
+              .append(leftChild.getStart())
+              .append(",")
+              .append(leftChild.getEnd())
+              .append("]");
+        }
+        if (rightSibling == null) {
+            sb.append(", rightSibling=").append(rightSibling);
+        } else {
+            sb.append(", rightSibling=[")
+              .append(rightSibling.getStart())
+              .append(",")
+              .append(rightSibling.getEnd())
+              .append("]");
+        }
+        sb.append('}');
+        return sb.toString();
+    }
+
     /**
      * Since items may be added out of order, there is no guarantee that we don't suddenly have a new node
      * whose interval emcompasses cuurent node(s). In which case we need to rebalance the tree.
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/tree/SubscriptionItemTree.java b/invoice/src/main/java/org/killbill/billing/invoice/tree/SubscriptionItemTree.java
index 2824c20..81ec626 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/tree/SubscriptionItemTree.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/tree/SubscriptionItemTree.java
@@ -89,8 +89,9 @@ public class SubscriptionItemTree {
      */
     public void build() {
         Preconditions.checkState(!isBuilt);
+
         for (InvoiceItem item : pendingItemAdj) {
-            root.addAdjustment(item.getStartDate(), item.getAmount(), item.getLinkedItemId());
+            root.addAdjustment(item);
         }
         pendingItemAdj.clear();
         root.buildForExistingItems(items);
diff --git a/invoice/src/test/java/org/killbill/billing/invoice/tree/TestSubscriptionItemTree.java b/invoice/src/test/java/org/killbill/billing/invoice/tree/TestSubscriptionItemTree.java
index a965ade..b8467f4 100644
--- a/invoice/src/test/java/org/killbill/billing/invoice/tree/TestSubscriptionItemTree.java
+++ b/invoice/src/test/java/org/killbill/billing/invoice/tree/TestSubscriptionItemTree.java
@@ -38,13 +38,22 @@ import org.killbill.billing.util.jackson.ObjectMapper;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
+import com.fasterxml.jackson.databind.SerializationFeature;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
 
 import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNull;
 import static org.testng.Assert.assertTrue;
 
 public class TestSubscriptionItemTree extends InvoiceTestSuiteNoDB {
 
+    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+
+    static {
+        OBJECT_MAPPER.enable(SerializationFeature.INDENT_OUTPUT);
+    }
+
     private final UUID invoiceId = UUID.randomUUID();
     private final UUID accountId = UUID.randomUUID();
     private final UUID subscriptionId = UUID.randomUUID();
@@ -1091,7 +1100,7 @@ public class TestSubscriptionItemTree extends InvoiceTestSuiteNoDB {
     }
 
     @Test(groups = "fast", description = "https://github.com/killbill/killbill/issues/286")
-    public void testMaxedOutProRation() {
+    public void testMaxedOutProRation() throws IOException {
         final LocalDate startDate = new LocalDate(2014, 1, 1);
         final LocalDate cancelDate = new LocalDate(2014, 1, 25);
         final LocalDate endDate = new LocalDate(2014, 2, 1);
@@ -1108,14 +1117,113 @@ public class TestSubscriptionItemTree extends InvoiceTestSuiteNoDB {
         tree.addItem(existingItemAdj1);
         tree.flatten(true);
 
+        //printTree(tree);
+
         final InvoiceItem proposed1 = new RecurringInvoiceItem(invoiceId, accountId, bundleId, subscriptionId, planName, phaseName, startDate, cancelDate, monthlyAmount1, monthlyRate1, currency);
         tree.mergeProposedItem(proposed1);
         tree.buildForMerge();
 
-        final List<InvoiceItem> expectedResult = Lists.newLinkedList();
+        //printTree(tree);
+
+        // We expect the proposed item because item adjustments don't change the subscription view of invoice
+        final List<InvoiceItem> expectedResult = ImmutableList.<InvoiceItem>of(proposed1);
+        verifyResult(tree.getView(), expectedResult);
+    }
+
+    @Test(groups = "fast")
+    public void testPartialProRation() {
+        final LocalDate startDate = new LocalDate(2014, 1, 1);
+        final LocalDate cancelDate = new LocalDate(2014, 1, 25);
+        final LocalDate endDate = new LocalDate(2014, 2, 1);
+
+        final BigDecimal monthlyRate1 = new BigDecimal("12.00");
+        final BigDecimal monthlyAmount1 = monthlyRate1;
+
+        final SubscriptionItemTree tree = new SubscriptionItemTree(subscriptionId, invoiceId);
+
+        final InvoiceItem existing1 = new RecurringInvoiceItem(invoiceId, accountId, bundleId, subscriptionId, planName, phaseName, startDate, endDate, monthlyAmount1, monthlyRate1, currency);
+        tree.addItem(existing1);
+        // Partially item adjust the recurring item
+        final InvoiceItem existingItemAdj1 = new ItemAdjInvoiceItem(existing1, startDate, monthlyRate1.negate().add(BigDecimal.ONE), currency);
+        tree.addItem(existingItemAdj1);
+        tree.flatten(true);
+
+        final InvoiceItem proposed1 = new RecurringInvoiceItem(invoiceId, accountId, bundleId, subscriptionId, planName, phaseName, startDate, cancelDate, monthlyAmount1, monthlyRate1, currency);
+        tree.mergeProposedItem(proposed1);
+        tree.buildForMerge();
+
+        final InvoiceItem repair = new RepairAdjInvoiceItem(invoiceId, accountId, cancelDate, endDate, BigDecimal.ONE.negate(), Currency.USD, existing1.getId());
+        final List<InvoiceItem> expectedResult = ImmutableList.<InvoiceItem>of(repair);
         verifyResult(tree.getView(), expectedResult);
     }
 
+    @Test(groups = "fast")
+    public void testWithWrongInitialItem() throws IOException {
+        final LocalDate wrongStartDate = new LocalDate(2016, 9, 9);
+        final LocalDate correctStartDate = new LocalDate(2016, 9, 8);
+        final LocalDate endDate = new LocalDate(2016, 10, 8);
+
+        final BigDecimal rate = new BigDecimal("12.00");
+        final BigDecimal amount = rate;
+
+        final SubscriptionItemTree tree = new SubscriptionItemTree(subscriptionId, invoiceId);
+
+        final InvoiceItem wrongInitialItem = new RecurringInvoiceItem(invoiceId,
+                                                                      accountId,
+                                                                      bundleId,
+                                                                      subscriptionId,
+                                                                      planName,
+                                                                      phaseName,
+                                                                      wrongStartDate,
+                                                                      endDate,
+                                                                      amount,
+                                                                      rate,
+                                                                      currency);
+        tree.addItem(wrongInitialItem);
+
+        final InvoiceItem itemAdj = new ItemAdjInvoiceItem(wrongInitialItem,
+                                                           new LocalDate(2016, 10, 2),
+                                                           amount.negate(),
+                                                           currency);
+        tree.addItem(itemAdj);
+
+        final InvoiceItem correctInitialItem = new RecurringInvoiceItem(invoiceId,
+                                                                        accountId,
+                                                                        bundleId,
+                                                                        subscriptionId,
+                                                                        planName,
+                                                                        phaseName,
+                                                                        correctStartDate,
+                                                                        endDate,
+                                                                        amount,
+                                                                        rate,
+                                                                        currency);
+        tree.addItem(correctInitialItem);
+
+        assertEquals(tree.getRoot().getStart(), correctStartDate);
+        assertEquals(tree.getRoot().getEnd(), endDate);
+        assertEquals(tree.getRoot().getLeftChild().getStart(), correctStartDate);
+        assertEquals(tree.getRoot().getLeftChild().getEnd(), endDate);
+        assertEquals(tree.getRoot().getLeftChild().getLeftChild().getStart(), wrongStartDate);
+        assertEquals(tree.getRoot().getLeftChild().getLeftChild().getEnd(), endDate);
+        assertNull(tree.getRoot().getLeftChild().getLeftChild().getLeftChild());
+        assertNull(tree.getRoot().getLeftChild().getLeftChild().getRightSibling());
+        assertNull(tree.getRoot().getLeftChild().getRightSibling());
+        assertNull(tree.getRoot().getRightSibling());
+
+        tree.flatten(true);
+
+        assertEquals(tree.getRoot().getStart(), correctStartDate);
+        assertEquals(tree.getRoot().getEnd(), endDate);
+        assertEquals(tree.getRoot().getLeftChild().getStart(), correctStartDate);
+        assertEquals(tree.getRoot().getLeftChild().getEnd(), endDate);
+        assertNull(tree.getRoot().getLeftChild().getLeftChild());
+        assertNull(tree.getRoot().getLeftChild().getRightSibling());
+        assertNull(tree.getRoot().getRightSibling());
+
+        //printTree(tree);
+    }
+
     private void verifyResult(final List<InvoiceItem> result, final List<InvoiceItem> expectedResult) {
         assertEquals(result.size(), expectedResult.size());
         for (int i = 0; i < expectedResult.size(); i++) {
@@ -1123,7 +1231,6 @@ public class TestSubscriptionItemTree extends InvoiceTestSuiteNoDB {
         }
     }
 
-
     @Test(groups = "fast")
     public void testWithWrongInitialItemInLoop() {
 
@@ -1172,4 +1279,10 @@ public class TestSubscriptionItemTree extends InvoiceTestSuiteNoDB {
         // We have repaired wrongInitialItem and generated the correctInitialItem and stopped
         Assert.assertEquals(previousExistingSize, 3);
     }
+
+    private void printTree(final SubscriptionItemTree tree) throws IOException {
+        final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        tree.getRoot().jsonSerializeTree(OBJECT_MAPPER, outputStream);
+        System.out.println(outputStream.toString("UTF-8"));
+    }
 }