keycloak-aplcache

KEYCLOAK-1881 Include key ID in <ds:KeyInfo> in SAML assertions

10/24/2016 6:43:48 PM

Details

diff --git a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/AbstractInitiateLogin.java b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/AbstractInitiateLogin.java
index 305ffeb..693e06e 100755
--- a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/AbstractInitiateLogin.java
+++ b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/AbstractInitiateLogin.java
@@ -79,7 +79,7 @@ public abstract class AbstractInitiateLogin implements AuthChallenge {
                 binding.canonicalizationMethod(deployment.getSignatureCanonicalizationMethod());
             }
 
-            binding.signWith(keypair);
+            binding.signWith(null, keypair);
             binding.signDocument();
         }
         return binding;
diff --git a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/webbrowsersso/WebBrowserSsoAuthenticationHandler.java b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/webbrowsersso/WebBrowserSsoAuthenticationHandler.java
index 5c1454f..3581357 100755
--- a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/webbrowsersso/WebBrowserSsoAuthenticationHandler.java
+++ b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/webbrowsersso/WebBrowserSsoAuthenticationHandler.java
@@ -82,7 +82,7 @@ public class WebBrowserSsoAuthenticationHandler extends AbstractSamlAuthenticati
             if (deployment.getSignatureCanonicalizationMethod() != null)
                 binding.canonicalizationMethod(deployment.getSignatureCanonicalizationMethod());
             binding.signatureAlgorithm(deployment.getSignatureAlgorithm())
-                    .signWith(deployment.getSigningKeyPair())
+                    .signWith(null, deployment.getSigningKeyPair())
                     .signDocument();
         }
 
@@ -113,7 +113,7 @@ public class WebBrowserSsoAuthenticationHandler extends AbstractSamlAuthenticati
             if (deployment.getSignatureCanonicalizationMethod() != null)
                 binding.canonicalizationMethod(deployment.getSignatureCanonicalizationMethod());
             binding.signatureAlgorithm(deployment.getSignatureAlgorithm());
-            binding.signWith(deployment.getSigningKeyPair())
+            binding.signWith(null, deployment.getSigningKeyPair())
                     .signDocument();
         }
 
