diff --git a/api/src/main/java/org/killbill/billing/tag/TagInternalApi.java b/api/src/main/java/org/killbill/billing/tag/TagInternalApi.java
index 90eda82..d8843b2 100644
--- a/api/src/main/java/org/killbill/billing/tag/TagInternalApi.java
+++ b/api/src/main/java/org/killbill/billing/tag/TagInternalApi.java
@@ -43,6 +43,8 @@ public interface TagInternalApi {
public List<Tag> getTagsForAccountType(ObjectType objectType, boolean includedDeleted, InternalTenantContext internalTenantContext);
+ public List<Tag> getTagsForAccount(boolean includedDeleted, InternalTenantContext context);
+
public void addTag(final UUID objectId, final ObjectType objectType, UUID tagDefinitionId, InternalCallContext context) throws TagApiException;
public void removeTag(final UUID objectId, final ObjectType objectType, final UUID tagDefinitionId, InternalCallContext context) throws TagApiException;
diff --git a/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/DefaultInternalBillingApi.java b/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/DefaultInternalBillingApi.java
index a3499b5..8217696 100644
--- a/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/DefaultInternalBillingApi.java
+++ b/junction/src/main/java/org/killbill/billing/junction/plumbing/billing/DefaultInternalBillingApi.java
@@ -26,6 +26,8 @@ import java.util.Set;
import java.util.SortedSet;
import java.util.UUID;
+import javax.annotation.Nullable;
+
import org.killbill.billing.ObjectType;
import org.killbill.billing.account.api.AccountApiException;
import org.killbill.billing.account.api.AccountInternalApi;
@@ -39,7 +41,6 @@ import org.killbill.billing.catalog.api.CatalogInternalApi;
import org.killbill.billing.catalog.api.Plan;
import org.killbill.billing.catalog.api.PlanPhase;
import org.killbill.billing.catalog.api.PlanPhaseSpecifier;
-import org.killbill.billing.catalog.api.StaticCatalog;
import org.killbill.billing.entitlement.api.SubscriptionEventType;
import org.killbill.billing.events.EffectiveSubscriptionInternalEvent;
import org.killbill.billing.invoice.api.DryRunArguments;
@@ -62,6 +63,7 @@ import org.slf4j.LoggerFactory;
import com.google.common.base.Function;
import com.google.common.base.Predicate;
import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.inject.Inject;
@@ -93,7 +95,8 @@ public class DefaultInternalBillingApi implements BillingInternalApi {
final Catalog currentCatalog = catalogInternalApi.getFullCatalog(true, true, context);
// Check to see if billing is off for the account
- final List<Tag> accountTags = tagApi.getTags(accountId, ObjectType.ACCOUNT, context);
+ final List<Tag> tagsForAccount = tagApi.getTagsForAccount(false, context);
+ final List<Tag> accountTags = getTagsForObjectType(ObjectType.ACCOUNT, tagsForAccount, null);
final boolean found_AUTO_INVOICING_OFF = is_AUTO_INVOICING_OFF(accountTags);
final boolean found_INVOICING_DRAFT = is_AUTO_INVOICING_DRAFT(accountTags);
final boolean found_INVOICING_REUSE_DRAFT = is_AUTO_INVOICING_REUSE_DRAFT(accountTags);
@@ -108,7 +111,7 @@ public class DefaultInternalBillingApi implements BillingInternalApi {
final ImmutableAccountData account = accountApi.getImmutableAccountDataById(accountId, context);
result = new DefaultBillingEventSet(false, found_INVOICING_DRAFT, found_INVOICING_REUSE_DRAFT);
- addBillingEventsForBundles(bundles, account, dryRunArguments, context, result, skippedSubscriptions, currentCatalog);
+ addBillingEventsForBundles(bundles, account, dryRunArguments, context, result, skippedSubscriptions, currentCatalog, tagsForAccount);
}
if (result.isEmpty()) {
@@ -135,7 +138,7 @@ public class DefaultInternalBillingApi implements BillingInternalApi {
}
private void addBillingEventsForBundles(final List<SubscriptionBaseBundle> bundles, final ImmutableAccountData account, final DryRunArguments dryRunArguments, final InternalCallContext context,
- final DefaultBillingEventSet result, final Set<UUID> skipSubscriptionsSet, final Catalog catalog) throws AccountApiException, CatalogApiException, SubscriptionBaseApiException {
+ final DefaultBillingEventSet result, final Set<UUID> skipSubscriptionsSet, final Catalog catalog, final List<Tag> tagsForAccount) throws AccountApiException, CatalogApiException, SubscriptionBaseApiException {
final boolean dryRunMode = dryRunArguments != null;
@@ -151,22 +154,30 @@ public class DefaultInternalBillingApi implements BillingInternalApi {
}
+ final Map<UUID, List<SubscriptionBase>> subscriptionsForAccount = subscriptionApi.getSubscriptionsForAccount(context);
+
for (final SubscriptionBaseBundle bundle : bundles) {
final DryRunArguments dryRunArgumentsForBundle = (dryRunArguments != null &&
dryRunArguments.getBundleId() != null &&
dryRunArguments.getBundleId().equals(bundle.getId())) ?
- dryRunArguments : null;
- final List<SubscriptionBase> subscriptions = subscriptionApi.getSubscriptionsForBundle(bundle.getId(), dryRunArgumentsForBundle, context);
+ dryRunArguments : null;
+ final List<SubscriptionBase> subscriptions;
+ // In dryRun mode, optimization is intentionally left as is, since is not a common path.
+ if (dryRunArgumentsForBundle == null || dryRunArgumentsForBundle.getAction() == null) {
+ subscriptions = getSubscriptionsForAccountByBundleId(subscriptionsForAccount,bundle.getId());
+ } else {
+ subscriptions = subscriptionApi.getSubscriptionsForBundle(bundle.getId(), dryRunArgumentsForBundle, context);
+ }
- //Check if billing is off for the bundle
- final List<Tag> bundleTags = tagApi.getTags(bundle.getId(), ObjectType.BUNDLE, context);
+ // Check if billing is off for the bundle
+ final List<Tag> bundleTags = getTagsForObjectType(ObjectType.BUNDLE, tagsForAccount, bundle.getId());
boolean found_AUTO_INVOICING_OFF = is_AUTO_INVOICING_OFF(bundleTags);
if (found_AUTO_INVOICING_OFF) {
for (final SubscriptionBase subscription : subscriptions) { // billing is off so list sub ids in set to be excluded
result.getSubscriptionIdsWithAutoInvoiceOff().add(subscription.getId());
}
} else { // billing is not off
- final SubscriptionBase baseSubscription = !subscriptions.isEmpty() ? subscriptions.get(0) : null;
+ final SubscriptionBase baseSubscription = subscriptions != null && !subscriptions.isEmpty() ? subscriptions.get(0) : null;
addBillingEventsForSubscription(account, subscriptions, baseSubscription, dryRunMode, context, result, skipSubscriptionsSet, catalog);
}
}
@@ -270,5 +281,23 @@ public class DefaultInternalBillingApi implements BillingInternalApi {
});
}
+ private List<Tag> getTagsForObjectType(final ObjectType objectType, final List<Tag> tags, final @Nullable UUID objectId) {
+ return ImmutableList.<Tag>copyOf(Iterables.<Tag>filter(tags,
+ new Predicate<Tag>() {
+ @Override
+ public boolean apply(final Tag input) {
+ if (objectId == null) {
+ return objectType == input.getObjectType();
+ } else {
+ return objectType == input.getObjectType() && objectId.equals(input.getObjectId());
+ }
+
+ }
+ }));
+ }
+
+ private List<SubscriptionBase> getSubscriptionsForAccountByBundleId(final Map<UUID, List<SubscriptionBase>> subscriptionsForAccount, final UUID bundleId) {
+ return subscriptionsForAccount.containsKey(bundleId) ? subscriptionsForAccount.get(bundleId) : ImmutableList.<SubscriptionBase>of();
+ }
}
diff --git a/junction/src/test/java/org/killbill/billing/junction/plumbing/billing/TestBillingApi.java b/junction/src/test/java/org/killbill/billing/junction/plumbing/billing/TestBillingApi.java
index b422e78..12b0e0e 100644
--- a/junction/src/test/java/org/killbill/billing/junction/plumbing/billing/TestBillingApi.java
+++ b/junction/src/test/java/org/killbill/billing/junction/plumbing/billing/TestBillingApi.java
@@ -21,6 +21,7 @@ package org.killbill.billing.junction.plumbing.billing;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
import java.util.SortedSet;
import java.util.UUID;
@@ -99,6 +100,9 @@ public class TestBillingApi extends JunctionTestSuiteNoDB {
Mockito.when(subscriptionInternalApi.getBundlesForAccount(Mockito.<UUID>any(), Mockito.<InternalTenantContext>any())).thenReturn(bundles);
Mockito.when(subscriptionInternalApi.getSubscriptionsForBundle(Mockito.<UUID>any(), Mockito.<DryRunArguments>any(), Mockito.<InternalTenantContext>any())).thenReturn(subscriptions);
+ Mockito.when(subscriptionInternalApi.getSubscriptionsForAccount(Mockito.<InternalTenantContext>any())).thenReturn(ImmutableMap.<UUID, List<SubscriptionBase>>builder()
+ .put(bunId, subscriptions)
+ .build());
Mockito.when(subscriptionInternalApi.getSubscriptionFromId(Mockito.<UUID>any(), Mockito.<InternalTenantContext>any())).thenReturn(subscription);
Mockito.when(subscriptionInternalApi.getBundleFromId(Mockito.<UUID>any(), Mockito.<InternalTenantContext>any())).thenReturn(bundle);
Mockito.when(subscriptionInternalApi.getBaseSubscription(Mockito.<UUID>any(), Mockito.<InternalTenantContext>any())).thenReturn(subscription);
diff --git a/util/src/test/java/org/killbill/billing/util/tag/dao/MockTagDao.java b/util/src/test/java/org/killbill/billing/util/tag/dao/MockTagDao.java
index a2238e9..636c76f 100644
--- a/util/src/test/java/org/killbill/billing/util/tag/dao/MockTagDao.java
+++ b/util/src/test/java/org/killbill/billing/util/tag/dao/MockTagDao.java
@@ -44,7 +44,13 @@ public class MockTagDao extends MockEntityDaoBase<TagModelDao, Tag, TagApiExcept
if (tagStore.get(tag.getObjectId()) == null) {
tagStore.put(tag.getObjectId(), new ArrayList<TagModelDao>());
}
+
+ // add it to the account tags
+ if (tagStore.get(getAccountId(context.getAccountRecordId())) == null) {
+ tagStore.put(getAccountId(context.getAccountRecordId()), new ArrayList<TagModelDao>());
+ }
tagStore.get(tag.getObjectId()).add(tag);
+ tagStore.get(getAccountId(context.getAccountRecordId())).add(tag);
}
@Override
@@ -93,10 +99,18 @@ public class MockTagDao extends MockEntityDaoBase<TagModelDao, Tag, TagApiExcept
@Override
public List<TagModelDao> getTagsForAccount(final boolean includedDeleted, final InternalTenantContext internalTenantContext) {
- throw new UnsupportedOperationException();
+ if (tagStore.get(getAccountId(internalTenantContext.getAccountRecordId())) == null) {
+ return ImmutableList.<TagModelDao>of();
+ }
+
+ return tagStore.get(getAccountId(internalTenantContext.getAccountRecordId()));
}
public void clear() {
tagStore.clear();
}
+
+ private UUID getAccountId(final Long accountRecordId) {
+ return new UUID(0L, accountRecordId);
+ }
}