killbill-memoizeit

Details

diff --git a/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsInterval.java b/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsInterval.java
index 07b571c..8047cd5 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsInterval.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/tree/ItemsInterval.java
@@ -165,7 +165,7 @@ public class ItemsInterval {
         // - Nothing at all; this valid, this just means its original items got removed during mergeCancellingPairs logic,
         //   but its NodeInterval has children so it could not be deleted.
         //
-        Preconditions.checkState(items.size() <= 2);
+        Preconditions.checkState(items.size() <= 2, "Double billing detected: %s", items);
 
         final Item item = items.size() > 0 && items.get(0).getAction() == ItemAction.ADD ? items.get(0) : null;
         return item;
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/tree/NodeInterval.java b/invoice/src/main/java/org/killbill/billing/invoice/tree/NodeInterval.java
index 884e31d..3edb7c0 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/tree/NodeInterval.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/tree/NodeInterval.java
@@ -150,22 +150,40 @@ public class NodeInterval {
     }
 
     public void removeChild(final NodeInterval toBeRemoved) {
-
         NodeInterval prevChild = null;
         NodeInterval curChild = leftChild;
         while (curChild != null) {
             if (curChild.isSame(toBeRemoved)) {
                 if (prevChild == null) {
-                    leftChild = curChild.getRightSibling();
+                    if (curChild.getLeftChild() == null) {
+                        leftChild = curChild.getRightSibling();
+                    } else {
+                        leftChild = curChild.getLeftChild();
+                        adjustRightMostChildSibling(curChild);
+                    }
                 } else {
-                    prevChild.rightSibling = curChild.getRightSibling();
+                    if (curChild.getLeftChild() == null) {
+                        prevChild.rightSibling = curChild.getRightSibling();
+                    } else {
+                        prevChild.rightSibling = curChild.getLeftChild();
+                        adjustRightMostChildSibling(curChild);
+                    }
                 }
                 break;
             }
             prevChild = curChild;
             curChild = curChild.getRightSibling();
         }
+    }
 
+    private void adjustRightMostChildSibling(final NodeInterval curNode) {
+        NodeInterval tmpChild = curNode.getLeftChild();
+        NodeInterval preTmpChild = null;
+        while (tmpChild != null) {
+            preTmpChild = tmpChild;
+            tmpChild = tmpChild.getRightSibling();
+        }
+        preTmpChild.rightSibling = curNode.getRightSibling();
     }
 
     @JsonIgnore
diff --git a/invoice/src/test/java/org/killbill/billing/invoice/tree/TestNodeInterval.java b/invoice/src/test/java/org/killbill/billing/invoice/tree/TestNodeInterval.java
index ba35fff..3e1a2ba 100644
--- a/invoice/src/test/java/org/killbill/billing/invoice/tree/TestNodeInterval.java
+++ b/invoice/src/test/java/org/killbill/billing/invoice/tree/TestNodeInterval.java
@@ -16,17 +16,18 @@
 
 package org.killbill.billing.invoice.tree;
 
+import java.util.ArrayList;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.UUID;
 
 import org.joda.time.LocalDate;
-import org.testng.annotations.Test;
-
 import org.killbill.billing.invoice.tree.NodeInterval.AddNodeCallback;
 import org.killbill.billing.invoice.tree.NodeInterval.BuildNodeCallback;
 import org.killbill.billing.invoice.tree.NodeInterval.SearchCallback;
 import org.killbill.billing.invoice.tree.NodeInterval.WalkCallback;
+import org.testng.Assert;
+import org.testng.annotations.Test;
 
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertNull;
@@ -52,6 +53,28 @@ public class TestNodeInterval /* extends InvoiceTestSuiteNoDB  */ {
         public UUID getId() {
             return id;
         }
+
+        @Override
+        public boolean equals(final Object o) {
+            if (this == o) {
+                return true;
+            }
+            if (!(o instanceof DummyNodeInterval)) {
+                return false;
+            }
+
+            final DummyNodeInterval that = (DummyNodeInterval) o;
+
+            if (id != null ? !id.equals(that.id) : that.id != null) {
+                return false;
+            }
+            return true;
+        }
+
+        @Override
+        public int hashCode() {
+            return id != null ? id.hashCode() : 0;
+        }
     }
 
     public class DummyAddNodeCallback implements AddNodeCallback {
@@ -67,6 +90,9 @@ public class TestNodeInterval /* extends InvoiceTestSuiteNoDB  */ {
         }
     }
 
+
+
+
     @Test(groups = "fast")
     public void testAddExistingItemSimple() {
         final DummyNodeInterval root = new DummyNodeInterval();
@@ -289,6 +315,118 @@ public class TestNodeInterval /* extends InvoiceTestSuiteNoDB  */ {
         }
     }
 
+    @Test(groups = "fast")
+    public void testRemoveLeftChildWithGrandChildren() {
+        final DummyNodeInterval root = new DummyNodeInterval();
+
+        final DummyNodeInterval top = createNodeInterval("2014-01-01", "2014-02-01");
+        root.addNode(top, CALLBACK);
+
+        final DummyNodeInterval firstChildLevel1 = createNodeInterval("2014-01-01", "2014-01-20");
+        final DummyNodeInterval secondChildLevel1 = createNodeInterval("2014-01-21", "2014-01-31");
+        root.addNode(firstChildLevel1, CALLBACK);
+        root.addNode(secondChildLevel1, CALLBACK);
+
+
+        final DummyNodeInterval firstChildLevel2 = createNodeInterval("2014-01-01", "2014-01-03");
+        final DummyNodeInterval secondChildLevel2 = createNodeInterval("2014-01-04", "2014-01-10");
+        final DummyNodeInterval thirdChildLevel2 = createNodeInterval("2014-01-11", "2014-01-20");
+        root.addNode(firstChildLevel2, CALLBACK);
+        root.addNode(secondChildLevel2, CALLBACK);
+        root.addNode(thirdChildLevel2, CALLBACK);
+
+        // Let's verify we get it right prior removing the node
+        final List<NodeInterval> expectedNodes = new ArrayList<NodeInterval>();
+        expectedNodes.add(root);
+        expectedNodes.add(top);
+        expectedNodes.add(firstChildLevel1);
+        expectedNodes.add(firstChildLevel2);
+        expectedNodes.add(secondChildLevel2);
+        expectedNodes.add(thirdChildLevel2);
+        expectedNodes.add(secondChildLevel1);
+
+        root.walkTree(new WalkCallback() {
+            @Override
+            public void onCurrentNode(final int depth, final NodeInterval curNode, final NodeInterval parent) {
+                Assert.assertEquals(curNode, expectedNodes.remove(0));
+            }
+        });
+
+        // Remove node and verify again
+        top.removeChild(firstChildLevel1);
+
+        final List<NodeInterval> expectedNodesAfterRemoval = new ArrayList<NodeInterval>();
+        expectedNodesAfterRemoval.add(root);
+        expectedNodesAfterRemoval.add(top);
+        expectedNodesAfterRemoval.add(firstChildLevel2);
+        expectedNodesAfterRemoval.add(secondChildLevel2);
+        expectedNodesAfterRemoval.add(thirdChildLevel2);
+        expectedNodesAfterRemoval.add(secondChildLevel1);
+
+        root.walkTree(new WalkCallback() {
+            @Override
+            public void onCurrentNode(final int depth, final NodeInterval curNode, final NodeInterval parent) {
+                Assert.assertEquals(curNode, expectedNodesAfterRemoval.remove(0));
+            }
+        });
+    }
+
+    @Test(groups = "fast")
+    public void testRemoveMiddleChildWithGrandChildren() {
+        final DummyNodeInterval root = new DummyNodeInterval();
+
+        final DummyNodeInterval top = createNodeInterval("2014-01-01", "2014-02-01");
+        root.addNode(top, CALLBACK);
+
+        final DummyNodeInterval firstChildLevel1 = createNodeInterval("2014-01-01", "2014-01-20");
+        final DummyNodeInterval secondChildLevel1 = createNodeInterval("2014-01-21", "2014-01-31");
+        root.addNode(firstChildLevel1, CALLBACK);
+        root.addNode(secondChildLevel1, CALLBACK);
+
+
+        final DummyNodeInterval firstChildLevel2 = createNodeInterval("2014-01-21", "2014-01-23");
+        final DummyNodeInterval secondChildLevel2 = createNodeInterval("2014-01-24", "2014-01-25");
+        final DummyNodeInterval thirdChildLevel2 = createNodeInterval("2014-01-26", "2014-01-31");
+        root.addNode(firstChildLevel2, CALLBACK);
+        root.addNode(secondChildLevel2, CALLBACK);
+        root.addNode(thirdChildLevel2, CALLBACK);
+
+        // Original List without removing node:
+        final List<NodeInterval> expectedNodes = new ArrayList<NodeInterval>();
+        expectedNodes.add(root);
+        expectedNodes.add(top);
+        expectedNodes.add(firstChildLevel1);
+        expectedNodes.add(secondChildLevel1);
+        expectedNodes.add(firstChildLevel2);
+        expectedNodes.add(secondChildLevel2);
+        expectedNodes.add(thirdChildLevel2);
+
+        root.walkTree(new WalkCallback() {
+            @Override
+            public void onCurrentNode(final int depth, final NodeInterval curNode, final NodeInterval parent) {
+                Assert.assertEquals(curNode, expectedNodes.remove(0));
+            }
+        });
+
+        top.removeChild(secondChildLevel1);
+
+        final List<NodeInterval> expectedNodesAfterRemoval = new ArrayList<NodeInterval>();
+        expectedNodesAfterRemoval.add(root);
+        expectedNodesAfterRemoval.add(top);
+        expectedNodesAfterRemoval.add(firstChildLevel1);
+        expectedNodesAfterRemoval.add(firstChildLevel2);
+        expectedNodesAfterRemoval.add(secondChildLevel2);
+        expectedNodesAfterRemoval.add(thirdChildLevel2);
+
+        root.walkTree(new WalkCallback() {
+            @Override
+            public void onCurrentNode(final int depth, final NodeInterval curNode, final NodeInterval parent) {
+                Assert.assertEquals(curNode, expectedNodesAfterRemoval.remove(0));
+            }
+        });
+
+    }
+
     private void checkInterval(final NodeInterval real, final NodeInterval expected) {
         assertEquals(real.getStart(), expected.getStart());
         assertEquals(real.getEnd(), expected.getEnd());