diff --git a/util/src/test/java/org/killbill/billing/DBTestingHelper.java b/util/src/test/java/org/killbill/billing/DBTestingHelper.java
index de0640b..20dd965 100644
--- a/util/src/test/java/org/killbill/billing/DBTestingHelper.java
+++ b/util/src/test/java/org/killbill/billing/DBTestingHelper.java
@@ -1,7 +1,7 @@
/*
* Copyright 2010-2013 Ning, Inc.
- * Copyright 2014-2017 Groupon, Inc
- * Copyright 2014-2017 The Billing Project, LLC
+ * Copyright 2014-2018 Groupon, Inc
+ * Copyright 2014-2018 The Billing Project, LLC
*
* The Billing Project licenses this file to you under the Apache License, version 2.0
* (the "License"); you may not use this file except in compliance with the
@@ -19,18 +19,27 @@
package org.killbill.billing;
import java.io.IOException;
+import java.io.PrintWriter;
import java.net.URL;
+import java.sql.Connection;
+import java.sql.SQLException;
+import java.sql.SQLFeatureNotSupportedException;
+import java.sql.SQLNonTransientConnectionException;
import java.util.Enumeration;
-import java.util.concurrent.atomic.AtomicBoolean;
+
+import javax.sql.DataSource;
import org.killbill.billing.platform.test.PlatformDBTestingHelper;
import org.killbill.billing.util.glue.IDBISetup;
import org.killbill.billing.util.io.IOUtils;
import org.killbill.commons.embeddeddb.EmbeddedDB;
+import org.killbill.commons.jdbi.guice.DBIProvider;
import org.skife.jdbi.v2.DBI;
import org.skife.jdbi.v2.IDBI;
import org.skife.jdbi.v2.ResultSetMapperFactory;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import com.google.common.base.MoreObjects;
@@ -38,7 +47,7 @@ public class DBTestingHelper extends PlatformDBTestingHelper {
private static DBTestingHelper dbTestingHelper = null;
- private AtomicBoolean initialized;
+ private DBI dbi;
public static synchronized DBTestingHelper get() {
if (dbTestingHelper == null) {
@@ -49,18 +58,18 @@ public class DBTestingHelper extends PlatformDBTestingHelper {
private DBTestingHelper() {
super();
- initialized = new AtomicBoolean(false);
}
@Override
- public IDBI getDBI() {
- final DBI dbi = (DBI) super.getDBI();
- // Register KB specific mappers
- if (initialized.compareAndSet(false, true)) {
+ public synchronized IDBI getDBI() {
+ if (dbi == null) {
+ final RetryableDataSource retryableDataSource = new RetryableDataSource(getDataSource());
+ dbi = (DBI) new DBIProvider(null, retryableDataSource, null).get();
+
+ // Register KB specific mappers
for (final ResultSetMapperFactory resultSetMapperFactory : IDBISetup.mapperFactoriesToRegister()) {
dbi.registerMapper(resultSetMapperFactory);
}
-
for (final ResultSetMapper resultSetMapper : IDBISetup.mappersToRegister()) {
dbi.registerMapper(resultSetMapper);
}
@@ -201,4 +210,73 @@ public class DBTestingHelper extends PlatformDBTestingHelper {
}
}
}
+
+ // DataSource which will retry recreating a connection once in case of a connection exception.
+ // This is useful for transient network errors in tests when using a separate database (e.g. Docker container),
+ // as we don't use a connection pool.
+ private static final class RetryableDataSource implements DataSource {
+
+ private static final Logger logger = LoggerFactory.getLogger(RetryableDataSource.class);
+
+ private final DataSource delegate;
+
+ private RetryableDataSource(final DataSource delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public Connection getConnection() throws SQLException {
+ try {
+ return delegate.getConnection();
+ } catch (final SQLNonTransientConnectionException e) {
+ logger.warn("Unable to retrieve connection, attempting to retry", e);
+ return delegate.getConnection();
+ }
+ }
+
+ @Override
+ public Connection getConnection(final String username, final String password) throws SQLException {
+ try {
+ return delegate.getConnection(username, password);
+ } catch (final SQLNonTransientConnectionException e) {
+ logger.warn("Unable to retrieve connection, attempting to retry", e);
+ return delegate.getConnection(username, password);
+ }
+ }
+
+ @Override
+ public <T> T unwrap(final Class<T> iface) throws SQLException {
+ return delegate.unwrap(iface);
+ }
+
+ @Override
+ public boolean isWrapperFor(final Class<?> iface) throws SQLException {
+ return delegate.isWrapperFor(iface);
+ }
+
+ @Override
+ public PrintWriter getLogWriter() throws SQLException {
+ return delegate.getLogWriter();
+ }
+
+ @Override
+ public void setLogWriter(final PrintWriter out) throws SQLException {
+ delegate.setLogWriter(out);
+ }
+
+ @Override
+ public void setLoginTimeout(final int seconds) throws SQLException {
+ delegate.setLoginTimeout(seconds);
+ }
+
+ @Override
+ public int getLoginTimeout() throws SQLException {
+ return delegate.getLoginTimeout();
+ }
+
+ //@Override
+ public java.util.logging.Logger getParentLogger() throws SQLFeatureNotSupportedException {
+ throw new SQLFeatureNotSupportedException("javax.sql.DataSource.getParentLogger() is not currently supported by " + this.getClass().getName());
+ }
+ }
}