killbill-uncached

junction: Fix incorrect computation for blocking states when

11/2/2015 11:54:54 PM

Details

diff --git a/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/BlockingCalculator.java b/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/BlockingCalculator.java
index c6ee181..6806e9c 100644
--- a/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/BlockingCalculator.java
+++ b/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/BlockingCalculator.java
@@ -18,8 +18,11 @@ package org.killbill.billing.junction.plumbing.billing;
 
 import java.math.BigDecimal;
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.Hashtable;
 import java.util.List;
+import java.util.Map;
 import java.util.SortedSet;
 import java.util.TreeSet;
 import java.util.UUID;
@@ -37,12 +40,17 @@ import org.killbill.billing.catalog.api.Currency;
 import org.killbill.billing.catalog.api.Plan;
 import org.killbill.billing.catalog.api.PlanPhase;
 import org.killbill.billing.entitlement.api.BlockingState;
+import org.killbill.billing.entitlement.api.BlockingStateType;
 import org.killbill.billing.junction.BillingEvent;
 import org.killbill.billing.junction.BlockingInternalApi;
 import org.killbill.billing.subscription.api.SubscriptionBase;
 import org.killbill.billing.subscription.api.SubscriptionBaseTransitionType;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Predicate;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
 import com.google.inject.Inject;
 
 public class BlockingCalculator {
@@ -94,12 +102,29 @@ public class BlockingCalculator {
         final SortedSet<BillingEvent> billingEventsToAdd = new TreeSet<BillingEvent>();
         final SortedSet<BillingEvent> billingEventsToRemove = new TreeSet<BillingEvent>();
 
+
         final List<BlockingState> blockingEvents = blockingApi.getBlockingAllForAccount(context);
-        final List<DisabledDuration> blockingDurations = createBlockingDurations(blockingEvents);
+
+        final Iterable<BlockingState> accountBlockingEvents = Iterables.filter(blockingEvents, new Predicate<BlockingState>() {
+            @Override
+            public boolean apply(final BlockingState input) {
+                return BlockingStateType.ACCOUNT == input.getType();
+            }
+        });
+
+        final Map<UUID, List<BlockingState>> perBundleBlockingEvents = getPerTypeBlockingEvents(BlockingStateType.SUBSCRIPTION_BUNDLE, blockingEvents);
+        final Map<UUID, List<BlockingState>> perSubscriptionBlockingEvents = getPerTypeBlockingEvents(BlockingStateType.SUBSCRIPTION, blockingEvents);
+
         for (final UUID bundleId : bundleMap.keySet()) {
             for (final SubscriptionBase subscription : bundleMap.get(bundleId)) {
-                billingEventsToAdd.addAll(createNewEvents(blockingDurations, billingEvents, subscription));
-                billingEventsToRemove.addAll(eventsToRemove(blockingDurations, billingEvents, subscription));
+
+                final List<BlockingState> subscriptionBlockingEvents = perSubscriptionBlockingEvents.get(subscription.getId()) != null ? perSubscriptionBlockingEvents.get(subscription.getId()) : ImmutableList.<BlockingState>of();
+                final List<BlockingState> bundleBlockingEvents = perBundleBlockingEvents.get(bundleId)  != null ? perBundleBlockingEvents.get(bundleId) : ImmutableList.<BlockingState>of();
+                final List<BlockingState> aggregateSubscriptionBlockingEvents = getAggregateBlockingEventsPerSubscription(subscriptionBlockingEvents, bundleBlockingEvents, accountBlockingEvents);
+                final List<DisabledDuration> accountBlockingDurations = createBlockingDurations(aggregateSubscriptionBlockingEvents);
+
+                billingEventsToAdd.addAll(createNewEvents(accountBlockingDurations, billingEvents, subscription));
+                billingEventsToRemove.addAll(eventsToRemove(accountBlockingDurations, billingEvents, subscription));
             }
         }
 
@@ -112,6 +137,32 @@ public class BlockingCalculator {
         }
     }
 
+
+    final List<BlockingState> getAggregateBlockingEventsPerSubscription(final Iterable<BlockingState> subscriptionBlockingEvents, final Iterable<BlockingState> bundleBlockingEvents, final Iterable<BlockingState> accountBlockingEvents) {
+        final Iterable<BlockingState> tmp = Iterables.concat(subscriptionBlockingEvents, bundleBlockingEvents, accountBlockingEvents);
+        final List<BlockingState> result  = Lists.newArrayList(tmp);
+        Collections.sort(result);
+        return result;
+    }
+
+    final Map<UUID, List<BlockingState>> getPerTypeBlockingEvents(final BlockingStateType type, final List<BlockingState> blockingEvents) {
+        final Iterable<BlockingState> bundleBlockingEvents = Iterables.filter(blockingEvents, new Predicate<BlockingState>() {
+            @Override
+            public boolean apply(final BlockingState input) {
+                return type == input.getType();
+            }
+        });
+
+        final Map<UUID, List<BlockingState>> perTypeBlockingEvents = new HashMap<UUID, List<BlockingState>>();
+        for  (final BlockingState cur : bundleBlockingEvents) {
+            if (!perTypeBlockingEvents.containsKey(cur.getBlockedId())) {
+                perTypeBlockingEvents.put(cur.getBlockedId(), new ArrayList<BlockingState>());
+            }
+            perTypeBlockingEvents.get(cur.getBlockedId()).add(cur);
+        }
+        return perTypeBlockingEvents;
+    }
+
     protected SortedSet<BillingEvent> eventsToRemove(final List<DisabledDuration> disabledDuration,
                                                      final SortedSet<BillingEvent> billingEvents, final SubscriptionBase subscription) {
         final SortedSet<BillingEvent> result = new TreeSet<BillingEvent>();