killbill-aplcache

util: revisit optimization of session update query It is

5/30/2015 7:14:37 AM

Details

diff --git a/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java b/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java
index 0b0188b..01a8729 100644
--- a/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java
+++ b/util/src/main/java/org/killbill/billing/util/security/shiro/dao/JDBCSessionDao.java
@@ -24,6 +24,8 @@ import java.io.Serializable;
 import javax.inject.Inject;
 
 import org.apache.shiro.session.Session;
+import org.apache.shiro.session.UnknownSessionException;
+import org.apache.shiro.session.mgt.SimpleSession;
 import org.apache.shiro.session.mgt.eis.CachingSessionDAO;
 import org.joda.time.DateTime;
 import org.joda.time.DateTimeZone;
@@ -46,10 +48,25 @@ public class JDBCSessionDao extends CachingSessionDAO {
 
     @Override
     protected void doUpdate(final Session session) {
-        // Assume only the last access time attribute was updated (see https://github.com/killbill/killbill/issues/326)
-        final DateTime lastAccessTime = new DateTime(session.getLastAccessTime(), DateTimeZone.UTC);
-        final Long sessionId = Long.valueOf(session.getId().toString());
-        jdbcSessionSqlDao.updateLastAccessTime(lastAccessTime, sessionId);
+        // The look-up should be cheap (most likely cached)
+        final Session previousSession = readSession(session.getId());
+
+        if (SessionUtils.sameSession(previousSession, session)) {
+            // Only the last access time attribute was updated.
+            // Avoid writing the state to disk for each request: we don't care so much about precision in the database,
+            // we just want to make sure the session doesn't timeout too early.
+            // Note also that in the case of a single node (or distributed cache), the timeout computation
+            // will be correct (because the cache value is correct).
+            // See https://github.com/killbill/killbill/issues/326
+            if (!SessionUtils.accessedRecently(previousSession, session)) {
+                final DateTime lastAccessTime = new DateTime(session.getLastAccessTime(), DateTimeZone.UTC);
+                final Long sessionId = Long.valueOf(session.getId().toString());
+                jdbcSessionSqlDao.updateLastAccessTime(lastAccessTime, sessionId);
+            }
+        } else {
+            // Various fields were changed, update the full row
+            jdbcSessionSqlDao.update(new SessionModelDao(session));
+        }
     }
 
     @Override
@@ -72,6 +89,28 @@ public class JDBCSessionDao extends CachingSessionDAO {
     }
 
     @Override
+    public Session readSession(final Serializable sessionId) throws UnknownSessionException {
+        final Session session = super.readSession(sessionId);
+
+        // Clone the session to avoid making changes to the existing one in the cache.
+        // This is required for the lookup in doUpdate to work
+        final SimpleSession clonedSession = new SimpleSession();
+        clonedSession.setId(session.getId());
+        clonedSession.setStartTimestamp(session.getStartTimestamp());
+        clonedSession.setLastAccessTime(session.getLastAccessTime());
+        clonedSession.setTimeout(session.getTimeout());
+        clonedSession.setHost(session.getHost());
+        clonedSession.setAttributes(SessionUtils.getSessionAttributes(session));
+
+        if (session instanceof SimpleSession) {
+            clonedSession.setStopTimestamp(((SimpleSession) session).getStopTimestamp());
+            clonedSession.setExpired(((SimpleSession) session).isExpired());
+        }
+
+        return clonedSession;
+    }
+
+    @Override
     protected Session doReadSession(final Serializable sessionId) {
         // Shiro should not pass us a null sessionId, but be safe...
         if (sessionId == null) {
diff --git a/util/src/main/java/org/killbill/billing/util/security/shiro/dao/SessionUtils.java b/util/src/main/java/org/killbill/billing/util/security/shiro/dao/SessionUtils.java
new file mode 100644
index 0000000..9885e6b
--- /dev/null
+++ b/util/src/main/java/org/killbill/billing/util/security/shiro/dao/SessionUtils.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2015 Groupon, Inc
+ * Copyright 2015 The Billing Project, LLC
+ *
+ * The Billing Project licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License.  You may obtain a copy of the License at:
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.killbill.billing.util.security.shiro.dao;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import javax.annotation.Nullable;
+
+import org.apache.shiro.session.Session;
+import org.apache.shiro.session.mgt.AbstractSessionManager;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.MoreObjects;
+
+public abstract class SessionUtils {
+
+    public SessionUtils() {}
+
+    // Check if the session was recently accessed ("recently" means within 5% of the timeout length)
+    public static boolean accessedRecently(final Session previousSession, final Session session) {
+        final Long timeoutMillis = MoreObjects.firstNonNull(session.getTimeout(), AbstractSessionManager.DEFAULT_GLOBAL_SESSION_TIMEOUT);
+        final Long errorMarginInMillis = 5L * timeoutMillis / 100;
+        return accessedRecently(previousSession, session, errorMarginInMillis);
+    }
+
+    @VisibleForTesting
+    static boolean accessedRecently(final Session previousSession, final Session session, final Long errorMarginInMillis) {
+        return previousSession.getLastAccessTime() != null &&
+               session.getLastAccessTime() != null &&
+               session.getLastAccessTime().getTime() < previousSession.getLastAccessTime().getTime() + errorMarginInMillis;
+    }
+
+    public static boolean sameSession(final Session previousSession, final Session newSession) {
+        return (previousSession.getStartTimestamp() != null ? previousSession.getStartTimestamp().compareTo(newSession.getStartTimestamp()) == 0 : newSession.getStartTimestamp() == null) &&
+               (previousSession.getTimeout() == newSession.getTimeout()) &&
+               (previousSession.getHost() != null ? previousSession.getHost().equals(newSession.getHost()) : newSession.getHost() == null) &&
+               sameSessionAttributes(previousSession, newSession);
+    }
+
+    @VisibleForTesting
+    static boolean sameSessionAttributes(@Nullable final Session previousSession, @Nullable final Session newSession) {
+        final Map<Object, Object> previousSessionAttributes = getSessionAttributes(previousSession);
+        final Map<Object, Object> newSessionAttributes = getSessionAttributes(newSession);
+        return previousSessionAttributes != null ? previousSessionAttributes.equals(newSessionAttributes) : newSessionAttributes == null;
+    }
+
+    public static Map<Object, Object> getSessionAttributes(@Nullable final Session session) {
+        if (session == null || session.getAttributeKeys() == null) {
+            return null;
+        }
+
+        final Map<Object, Object> attributes = new LinkedHashMap<Object, Object>();
+        for (final Object attributeKey : session.getAttributeKeys()) {
+            attributes.put(attributeKey, session.getAttribute(attributeKey));
+        }
+        return attributes;
+    }
+}
diff --git a/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestJDBCSessionDao.java b/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestJDBCSessionDao.java
index df97ca8..b943b2d 100644
--- a/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestJDBCSessionDao.java
+++ b/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestJDBCSessionDao.java
@@ -56,8 +56,15 @@ public class TestJDBCSessionDao extends UtilTestSuiteWithEmbeddedDB {
         final Session retrievedSession = jdbcSessionDao.doReadSession(sessionId);
         Assert.assertEquals(retrievedSession, session);
 
-        // Update
-        final Date lastAccessTime = DateTime.now().withTimeAtStartOfDay().toDate(); // Milliseconds will be truncated
+        // Update too soon, the database state won't be updated
+        Date lastAccessTime = new Date(retrievedSession.getLastAccessTime().getTime() + 1000);
+        Assert.assertNotEquals(retrievedSession.getLastAccessTime(), lastAccessTime);
+        session.setLastAccessTime(lastAccessTime);
+        jdbcSessionDao.doUpdate(session);
+        Assert.assertEquals(jdbcSessionDao.doReadSession(sessionId).getLastAccessTime().compareTo(retrievedSession.getLastAccessTime()), 0);
+
+        // Actual database update
+        lastAccessTime = new Date(retrievedSession.getLastAccessTime().getTime() + 100000);
         Assert.assertNotEquals(retrievedSession.getLastAccessTime(), lastAccessTime);
         session.setLastAccessTime(lastAccessTime);
         jdbcSessionDao.doUpdate(session);
@@ -70,8 +77,9 @@ public class TestJDBCSessionDao extends UtilTestSuiteWithEmbeddedDB {
 
     private SimpleSession createSession() {
         final SimpleSession simpleSession = new SimpleSession();
-        simpleSession.setStartTimestamp(new Date(System.currentTimeMillis() - 5000));
-        simpleSession.setLastAccessTime(new Date(System.currentTimeMillis()));
+        // Truncate milliseconds for MySQL
+        simpleSession.setStartTimestamp(DateTime.now().withTimeAtStartOfDay().minusSeconds(5).toDate());
+        simpleSession.setLastAccessTime(DateTime.now().withTimeAtStartOfDay().toDate());
         simpleSession.setTimeout(493934L);
         simpleSession.setHost(UUID.randomUUID().toString());
         simpleSession.setAttribute(UUID.randomUUID().toString(), Short.MIN_VALUE);
diff --git a/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestSessionUtils.java b/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestSessionUtils.java
new file mode 100644
index 0000000..387dfe6
--- /dev/null
+++ b/util/src/test/java/org/killbill/billing/util/security/shiro/dao/TestSessionUtils.java
@@ -0,0 +1,176 @@
+/*
+ * Copyright 2015 Groupon, Inc
+ * Copyright 2015 The Billing Project, LLC
+ *
+ * The Billing Project licenses this file to you under the Apache License, version 2.0
+ * (the "License"); you may not use this file except in compliance with the
+ * License.  You may obtain a copy of the License at:
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.killbill.billing.util.security.shiro.dao;
+
+import java.util.Date;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.UUID;
+
+import org.apache.shiro.session.mgt.SimpleSession;
+import org.killbill.billing.util.UtilTestSuiteNoDB;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+import com.google.common.collect.ImmutableMap;
+
+public class TestSessionUtils extends UtilTestSuiteNoDB {
+
+    private static final long MINUTES_IN_MILLIS = 60 * 1000L;
+
+    @Test(groups = "fast")
+    public void testAccessedRecently() throws Exception {
+        final Long t2 = System.currentTimeMillis();
+        final Long t1 = t2 - (3 * MINUTES_IN_MILLIS);
+
+        final SimpleSession session1 = new SimpleSession();
+        final SimpleSession session2 = new SimpleSession();
+        session1.setLastAccessTime(null);
+        session2.setLastAccessTime(null);
+
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2));
+
+        session1.setLastAccessTime(new Date(t1));
+        session2.setLastAccessTime(new Date(t2));
+
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2));
+
+        // For a timeout of 1 hour, 5% is 3 minutes
+        session2.setTimeout(59 * MINUTES_IN_MILLIS);
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2));
+
+        session2.setTimeout(60 * MINUTES_IN_MILLIS);
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2));
+
+        session2.setTimeout(61 * MINUTES_IN_MILLIS);
+        Assert.assertTrue(SessionUtils.accessedRecently(session1, session2));
+    }
+
+    @Test(groups = "fast")
+    public void testAccessedRecentlyWithError() throws Exception {
+        final Long t2 = System.currentTimeMillis();
+        final Long t1 = t2 - (3 * MINUTES_IN_MILLIS);
+
+        final SimpleSession session1 = new SimpleSession();
+        final SimpleSession session2 = new SimpleSession();
+        session1.setLastAccessTime(null);
+        session2.setLastAccessTime(null);
+
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 0L));
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS - 1));
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS));
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS + 1));
+
+        session1.setLastAccessTime(new Date(t1));
+
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 0L));
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS - 1));
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS));
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS + 1));
+
+        session2.setLastAccessTime(new Date(t2));
+
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 0L));
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS - 1));
+        Assert.assertFalse(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS));
+        Assert.assertTrue(SessionUtils.accessedRecently(session1, session2, 3 * MINUTES_IN_MILLIS + 1));
+    }
+
+    @Test(groups = "fast")
+    public void testSameSession() throws Exception {
+        final SimpleSession session1 = new SimpleSession();
+        final SimpleSession session2 = new SimpleSession();
+
+        Assert.assertTrue(SessionUtils.sameSession(session1, session2));
+        Assert.assertTrue(SessionUtils.sameSession(session2, session1));
+
+        session1.setStartTimestamp(new Date(2 * System.currentTimeMillis()));
+        Assert.assertFalse(SessionUtils.sameSession(session1, session2));
+        Assert.assertFalse(SessionUtils.sameSession(session2, session1));
+
+        session2.setStartTimestamp(session1.getStartTimestamp());
+        Assert.assertTrue(SessionUtils.sameSession(session1, session2));
+        Assert.assertTrue(SessionUtils.sameSession(session2, session1));
+
+        session1.setTimeout(12345L);
+        Assert.assertFalse(SessionUtils.sameSession(session1, session2));
+        Assert.assertFalse(SessionUtils.sameSession(session2, session1));
+
+        session2.setTimeout(session1.getTimeout());
+        Assert.assertTrue(SessionUtils.sameSession(session1, session2));
+        Assert.assertTrue(SessionUtils.sameSession(session2, session1));
+
+        session1.setHost(UUID.randomUUID().toString());
+        Assert.assertFalse(SessionUtils.sameSession(session1, session2));
+        Assert.assertFalse(SessionUtils.sameSession(session2, session1));
+
+        session2.setHost(session1.getHost());
+        Assert.assertTrue(SessionUtils.sameSession(session1, session2));
+        Assert.assertTrue(SessionUtils.sameSession(session2, session1));
+
+        session1.setAttributes(buildAttributes(UUID.randomUUID()));
+        Assert.assertFalse(SessionUtils.sameSession(session1, session2));
+        Assert.assertFalse(SessionUtils.sameSession(session2, session1));
+
+        session2.setAttributes(session1.getAttributes());
+        Assert.assertTrue(SessionUtils.sameSession(session1, session2));
+        Assert.assertTrue(SessionUtils.sameSession(session2, session1));
+    }
+
+    @Test(groups = "fast")
+    public void testSameSessionAttributes() throws Exception {
+        final UUID oneKey = UUID.randomUUID();
+        final SimpleSession session1 = new SimpleSession();
+        final SimpleSession session2 = new SimpleSession();
+        final SimpleSession session3 = new SimpleSession();
+        final Map<Object, Object> attributes = buildAttributes(oneKey);
+        session1.setAttributes(attributes);
+        session2.setAttributes(new LinkedHashMap<Object, Object>(attributes));
+
+        Assert.assertFalse(SessionUtils.sameSessionAttributes(session1, null));
+        Assert.assertFalse(SessionUtils.sameSessionAttributes(null, session1));
+        Assert.assertFalse(SessionUtils.sameSessionAttributes(session1, session3));
+
+        Assert.assertTrue(SessionUtils.sameSessionAttributes(null, null));
+        Assert.assertTrue(SessionUtils.sameSessionAttributes(session1, session1));
+        Assert.assertTrue(SessionUtils.sameSessionAttributes(session1, session2));
+        Assert.assertTrue(SessionUtils.sameSessionAttributes(session2, session1));
+
+        session2.removeAttribute(oneKey);
+
+        Assert.assertFalse(SessionUtils.sameSessionAttributes(session1, session2));
+        Assert.assertFalse(SessionUtils.sameSessionAttributes(session2, session1));
+    }
+
+    @Test(groups = "fast")
+    public void testGetSessionAttributes() throws Exception {
+        final SimpleSession session = new SimpleSession();
+        final Map<Object, Object> attributes = buildAttributes(UUID.randomUUID());
+        session.setAttributes(attributes);
+
+        Assert.assertEquals(SessionUtils.getSessionAttributes(session), attributes);
+    }
+
+    private Map<Object, Object> buildAttributes(final UUID oneKey) {
+        return ImmutableMap.<Object, Object>of(oneKey, 1L,
+                                               UUID.randomUUID(), "2",
+                                               UUID.randomUUID(), (short) 3,
+                                               UUID.randomUUID(), 4,
+                                               UUID.randomUUID(), UUID.randomUUID());
+    }
+}