killbill-uncached

invoice: See #321 Initial implementation for billingMode

8/28/2015 1:34:12 AM

Details

diff --git a/invoice/src/main/java/org/killbill/billing/invoice/generator/BillingIntervalDetail.java b/invoice/src/main/java/org/killbill/billing/invoice/generator/BillingIntervalDetail.java
index 2d95f1f..2f8edca 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/generator/BillingIntervalDetail.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/generator/BillingIntervalDetail.java
@@ -17,7 +17,7 @@
 package org.killbill.billing.invoice.generator;
 
 import org.joda.time.LocalDate;
-
+import org.killbill.billing.catalog.api.BillingMode;
 import org.killbill.billing.catalog.api.BillingPeriod;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -29,17 +29,23 @@ public class BillingIntervalDetail {
     private final LocalDate targetDate;
     private final int billingCycleDay;
     private final BillingPeriod billingPeriod;
-
+    private final BillingMode billingMode;
     private LocalDate firstBillingCycleDate;
     private LocalDate effectiveEndDate;
     private LocalDate lastBillingCycleDate;
 
-    public BillingIntervalDetail(final LocalDate startDate, final LocalDate endDate, final LocalDate targetDate, final int billingCycleDay, final BillingPeriod billingPeriod) {
+    public BillingIntervalDetail(final LocalDate startDate,
+                                 final LocalDate endDate,
+                                 final LocalDate targetDate,
+                                 final int billingCycleDay,
+                                 final BillingPeriod billingPeriod,
+                                 final BillingMode billingMode) {
         this.startDate = startDate;
         this.endDate = endDate;
         this.targetDate = targetDate;
         this.billingCycleDay = billingCycleDay;
         this.billingPeriod = billingPeriod;
+        this.billingMode = billingMode;
         computeAll();
     }
 
@@ -61,6 +67,11 @@ public class BillingIntervalDetail {
         return lastBillingCycleDate;
     }
 
+    public boolean hasSomethingToBill() {
+        return effectiveEndDate != null /* IN_ARREAR mode prior we have reached our firstBillingCycleDate */ &&
+               (endDate == null || endDate.isAfter(startDate)); /* When there is an endDate, it should be > startDate since we don't bill for less than a day */
+    }
+
     private void computeAll() {
         calculateFirstBillingCycleDate();
         calculateEffectiveEndDate();
@@ -87,6 +98,44 @@ public class BillingIntervalDetail {
     }
 
     private void calculateEffectiveEndDate() {
+        if (billingMode == BillingMode.IN_ADVANCE) {
+            calculateInAdvanceEffectiveEndDate();
+        } else {
+            calculateInArrearEffectiveEndDate();
+        }
+    }
+
+    private void calculateInArrearEffectiveEndDate() {
+        if (targetDate.isBefore(firstBillingCycleDate)) {
+            // Nothing to bill for, hasSomethingToBill will return false
+            effectiveEndDate = null;
+            return;
+        }
+
+        if (endDate != null && endDate.isBefore(firstBillingCycleDate)) {
+            effectiveEndDate = endDate;
+            return;
+        }
+
+        final int numberOfMonthsInPeriod = billingPeriod.getNumberOfMonths();
+        int numberOfPeriods = 0;
+        LocalDate proposedDate = firstBillingCycleDate;
+
+        while (proposedDate.isBefore(targetDate)) {
+            proposedDate = firstBillingCycleDate.plusMonths(numberOfPeriods * numberOfMonthsInPeriod);
+            numberOfPeriods += 1;
+        }
+        proposedDate = alignProposedBillCycleDate(proposedDate, billingCycleDay);
+
+        // The proposedDate is greater to our endDate => return it
+        if (endDate != null && endDate.isBefore(proposedDate)) {
+            effectiveEndDate = endDate;
+        } else {
+            effectiveEndDate = proposedDate;
+        }
+    }
+
+    private void calculateInAdvanceEffectiveEndDate() {
 
         // We have an endDate and the targetDate is greater or equal to our endDate => return it
         if (endDate != null && !targetDate.isBefore(endDate)) {
@@ -117,9 +166,13 @@ public class BillingIntervalDetail {
         }
     }
 
-
     private void calculateLastBillingCycleDate() {
 
+        if (effectiveEndDate == null) {
+            lastBillingCycleDate = firstBillingCycleDate;
+            return;
+        }
+
         // Start from firstBillingCycleDate and billingPeriod until we pass the effectiveEndDate
         LocalDate proposedDate = firstBillingCycleDate;
         int numberOfPeriods = 0;
@@ -134,13 +187,12 @@ public class BillingIntervalDetail {
 
         if (proposedDate.isBefore(firstBillingCycleDate)) {
             // Make sure not to go too far in the past
-            lastBillingCycleDate =  firstBillingCycleDate;
+            lastBillingCycleDate = firstBillingCycleDate;
         } else {
-            lastBillingCycleDate =  proposedDate;
+            lastBillingCycleDate = proposedDate;
         }
     }
 
-
     //
     // We start from a billCycleDate
     //
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/generator/DefaultInvoiceGenerator.java b/invoice/src/main/java/org/killbill/billing/invoice/generator/DefaultInvoiceGenerator.java
index 1bc6c58..af96f5d 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/generator/DefaultInvoiceGenerator.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/generator/DefaultInvoiceGenerator.java
@@ -47,7 +47,7 @@ import org.killbill.billing.invoice.api.InvoiceItemType;
 import org.killbill.billing.invoice.model.BillingModeGenerator;
 import org.killbill.billing.invoice.model.DefaultInvoice;
 import org.killbill.billing.invoice.model.FixedPriceInvoiceItem;
-import org.killbill.billing.invoice.model.InAdvanceBillingMode;
+import org.killbill.billing.invoice.model.DefaultBillingModeGenerator;
 import org.killbill.billing.invoice.model.InvalidDateSequenceException;
 import org.killbill.billing.invoice.model.RecurringInvoiceItem;
 import org.killbill.billing.invoice.model.RecurringInvoiceItemData;
@@ -57,7 +57,6 @@ import org.killbill.billing.invoice.usage.RawUsageOptimizer.RawUsageOptimizerRes
 import org.killbill.billing.invoice.usage.SubscriptionConsumableInArrear;
 import org.killbill.billing.junction.BillingEvent;
 import org.killbill.billing.junction.BillingEventSet;
-import org.killbill.billing.usage.RawUsage;
 import org.killbill.billing.util.config.InvoiceConfig;
 import org.killbill.billing.util.currency.KillBillMoney;
 import org.killbill.clock.Clock;
@@ -79,12 +78,14 @@ public class DefaultInvoiceGenerator implements InvoiceGenerator {
     private final Clock clock;
     private final InvoiceConfig config;
     private final RawUsageOptimizer rawUsageOptimizer;
+    final BillingModeGenerator billingModeGenerator;
 
     @Inject
     public DefaultInvoiceGenerator(final Clock clock, final InvoiceConfig config, final RawUsageOptimizer rawUsageOptimizer) {
         this.clock = clock;
         this.config = config;
         this.rawUsageOptimizer = rawUsageOptimizer;
+        this.billingModeGenerator = new DefaultBillingModeGenerator();
     }
 
     /*
@@ -105,7 +106,7 @@ public class DefaultInvoiceGenerator implements InvoiceGenerator {
         final Invoice invoice = new DefaultInvoice(account.getId(), new LocalDate(clock.getUTCNow(), account.getTimeZone()), adjustedTargetDate, targetCurrency);
         final UUID invoiceId = invoice.getId();
 
-        final List<InvoiceItem> inAdvanceItems = generateInAdvanceInvoiceItems(account.getId(), invoiceId, events, existingInvoices, adjustedTargetDate, targetCurrency);
+        final List<InvoiceItem> inAdvanceItems = generateFixedAndRecurringInvoiceItems(account.getId(), invoiceId, events, existingInvoices, adjustedTargetDate, targetCurrency);
         invoice.addInvoiceItems(inAdvanceItems);
 
         final List<InvoiceItem> usageItems = generateUsageConsumableInArrearItems(account, invoiceId, events, existingInvoices, targetDate, context);
@@ -209,9 +210,9 @@ public class DefaultInvoiceGenerator implements InvoiceGenerator {
         return result;
     }
 
-    private List<InvoiceItem> generateInAdvanceInvoiceItems(final UUID accountId, final UUID invoiceId, final BillingEventSet eventSet,
-                                                            @Nullable final List<Invoice> existingInvoices, final LocalDate targetDate,
-                                                            final Currency targetCurrency) throws InvoiceApiException {
+    private List<InvoiceItem> generateFixedAndRecurringInvoiceItems(final UUID accountId, final UUID invoiceId, final BillingEventSet eventSet,
+                                                                    @Nullable final List<Invoice> existingInvoices, final LocalDate targetDate,
+                                                                    final Currency targetCurrency) throws InvoiceApiException {
         final AccountItemTree accountItemTree = new AccountItemTree(accountId, invoiceId);
         if (existingInvoices != null) {
             for (final Invoice invoice : existingInvoices) {
@@ -226,7 +227,9 @@ public class DefaultInvoiceGenerator implements InvoiceGenerator {
         }
 
         // Generate list of proposed invoice items based on billing events from junction-- proposed items are ALL items since beginning of time
-        final List<InvoiceItem> proposedItems = generateInAdvanceInvoiceItems(invoiceId, accountId, eventSet, targetDate, targetCurrency);
+        final List<InvoiceItem> proposedItems = new ArrayList<InvoiceItem>();
+        generateRecurringInvoiceItems(invoiceId, accountId, eventSet, targetDate, targetCurrency, proposedItems);
+        processFixedPriceEvents(invoiceId, accountId, eventSet, targetDate, targetCurrency, proposedItems);
 
         accountItemTree.mergeWithProposedItems(proposedItems);
         return accountItemTree.getResultingItemList();
@@ -255,12 +258,10 @@ public class DefaultInvoiceGenerator implements InvoiceGenerator {
         return maxDate;
     }
 
-    private List<InvoiceItem> generateInAdvanceInvoiceItems(final UUID invoiceId, final UUID accountId, final BillingEventSet events,
-                                                            final LocalDate targetDate, final Currency currency) throws InvoiceApiException {
-        final List<InvoiceItem> items = new ArrayList<InvoiceItem>();
-
+    private List<InvoiceItem> generateRecurringInvoiceItems(final UUID invoiceId, final UUID accountId, final BillingEventSet events,
+                                                            final LocalDate targetDate, final Currency currency, final List<InvoiceItem> proposedItems) throws InvoiceApiException {
         if (events.size() == 0) {
-            return items;
+            return proposedItems;
         }
 
         // Pretty-print the generated invoice items from the junction events
@@ -277,32 +278,39 @@ public class DefaultInvoiceGenerator implements InvoiceGenerator {
             if (!events.getSubscriptionIdsWithAutoInvoiceOff().
                     contains(thisEvent.getSubscription().getId())) { // don't consider events for subscriptions that have auto_invoice_off
                 final BillingEvent adjustedNextEvent = (thisEvent.getSubscription().getId() == nextEvent.getSubscription().getId()) ? nextEvent : null;
-                items.addAll(processInAdvanceEvents(invoiceId, accountId, thisEvent, adjustedNextEvent, targetDate, currency, logStringBuilder));
+                proposedItems.addAll(processRecurringEvents(invoiceId, accountId, thisEvent, adjustedNextEvent, targetDate, currency, logStringBuilder, events.getRecurringBillingMode()));
             }
         }
-        items.addAll(processInAdvanceEvents(invoiceId, accountId, nextEvent, null, targetDate, currency, logStringBuilder));
+        proposedItems.addAll(processRecurringEvents(invoiceId, accountId, nextEvent, null, targetDate, currency, logStringBuilder, events.getRecurringBillingMode()));
 
         log.info(logStringBuilder.toString());
 
-        return items;
+        return proposedItems;
     }
 
+    private List<InvoiceItem> processFixedPriceEvents(final UUID invoiceId, final UUID accountId, final BillingEventSet events, final LocalDate targetDate, final Currency currency, final List<InvoiceItem> proposedItems) {
+        final Iterator<BillingEvent> eventIt = events.iterator();
+        while (eventIt.hasNext()) {
+            final BillingEvent thisEvent = eventIt.next();
+
+            final InvoiceItem fixedPriceInvoiceItem = generateFixedPriceItem(invoiceId, accountId, thisEvent, targetDate, currency);
+            if (fixedPriceInvoiceItem != null) {
+                proposedItems.add(fixedPriceInvoiceItem);
+            }
+        }
+        return proposedItems;
+    }
+
+
     // Turn a set of events into a list of invoice items. Note that the dates on the invoice items will be rounded (granularity of a day)
-    private List<InvoiceItem> processInAdvanceEvents(final UUID invoiceId, final UUID accountId, final BillingEvent thisEvent, @Nullable final BillingEvent nextEvent,
+    private List<InvoiceItem> processRecurringEvents(final UUID invoiceId, final UUID accountId, final BillingEvent thisEvent, @Nullable final BillingEvent nextEvent,
                                                      final LocalDate targetDate, final Currency currency,
-                                                     final StringBuilder logStringBuilder) throws InvoiceApiException {
+                                                     final StringBuilder logStringBuilder, final BillingMode billingMode) throws InvoiceApiException {
         final List<InvoiceItem> items = new ArrayList<InvoiceItem>();
 
-        // Handle fixed price items
-        final InvoiceItem fixedPriceInvoiceItem = generateFixedPriceItem(invoiceId, accountId, thisEvent, targetDate, currency);
-        if (fixedPriceInvoiceItem != null) {
-            items.add(fixedPriceInvoiceItem);
-        }
-
         // Handle recurring items
         final BillingPeriod billingPeriod = thisEvent.getBillingPeriod();
         if (billingPeriod != BillingPeriod.NO_BILLING_PERIOD) {
-            final BillingModeGenerator billingModeGenerator = instantiateBillingMode(thisEvent.getBillingMode());
             final LocalDate startDate = new LocalDate(thisEvent.getEffectiveDate(), thisEvent.getTimeZone());
 
             if (!startDate.isAfter(targetDate)) {
@@ -312,7 +320,7 @@ public class DefaultInvoiceGenerator implements InvoiceGenerator {
 
                 final List<RecurringInvoiceItemData> itemData;
                 try {
-                    itemData = billingModeGenerator.generateInvoiceItemData(startDate, endDate, targetDate, billCycleDayLocal, billingPeriod);
+                    itemData = billingModeGenerator.generateInvoiceItemData(startDate, endDate, targetDate, billCycleDayLocal, billingPeriod, billingMode);
                 } catch (InvalidDateSequenceException e) {
                     throw new InvoiceApiException(ErrorCode.INVOICE_INVALID_DATE_SEQUENCE, startDate, endDate, targetDate);
                 }
@@ -348,15 +356,6 @@ public class DefaultInvoiceGenerator implements InvoiceGenerator {
         return items;
     }
 
-    private BillingModeGenerator instantiateBillingMode(final BillingMode billingMode) {
-        switch (billingMode) {
-            case IN_ADVANCE:
-                return new InAdvanceBillingMode();
-            default:
-                throw new UnsupportedOperationException();
-        }
-    }
-
     InvoiceItem generateFixedPriceItem(final UUID invoiceId, final UUID accountId, final BillingEvent thisEvent,
                                        final LocalDate targetDate, final Currency currency) {
         final LocalDate roundedStartDate = new LocalDate(thisEvent.getEffectiveDate(), thisEvent.getTimeZone());
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/model/BillingModeGenerator.java b/invoice/src/main/java/org/killbill/billing/invoice/model/BillingModeGenerator.java
index 4116f3c..26fbfd9 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/model/BillingModeGenerator.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/model/BillingModeGenerator.java
@@ -23,10 +23,11 @@ import javax.annotation.Nullable;
 import org.joda.time.DateTimeZone;
 import org.joda.time.LocalDate;
 
+import org.killbill.billing.catalog.api.BillingMode;
 import org.killbill.billing.catalog.api.BillingPeriod;
 
 public interface BillingModeGenerator {
 
     List<RecurringInvoiceItemData> generateInvoiceItemData(LocalDate startDate, @Nullable LocalDate endDate, LocalDate targetDate,
-                                                           int billingCycleDay, BillingPeriod billingPeriod) throws InvalidDateSequenceException;
+                                                           int billingCycleDay, BillingPeriod billingPeriod, BillingMode billingMode) throws InvalidDateSequenceException;
 }
diff --git a/invoice/src/main/java/org/killbill/billing/invoice/usage/ContiguousIntervalConsumableInArrear.java b/invoice/src/main/java/org/killbill/billing/invoice/usage/ContiguousIntervalConsumableInArrear.java
index f1747c2..a7d78e4 100644
--- a/invoice/src/main/java/org/killbill/billing/invoice/usage/ContiguousIntervalConsumableInArrear.java
+++ b/invoice/src/main/java/org/killbill/billing/invoice/usage/ContiguousIntervalConsumableInArrear.java
@@ -106,7 +106,7 @@ public class ContiguousIntervalConsumableInArrear {
         }
         final LocalDate endDate = closedInterval ? new LocalDate(billingEvents.get(billingEvents.size() - 1).getEffectiveDate(), getAccountTimeZone()) : targetDate;
 
-        final BillingIntervalDetail bid = new BillingIntervalDetail(startDate, endDate, targetDate, getBCD(), usage.getBillingPeriod());
+        final BillingIntervalDetail bid = new BillingIntervalDetail(startDate, endDate, targetDate, getBCD(), usage.getBillingPeriod(), usage.getBillingMode());
 
         int numberOfPeriod = 0;
         // First billingCycleDate prior startDate
diff --git a/invoice/src/test/java/org/killbill/billing/invoice/generator/TestBillingIntervalDetail.java b/invoice/src/test/java/org/killbill/billing/invoice/generator/TestBillingIntervalDetail.java
index 9941eb9..8021c2c 100644
--- a/invoice/src/test/java/org/killbill/billing/invoice/generator/TestBillingIntervalDetail.java
+++ b/invoice/src/test/java/org/killbill/billing/invoice/generator/TestBillingIntervalDetail.java
@@ -17,6 +17,7 @@
 package org.killbill.billing.invoice.generator;
 
 import org.joda.time.LocalDate;
+import org.killbill.billing.catalog.api.BillingMode;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
@@ -35,7 +36,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
     public void testCalculateFirstBillingCycleDate1() throws Exception {
         final LocalDate from = new LocalDate("2012-01-16");
         final int bcd = 17;
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), bcd, BillingPeriod.ANNUAL);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), bcd, BillingPeriod.ANNUAL, BillingMode.IN_ADVANCE);
         billingIntervalDetail.calculateFirstBillingCycleDate();
         Assert.assertEquals(billingIntervalDetail.getFirstBillingCycleDate(), new LocalDate("2012-01-17"));
     }
@@ -50,7 +51,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
     public void testCalculateFirstBillingCycleDate2() throws Exception {
         final LocalDate from = new LocalDate("2012-02-16");
         final int bcd = 30;
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), bcd, BillingPeriod.ANNUAL);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), bcd, BillingPeriod.ANNUAL, BillingMode.IN_ADVANCE);
         billingIntervalDetail.calculateFirstBillingCycleDate();
         Assert.assertEquals(billingIntervalDetail.getFirstBillingCycleDate(), new LocalDate("2012-02-29"));
     }
@@ -69,7 +70,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
     public void testCalculateFirstBillingCycleDate4() throws Exception {
         final LocalDate from = new LocalDate("2012-01-31");
         final int bcd = 30;
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), bcd, BillingPeriod.MONTHLY);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), bcd, BillingPeriod.MONTHLY, BillingMode.IN_ADVANCE);
         billingIntervalDetail.calculateFirstBillingCycleDate();
         Assert.assertEquals(billingIntervalDetail.getFirstBillingCycleDate(), new LocalDate("2012-02-29"));
     }
@@ -84,7 +85,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
     public void testCalculateFirstBillingCycleDate3() throws Exception {
         final LocalDate from = new LocalDate("2012-02-16");
         final int bcd = 14;
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), bcd, BillingPeriod.ANNUAL);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), bcd, BillingPeriod.ANNUAL, BillingMode.IN_ADVANCE);
         billingIntervalDetail.calculateFirstBillingCycleDate();
         Assert.assertEquals(billingIntervalDetail.getFirstBillingCycleDate(), new LocalDate("2013-02-14"));
     }
@@ -92,7 +93,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
     @Test(groups = "fast")
     public void testNextBCDShouldNotBeInThePast() throws Exception {
         final LocalDate from = new LocalDate("2012-07-16");
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), 15, BillingPeriod.MONTHLY);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), 15, BillingPeriod.MONTHLY, BillingMode.IN_ADVANCE);
         final LocalDate to = billingIntervalDetail.getFirstBillingCycleDate();
         Assert.assertEquals(to, new LocalDate("2012-08-15"));
     }
@@ -100,7 +101,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
     @Test(groups = "fast")
     public void testBeforeBCDWithOnOrAfter() throws Exception {
         final LocalDate from = new LocalDate("2012-03-02");
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), 3, BillingPeriod.MONTHLY);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), 3, BillingPeriod.MONTHLY, BillingMode.IN_ADVANCE);
         final LocalDate to = billingIntervalDetail.getFirstBillingCycleDate();
         Assert.assertEquals(to, new LocalDate("2012-03-03"));
     }
@@ -108,7 +109,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
     @Test(groups = "fast")
     public void testEqualBCDWithOnOrAfter() throws Exception {
         final LocalDate from = new LocalDate("2012-03-03");
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), 3, BillingPeriod.MONTHLY);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), 3, BillingPeriod.MONTHLY, BillingMode.IN_ADVANCE);
         final LocalDate to = billingIntervalDetail.getFirstBillingCycleDate();
         Assert.assertEquals(to, new LocalDate("2012-03-03"));
     }
@@ -116,7 +117,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
     @Test(groups = "fast")
     public void testAfterBCDWithOnOrAfter() throws Exception {
         final LocalDate from = new LocalDate("2012-03-04");
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), 3, BillingPeriod.MONTHLY);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(from, null, new LocalDate(), 3, BillingPeriod.MONTHLY, BillingMode.IN_ADVANCE);
         final LocalDate to = billingIntervalDetail.getFirstBillingCycleDate();
         Assert.assertEquals(to, new LocalDate("2012-04-03"));
     }
@@ -127,7 +128,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
         final LocalDate targetDate = new LocalDate(2012, 8, 16);
         final BillingPeriod billingPeriod = BillingPeriod.MONTHLY;
 
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(firstBCD, null, targetDate, 16, billingPeriod);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(firstBCD, null, targetDate, 16, billingPeriod, BillingMode.IN_ADVANCE);
         final LocalDate effectiveEndDate = billingIntervalDetail.getEffectiveEndDate();
         Assert.assertEquals(effectiveEndDate, new LocalDate(2012, 9, 16));
     }
@@ -138,7 +139,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
         final LocalDate endDate = new LocalDate(2012, 9, 15); // so we get effectiveEndDate on 9/15
         final LocalDate targetDate = new LocalDate(2012, 8, 16);
 
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(start, endDate, targetDate, 16, BillingPeriod.MONTHLY);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(start, endDate, targetDate, 16, BillingPeriod.MONTHLY, BillingMode.IN_ADVANCE);
         final LocalDate lastBCD = billingIntervalDetail.getLastBillingCycleDate();
         Assert.assertEquals(lastBCD, new LocalDate(2012, 8, 16));
     }
@@ -148,7 +149,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
         final LocalDate start = new LocalDate("2012-07-16");
         final int bcdLocal = 15;
 
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(start, null, start, bcdLocal, BillingPeriod.MONTHLY);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(start, null, start, bcdLocal, BillingPeriod.MONTHLY, BillingMode.IN_ADVANCE);
         final LocalDate lastBCD = billingIntervalDetail.getLastBillingCycleDate();
         Assert.assertEquals(lastBCD, new LocalDate("2012-08-15"));
     }
@@ -160,7 +161,7 @@ public class TestBillingIntervalDetail extends InvoiceTestSuiteNoDB {
         final LocalDate end = null;
         final int bcdLocal = 31;
 
-        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(start, end, targetDate, bcdLocal, BillingPeriod.MONTHLY);
+        final BillingIntervalDetail billingIntervalDetail = new BillingIntervalDetail(start, end, targetDate, bcdLocal, BillingPeriod.MONTHLY, BillingMode.IN_ADVANCE);
         final LocalDate effectiveEndDate = billingIntervalDetail.getEffectiveEndDate();
         Assert.assertEquals(effectiveEndDate, new LocalDate("2012-05-31"));
     }
diff --git a/invoice/src/test/java/org/killbill/billing/invoice/model/TestInAdvanceBillingMode.java b/invoice/src/test/java/org/killbill/billing/invoice/model/TestInAdvanceBillingMode.java
index c5c35b9..db121f6 100644
--- a/invoice/src/test/java/org/killbill/billing/invoice/model/TestInAdvanceBillingMode.java
+++ b/invoice/src/test/java/org/killbill/billing/invoice/model/TestInAdvanceBillingMode.java
@@ -23,6 +23,7 @@ import java.util.List;
 import org.joda.time.DateTime;
 import org.joda.time.DateTimeZone;
 import org.joda.time.LocalDate;
+import org.killbill.billing.catalog.api.BillingMode;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
@@ -148,9 +149,9 @@ public class TestInAdvanceBillingMode extends InvoiceTestSuiteNoDB {
     private void verifyInvoiceItems(final LocalDate startDate, final LocalDate endDate, final LocalDate targetDate,
                                     final DateTimeZone dateTimeZone, final int billingCycleDayLocal, final BillingPeriod billingPeriod,
                                     final LinkedHashMap<LocalDate, LocalDate> expectedDates) throws InvalidDateSequenceException {
-        final InAdvanceBillingMode billingMode = new InAdvanceBillingMode();
+        final DefaultBillingModeGenerator billingMode = new DefaultBillingModeGenerator();
 
-        final List<RecurringInvoiceItemData> invoiceItems = billingMode.generateInvoiceItemData(startDate, endDate, targetDate, billingCycleDayLocal, billingPeriod);
+        final List<RecurringInvoiceItemData> invoiceItems = billingMode.generateInvoiceItemData(startDate, endDate, targetDate, billingCycleDayLocal, billingPeriod, BillingMode.IN_ADVANCE);
 
         int i = 0;
         for (final LocalDate periodStartDate : expectedDates.keySet()) {
diff --git a/invoice/src/test/java/org/killbill/billing/invoice/tests/inAdvance/ProRationInAdvanceTestBase.java b/invoice/src/test/java/org/killbill/billing/invoice/tests/inAdvance/ProRationInAdvanceTestBase.java
index 6fd8e15..4572c5b 100644
--- a/invoice/src/test/java/org/killbill/billing/invoice/tests/inAdvance/ProRationInAdvanceTestBase.java
+++ b/invoice/src/test/java/org/killbill/billing/invoice/tests/inAdvance/ProRationInAdvanceTestBase.java
@@ -16,14 +16,13 @@
 
 package org.killbill.billing.invoice.tests.inAdvance;
 
-import org.killbill.billing.invoice.model.BillingModeGenerator;
-import org.killbill.billing.invoice.model.InAdvanceBillingMode;
+import org.killbill.billing.catalog.api.BillingMode;
 import org.killbill.billing.invoice.tests.ProRationTestBase;
 
 public abstract class ProRationInAdvanceTestBase extends ProRationTestBase {
 
     @Override
-    protected BillingModeGenerator getBillingMode() {
-        return new InAdvanceBillingMode();
+    protected BillingMode getBillingMode() {
+        return BillingMode.IN_ADVANCE;
     }
 }
diff --git a/invoice/src/test/java/org/killbill/billing/invoice/tests/inAdvance/TestValidationProRation.java b/invoice/src/test/java/org/killbill/billing/invoice/tests/inAdvance/TestValidationProRation.java
index 741316e..eeec015 100644
--- a/invoice/src/test/java/org/killbill/billing/invoice/tests/inAdvance/TestValidationProRation.java
+++ b/invoice/src/test/java/org/killbill/billing/invoice/tests/inAdvance/TestValidationProRation.java
@@ -17,11 +17,12 @@
 package org.killbill.billing.invoice.tests.inAdvance;
 
 import org.joda.time.LocalDate;
+import org.killbill.billing.catalog.api.BillingMode;
 import org.killbill.billing.invoice.model.BillingModeGenerator;
 import org.testng.annotations.Test;
 
 import org.killbill.billing.catalog.api.BillingPeriod;
-import org.killbill.billing.invoice.model.InAdvanceBillingMode;
+import org.killbill.billing.invoice.model.DefaultBillingModeGenerator;
 import org.killbill.billing.invoice.model.InvalidDateSequenceException;
 import org.killbill.billing.invoice.tests.ProRationTestBase;
 
@@ -35,8 +36,13 @@ public class TestValidationProRation extends ProRationTestBase {
     }
 
     @Override
-    protected BillingModeGenerator getBillingMode() {
-        return new InAdvanceBillingMode();
+    protected BillingMode getBillingMode() {
+        return BillingMode.IN_ADVANCE;
+    }
+
+    @Override
+    protected BillingModeGenerator getBillingModeGenerator() {
+        return new DefaultBillingModeGenerator();
     }
 
     @Test(groups = "fast", expectedExceptions = InvalidDateSequenceException.class)
diff --git a/invoice/src/test/java/org/killbill/billing/invoice/tests/ProRationTestBase.java b/invoice/src/test/java/org/killbill/billing/invoice/tests/ProRationTestBase.java
index 4ce38af..e86af76 100644
--- a/invoice/src/test/java/org/killbill/billing/invoice/tests/ProRationTestBase.java
+++ b/invoice/src/test/java/org/killbill/billing/invoice/tests/ProRationTestBase.java
@@ -23,9 +23,11 @@ import java.util.List;
 
 import org.joda.time.LocalDate;
 
+import org.killbill.billing.catalog.api.BillingMode;
 import org.killbill.billing.catalog.api.BillingPeriod;
 import org.killbill.billing.invoice.InvoiceTestSuiteNoDB;
 import org.killbill.billing.invoice.model.BillingModeGenerator;
+import org.killbill.billing.invoice.model.DefaultBillingModeGenerator;
 import org.killbill.billing.invoice.model.InvalidDateSequenceException;
 import org.killbill.billing.invoice.model.RecurringInvoiceItemData;
 
@@ -34,10 +36,15 @@ import static org.testng.Assert.fail;
 
 public abstract class ProRationTestBase extends InvoiceTestSuiteNoDB {
 
-    protected abstract BillingModeGenerator getBillingMode();
+    protected BillingModeGenerator getBillingModeGenerator() {
+        return new DefaultBillingModeGenerator();
+    }
 
     protected abstract BillingPeriod getBillingPeriod();
 
+    protected abstract BillingMode getBillingMode();
+
+
     protected void testCalculateNumberOfBillingCycles(final LocalDate startDate, final LocalDate targetDate, final int billingCycleDay, final BigDecimal expectedValue) throws InvalidDateSequenceException {
         try {
             final BigDecimal numberOfBillingCycles;
@@ -65,7 +72,7 @@ public abstract class ProRationTestBase extends InvoiceTestSuiteNoDB {
     }
 
     protected BigDecimal calculateNumberOfBillingCycles(final LocalDate startDate, final LocalDate endDate, final LocalDate targetDate, final int billingCycleDay) throws InvalidDateSequenceException {
-        final List<RecurringInvoiceItemData> items = getBillingMode().generateInvoiceItemData(startDate, endDate, targetDate, billingCycleDay, getBillingPeriod());
+        final List<RecurringInvoiceItemData> items = getBillingModeGenerator().generateInvoiceItemData(startDate, endDate, targetDate, billingCycleDay, getBillingPeriod(), getBillingMode());
 
         BigDecimal numberOfBillingCycles = ZERO;
         for (final RecurringInvoiceItemData item : items) {
@@ -76,7 +83,7 @@ public abstract class ProRationTestBase extends InvoiceTestSuiteNoDB {
     }
 
     protected BigDecimal calculateNumberOfBillingCycles(final LocalDate startDate, final LocalDate targetDate, final int billingCycleDay) throws InvalidDateSequenceException {
-        final List<RecurringInvoiceItemData> items = getBillingMode().generateInvoiceItemData(startDate, null, targetDate, billingCycleDay, getBillingPeriod());
+        final List<RecurringInvoiceItemData> items = getBillingModeGenerator().generateInvoiceItemData(startDate, null, targetDate, billingCycleDay, getBillingPeriod(), getBillingMode());
 
         BigDecimal numberOfBillingCycles = ZERO;
         for (final RecurringInvoiceItemData item : items) {