diff --git a/saml-core/src/main/java/org/keycloak/saml/BaseSAML2BindingBuilder.java b/saml-core/src/main/java/org/keycloak/saml/BaseSAML2BindingBuilder.java
index 6d84c13..78a4480 100755
--- a/saml-core/src/main/java/org/keycloak/saml/BaseSAML2BindingBuilder.java
+++ b/saml-core/src/main/java/org/keycloak/saml/BaseSAML2BindingBuilder.java
@@ -55,6 +55,7 @@ import static org.keycloak.saml.common.util.StringUtil.isNotNull;
 public class BaseSAML2BindingBuilder<T extends BaseSAML2BindingBuilder> {
     protected static final Logger logger = Logger.getLogger(BaseSAML2BindingBuilder.class);
 
+    protected String signingKeyId;
     protected KeyPair signingKeyPair;
     protected X509Certificate signingCertificate;
     protected boolean sign;
@@ -82,23 +83,27 @@ public class BaseSAML2BindingBuilder<T extends BaseSAML2BindingBuilder> {
         return (T)this;
     }
 
-    public T signWith(KeyPair keyPair) {
+    public T signWith(String signingKeyId, KeyPair keyPair) {
+        this.signingKeyId = signingKeyId;
         this.signingKeyPair = keyPair;
         return (T)this;
     }
 
-    public T signWith(PrivateKey privateKey, PublicKey publicKey) {
+    public T signWith(String signingKeyId, PrivateKey privateKey, PublicKey publicKey) {
+        this.signingKeyId = signingKeyId;
         this.signingKeyPair = new KeyPair(publicKey, privateKey);
         return (T)this;
     }
 
-    public T signWith(KeyPair keyPair, X509Certificate cert) {
+    public T signWith(String signingKeyId, KeyPair keyPair, X509Certificate cert) {
+        this.signingKeyId = signingKeyId;
         this.signingKeyPair = keyPair;
         this.signingCertificate = cert;
         return (T)this;
     }
 
-    public T signWith(PrivateKey privateKey, PublicKey publicKey, X509Certificate cert) {
+    public T signWith(String signingKeyId, PrivateKey privateKey, PublicKey publicKey, X509Certificate cert) {
+        this.signingKeyId = signingKeyId;
         this.signingKeyPair = new KeyPair(publicKey, privateKey);
         this.signingCertificate = cert;
         return (T)this;
@@ -263,7 +268,7 @@ public class BaseSAML2BindingBuilder<T extends BaseSAML2BindingBuilder> {
             samlSignature.setX509Certificate(signingCertificate);
         }
 
-        samlSignature.signSAMLDocument(samlDocument, signingKeyPair, canonicalizationMethodType);
+        samlSignature.signSAMLDocument(samlDocument, signingKeyId, signingKeyPair, canonicalizationMethodType);
     }
 
     public void signAssertion(Document samlDocument) throws ProcessingException {
diff --git a/saml-core/src/main/java/org/keycloak/saml/processing/api/saml/v2/sig/SAML2Signature.java b/saml-core/src/main/java/org/keycloak/saml/processing/api/saml/v2/sig/SAML2Signature.java
index 5ac8ce1..57777ab 100755
--- a/saml-core/src/main/java/org/keycloak/saml/processing/api/saml/v2/sig/SAML2Signature.java
+++ b/saml-core/src/main/java/org/keycloak/saml/processing/api/saml/v2/sig/SAML2Signature.java
@@ -121,7 +121,7 @@ public class SAML2Signature {
      * @throws MarshalException
      * @throws GeneralSecurityException
      */
-    public Document sign(Document doc, String referenceID, KeyPair keyPair, String canonicalizationMethodType) throws ParserConfigurationException,
+    public Document sign(Document doc, String referenceID, String keyId, KeyPair keyPair, String canonicalizationMethodType) throws ParserConfigurationException,
             GeneralSecurityException, MarshalException, XMLSignatureException {
         String referenceURI = "#" + referenceID;
 
@@ -130,6 +130,7 @@ public class SAML2Signature {
         if (sibling != null) {
             SignatureUtilTransferObject dto = new SignatureUtilTransferObject();
             dto.setDocumentToBeSigned(doc);
+            dto.setKeyId(keyId);
             dto.setKeyPair(keyPair);
             dto.setDigestMethod(digestMethod);
             dto.setSignatureMethod(signatureMethod);
@@ -142,7 +143,7 @@ public class SAML2Signature {
 
             return XMLSignatureUtil.sign(dto, canonicalizationMethodType);
         }
-        return XMLSignatureUtil.sign(doc, keyPair, digestMethod, signatureMethod, referenceURI, canonicalizationMethodType);
+        return XMLSignatureUtil.sign(doc, keyId, keyPair, digestMethod, signatureMethod, referenceURI, canonicalizationMethodType);
     }
 
     /**
@@ -153,11 +154,11 @@ public class SAML2Signature {
      *
      * @throws org.keycloak.saml.common.exceptions.ProcessingException
      */
-    public void signSAMLDocument(Document samlDocument, KeyPair keypair, String canonicalizationMethodType) throws ProcessingException {
+    public void signSAMLDocument(Document samlDocument, String keyId, KeyPair keypair, String canonicalizationMethodType) throws ProcessingException {
         // Get the ID from the root
         String id = samlDocument.getDocumentElement().getAttribute(ID_ATTRIBUTE_NAME);
         try {
-            sign(samlDocument, id, keypair, canonicalizationMethodType);
+            sign(samlDocument, id, keyId, keypair, canonicalizationMethodType);
         } catch (Exception e) {
             throw new ProcessingException(logger.signatureError(e));
         }
diff --git a/saml-core/src/main/java/org/keycloak/saml/processing/core/util/SignatureUtilTransferObject.java b/saml-core/src/main/java/org/keycloak/saml/processing/core/util/SignatureUtilTransferObject.java
index 19924e9..f8181fe 100755
--- a/saml-core/src/main/java/org/keycloak/saml/processing/core/util/SignatureUtilTransferObject.java
+++ b/saml-core/src/main/java/org/keycloak/saml/processing/core/util/SignatureUtilTransferObject.java
@@ -32,6 +32,9 @@ public class SignatureUtilTransferObject {
     private X509Certificate x509Certificate;
 
     private Document documentToBeSigned;
+
+    private String keyId;
+
     private KeyPair keyPair;
 
     private Node nextSibling;
@@ -111,4 +114,12 @@ public class SignatureUtilTransferObject {
     public void setX509Certificate(X509Certificate x509Certificate) {
         this.x509Certificate = x509Certificate;
     }
+
+    public String getKeyId() {
+        return keyId;
+    }
+
+    public void setKeyId(String keyId) {
+        this.keyId = keyId;
+    }
 }
\ No newline at end of file
diff --git a/saml-core/src/main/java/org/keycloak/saml/processing/core/util/XMLSignatureUtil.java b/saml-core/src/main/java/org/keycloak/saml/processing/core/util/XMLSignatureUtil.java
index 98635b7..7000075 100755
--- a/saml-core/src/main/java/org/keycloak/saml/processing/core/util/XMLSignatureUtil.java
+++ b/saml-core/src/main/java/org/keycloak/saml/processing/core/util/XMLSignatureUtil.java
@@ -79,7 +79,9 @@ import java.security.interfaces.DSAPublicKey;
 import java.security.interfaces.RSAPublicKey;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.LinkedList;
 import java.util.List;
+import javax.xml.crypto.dsig.keyinfo.KeyName;
 
 /**
  * Utility for XML Signature <b>Note:</b> You can change the canonicalization method type by using the system property
@@ -157,7 +159,7 @@ public class XMLSignatureUtil {
      * @throws MarshalException
      * @throws GeneralSecurityException
      */
-    public static Document sign(Document doc, Node nodeToBeSigned, KeyPair keyPair, String digestMethod,
+    public static Document sign(Document doc, Node nodeToBeSigned, String keyId, KeyPair keyPair, String digestMethod,
                                 String signatureMethod, String referenceURI, X509Certificate x509Certificate,
                                 String canonicalizationMethodType) throws ParserConfigurationException, GeneralSecurityException,
             MarshalException, XMLSignatureException {
@@ -179,7 +181,7 @@ public class XMLSignatureUtil {
         if (!referenceURI.isEmpty()) {
             propagateIDAttributeSetup(nodeToBeSigned, newDoc.getDocumentElement());
         }
-        newDoc = sign(newDoc, keyPair, digestMethod, signatureMethod, referenceURI, x509Certificate, canonicalizationMethodType);
+        newDoc = sign(newDoc, keyId, keyPair, digestMethod, signatureMethod, referenceURI, x509Certificate, canonicalizationMethodType);
 
         // if the signed element is a SAMLv2.0 assertion we need to move the signature element to the position
         // specified in the schema (before the assertion subject element).
@@ -220,10 +222,10 @@ public class XMLSignatureUtil {
      * @throws MarshalException
      * @throws XMLSignatureException
      */
-    public static void sign(Element elementToSign, Node nextSibling, KeyPair keyPair, String digestMethod,
+    public static void sign(Element elementToSign, Node nextSibling, String keyId, KeyPair keyPair, String digestMethod,
                             String signatureMethod, String referenceURI, String canonicalizationMethodType)
             throws GeneralSecurityException, MarshalException, XMLSignatureException {
-        sign(elementToSign, nextSibling, keyPair, digestMethod, signatureMethod, referenceURI, null, canonicalizationMethodType);
+        sign(elementToSign, nextSibling, keyId, keyPair, digestMethod, signatureMethod, referenceURI, null, canonicalizationMethodType);
     }
 
     /**
@@ -242,7 +244,7 @@ public class XMLSignatureUtil {
      * @throws XMLSignatureException
      * @since 2.5.0
      */
-    public static void sign(Element elementToSign, Node nextSibling, KeyPair keyPair, String digestMethod,
+    public static void sign(Element elementToSign, Node nextSibling, String keyId, KeyPair keyPair, String digestMethod,
                             String signatureMethod, String referenceURI, X509Certificate x509Certificate, String canonicalizationMethodType)
             throws GeneralSecurityException, MarshalException, XMLSignatureException {
         PrivateKey signingKey = keyPair.getPrivate();
@@ -250,7 +252,7 @@ public class XMLSignatureUtil {
 
         DOMSignContext dsc = new DOMSignContext(signingKey, elementToSign, nextSibling);
 
-        signImpl(dsc, digestMethod, signatureMethod, referenceURI, publicKey, x509Certificate, canonicalizationMethodType);
+        signImpl(dsc, digestMethod, signatureMethod, referenceURI, keyId, publicKey, x509Certificate, canonicalizationMethodType);
     }
 
     /**
@@ -284,9 +286,9 @@ public class XMLSignatureUtil {
      * @throws XMLSignatureException
      * @throws MarshalException
      */
-    public static Document sign(Document doc, KeyPair keyPair, String digestMethod, String signatureMethod, String referenceURI, String canonicalizationMethodType)
+    public static Document sign(Document doc, String keyId, KeyPair keyPair, String digestMethod, String signatureMethod, String referenceURI, String canonicalizationMethodType)
             throws GeneralSecurityException, MarshalException, XMLSignatureException {
-        return sign(doc, keyPair, digestMethod, signatureMethod, referenceURI, null, canonicalizationMethodType);
+        return sign(doc, keyId, keyPair, digestMethod, signatureMethod, referenceURI, null, canonicalizationMethodType);
     }
 
     /**
@@ -304,7 +306,7 @@ public class XMLSignatureUtil {
      * @throws MarshalException
      * @since 2.5.0
      */
-    public static Document sign(Document doc, KeyPair keyPair, String digestMethod, String signatureMethod, String referenceURI,
+    public static Document sign(Document doc, String keyId, KeyPair keyPair, String digestMethod, String signatureMethod, String referenceURI,
                                 X509Certificate x509Certificate, String canonicalizationMethodType)
             throws GeneralSecurityException, MarshalException, XMLSignatureException {
         logger.trace("Document to be signed=" + DocumentUtil.asString(doc));
@@ -313,7 +315,7 @@ public class XMLSignatureUtil {
 
         DOMSignContext dsc = new DOMSignContext(signingKey, doc.getDocumentElement());
 
-        signImpl(dsc, digestMethod, signatureMethod, referenceURI, publicKey, x509Certificate, canonicalizationMethodType);
+        signImpl(dsc, digestMethod, signatureMethod, referenceURI, keyId, publicKey, x509Certificate, canonicalizationMethodType);
 
         return doc;
     }
@@ -344,7 +346,7 @@ public class XMLSignatureUtil {
 
         DOMSignContext dsc = new DOMSignContext(signingKey, doc.getDocumentElement(), nextSibling);
 
-        signImpl(dsc, digestMethod, signatureMethod, referenceURI, publicKey, dto.getX509Certificate(), canonicalizationMethodType);
+        signImpl(dsc, digestMethod, signatureMethod, referenceURI, dto.getKeyId(), publicKey, dto.getX509Certificate(), canonicalizationMethodType);
 
         return doc;
     }
@@ -594,7 +596,7 @@ public class XMLSignatureUtil {
         throw logger.unsupportedType(key.toString());
     }
 
-    private static void signImpl(DOMSignContext dsc, String digestMethod, String signatureMethod, String referenceURI, PublicKey publicKey,
+    private static void signImpl(DOMSignContext dsc, String digestMethod, String signatureMethod, String referenceURI, String keyId, PublicKey publicKey,
                                  X509Certificate x509Certificate, String canonicalizationMethodType)
             throws GeneralSecurityException, MarshalException, XMLSignatureException {
         dsc.setDefaultNamespacePrefix("dsig");
@@ -618,35 +620,32 @@ public class XMLSignatureUtil {
 
         KeyInfo ki = null;
         if (includeKeyInfoInSignature) {
-            ki = createKeyInfo(publicKey, x509Certificate);
+            ki = createKeyInfo(keyId, publicKey, x509Certificate);
+        } else {
+            ki = createKeyInfo(keyId, null, null);
         }
         XMLSignature signature = fac.newXMLSignature(si, ki);
 
         signature.sign(dsc);
     }
 
-    private static KeyInfo createKeyInfo(PublicKey publicKey, X509Certificate x509Certificate) throws KeyException {
+    private static KeyInfo createKeyInfo(String keyId, PublicKey publicKey, X509Certificate x509Certificate) throws KeyException {
         KeyInfoFactory keyInfoFactory = fac.getKeyInfoFactory();
-        KeyInfo keyInfo = null;
-        KeyValue keyValue = null;
-        //Just with public key
-        if (publicKey != null) {
-            keyValue = keyInfoFactory.newKeyValue(publicKey);
-            keyInfo = keyInfoFactory.newKeyInfo(Collections.singletonList(keyValue));
+
+        List<Object> items = new LinkedList<>();
+
+        if (keyId != null) {
+            items.add(keyInfoFactory.newKeyName(keyId));
         }
-        if (x509Certificate != null) {
-            List x509list = new ArrayList();
 
-            x509list.add(x509Certificate);
-            X509Data x509Data = keyInfoFactory.newX509Data(x509list);
-            List items = new ArrayList();
+        if (x509Certificate != null) {
+            items.add(keyInfoFactory.newX509Data(Collections.singletonList(x509Certificate)));
+        }
 
-            items.add(x509Data);
-            if (keyValue != null) {
-                items.add(keyValue);
-            }
-            keyInfo = keyInfoFactory.newKeyInfo(items);
+        if (publicKey != null) {
+            items.add(keyInfoFactory.newKeyValue(publicKey));
         }
-        return keyInfo;
+
+        return keyInfoFactory.newKeyInfo(items);
     }
 }
\ No newline at end of file
diff --git a/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java b/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java
index 997bc93..d88a34e 100755
--- a/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java
+++ b/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java
@@ -267,7 +267,7 @@ public class SAMLEndpoint {
                         .relayState(relayState);
             if (config.isWantAuthnRequestsSigned()) {
                 KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
-                binding.signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate())
+                binding.signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate())
                         .signatureAlgorithm(provider.getSignatureAlgorithm())
                         .signDocument();
             }
diff --git a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java
index 104b8f8..6452c74 100755
--- a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java
+++ b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java
@@ -103,7 +103,7 @@ public class SAMLIdentityProvider extends AbstractIdentityProvider<SAMLIdentityP
 
                 KeyPair keypair = new KeyPair(keys.getPublicKey(), keys.getPrivateKey());
 
-                binding.signWith(keypair);
+                binding.signWith(keys.getKid(), keypair);
                 binding.signatureAlgorithm(getSignatureAlgorithm());
                 binding.signDocument();
             }
@@ -198,7 +198,7 @@ public class SAMLIdentityProvider extends AbstractIdentityProvider<SAMLIdentityP
                 .relayState(userSession.getId());
         if (getConfig().isWantAuthnRequestsSigned()) {
             KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
-            binding.signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate())
+            binding.signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate())
                     .signatureAlgorithm(getSignatureAlgorithm())
                     .signDocument();
         }
diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java b/services/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java
index 14726d3..2b6b9f7 100755
--- a/services/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java
+++ b/services/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java
@@ -402,14 +402,14 @@ public class SamlProtocol implements LoginProtocol {
             if (canonicalization != null) {
                 bindingBuilder.canonicalizationMethod(canonicalization);
             }
-            bindingBuilder.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
+            bindingBuilder.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
         }
         if (samlClient.requiresAssertionSignature()) {
             String canonicalization = samlClient.getCanonicalizationMethod();
             if (canonicalization != null) {
                 bindingBuilder.canonicalizationMethod(canonicalization);
             }
-            bindingBuilder.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signAssertions();
+            bindingBuilder.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signAssertions();
         }
         if (samlClient.requiresEncryption()) {
             PublicKey publicKey = null;
@@ -541,7 +541,7 @@ public class SamlProtocol implements LoginProtocol {
                 binding.canonicalizationMethod(canonicalization);
             }
             KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
-            binding.signatureAlgorithm(algorithm).signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
+            binding.signatureAlgorithm(algorithm).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
         }
 
         try {
@@ -639,7 +639,7 @@ public class SamlProtocol implements LoginProtocol {
         JaxrsSAML2BindingBuilder binding = new JaxrsSAML2BindingBuilder();
         if (samlClient.requiresRealmSignature()) {
             KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
-            binding.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
+            binding.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
         }
         return binding;
     }
diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlService.java b/services/src/main/java/org/keycloak/protocol/saml/SamlService.java
index ea01085..e02093c 100755
--- a/services/src/main/java/org/keycloak/protocol/saml/SamlService.java
+++ b/services/src/main/java/org/keycloak/protocol/saml/SamlService.java
@@ -400,7 +400,7 @@ public class SamlService extends AuthorizationEndpointBase {
             if (samlClient.requiresRealmSignature()) {
                 SignatureAlgorithm algorithm = samlClient.getSignatureAlgorithm();
                 KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
-                binding.signatureAlgorithm(algorithm).signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
+                binding.signatureAlgorithm(algorithm).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
 
             }
             try {