diff --git a/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java b/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java
index 0b0188b..01a8729 100644
--- a/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java
+++ b/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java
@@ -24,6 +24,8 @@ import java.io.Serializable;
import javax.inject.Inject;
import org.apache.shiro.session.Session;
+import org.apache.shiro.session.UnknownSessionException;
+import org.apache.shiro.session.mgt.SimpleSession;
import org.apache.shiro.session.mgt.eis.CachingSessionDAO;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
@@ -46,10 +48,25 @@ public class JDBCSessionDao extends CachingSessionDAO {
@Override
protected void doUpdate(final Session session) {
- // Assume only the last access time attribute was updated (see https://github.com/killbill/killbill/issues/326)
- final DateTime lastAccessTime = new DateTime(session.getLastAccessTime(), DateTimeZone.UTC);
- final Long sessionId = Long.valueOf(session.getId().toString());
- jdbcSessionSqlDao.updateLastAccessTime(lastAccessTime, sessionId);
+ // The look-up should be cheap (most likely cached)
+ final Session previousSession = readSession(session.getId());
+
+ if (SessionUtils.sameSession(previousSession, session)) {
+ // Only the last access time attribute was updated.
+ // Avoid writing the state to disk for each request: we don't care so much about precision in the database,
+ // we just want to make sure the session doesn't timeout too early.
+ // Note also that in the case of a single node (or distributed cache), the timeout computation
+ // will be correct (because the cache value is correct).
+ // See https://github.com/killbill/killbill/issues/326
+ if (!SessionUtils.accessedRecently(previousSession, session)) {
+ final DateTime lastAccessTime = new DateTime(session.getLastAccessTime(), DateTimeZone.UTC);
+ final Long sessionId = Long.valueOf(session.getId().toString());
+ jdbcSessionSqlDao.updateLastAccessTime(lastAccessTime, sessionId);
+ }
+ } else {
+ // Various fields were changed, update the full row
+ jdbcSessionSqlDao.update(new SessionModelDao(session));
+ }
}
@Override
@@ -72,6 +89,28 @@ public class JDBCSessionDao extends CachingSessionDAO {
}
@Override
+ public Session readSession(final Serializable sessionId) throws UnknownSessionException {
+ final Session session = super.readSession(sessionId);
+
+ // Clone the session to avoid making changes to the existing one in the cache.
+ // This is required for the lookup in doUpdate to work
+ final SimpleSession clonedSession = new SimpleSession();
+ clonedSession.setId(session.getId());
+ clonedSession.setStartTimestamp(session.getStartTimestamp());
+ clonedSession.setLastAccessTime(session.getLastAccessTime());
+ clonedSession.setTimeout(session.getTimeout());
+ clonedSession.setHost(session.getHost());
+ clonedSession.setAttributes(SessionUtils.getSessionAttributes(session));
+
+ if (session instanceof SimpleSession) {
+ clonedSession.setStopTimestamp(((SimpleSession) session).getStopTimestamp());
+ clonedSession.setExpired(((SimpleSession) session).isExpired());
+ }
+
+ return clonedSession;
+ }
+
+ @Override
protected Session doReadSession(final Serializable sessionId) {
// Shiro should not pass us a null sessionId, but be safe...
if (sessionId == null) {
diff --git a/util/src/main/java/org/killbill/billing/util/security/shiro/dao/SessionUtils.java b/util/src/main/java/org/killbill/billing/util/security/shiro/dao/SessionUtils.java
new file mode 100644
index 0000000..9885e6b
--- /dev/null
+++ b/util/src/main/java/org/killbill/billing/util/security/shiro/dao/SessionUtils.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2015 Groupon, Inc
+ * Copyright 2015 The Billing Project, LLC
+ *
+ * The Billing Project licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License. You may obtain a copy of the License at:
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.killbill.billing.util.security.shiro.dao;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import javax.annotation.Nullable;
+
+import org.apache.shiro.session.Session;
+import org.apache.shiro.session.mgt.AbstractSessionManager;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.MoreObjects;
+
+public abstract class SessionUtils {
+
+ public SessionUtils() {}
+
+ // Check if the session was recently accessed ("recently" means within 5% of the timeout length)
+ public static boolean accessedRecently(final Session previousSession, final Session session) {
+ final Long timeoutMillis = MoreObjects.firstNonNull(session.getTimeout(), AbstractSessionManager.DEFAULT_GLOBAL_SESSION_TIMEOUT);
+ final Long errorMarginInMillis = 5L * timeoutMillis / 100;
+ return accessedRecently(previousSession, session, errorMarginInMillis);
+ }
+
+ @VisibleForTesting
+ static boolean accessedRecently(final Session previousSession, final Session session, final Long errorMarginInMillis) {
+ return previousSession.getLastAccessTime() != null &&
+ session.getLastAccessTime() != null &&
+ session.getLastAccessTime().getTime() < previousSession.getLastAccessTime().getTime() + errorMarginInMillis;
+ }
+
+ public static boolean sameSession(final Session previousSession, final Session newSession) {
+ return (previousSession.getStartTimestamp() != null ? previousSession.getStartTimestamp().compareTo(newSession.getStartTimestamp()) == 0 : newSession.getStartTimestamp() == null) &&
+ (previousSession.getTimeout() == newSession.getTimeout()) &&
+ (previousSession.getHost() != null ? previousSession.getHost().equals(newSession.getHost()) : newSession.getHost() == null) &&
+ sameSessionAttributes(previousSession, newSession);
+ }
+
+ @VisibleForTesting
+ static boolean sameSessionAttributes(@Nullable final Session previousSession, @Nullable final Session newSession) {
+ final Map<Object, Object> previousSessionAttributes = getSessionAttributes(previousSession);
+ final Map<Object, Object> newSessionAttributes = getSessionAttributes(newSession);
+ return previousSessionAttributes != null ? previousSessionAttributes.equals(newSessionAttributes) : newSessionAttributes == null;
+ }
+
+ public static Map<Object, Object> getSessionAttributes(@Nullable final Session session) {
+ if (session == null || session.getAttributeKeys() == null) {
+ return null;
+ }
+
+ final Map<Object, Object> attributes = new LinkedHashMap<Object, Object>();
+ for (final Object attributeKey : session.getAttributeKeys()) {
+ attributes.put(attributeKey, session.getAttribute(attributeKey));
+ }
+ return attributes;
+ }
+}
diff --git a/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestJDBCSessionDao.java b/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestJDBCSessionDao.java
index df97ca8..b943b2d 100644
--- a/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestJDBCSessionDao.java
+++ b/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestJDBCSessionDao.java
@@ -56,8 +56,15 @@ public class TestJDBCSessionDao extends UtilTestSuiteWithEmbeddedDB {
final Session retrievedSession = jdbcSessionDao.doReadSession(sessionId);
Assert.assertEquals(retrievedSession, session);
- // Update
- final Date lastAccessTime = DateTime.now().withTimeAtStartOfDay().toDate(); // Milliseconds will be truncated
+ // Update too soon, the database state won't be updated
+ Date lastAccessTime = new Date(retrievedSession.getLastAccessTime().getTime() + 1000);
+ Assert.assertNotEquals(retrievedSession.getLastAccessTime(), lastAccessTime);
+ session.setLastAccessTime(lastAccessTime);
+ jdbcSessionDao.doUpdate(session);
+ Assert.assertEquals(jdbcSessionDao.doReadSession(sessionId).getLastAccessTime().compareTo(retrievedSession.getLastAccessTime()), 0);
+
+ // Actual database update
+ lastAccessTime = new Date(retrievedSession.getLastAccessTime().getTime() + 100000);
Assert.assertNotEquals(retrievedSession.getLastAccessTime(), lastAccessTime);
session.setLastAccessTime(lastAccessTime);
jdbcSessionDao.doUpdate(session);
@@ -70,8 +77,9 @@ public class TestJDBCSessionDao extends UtilTestSuiteWithEmbeddedDB {
private SimpleSession createSession() {
final SimpleSession simpleSession = new SimpleSession();
- simpleSession.setStartTimestamp(new Date(System.currentTimeMillis() - 5000));
- simpleSession.setLastAccessTime(new Date(System.currentTimeMillis()));
+ // Truncate milliseconds for MySQL
+ simpleSession.setStartTimestamp(DateTime.now().withTimeAtStartOfDay().minusSeconds(5).toDate());
+ simpleSession.setLastAccessTime(DateTime.now().withTimeAtStartOfDay().toDate());
simpleSession.setTimeout(493934L);
simpleSession.setHost(UUID.randomUUID().toString());
simpleSession.setAttribute(UUID.randomUUID().toString(), Short.MIN_VALUE);
diff --git a/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestSessionUtils.java b/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestSessionUtils.java
new file mode 100644
index 0000000..387dfe6
--- /dev/null
+++ b/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestSessionUtils.java
@@ -0,0 +1,176 @@
+/*
+ * Copyright 2015 Groupon, Inc
+ * Copyright 2015 The Billing Project, LLC
+ *
+ * The Billing Project licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License. You may obtain a copy of the License at:
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.killbill.billing.util.security.shiro.dao;
+
+import java.util.Date;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.UUID;
+
+import org.apache.shiro.session.mgt.SimpleSession;
+import org.killbill.billing.util.UtilTestSuiteNoDB;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import com.google.common.collect.ImmutableMap;
+
+public class TestSessionUtils extends UtilTestSuiteNoDB {
+
+ private static final long MINUTES_IN_MILLIS = 60 * 1000L;
+
+ @Test(groups = "fast")
+ public void testAccessedRecently() throws Exception {
+ final Long t2 = System.currentTimeMillis();
+ final Long t1 = t2 - (3 * MINUTES_IN_MILLIS);
+
+ final SimpleSession session1 = new SimpleSession();
+ final SimpleSession session2 = new SimpleSession();
+ session1.setLastAccessTime(null);
+ session2.setLastAccessTime(null);
+
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2));
+
+ session1.setLastAccessTime(new Date(t1));
+ session2.setLastAccessTime(new Date(t2));
+
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2));
+
+ // For a timeout of 1 hour, 5% is 3 minutes
+ session2.setTimeout(59 * MINUTES_IN_MILLIS);
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2));
+
+ session2.setTimeout(60 * MINUTES_IN_MILLIS);
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2));
+
+ session2.setTimeout(61 * MINUTES_IN_MILLIS);
+ Assert.assertTrue(SessionUtils.accessedRecently(session1, session2));
+ }
+
+ @Test(groups = "fast")
+ public void testAccessedRecentlyWithError() throws Exception {
+ final Long t2 = System.currentTimeMillis();
+ final Long t1 = t2 - (3 * MINUTES_IN_MILLIS);
+
+ final SimpleSession session1 = new SimpleSession();
+ final SimpleSession session2 = new SimpleSession();
+ session1.setLastAccessTime(null);
+ session2.setLastAccessTime(null);
+
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 0L));
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS - 1));
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS));
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS + 1));
+
+ session1.setLastAccessTime(new Date(t1));
+
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 0L));
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS - 1));
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS));
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS + 1));
+
+ session2.setLastAccessTime(new Date(t2));
+
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 0L));
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS - 1));
+ Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS));
+ Assert.assertTrue(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS + 1));
+ }
+
+ @Test(groups = "fast")
+ public void testSameSession() throws Exception {
+ final SimpleSession session1 = new SimpleSession();
+ final SimpleSession session2 = new SimpleSession();
+
+ Assert.assertTrue(SessionUtils.sameSession(session1, session2));
+ Assert.assertTrue(SessionUtils.sameSession(session2, session1));
+
+ session1.setStartTimestamp(new Date(2 * System.currentTimeMillis()));
+ Assert.assertFalse(SessionUtils.sameSession(session1, session2));
+ Assert.assertFalse(SessionUtils.sameSession(session2, session1));
+
+ session2.setStartTimestamp(session1.getStartTimestamp());
+ Assert.assertTrue(SessionUtils.sameSession(session1, session2));
+ Assert.assertTrue(SessionUtils.sameSession(session2, session1));
+
+ session1.setTimeout(12345L);
+ Assert.assertFalse(SessionUtils.sameSession(session1, session2));
+ Assert.assertFalse(SessionUtils.sameSession(session2, session1));
+
+ session2.setTimeout(session1.getTimeout());
+ Assert.assertTrue(SessionUtils.sameSession(session1, session2));
+ Assert.assertTrue(SessionUtils.sameSession(session2, session1));
+
+ session1.setHost(UUID.randomUUID().toString());
+ Assert.assertFalse(SessionUtils.sameSession(session1, session2));
+ Assert.assertFalse(SessionUtils.sameSession(session2, session1));
+
+ session2.setHost(session1.getHost());
+ Assert.assertTrue(SessionUtils.sameSession(session1, session2));
+ Assert.assertTrue(SessionUtils.sameSession(session2, session1));
+
+ session1.setAttributes(buildAttributes(UUID.randomUUID()));
+ Assert.assertFalse(SessionUtils.sameSession(session1, session2));
+ Assert.assertFalse(SessionUtils.sameSession(session2, session1));
+
+ session2.setAttributes(session1.getAttributes());
+ Assert.assertTrue(SessionUtils.sameSession(session1, session2));
+ Assert.assertTrue(SessionUtils.sameSession(session2, session1));
+ }
+
+ @Test(groups = "fast")
+ public void testSameSessionAttributes() throws Exception {
+ final UUID oneKey = UUID.randomUUID();
+ final SimpleSession session1 = new SimpleSession();
+ final SimpleSession session2 = new SimpleSession();
+ final SimpleSession session3 = new SimpleSession();
+ final Map<Object, Object> attributes = buildAttributes(oneKey);
+ session1.setAttributes(attributes);
+ session2.setAttributes(new LinkedHashMap<Object, Object>(attributes));
+
+ Assert.assertFalse(SessionUtils.sameSessionAttributes(session1, null));
+ Assert.assertFalse(SessionUtils.sameSessionAttributes(null, session1));
+ Assert.assertFalse(SessionUtils.sameSessionAttributes(session1, session3));
+
+ Assert.assertTrue(SessionUtils.sameSessionAttributes(null, null));
+ Assert.assertTrue(SessionUtils.sameSessionAttributes(session1, session1));
+ Assert.assertTrue(SessionUtils.sameSessionAttributes(session1, session2));
+ Assert.assertTrue(SessionUtils.sameSessionAttributes(session2, session1));
+
+ session2.removeAttribute(oneKey);
+
+ Assert.assertFalse(SessionUtils.sameSessionAttributes(session1, session2));
+ Assert.assertFalse(SessionUtils.sameSessionAttributes(session2, session1));
+ }
+
+ @Test(groups = "fast")
+ public void testGetSessionAttributes() throws Exception {
+ final SimpleSession session = new SimpleSession();
+ final Map<Object, Object> attributes = buildAttributes(UUID.randomUUID());
+ session.setAttributes(attributes);
+
+ Assert.assertEquals(SessionUtils.getSessionAttributes(session), attributes);
+ }
+
+ private Map<Object, Object> buildAttributes(final UUID oneKey) {
+ return ImmutableMap.<Object, Object>of(oneKey, 1L,
+ UUID.randomUUID(), "2",
+ UUID.randomUUID(), (short) 3,
+ UUID.randomUUID(), 4,
+ UUID.randomUUID(), UUID.randomUUID());
+ }
+}