killbill-uncached

invoice: use the tag user api instead of the dao directly Add

6/11/2012 7:39:06 PM

Details

invoice/pom.xml 5(+5 -0)

diff --git a/invoice/pom.xml b/invoice/pom.xml
index 3767525..d2ee9d0 100644
--- a/invoice/pom.xml
+++ b/invoice/pom.xml
@@ -86,6 +86,11 @@
             <scope>test</scope>
         </dependency>
         <dependency>
+            <groupId>org.mockito</groupId>
+            <artifactId>mockito-all</artifactId>
+            <scope>test</scope>
+        </dependency>
+        <dependency>
             <groupId>org.testng</groupId>
             <artifactId>testng</artifactId>
             <scope>test</scope>
diff --git a/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java b/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java
index 779a359..c3bd916 100644
--- a/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java
+++ b/invoice/src/main/java/com/ning/billing/invoice/dao/DefaultInvoiceDao.java
@@ -21,48 +21,48 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.UUID;
 
-import com.ning.billing.ErrorCode;
-import com.ning.billing.catalog.api.Currency;
-import com.ning.billing.invoice.api.InvoiceApiException;
-import com.ning.billing.invoice.model.CreditInvoiceItem;
-import com.ning.billing.invoice.model.DefaultInvoice;
 import org.joda.time.DateTime;
 import org.skife.jdbi.v2.IDBI;
 import org.skife.jdbi.v2.Transaction;
 import org.skife.jdbi.v2.TransactionStatus;
 
 import com.google.inject.Inject;
+import com.ning.billing.ErrorCode;
+import com.ning.billing.catalog.api.Currency;
 import com.ning.billing.invoice.api.Invoice;
+import com.ning.billing.invoice.api.InvoiceApiException;
 import com.ning.billing.invoice.api.InvoiceItem;
 import com.ning.billing.invoice.api.InvoicePayment;
+import com.ning.billing.invoice.model.CreditInvoiceItem;
+import com.ning.billing.invoice.model.DefaultInvoice;
 import com.ning.billing.invoice.model.FixedPriceInvoiceItem;
 import com.ning.billing.invoice.model.RecurringInvoiceItem;
 import com.ning.billing.invoice.notification.NextBillingDatePoster;
 import com.ning.billing.util.ChangeType;
+import com.ning.billing.util.api.TagUserApi;
 import com.ning.billing.util.callcontext.CallContext;
 import com.ning.billing.util.dao.EntityAudit;
 import com.ning.billing.util.dao.ObjectType;
 import com.ning.billing.util.dao.TableName;
 import com.ning.billing.util.tag.ControlTagType;
