keycloak-memoizeit

Merge pull request #2999 from stianst/KEYCLOAK-3189 KEYCLOAK-3189

7/5/2016 4:33:21 AM

Details

diff --git a/core/src/main/java/org/keycloak/jose/jwk/JWKParser.java b/core/src/main/java/org/keycloak/jose/jwk/JWKParser.java
index a503a3c..1bad9cf 100755
--- a/core/src/main/java/org/keycloak/jose/jwk/JWKParser.java
+++ b/core/src/main/java/org/keycloak/jose/jwk/JWKParser.java
@@ -21,6 +21,7 @@ import com.fasterxml.jackson.core.type.TypeReference;
 import org.keycloak.common.util.Base64Url;
 import org.keycloak.util.JsonSerialization;
 
+import java.io.InputStream;
 import java.math.BigInteger;
 import java.security.KeyFactory;
 import java.security.PublicKey;
diff --git a/core/src/main/java/org/keycloak/jose/jws/JWSBuilder.java b/core/src/main/java/org/keycloak/jose/jws/JWSBuilder.java
index e344389..e4a9805 100755
--- a/core/src/main/java/org/keycloak/jose/jws/JWSBuilder.java
+++ b/core/src/main/java/org/keycloak/jose/jws/JWSBuilder.java
@@ -33,6 +33,7 @@ import java.security.PrivateKey;
  */
 public class JWSBuilder {
     String type;
+    String kid;
     String contentType;
     byte[] contentBytes;
 
@@ -41,6 +42,11 @@ public class JWSBuilder {
         return this;
     }
 
+    public JWSBuilder kid(String kid) {
+        this.kid = kid;
+        return this;
+    }
+
     public JWSBuilder contentType(String type) {
         this.contentType = type;
         return this;
@@ -66,6 +72,7 @@ public class JWSBuilder {
         builder.append("\"alg\":\"").append(alg.toString()).append("\"");
 
         if (type != null) builder.append(",\"typ\" : \"").append(type).append("\"");
+        if (kid != null) builder.append(",\"kid\" : \"").append(kid).append("\"");
         if (contentType != null) builder.append(",\"cty\":\"").append(contentType).append("\"");
         builder.append("}");
         try {
diff --git a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/entities/CachedRealm.java b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/entities/CachedRealm.java
index b93da32..6a4ff4f 100755
--- a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/entities/CachedRealm.java
+++ b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/entities/CachedRealm.java
@@ -93,6 +93,7 @@ public class CachedRealm extends AbstractRevisioned {
     protected PasswordPolicy passwordPolicy;
     protected OTPPolicy otpPolicy;
 
+    protected transient String keyId;
     protected transient PublicKey publicKey;
     protected String publicKeyPem;
     protected transient PrivateKey privateKey;
@@ -189,6 +190,7 @@ public class CachedRealm extends AbstractRevisioned {
         passwordPolicy = model.getPasswordPolicy();
         otpPolicy = model.getOTPPolicy();
 
+        keyId = model.getKeyId();
         publicKeyPem = model.getPublicKeyPem();
         publicKey = model.getPublicKey();
         privateKeyPem = model.getPrivateKeyPem();
@@ -397,6 +399,10 @@ public class CachedRealm extends AbstractRevisioned {
         return accessCodeLifespanLogin;
     }
 
+    public String getKeyId() {
+        return keyId;
+    }
+
     public String getPublicKeyPem() {
         return publicKeyPem;
     }
diff --git a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/RealmAdapter.java b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/RealmAdapter.java
index be67c49..4ea71c7 100755
--- a/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/RealmAdapter.java
+++ b/model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/RealmAdapter.java
@@ -403,6 +403,12 @@ public class RealmAdapter implements RealmModel {
     }
 
     @Override
+    public String getKeyId() {
+        if (isUpdated()) return updated.getKeyId();
+        return cached.getKeyId();
+    }
+
+    @Override
     public String getPublicKeyPem() {
         if (isUpdated()) return updated.getPublicKeyPem();
         return cached.getPublicKeyPem();
diff --git a/model/jpa/src/main/java/org/keycloak/models/jpa/RealmAdapter.java b/model/jpa/src/main/java/org/keycloak/models/jpa/RealmAdapter.java
index 4167ebc..8de1395 100755
--- a/model/jpa/src/main/java/org/keycloak/models/jpa/RealmAdapter.java
+++ b/model/jpa/src/main/java/org/keycloak/models/jpa/RealmAdapter.java
@@ -20,6 +20,7 @@ package org.keycloak.models.jpa;
 import org.jboss.logging.Logger;
 import org.keycloak.connections.jpa.util.JpaUtils;
 import org.keycloak.common.enums.SslRequired;
+import org.keycloak.jose.jwk.JWKBuilder;
 import org.keycloak.models.AuthenticationExecutionModel;
 import org.keycloak.models.AuthenticationFlowModel;
 import org.keycloak.models.AuthenticatorConfigModel;
@@ -460,6 +461,12 @@ public class RealmAdapter implements RealmModel, JpaModel<RealmEntity> {
     }
 
     @Override
+    public String getKeyId() {
+        PublicKey publicKey = getPublicKey();
+        return publicKey != null ? JWKBuilder.create().rs256(publicKey).getKeyId() : null;
+    }
+
+    @Override
     public String getPublicKeyPem() {
         return realm.getPublicKeyPem();
     }
diff --git a/model/mongo/src/main/java/org/keycloak/models/mongo/keycloak/adapters/RealmAdapter.java b/model/mongo/src/main/java/org/keycloak/models/mongo/keycloak/adapters/RealmAdapter.java
index 6fff8d5..2dcda48 100755
--- a/model/mongo/src/main/java/org/keycloak/models/mongo/keycloak/adapters/RealmAdapter.java
+++ b/model/mongo/src/main/java/org/keycloak/models/mongo/keycloak/adapters/RealmAdapter.java
@@ -22,6 +22,7 @@ import com.mongodb.QueryBuilder;
 
 import org.keycloak.connections.mongo.api.context.MongoStoreInvocationContext;
 import org.keycloak.common.enums.SslRequired;
+import org.keycloak.jose.jwk.JWKBuilder;
 import org.keycloak.models.AuthenticationExecutionModel;
 import org.keycloak.models.AuthenticationFlowModel;
 import org.keycloak.models.AuthenticatorConfigModel;
@@ -454,6 +455,12 @@ public class RealmAdapter extends AbstractMongoAdapter<MongoRealmEntity> impleme
     }
 
     @Override
+    public String getKeyId() {
+        PublicKey publicKey = getPublicKey();
+        return publicKey != null ? JWKBuilder.create().rs256(publicKey).getKeyId() : null;
+    }
+
+    @Override
     public String getPublicKeyPem() {
         return realm.getPublicKeyPem();
     }
diff --git a/server-spi/src/main/java/org/keycloak/models/RealmModel.java b/server-spi/src/main/java/org/keycloak/models/RealmModel.java
index 9fe36ac..65ab90e 100755
--- a/server-spi/src/main/java/org/keycloak/models/RealmModel.java
+++ b/server-spi/src/main/java/org/keycloak/models/RealmModel.java
@@ -151,6 +151,8 @@ public interface RealmModel extends RoleContainerModel {
 
     void setAccessCodeLifespanLogin(int seconds);
 
+    String getKeyId();
+
     String getPublicKeyPem();
 
     void setPublicKeyPem(String publicKeyPem);
diff --git a/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java b/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java
index 75b3f46..fad4aeb 100644
--- a/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java
+++ b/services/src/main/java/org/keycloak/protocol/oidc/OIDCLoginProtocolService.java
@@ -99,6 +99,11 @@ public class OIDCLoginProtocolService {
         return uriBuilder.path(OIDCLoginProtocolService.class, "token");
     }
 
+    public static UriBuilder certsUrl(UriBuilder baseUriBuilder) {
+        UriBuilder uriBuilder = tokenServiceBaseUrl(baseUriBuilder);
+        return uriBuilder.path(OIDCLoginProtocolService.class, "certs");
+    }
+
     public static UriBuilder tokenIntrospectionUrl(UriBuilder baseUriBuilder) {
         return tokenUrl(baseUriBuilder).path(TokenEndpoint.class, "introspect");
     }
diff --git a/services/src/main/java/org/keycloak/protocol/oidc/TokenManager.java b/services/src/main/java/org/keycloak/protocol/oidc/TokenManager.java
index bbd41c0..2f05718 100755
--- a/services/src/main/java/org/keycloak/protocol/oidc/TokenManager.java
+++ b/services/src/main/java/org/keycloak/protocol/oidc/TokenManager.java
@@ -78,6 +78,7 @@ import java.util.Set;
  */
 public class TokenManager {
     protected static final ServicesLogger logger = ServicesLogger.ROOT_LOGGER;
+    private static final String JWT = "JWT";
 
     public static void applyScope(RoleModel role, RoleModel scope, Set<RoleModel> visited, Set<RoleModel> requested) {
         if (visited.contains(scope)) return;
@@ -570,6 +571,8 @@ public class TokenManager {
 
     public String encodeToken(RealmModel realm, Object token) {
         String encodedToken = new JWSBuilder()
+                .type(JWT)
+                .kid(realm.getKeyId())
                 .jsonContent(token)
                 .rsa256(realm.getPrivateKey());
         return encodedToken;
@@ -680,11 +683,11 @@ public class TokenManager {
 
             AccessTokenResponse res = new AccessTokenResponse();
             if (idToken != null) {
-                String encodedToken = new JWSBuilder().jsonContent(idToken).rsa256(realm.getPrivateKey());
+                String encodedToken = new JWSBuilder().type(JWT).kid(realm.getKeyId()).jsonContent(idToken).rsa256(realm.getPrivateKey());
                 res.setIdToken(encodedToken);
             }
             if (accessToken != null) {
-                String encodedToken = new JWSBuilder().jsonContent(accessToken).rsa256(realm.getPrivateKey());
+                String encodedToken = new JWSBuilder().type(JWT).kid(realm.getKeyId()).jsonContent(accessToken).rsa256(realm.getPrivateKey());
                 res.setToken(encodedToken);
                 res.setTokenType("bearer");
                 res.setSessionState(accessToken.getSessionState());
@@ -693,7 +696,7 @@ public class TokenManager {
                 }
             }
             if (refreshToken != null) {
-                String encodedToken = new JWSBuilder().jsonContent(refreshToken).rsa256(realm.getPrivateKey());
+                String encodedToken = new JWSBuilder().type(JWT).kid(realm.getKeyId()).jsonContent(refreshToken).rsa256(realm.getPrivateKey());
                 res.setRefreshToken(encodedToken);
                 if (refreshToken.getExpiration() != 0) {
                     res.setRefreshExpiresIn(refreshToken.getExpiration() - Time.currentTime());
diff --git a/testsuite/integration-arquillian/tests/base/src/main/java/org/keycloak/testsuite/util/OAuthClient.java b/testsuite/integration-arquillian/tests/base/src/main/java/org/keycloak/testsuite/util/OAuthClient.java
index c0088a1..b4b9f40 100644
--- a/testsuite/integration-arquillian/tests/base/src/main/java/org/keycloak/testsuite/util/OAuthClient.java
+++ b/testsuite/integration-arquillian/tests/base/src/main/java/org/keycloak/testsuite/util/OAuthClient.java
@@ -22,6 +22,8 @@ import org.apache.commons.io.output.ByteArrayOutputStream;
 import org.apache.http.HttpResponse;
 import org.apache.http.NameValuePair;
 import org.apache.http.client.entity.UrlEncodedFormEntity;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpGet;
 import org.apache.http.client.methods.HttpPost;
 import org.apache.http.client.utils.URLEncodedUtils;
 import org.apache.http.impl.client.CloseableHttpClient;
@@ -34,9 +36,13 @@ import org.keycloak.admin.client.Keycloak;
 import org.keycloak.common.VerificationException;
 import org.keycloak.common.util.PemUtils;
 import org.keycloak.constants.AdapterConstants;
+import org.keycloak.jose.jwk.JWK;
+import org.keycloak.jose.jwk.JWKBuilder;
+import org.keycloak.jose.jwk.JWKParser;
 import org.keycloak.jose.jws.JWSInput;
 import org.keycloak.jose.jws.crypto.RSAProvider;
 import org.keycloak.protocol.oidc.OIDCLoginProtocolService;
+import org.keycloak.protocol.oidc.representations.JSONWebKeySet;
 import org.keycloak.representations.AccessToken;
 import org.keycloak.representations.RefreshToken;
 import org.keycloak.testsuite.arquillian.AuthServerTestEnricher;
@@ -279,6 +285,17 @@ public class OAuthClient {
         }
     }
 
+    public JSONWebKeySet doCertsRequest(String realm) throws Exception {
+        CloseableHttpClient client = new DefaultHttpClient();
+        try {
+            HttpGet get = new HttpGet(getCertsUrl(realm));
+            CloseableHttpResponse response = client.execute(get);
+            return JsonSerialization.readValue(response.getEntity().getContent(), JSONWebKeySet.class);
+        } finally {
+            closeClient(client);
+        }
+    }
+
     public AccessTokenResponse doClientCredentialsGrantAccessTokenRequest(String clientSecret) throws Exception {
         CloseableHttpClient client = new DefaultHttpClient();
         try {
@@ -503,6 +520,11 @@ public class OAuthClient {
         return b.build(realm).toString();
     }
 
+    public String getCertsUrl(String realm) {
+        UriBuilder b = OIDCLoginProtocolService.certsUrl(UriBuilder.fromUri(baseUrl));
+        return b.build(realm).toString();
+    }
+
     public String getServiceAccountUrl() {
         return getResourceOwnerPasswordCredentialGrantUrl();
     }
@@ -591,6 +613,7 @@ public class OAuthClient {
     public static class AccessTokenResponse {
         private int statusCode;
 
+        private String idToken;
         private String accessToken;
         private String tokenType;
         private int expiresIn;
@@ -610,6 +633,7 @@ public class OAuthClient {
             Map responseJson = JsonSerialization.readValue(s, Map.class);
 
             if (statusCode == 200) {
+                idToken = (String)responseJson.get("id_token");
                 accessToken = (String)responseJson.get("access_token");
                 tokenType = (String)responseJson.get("token_type");
                 expiresIn = (Integer)responseJson.get("expires_in");
@@ -624,6 +648,10 @@ public class OAuthClient {
             }
         }
 
+        public String getIdToken() {
+            return idToken;
+        }
+
         public String getAccessToken() {
             return accessToken;
         }
diff --git a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/oauth/AccessTokenTest.java b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/oauth/AccessTokenTest.java
index 48b0275..effeead 100755
--- a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/oauth/AccessTokenTest.java
+++ b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/oauth/AccessTokenTest.java
@@ -32,8 +32,11 @@ import org.keycloak.admin.client.resource.ClientTemplateResource;
 import org.keycloak.admin.client.resource.RealmResource;
 import org.keycloak.admin.client.resource.UserResource;
 import org.keycloak.common.enums.SslRequired;
+import org.keycloak.common.util.PemUtils;
 import org.keycloak.events.Details;
 import org.keycloak.events.Errors;
+import org.keycloak.jose.jwk.JWKBuilder;
+import org.keycloak.jose.jws.JWSHeader;
 import org.keycloak.jose.jws.JWSInput;
 import org.keycloak.jose.jws.JWSInputException;
 import org.keycloak.models.ProtocolMapperModel;
@@ -155,6 +158,26 @@ public class AccessTokenTest extends AbstractKeycloakTest {
 
         assertEquals("bearer", response.getTokenType());
 
+        String expectedKid = oauth.doCertsRequest("test").getKeys()[0].getKeyId();
+
+        JWSHeader header = new JWSInput(response.getAccessToken()).getHeader();
+        assertEquals("RS256", header.getAlgorithm().name());
+        assertEquals("JWT", header.getType());
+        assertEquals(expectedKid, header.getKeyId());
+        assertNull(header.getContentType());
+
+        header = new JWSInput(response.getIdToken()).getHeader();
+        assertEquals("RS256", header.getAlgorithm().name());
+        assertEquals("JWT", header.getType());
+        assertEquals(expectedKid, header.getKeyId());
+        assertNull(header.getContentType());
+
+        header = new JWSInput(response.getRefreshToken()).getHeader();
+        assertEquals("RS256", header.getAlgorithm().name());
+        assertEquals("JWT", header.getType());
+        assertEquals(expectedKid, header.getKeyId());
+        assertNull(header.getContentType());
+
         AccessToken token = oauth.verifyToken(response.getAccessToken());
 
         assertEquals(findUserByUsername(adminClient.realm("test"), "test-user@localhost").getId(), token.getSubject());