diff --git a/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/BlockingCalculator.java b/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/BlockingCalculator.java
index c6ee181..6806e9c 100644
--- a/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/BlockingCalculator.java
+++ b/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/BlockingCalculator.java
@@ -18,8 +18,11 @@ package org.killbill.billing.junction.plumbing.billing;
import java.math.BigDecimal;
import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.Hashtable;
import java.util.List;
+import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.UUID;
@@ -37,12 +40,17 @@ import org.killbill.billing.catalog.api.Currency;
import org.killbill.billing.catalog.api.Plan;
import org.killbill.billing.catalog.api.PlanPhase;
import org.killbill.billing.entitlement.api.BlockingState;
+import org.killbill.billing.entitlement.api.BlockingStateType;
import org.killbill.billing.junction.BillingEvent;
import org.killbill.billing.junction.BlockingInternalApi;
import org.killbill.billing.subscription.api.SubscriptionBase;
import org.killbill.billing.subscription.api.SubscriptionBaseTransitionType;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Predicate;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
import com.google.inject.Inject;
public class BlockingCalculator {
@@ -94,12 +102,29 @@ public class BlockingCalculator {
final SortedSet<BillingEvent> billingEventsToAdd = new TreeSet<BillingEvent>();
final SortedSet<BillingEvent> billingEventsToRemove = new TreeSet<BillingEvent>();
+
final List<BlockingState> blockingEvents = blockingApi.getBlockingAllForAccount(context);
- final List<DisabledDuration> blockingDurations = createBlockingDurations(blockingEvents);
+
+ final Iterable<BlockingState> accountBlockingEvents = Iterables.filter(blockingEvents, new Predicate<BlockingState>() {
+ @Override
+ public boolean apply(final BlockingState input) {
+ return BlockingStateType.ACCOUNT == input.getType();
+ }
+ });
+
+ final Map<UUID, List<BlockingState>> perBundleBlockingEvents = getPerTypeBlockingEvents(BlockingStateType.SUBSCRIPTION_BUNDLE, blockingEvents);
+ final Map<UUID, List<BlockingState>> perSubscriptionBlockingEvents = getPerTypeBlockingEvents(BlockingStateType.SUBSCRIPTION, blockingEvents);
+
for (final UUID bundleId : bundleMap.keySet()) {
for (final SubscriptionBase subscription : bundleMap.get(bundleId)) {
- billingEventsToAdd.addAll(createNewEvents(blockingDurations, billingEvents, subscription));
- billingEventsToRemove.addAll(eventsToRemove(blockingDurations, billingEvents, subscription));
+
+ final List<BlockingState> subscriptionBlockingEvents = perSubscriptionBlockingEvents.get(subscription.getId()) != null ? perSubscriptionBlockingEvents.get(subscription.getId()) : ImmutableList.<BlockingState>of();
+ final List<BlockingState> bundleBlockingEvents = perBundleBlockingEvents.get(bundleId) != null ? perBundleBlockingEvents.get(bundleId) : ImmutableList.<BlockingState>of();
+ final List<BlockingState> aggregateSubscriptionBlockingEvents = getAggregateBlockingEventsPerSubscription(subscriptionBlockingEvents, bundleBlockingEvents, accountBlockingEvents);
+ final List<DisabledDuration> accountBlockingDurations = createBlockingDurations(aggregateSubscriptionBlockingEvents);
+
+ billingEventsToAdd.addAll(createNewEvents(accountBlockingDurations, billingEvents, subscription));
+ billingEventsToRemove.addAll(eventsToRemove(accountBlockingDurations, billingEvents, subscription));
}
}
@@ -112,6 +137,32 @@ public class BlockingCalculator {
}
}
+
+ final List<BlockingState> getAggregateBlockingEventsPerSubscription(final Iterable<BlockingState> subscriptionBlockingEvents, final Iterable<BlockingState> bundleBlockingEvents, final Iterable<BlockingState> accountBlockingEvents) {
+ final Iterable<BlockingState> tmp = Iterables.concat(subscriptionBlockingEvents, bundleBlockingEvents, accountBlockingEvents);
+ final List<BlockingState> result = Lists.newArrayList(tmp);
+ Collections.sort(result);
+ return result;
+ }
+
+ final Map<UUID, List<BlockingState>> getPerTypeBlockingEvents(final BlockingStateType type, final List<BlockingState> blockingEvents) {
+ final Iterable<BlockingState> bundleBlockingEvents = Iterables.filter(blockingEvents, new Predicate<BlockingState>() {
+ @Override
+ public boolean apply(final BlockingState input) {
+ return type == input.getType();
+ }
+ });
+
+ final Map<UUID, List<BlockingState>> perTypeBlockingEvents = new HashMap<UUID, List<BlockingState>>();
+ for (final BlockingState cur : bundleBlockingEvents) {
+ if (!perTypeBlockingEvents.containsKey(cur.getBlockedId())) {
+ perTypeBlockingEvents.put(cur.getBlockedId(), new ArrayList<BlockingState>());
+ }
+ perTypeBlockingEvents.get(cur.getBlockedId()).add(cur);
+ }
+ return perTypeBlockingEvents;
+ }
+
protected SortedSet<BillingEvent> eventsToRemove(final List<DisabledDuration> disabledDuration,
final SortedSet<BillingEvent> billingEvents, final SubscriptionBase subscription) {
final SortedSet<BillingEvent> result = new TreeSet<BillingEvent>();