-import com.ning.billing.util.tag.dao.TagDao;
 
 public class DefaultInvoiceDao implements InvoiceDao {
     private final InvoiceSqlDao invoiceSqlDao;
     private final InvoicePaymentSqlDao invoicePaymentSqlDao;
     private final CreditInvoiceItemSqlDao creditInvoiceItemSqlDao;
-    private final TagDao tagDao;
+    private final TagUserApi tagUserApi;
 
-	private final NextBillingDatePoster nextBillingDatePoster;
+    private final NextBillingDatePoster nextBillingDatePoster;
 
     @Inject
     public DefaultInvoiceDao(final IDBI dbi,
                              final NextBillingDatePoster nextBillingDatePoster,
-                             final TagDao tagDao) {
+                             final TagUserApi tagUserApi) {
         this.invoiceSqlDao = dbi.onDemand(InvoiceSqlDao.class);
         this.invoicePaymentSqlDao = dbi.onDemand(InvoicePaymentSqlDao.class);
         this.creditInvoiceItemSqlDao = dbi.onDemand(CreditInvoiceItemSqlDao.class);
         this.nextBillingDatePoster = nextBillingDatePoster;
-        this.tagDao = tagDao;
+        this.tagUserApi = tagUserApi;
     }
 
     @Override
@@ -70,7 +70,7 @@ public class DefaultInvoiceDao implements InvoiceDao {
         return invoiceSqlDao.inTransaction(new Transaction<List<Invoice>, InvoiceSqlDao>() {
             @Override
             public List<Invoice> inTransaction(final InvoiceSqlDao invoiceDao, final TransactionStatus status) throws Exception {
-                List<Invoice> invoices = invoiceDao.getInvoicesByAccount(accountId.toString());
+                final List<Invoice> invoices = invoiceDao.getInvoicesByAccount(accountId.toString());
 
                 populateChildren(invoices, invoiceDao);
 
@@ -81,16 +81,16 @@ public class DefaultInvoiceDao implements InvoiceDao {
 
     @Override
     public List<Invoice> getAllInvoicesByAccount(final UUID accountId) {
-    	return invoiceSqlDao.inTransaction(new Transaction<List<Invoice>, InvoiceSqlDao>() {
-    		@Override
-    		public List<Invoice> inTransaction(final InvoiceSqlDao invoiceDao, final TransactionStatus status) throws Exception {
-    			List<Invoice> invoices = invoiceDao.getAllInvoicesByAccount(accountId.toString());
+        return invoiceSqlDao.inTransaction(new Transaction<List<Invoice>, InvoiceSqlDao>() {
+            @Override
+            public List<Invoice> inTransaction(final InvoiceSqlDao invoiceDao, final TransactionStatus status) throws Exception {
+                final List<Invoice> invoices = invoiceDao.getAllInvoicesByAccount(accountId.toString());
 
                 populateChildren(invoices, invoiceDao);
 
-    			return invoices;
-    		}
-    	});
+                return invoices;
+            }
+        });
     }
 
     @Override
@@ -98,7 +98,7 @@ public class DefaultInvoiceDao implements InvoiceDao {
         return invoiceSqlDao.inTransaction(new Transaction<List<Invoice>, InvoiceSqlDao>() {
             @Override
             public List<Invoice> inTransaction(final InvoiceSqlDao invoiceDao, final TransactionStatus status) throws Exception {
-                List<Invoice> invoices = invoiceDao.getInvoicesByAccountAfterDate(accountId.toString(), fromDate.toDate());
+                final List<Invoice> invoices = invoiceDao.getInvoicesByAccountAfterDate(accountId.toString(), fromDate.toDate());
 
                 populateChildren(invoices, invoiceDao);
 
@@ -110,14 +110,14 @@ public class DefaultInvoiceDao implements InvoiceDao {
     @Override
     public List<Invoice> get() {
         return invoiceSqlDao.inTransaction(new Transaction<List<Invoice>, InvoiceSqlDao>() {
-             @Override
-             public List<Invoice> inTransaction(final InvoiceSqlDao invoiceDao, final TransactionStatus status) throws Exception {
-                 List<Invoice> invoices = invoiceDao.get();
+            @Override
+            public List<Invoice> inTransaction(final InvoiceSqlDao invoiceDao, final TransactionStatus status) throws Exception {
+                final List<Invoice> invoices = invoiceDao.get();
 
-                 populateChildren(invoices, invoiceDao);
+                populateChildren(invoices, invoiceDao);
 
-                 return invoices;
-             }
+                return invoices;
+            }
         });
     }
 
@@ -126,7 +126,7 @@ public class DefaultInvoiceDao implements InvoiceDao {
         return invoiceSqlDao.inTransaction(new Transaction<Invoice, InvoiceSqlDao>() {
             @Override
             public Invoice inTransaction(final InvoiceSqlDao invoiceDao, final TransactionStatus status) throws Exception {
-                Invoice invoice = invoiceDao.getById(invoiceId.toString());
+                final Invoice invoice = invoiceDao.getById(invoiceId.toString());
 
                 if (invoice != null) {
                     populateChildren(invoice, invoiceDao);
@@ -139,45 +139,45 @@ public class DefaultInvoiceDao implements InvoiceDao {
 
     @Override
     public void create(final Invoice invoice, final CallContext context) {
-        
+
         invoiceSqlDao.inTransaction(new Transaction<Void, InvoiceSqlDao>() {
             @Override
             public Void inTransaction(final InvoiceSqlDao transactional, final TransactionStatus status) throws Exception {
 
                 // STEPH this seems useless
-                Invoice currentInvoice = transactional.getById(invoice.getId().toString());
+                final Invoice currentInvoice = transactional.getById(invoice.getId().toString());
 
                 if (currentInvoice == null) {
-                    List<EntityAudit> audits = new ArrayList<EntityAudit>();
+                    final List<EntityAudit> audits = new ArrayList<EntityAudit>();
 
                     transactional.create(invoice, context);
-                    Long recordId = transactional.getRecordId(invoice.getId().toString());
+                    final Long recordId = transactional.getRecordId(invoice.getId().toString());
                     audits.add(new EntityAudit(TableName.INVOICES, recordId, ChangeType.INSERT));
 
                     List<Long> recordIdList;
 
-                    List<InvoiceItem> recurringInvoiceItems = invoice.getInvoiceItems(RecurringInvoiceItem.class);
-                    RecurringInvoiceItemSqlDao recurringInvoiceItemDao = transactional.become(RecurringInvoiceItemSqlDao.class);
+                    final List<InvoiceItem> recurringInvoiceItems = invoice.getInvoiceItems(RecurringInvoiceItem.class);
+                    final RecurringInvoiceItemSqlDao recurringInvoiceItemDao = transactional.become(RecurringInvoiceItemSqlDao.class);
                     recurringInvoiceItemDao.batchCreateFromTransaction(recurringInvoiceItems, context);
                     recordIdList = recurringInvoiceItemDao.getRecordIds(invoice.getId().toString());
                     audits.addAll(createAudits(TableName.RECURRING_INVOICE_ITEMS, recordIdList));
 
                     notifyOfFutureBillingEvents(transactional, recurringInvoiceItems);
 
-                    List<InvoiceItem> fixedPriceInvoiceItems = invoice.getInvoiceItems(FixedPriceInvoiceItem.class);
-                    FixedPriceInvoiceItemSqlDao fixedPriceInvoiceItemDao = transactional.become(FixedPriceInvoiceItemSqlDao.class);
+                    final List<InvoiceItem> fixedPriceInvoiceItems = invoice.getInvoiceItems(FixedPriceInvoiceItem.class);
+                    final FixedPriceInvoiceItemSqlDao fixedPriceInvoiceItemDao = transactional.become(FixedPriceInvoiceItemSqlDao.class);
                     fixedPriceInvoiceItemDao.batchCreateFromTransaction(fixedPriceInvoiceItems, context);
                     recordIdList = fixedPriceInvoiceItemDao.getRecordIds(invoice.getId().toString());
                     audits.addAll(createAudits(TableName.FIXED_INVOICE_ITEMS, recordIdList));
 
-                    List<InvoiceItem> creditInvoiceItems = invoice.getInvoiceItems(CreditInvoiceItem.class);
-                    CreditInvoiceItemSqlDao creditInvoiceItemSqlDao = transactional.become(CreditInvoiceItemSqlDao.class);
+                    final List<InvoiceItem> creditInvoiceItems = invoice.getInvoiceItems(CreditInvoiceItem.class);
+                    final CreditInvoiceItemSqlDao creditInvoiceItemSqlDao = transactional.become(CreditInvoiceItemSqlDao.class);
                     creditInvoiceItemSqlDao.batchCreateFromTransaction(creditInvoiceItems, context);
                     recordIdList = creditInvoiceItemSqlDao.getRecordIds(invoice.getId().toString());
                     audits.addAll(createAudits(TableName.CREDIT_INVOICE_ITEMS, recordIdList));
 
-                    List<InvoicePayment> invoicePayments = invoice.getPayments();
-                    InvoicePaymentSqlDao invoicePaymentSqlDao = transactional.become(InvoicePaymentSqlDao.class);
+                    final List<InvoicePayment> invoicePayments = invoice.getPayments();
+                    final InvoicePaymentSqlDao invoicePaymentSqlDao = transactional.become(InvoicePaymentSqlDao.class);
                     invoicePaymentSqlDao.batchCreateFromTransaction(invoicePayments, context);
                     recordIdList = invoicePaymentSqlDao.getRecordIds(invoice.getId().toString());
                     audits.addAll(createAudits(TableName.INVOICE_PAYMENTS, recordIdList));
@@ -191,8 +191,8 @@ public class DefaultInvoiceDao implements InvoiceDao {
     }
 
     private List<EntityAudit> createAudits(final TableName tableName, final List<Long> recordIdList) {
-        List<EntityAudit> entityAuditList = new ArrayList<EntityAudit>();
-        for (Long recordId : recordIdList) {
+        final List<EntityAudit> entityAuditList = new ArrayList<EntityAudit>();
+        for (final Long recordId : recordIdList) {
             entityAuditList.add(new EntityAudit(tableName, recordId, ChangeType.INSERT));
         }
 
@@ -204,7 +204,7 @@ public class DefaultInvoiceDao implements InvoiceDao {
         return invoiceSqlDao.inTransaction(new Transaction<List<Invoice>, InvoiceSqlDao>() {
             @Override
             public List<Invoice> inTransaction(final InvoiceSqlDao invoiceDao, final TransactionStatus status) throws Exception {
-                List<Invoice> invoices = invoiceDao.getInvoicesBySubscription(subscriptionId.toString());
+                final List<Invoice> invoices = invoiceDao.getInvoicesBySubscription(subscriptionId.toString());
 
                 populateChildren(invoices, invoiceDao);
 
@@ -222,12 +222,12 @@ public class DefaultInvoiceDao implements InvoiceDao {
     public void notifyOfPaymentAttempt(final InvoicePayment invoicePayment, final CallContext context) {
         invoicePaymentSqlDao.inTransaction(new Transaction<Void, InvoicePaymentSqlDao>() {
             @Override
-            public Void inTransaction(InvoicePaymentSqlDao transactional, TransactionStatus status) throws Exception {
+            public Void inTransaction(final InvoicePaymentSqlDao transactional, final TransactionStatus status) throws Exception {
                 transactional.notifyOfPaymentAttempt(invoicePayment, context);
 
-                String invoicePaymentId = invoicePayment.getId().toString();
-                Long recordId = transactional.getRecordId(invoicePaymentId);
-                EntityAudit audit = new EntityAudit(TableName.INVOICE_PAYMENTS, recordId, ChangeType.INSERT);
+                final String invoicePaymentId = invoicePayment.getId().toString();
+                final Long recordId = transactional.getRecordId(invoicePaymentId);
+                final EntityAudit audit = new EntityAudit(TableName.INVOICE_PAYMENTS, recordId, ChangeType.INSERT);
                 transactional.insertAuditFromTransaction(audit, context);
 
                 return null;
@@ -240,7 +240,7 @@ public class DefaultInvoiceDao implements InvoiceDao {
         return invoiceSqlDao.inTransaction(new Transaction<List<Invoice>, InvoiceSqlDao>() {
             @Override
             public List<Invoice> inTransaction(final InvoiceSqlDao invoiceDao, final TransactionStatus status) throws Exception {
-                List<Invoice> invoices = invoiceSqlDao.getUnpaidInvoicesByAccountId(accountId.toString(), upToDate.toDate());
+                final List<Invoice> invoices = invoiceSqlDao.getUnpaidInvoicesByAccountId(accountId.toString(), upToDate.toDate());
 
                 populateChildren(invoices, invoiceDao);
 
@@ -261,17 +261,17 @@ public class DefaultInvoiceDao implements InvoiceDao {
 
     @Override
     public void setWrittenOff(final UUID invoiceId, final CallContext context) {
-        tagDao.insertTag(invoiceId, ObjectType.INVOICE, ControlTagType.WRITTEN_OFF.toTagDefinition(), context);
+        tagUserApi.addTag(invoiceId, ObjectType.INVOICE, ControlTagType.WRITTEN_OFF.toTagDefinition(), context);
     }
 
     @Override
     public void removeWrittenOff(final UUID invoiceId, final CallContext context) throws InvoiceApiException {
-        tagDao.deleteTag(invoiceId, ObjectType.INVOICE, ControlTagType.WRITTEN_OFF.toTagDefinition(), context);
+        tagUserApi.removeTag(invoiceId, ObjectType.INVOICE, ControlTagType.WRITTEN_OFF.toTagDefinition(), context);
     }
 
     @Override
     public void postChargeback(final UUID invoicePaymentId, final BigDecimal amount, final CallContext context) throws InvoiceApiException {
-        InvoicePayment payment = invoicePaymentSqlDao.getById(invoicePaymentId.toString());
+        final InvoicePayment payment = invoicePaymentSqlDao.getById(invoicePaymentId.toString());
         if (payment == null) {
             throw new InvoiceApiException(ErrorCode.INVOICE_PAYMENT_NOT_FOUND, invoicePaymentId.toString());
         } else {
@@ -279,20 +279,20 @@ public class DefaultInvoiceDao implements InvoiceDao {
                 throw new InvoiceApiException(ErrorCode.CHARGE_BACK_AMOUNT_IS_NEGATIVE);
             }
 
-            InvoicePayment chargeBack = payment.asChargeBack(amount, context.getCreatedDate());
+            final InvoicePayment chargeBack = payment.asChargeBack(amount, context.getCreatedDate());
             invoicePaymentSqlDao.create(chargeBack, context);
         }
     }
 
     @Override
-    public BigDecimal getRemainingAmountPaid(UUID invoicePaymentId) {
-        BigDecimal amount = invoicePaymentSqlDao.getRemainingAmountPaid(invoicePaymentId.toString());
+    public BigDecimal getRemainingAmountPaid(final UUID invoicePaymentId) {
+        final BigDecimal amount = invoicePaymentSqlDao.getRemainingAmountPaid(invoicePaymentId.toString());
         return amount == null ? BigDecimal.ZERO : amount;
     }
 
     @Override
-    public UUID getAccountIdFromInvoicePaymentId(UUID invoicePaymentId) throws InvoiceApiException {
-        UUID accountId = invoicePaymentSqlDao.getAccountIdFromInvoicePaymentId(invoicePaymentId.toString());
+    public UUID getAccountIdFromInvoicePaymentId(final UUID invoicePaymentId) throws InvoiceApiException {
+        final UUID accountId = invoicePaymentSqlDao.getAccountIdFromInvoicePaymentId(invoicePaymentId.toString());
         if (accountId == null) {
             throw new InvoiceApiException(ErrorCode.CHARGE_BACK_COULD_NOT_FIND_ACCOUNT_ID, invoicePaymentId);
         } else {
@@ -301,7 +301,7 @@ public class DefaultInvoiceDao implements InvoiceDao {
     }
 
     @Override
-    public List<InvoicePayment> getChargebacksByAccountId(UUID accountId) {
+    public List<InvoicePayment> getChargebacksByAccountId(final UUID accountId) {
         return invoicePaymentSqlDao.getChargeBacksByAccountId(accountId.toString());
     }
 
@@ -311,8 +311,8 @@ public class DefaultInvoiceDao implements InvoiceDao {
     }
 
     @Override
-    public InvoicePayment getChargebackById(UUID chargebackId) throws InvoiceApiException {
-        InvoicePayment chargeback = invoicePaymentSqlDao.getById(chargebackId.toString());
+    public InvoicePayment getChargebackById(final UUID chargebackId) throws InvoiceApiException {
+        final InvoicePayment chargeback = invoicePaymentSqlDao.getById(chargebackId.toString());
         if (chargeback == null) {
             throw new InvoiceApiException(ErrorCode.CHARGE_BACK_DOES_NOT_EXIST, chargebackId);
         } else {
@@ -328,12 +328,12 @@ public class DefaultInvoiceDao implements InvoiceDao {
     // TODO: make this transactional
     @Override
     public InvoiceItem insertCredit(final UUID accountId, final BigDecimal amount,
-                             final DateTime effectiveDate, final Currency currency,
-                             final CallContext context) {
-        Invoice invoice = new DefaultInvoice(accountId, effectiveDate, effectiveDate, currency);
+                                    final DateTime effectiveDate, final Currency currency,
+                                    final CallContext context) {
+        final Invoice invoice = new DefaultInvoice(accountId, effectiveDate, effectiveDate, currency);
         invoiceSqlDao.create(invoice, context);
 
-        InvoiceItem credit = new CreditInvoiceItem(invoice.getId(), accountId, effectiveDate, amount, currency);
+        final InvoiceItem credit = new CreditInvoiceItem(invoice.getId(), accountId, effectiveDate, amount, currency);
         creditInvoiceItemSqlDao.create(credit, context);
 
         return credit;
@@ -344,12 +344,12 @@ public class DefaultInvoiceDao implements InvoiceDao {
         invoiceSqlDao.test();
     }
 
-    private void populateChildren(final Invoice invoice, InvoiceSqlDao invoiceSqlDao) {
+    private void populateChildren(final Invoice invoice, final InvoiceSqlDao invoiceSqlDao) {
         getInvoiceItemsWithinTransaction(invoice, invoiceSqlDao);
         getInvoicePaymentsWithinTransaction(invoice, invoiceSqlDao);
     }
 
-    private void populateChildren(List<Invoice> invoices, InvoiceSqlDao invoiceSqlDao) {
+    private void populateChildren(final List<Invoice> invoices, final InvoiceSqlDao invoiceSqlDao) {
         getInvoiceItemsWithinTransaction(invoices, invoiceSqlDao);
         getInvoicePaymentsWithinTransaction(invoices, invoiceSqlDao);
     }
@@ -361,42 +361,42 @@ public class DefaultInvoiceDao implements InvoiceDao {
     }
 
     private void getInvoiceItemsWithinTransaction(final Invoice invoice, final InvoiceSqlDao invoiceDao) {
-        String invoiceId = invoice.getId().toString();
+        final String invoiceId = invoice.getId().toString();
 
-        RecurringInvoiceItemSqlDao recurringInvoiceItemDao = invoiceDao.become(RecurringInvoiceItemSqlDao.class);
-        List<InvoiceItem> recurringInvoiceItems = recurringInvoiceItemDao.getInvoiceItemsByInvoice(invoiceId);
+        final RecurringInvoiceItemSqlDao recurringInvoiceItemDao = invoiceDao.become(RecurringInvoiceItemSqlDao.class);
+        final List<InvoiceItem> recurringInvoiceItems = recurringInvoiceItemDao.getInvoiceItemsByInvoice(invoiceId);
         invoice.addInvoiceItems(recurringInvoiceItems);
 
-        FixedPriceInvoiceItemSqlDao fixedPriceInvoiceItemDao = invoiceDao.become(FixedPriceInvoiceItemSqlDao.class);
-        List<InvoiceItem> fixedPriceInvoiceItems = fixedPriceInvoiceItemDao.getInvoiceItemsByInvoice(invoiceId);
+        final FixedPriceInvoiceItemSqlDao fixedPriceInvoiceItemDao = invoiceDao.become(FixedPriceInvoiceItemSqlDao.class);
+        final List<InvoiceItem> fixedPriceInvoiceItems = fixedPriceInvoiceItemDao.getInvoiceItemsByInvoice(invoiceId);
         invoice.addInvoiceItems(fixedPriceInvoiceItems);
 
-        CreditInvoiceItemSqlDao creditInvoiceItemSqlDao = invoiceDao.become(CreditInvoiceItemSqlDao.class);
-        List<InvoiceItem> creditInvoiceItems = creditInvoiceItemSqlDao.getInvoiceItemsByInvoice(invoiceId);
+        final CreditInvoiceItemSqlDao creditInvoiceItemSqlDao = invoiceDao.become(CreditInvoiceItemSqlDao.class);
+        final List<InvoiceItem> creditInvoiceItems = creditInvoiceItemSqlDao.getInvoiceItemsByInvoice(invoiceId);
         invoice.addInvoiceItems(creditInvoiceItems);
     }
 
     private void getInvoicePaymentsWithinTransaction(final List<Invoice> invoices, final InvoiceSqlDao invoiceDao) {
-        for (Invoice invoice : invoices) {
+        for (final Invoice invoice : invoices) {
             getInvoicePaymentsWithinTransaction(invoice, invoiceDao);
         }
     }
 
     private void getInvoicePaymentsWithinTransaction(final Invoice invoice, final InvoiceSqlDao invoiceSqlDao) {
-        InvoicePaymentSqlDao invoicePaymentSqlDao = invoiceSqlDao.become(InvoicePaymentSqlDao.class);
-        String invoiceId = invoice.getId().toString();
-        List<InvoicePayment> invoicePayments = invoicePaymentSqlDao.getPaymentsForInvoice(invoiceId);
+        final InvoicePaymentSqlDao invoicePaymentSqlDao = invoiceSqlDao.become(InvoicePaymentSqlDao.class);
+        final String invoiceId = invoice.getId().toString();
+        final List<InvoicePayment> invoicePayments = invoicePaymentSqlDao.getPaymentsForInvoice(invoiceId);
         invoice.addPayments(invoicePayments);
     }
 
     private void notifyOfFutureBillingEvents(final InvoiceSqlDao dao, final List<InvoiceItem> invoiceItems) {
         for (final InvoiceItem item : invoiceItems) {
             if (item instanceof RecurringInvoiceItem) {
-                RecurringInvoiceItem recurringInvoiceItem = (RecurringInvoiceItem) item;
+                final RecurringInvoiceItem recurringInvoiceItem = (RecurringInvoiceItem) item;
                 if ((recurringInvoiceItem.getEndDate() != null) &&
                         (recurringInvoiceItem.getAmount() == null ||
                                 recurringInvoiceItem.getAmount().compareTo(BigDecimal.ZERO) >= 0)) {
-                	nextBillingDatePoster.insertNextBillingNotification(dao, item.getSubscriptionId(), recurringInvoiceItem.getEndDate());
+                    nextBillingDatePoster.insertNextBillingNotification(dao, item.getSubscriptionId(), recurringInvoiceItem.getEndDate());
                 }
             }
         }
diff --git a/invoice/src/test/java/com/ning/billing/invoice/dao/InvoiceDaoTestBase.java b/invoice/src/test/java/com/ning/billing/invoice/dao/InvoiceDaoTestBase.java
index afba58b..70ca921 100644
--- a/invoice/src/test/java/com/ning/billing/invoice/dao/InvoiceDaoTestBase.java
+++ b/invoice/src/test/java/com/ning/billing/invoice/dao/InvoiceDaoTestBase.java
@@ -23,10 +23,13 @@ import java.io.IOException;
 import com.ning.billing.dbi.MysqlTestingHelper;
 import com.ning.billing.invoice.notification.MockNextBillingDatePoster;
 import com.ning.billing.invoice.notification.NextBillingDatePoster;
+import com.ning.billing.util.api.TagUserApi;
 import com.ning.billing.util.callcontext.TestCallContext;
 import com.ning.billing.util.clock.ClockMock;
+import com.ning.billing.util.tag.api.DefaultTagUserApi;
 import com.ning.billing.util.tag.dao.AuditedTagDao;
 import com.ning.billing.util.tag.dao.MockTagDao;
+import com.ning.billing.util.tag.dao.MockTagDefinitionDao;
 import com.ning.billing.util.tag.dao.TagDao;
 import org.apache.commons.io.IOUtils;
 import org.skife.jdbi.v2.Handle;
@@ -43,6 +46,7 @@ import com.ning.billing.invoice.model.InvoiceGenerator;
 import com.ning.billing.invoice.tests.InvoicingTestBase;
 import com.ning.billing.util.callcontext.CallContext;
 import com.ning.billing.util.clock.Clock;
+import com.ning.billing.util.tag.dao.TagDefinitionDao;
 
 public abstract class InvoiceDaoTestBase extends InvoicingTestBase {
     protected IDBI dbi;
@@ -78,8 +82,10 @@ public abstract class InvoiceDaoTestBase extends InvoicingTestBase {
         mysqlTestingHelper.initDb(utilDdl);
 
         NextBillingDatePoster nextBillingDatePoster = new MockNextBillingDatePoster();
-        TagDao tagDao = new AuditedTagDao(dbi);
-        invoiceDao = new DefaultInvoiceDao(dbi, nextBillingDatePoster, tagDao);
+        final TagDefinitionDao tagDefinitionDao = new MockTagDefinitionDao();
+        final TagDao tagDao = new AuditedTagDao(dbi);
+        final TagUserApi tagUserApi = new DefaultTagUserApi(tagDefinitionDao, tagDao);
+        invoiceDao = new DefaultInvoiceDao(dbi, nextBillingDatePoster, tagUserApi);
         invoiceDao.test();
 
         recurringInvoiceItemDao = dbi.onDemand(RecurringInvoiceItemSqlDao.class);
diff --git a/invoice/src/test/java/com/ning/billing/invoice/dao/TestDefaultInvoiceDao.java b/invoice/src/test/java/com/ning/billing/invoice/dao/TestDefaultInvoiceDao.java
new file mode 100644
index 0000000..169f2ff
--- /dev/null
+++ b/invoice/src/test/java/com/ning/billing/invoice/dao/TestDefaultInvoiceDao.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright 2010-2012 Ning, Inc.
+ *
+ * Ning licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License.  You may obtain a copy of the License at:
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.ning.billing.invoice.dao;
+
+import java.util.Map;
+import java.util.UUID;
+
+import org.mockito.Mockito;
+import org.skife.jdbi.v2.IDBI;
+import org.testng.Assert;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+import com.ning.billing.invoice.notification.NextBillingDatePoster;
+import com.ning.billing.util.api.TagUserApi;
+import com.ning.billing.util.callcontext.CallContext;
+import com.ning.billing.util.dao.ObjectType;
+import com.ning.billing.util.tag.ControlTagType;
+import com.ning.billing.util.tag.Tag;
+import com.ning.billing.util.tag.api.DefaultTagUserApi;
+import com.ning.billing.util.tag.dao.MockTagDao;
+import com.ning.billing.util.tag.dao.MockTagDefinitionDao;
+import com.ning.billing.util.tag.dao.TagDao;
+import com.ning.billing.util.tag.dao.TagDefinitionDao;
+
+public class TestDefaultInvoiceDao {
+    private TagUserApi tagUserApi;
+    private DefaultInvoiceDao dao;
+
+    @BeforeMethod(groups = "fast")
+    public void setUp() throws Exception {
+        final IDBI idbi = Mockito.mock(IDBI.class);
+        final NextBillingDatePoster poster = Mockito.mock(NextBillingDatePoster.class);
+        final TagDefinitionDao tagDefinitionDao = new MockTagDefinitionDao();
+        final TagDao tagDao = new MockTagDao();
+        tagUserApi = new DefaultTagUserApi(tagDefinitionDao, tagDao);
+        dao = new DefaultInvoiceDao(idbi, poster, tagUserApi);
+    }
+
+    @Test(groups = "fast")
+    public void testSetWrittenOff() throws Exception {
+        final UUID invoiceId = UUID.randomUUID();
+
+        final Map<String, Tag> beforeTags = tagUserApi.getTags(invoiceId, ObjectType.INVOICE);
+        Assert.assertEquals(beforeTags.keySet().size(), 0);
+
+        dao.setWrittenOff(invoiceId, Mockito.mock(CallContext.class));
+
+        final Map<String, Tag> afterTags = tagUserApi.getTags(invoiceId, ObjectType.INVOICE);
+        Assert.assertEquals(afterTags.keySet().size(), 1);
+        final String name = ControlTagType.WRITTEN_OFF.toTagDefinition().getName();
+        Assert.assertEquals(afterTags.get(name).getTagDefinitionName(), name);
+    }
+
+    @Test(groups = "fast")
+    public void testRemoveWrittenOff() throws Exception {
+        final UUID invoiceId = UUID.randomUUID();
+
+        dao.setWrittenOff(invoiceId, Mockito.mock(CallContext.class));
+
+        final Map<String, Tag> beforeTags = tagUserApi.getTags(invoiceId, ObjectType.INVOICE);
+        Assert.assertEquals(beforeTags.keySet().size(), 1);
+        final String name = ControlTagType.WRITTEN_OFF.toTagDefinition().getName();
+        Assert.assertEquals(beforeTags.get(name).getTagDefinitionName(), name);
+
+        dao.removeWrittenOff(invoiceId, Mockito.mock(CallContext.class));
+
+        final Map<String, Tag> afterTags = tagUserApi.getTags(invoiceId, ObjectType.INVOICE);
+        Assert.assertEquals(afterTags.keySet().size(), 0);
+    }
+}
diff --git a/invoice/src/test/java/com/ning/billing/invoice/tests/ChargeBackTests.java b/invoice/src/test/java/com/ning/billing/invoice/tests/ChargeBackTests.java
index 6eac0dc..8d5ac43 100644
--- a/invoice/src/test/java/com/ning/billing/invoice/tests/ChargeBackTests.java
+++ b/invoice/src/test/java/com/ning/billing/invoice/tests/ChargeBackTests.java
@@ -32,12 +32,17 @@ import com.ning.billing.invoice.notification.MockNextBillingDatePoster;
 import com.ning.billing.invoice.notification.NextBillingDatePoster;
 import com.ning.billing.mock.BrainDeadProxyFactory;
 import com.ning.billing.mock.BrainDeadProxyFactory.ZombieControl;
+import com.ning.billing.util.api.TagUserApi;
 import com.ning.billing.util.callcontext.CallContext;
 import com.ning.billing.util.callcontext.TestCallContext;
 import com.ning.billing.util.clock.Clock;
 import com.ning.billing.util.clock.ClockMock;
+import com.ning.billing.util.tag.api.DefaultTagUserApi;
 import com.ning.billing.util.tag.dao.MockTagDao;
+import com.ning.billing.util.tag.dao.MockTagDefinitionDao;
 import com.ning.billing.util.tag.dao.TagDao;
+import com.ning.billing.util.tag.dao.TagDefinitionDao;
+
 import org.skife.jdbi.v2.IDBI;
 import org.testng.annotations.BeforeClass;
 import org.testng.annotations.Test;
@@ -71,8 +76,10 @@ public class ChargeBackTests {
         invoiceSqlDao.test();
 
         NextBillingDatePoster nextBillingDatePoster = new MockNextBillingDatePoster();
-        TagDao tagDao = new MockTagDao();
-        InvoiceDao invoiceDao = new DefaultInvoiceDao(dbi, nextBillingDatePoster, tagDao);
+        final TagDefinitionDao tagDefinitionDao = new MockTagDefinitionDao();
+        final TagDao tagDao = new MockTagDao();
+        final TagUserApi tagUserApi = new DefaultTagUserApi(tagDefinitionDao, tagDao);
+        InvoiceDao invoiceDao = new DefaultInvoiceDao(dbi, nextBillingDatePoster, tagUserApi);
         invoicePaymentApi = new DefaultInvoicePaymentApi(invoiceDao);
 
         context = new TestCallContext("Charge back tests");
diff --git a/util/src/test/java/com/ning/billing/util/tag/dao/MockTagDao.java b/util/src/test/java/com/ning/billing/util/tag/dao/MockTagDao.java
index e67ef75..9253c07 100644
--- a/util/src/test/java/com/ning/billing/util/tag/dao/MockTagDao.java
+++ b/util/src/test/java/com/ning/billing/util/tag/dao/MockTagDao.java
@@ -16,6 +16,8 @@
 
 package com.ning.billing.util.tag.dao;
 
+import javax.annotation.Nullable;
+
 import com.ning.billing.util.callcontext.CallContext;
 import com.ning.billing.util.dao.ObjectType;
 import com.ning.billing.util.tag.Tag;
@@ -53,10 +55,12 @@ public class MockTagDao implements TagDao {
         return getMap(tagStore.get(objectId));
     }
 
-    private Map<String, Tag> getMap(List<Tag> tags) {
+    private Map<String, Tag> getMap(@Nullable final List<Tag> tags) {
         Map<String, Tag> map = new HashMap<String, Tag>();
-        for (Tag tag : tags) {
-            map.put(tag.getTagDefinitionName(), tag);
+        if (tags != null) {
+            for (Tag tag : tags) {
+                map.put(tag.getTagDefinitionName(), tag);
+            }
         }
         return map;
     }