killbill-memoizeit

util: fix JDBC connection leak in pagination API This fixes

1/31/2018 2:51:32 AM

Details

diff --git a/util/src/main/java/org/killbill/billing/util/entity/dao/DefaultPaginationSqlDaoHelper.java b/util/src/main/java/org/killbill/billing/util/entity/dao/DefaultPaginationSqlDaoHelper.java
index 574d3b0..2216e44 100644
--- a/util/src/main/java/org/killbill/billing/util/entity/dao/DefaultPaginationSqlDaoHelper.java
+++ b/util/src/main/java/org/killbill/billing/util/entity/dao/DefaultPaginationSqlDaoHelper.java
@@ -1,7 +1,7 @@
 /*
  * Copyright 2010-2014 Ning, Inc.
- * Copyright 2014-2015 Groupon, Inc
- * Copyright 2014-2015 The Billing Project, LLC
+ * Copyright 2014-2018 Groupon, Inc
+ * Copyright 2014-2018 The Billing Project, LLC
  *
  * The Billing Project licenses this file to you under the Apache License, version 2.0
  * (the "License"); you may not use this file except in compliance with the
@@ -18,6 +18,8 @@
 
 package org.killbill.billing.util.entity.dao;
 
+import java.io.Closeable;
+import java.io.IOException;
 import java.util.Iterator;
 
 import javax.annotation.Nullable;
@@ -26,9 +28,13 @@ import org.killbill.billing.callcontext.InternalTenantContext;
 import org.killbill.billing.util.entity.DefaultPagination;
 import org.killbill.billing.util.entity.Entity;
 import org.killbill.billing.util.entity.Pagination;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 public class DefaultPaginationSqlDaoHelper {
 
+    private static final Logger logger = LoggerFactory.getLogger(DefaultPaginationSqlDaoHelper.class);
+
     // Number large enough so that small installations have access to an accurate count
     // but small enough to not impact very large deployments
     // TODO Should this be configurable per tenant?
@@ -66,14 +72,35 @@ public class DefaultPaginationSqlDaoHelper {
         // We usually always want to wrap our queries in an EntitySqlDaoTransactionWrapper... except here.
         // Since we want to stream the results out, we don't want to auto-commit when this method returns.
         final EntitySqlDao<M, E> sqlDao = transactionalSqlDao.onDemandForStreamingResults(sqlDaoClazz);
-        // The count to get maxNbRecords can be expensive on very large datasets. As a heuristic to check how large that number is,
-        // we retrieve 1 record at offset SIMPLE_PAGINATION_THRESHOLD (pretty fast). If we've found a record, that means the count is larger
-        // than this threshold and we don't issue the full count query
         final Long maxNbRecords;
-        if (context == null || paginationIteratorBuilder.build((S) sqlDao, SIMPLE_PAGINATION_THRESHOLD, 1L, ordering, context).hasNext()) {
+        if (context == null) {
             maxNbRecords = null;
         } else {
-            maxNbRecords = sqlDao.getCount(context);
+            // The count to get maxNbRecords can be expensive on very large datasets. As a heuristic to check how large that number is,
+            // we retrieve 1 record at offset SIMPLE_PAGINATION_THRESHOLD (pretty fast). If we've found a record, that means the count is larger
+            // than this threshold and we don't issue the full count query
+            final Iterator<M> simplePaginationIterator = paginationIteratorBuilder.build((S) sqlDao, SIMPLE_PAGINATION_THRESHOLD, 1L, ordering, context);
+            final boolean veryLargeDataSet = simplePaginationIterator.hasNext();
+
+            // Make sure to free resources (https://github.com/killbill/killbill/issues/853)
+            if (simplePaginationIterator instanceof Closeable) {
+                // Always the case with the current implementation (delegateIterator is a org.skife.jdbi.v2.ResultIterator)
+                try {
+                    ((Closeable) simplePaginationIterator).close();
+                } catch (final IOException e) {
+                    logger.warn("Unable to close iterator", e);
+                }
+            } else {
+                while (simplePaginationIterator.hasNext()) {
+                    simplePaginationIterator.next();
+                }
+            }
+
+            if (veryLargeDataSet) {
+                maxNbRecords = null;
+            } else {
+                maxNbRecords = sqlDao.getCount(context);
+            }
         }
         final Iterator<M> results = paginationIteratorBuilder.build((S) sqlDao, offset, limit, ordering, context);