keycloak-uncached

KEYCLOAK-1881 Include key ID for REDIRECT and use it for validation Contrary

11/2/2016 5:46:06 AM

Changes

Details

diff --git a/adapters/saml/core/nbproject/project.properties b/adapters/saml/core/nbproject/project.properties
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/adapters/saml/core/nbproject/project.properties
diff --git a/adapters/saml/core/pom.xml b/adapters/saml/core/pom.xml
index 16dce33..b01061b 100755
--- a/adapters/saml/core/pom.xml
+++ b/adapters/saml/core/pom.xml
@@ -34,6 +34,7 @@
         <timestamp>${maven.build.timestamp}</timestamp>
         <maven.build.timestamp.format>yyyy-MM-dd HH:mm</maven.build.timestamp.format>
     </properties>
+
     <dependencies>
         <dependency>
             <groupId>org.keycloak</groupId>
@@ -70,6 +71,11 @@
             <artifactId>junit</artifactId>
             <scope>test</scope>
         </dependency>
+        <dependency>
+            <groupId>org.apache.httpcomponents</groupId>
+            <artifactId>httpclient</artifactId>
+            <scope>provided</scope>
+        </dependency>
     </dependencies>
     <build>
         <plugins>
diff --git a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/AbstractInitiateLogin.java b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/AbstractInitiateLogin.java
index 693e06e..6ddf52c 100755
--- a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/AbstractInitiateLogin.java
+++ b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/AbstractInitiateLogin.java
@@ -80,6 +80,8 @@ public abstract class AbstractInitiateLogin implements AuthChallenge {
             }
 
             binding.signWith(null, keypair);
+            // TODO: As part of KEYCLOAK-3810, add KeyID to the SAML document
+            //   <related DocumentBuilder>.addExtension(new KeycloakKeySamlExtensionGenerator(<key ID>));
             binding.signDocument();
         }
         return binding;
diff --git a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/config/parsers/DeploymentBuilder.java b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/config/parsers/DeploymentBuilder.java
index d6e4bce..ee21620 100755
--- a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/config/parsers/DeploymentBuilder.java
+++ b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/config/parsers/DeploymentBuilder.java
@@ -202,6 +202,7 @@ public class DeploymentBuilder {
             }
         }
 
+        idp.refreshKeyLocatorConfiguration();
 
         return deployment;
     }
diff --git a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/DefaultSamlDeployment.java b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/DefaultSamlDeployment.java
index ee753ad..fcbe1e9 100755
--- a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/DefaultSamlDeployment.java
+++ b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/DefaultSamlDeployment.java
@@ -24,6 +24,12 @@ import java.security.KeyPair;
 import java.security.PrivateKey;
 import java.security.PublicKey;
 import java.util.Set;
+import org.apache.http.client.HttpClient;
+import org.keycloak.adapters.HttpClientBuilder;
+import org.keycloak.adapters.saml.rotation.SamlDescriptorPublicKeyLocator;
+import org.keycloak.rotation.CompositeKeyLocator;
+import org.keycloak.rotation.HardcodedKeyLocator;
+import org.keycloak.rotation.KeyLocator;
 
 /**
  * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
@@ -179,10 +185,14 @@ public class DefaultSamlDeployment implements SamlDeployment {
 
     public static class DefaultIDP implements IDP {
 
+        private static final int DEFAULT_CACHE_TTL = 24 * 60 * 60;
+
         private String entityID;
-        private PublicKey signatureValidationKey;
+        private final CompositeKeyLocator signatureValidationKeyLocator = new CompositeKeyLocator();
         private SingleSignOnService singleSignOnService;
         private SingleLogoutService singleLogoutService;
+        private HardcodedKeyLocator hardcodedKeyLocator;
+        private int minTimeBetweenDescriptorRequests;
 
         @Override
         public String getEntityID() {
@@ -200,8 +210,17 @@ public class DefaultSamlDeployment implements SamlDeployment {
         }
 
         @Override
-        public PublicKey getSignatureValidationKey() {
-            return signatureValidationKey;
+        public KeyLocator getSignatureValidationKeyLocator() {
+            return this.signatureValidationKeyLocator;
+        }
+
+        @Override
+        public int getMinTimeBetweenDescriptorRequests() {
+            return minTimeBetweenDescriptorRequests;
+        }
+
+        public void setMinTimeBetweenDescriptorRequests(int minTimeBetweenDescriptorRequests) {
+            this.minTimeBetweenDescriptorRequests = minTimeBetweenDescriptorRequests;
         }
 
         public void setEntityID(String entityID) {
@@ -209,16 +228,35 @@ public class DefaultSamlDeployment implements SamlDeployment {
         }
 
         public void setSignatureValidationKey(PublicKey signatureValidationKey) {
-            this.signatureValidationKey = signatureValidationKey;
+            this.hardcodedKeyLocator = signatureValidationKey == null ? null : new HardcodedKeyLocator(signatureValidationKey);
+            refreshKeyLocatorConfiguration();
         }
 
         public void setSingleSignOnService(SingleSignOnService singleSignOnService) {
             this.singleSignOnService = singleSignOnService;
+            refreshKeyLocatorConfiguration();
         }
 
         public void setSingleLogoutService(SingleLogoutService singleLogoutService) {
             this.singleLogoutService = singleLogoutService;
         }
+
+        public void refreshKeyLocatorConfiguration() {
+            this.signatureValidationKeyLocator.clear();
+
+            // When key is set, use that (and only that), otherwise configure dynamic key locator
+            if (this.hardcodedKeyLocator != null) {
+                this.signatureValidationKeyLocator.add(this.hardcodedKeyLocator);
+            } else if (this.singleSignOnService != null) {
+                String samlDescriptorUrl = singleSignOnService.getRequestBindingUrl() + "/descriptor";
+                // TODO
+                HttpClient httpClient = new HttpClientBuilder().build();
+                SamlDescriptorPublicKeyLocator samlDescriptorPublicKeyLocator =
+                  new SamlDescriptorPublicKeyLocator(
+                    samlDescriptorUrl, this.minTimeBetweenDescriptorRequests, DEFAULT_CACHE_TTL, httpClient);
+                this.signatureValidationKeyLocator.add(samlDescriptorPublicKeyLocator);
+            }
+        }
     }
 
     private IDP idp;
diff --git a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/AbstractSamlAuthenticationHandler.java b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/AbstractSamlAuthenticationHandler.java
index e9247b3..429d610 100644
--- a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/AbstractSamlAuthenticationHandler.java
+++ b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/AbstractSamlAuthenticationHandler.java
@@ -64,11 +64,20 @@ import org.w3c.dom.Node;
 
 import java.io.IOException;
 import java.net.URI;
+import java.security.InvalidKeyException;
+import java.security.Key;
+import java.security.KeyManagementException;
 import java.security.PublicKey;
 import java.security.Signature;
+import java.security.SignatureException;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+import org.keycloak.dom.saml.v2.SAML2Object;
+import org.keycloak.dom.saml.v2.protocol.ExtensionsType;
+import org.keycloak.rotation.KeyLocator;
+import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
+import org.w3c.dom.Element;
 
 /**
  *
@@ -257,13 +266,44 @@ public abstract class AbstractSamlAuthenticationHandler implements SamlAuthentic
     }
 
     private void validateSamlSignature(SAMLDocumentHolder holder, boolean postBinding, String paramKey) throws VerificationException {
+        KeyLocator signatureValidationKey = deployment.getIDP().getSignatureValidationKeyLocator();
         if (postBinding) {
-            verifyPostBindingSignature(holder.getSamlDocument(), deployment.getIDP().getSignatureValidationKey());
+            verifyPostBindingSignature(holder.getSamlDocument(), signatureValidationKey);
         } else {
-            verifyRedirectBindingSignature(deployment.getIDP().getSignatureValidationKey(), paramKey);
+            String keyId = getMessageSigningKeyId(holder.getSamlObject());
+            verifyRedirectBindingSignature(paramKey, signatureValidationKey, keyId);
         }
     }
 
+    private String getMessageSigningKeyId(SAML2Object doc) {
+        final ExtensionsType extensions;
+        if (doc instanceof RequestAbstractType) {
+            extensions = ((RequestAbstractType) doc).getExtensions();
+        } else if (doc instanceof StatusResponseType) {
+            extensions = ((StatusResponseType) doc).getExtensions();
+        } else {
+            return null;
+        }
+
+        if (extensions == null) {
+            return null;
+        }
+
+        for (Object ext : extensions.getAny()) {
+            if (! (ext instanceof Element)) {
+                continue;
+            }
+
+            String res = KeycloakKeySamlExtensionGenerator.getMessageSigningKeyIdFromElement((Element) ext);
+
+            if (res != null) {
+                return res;
+            }
+        }
+
+        return null;
+    }
+
     private boolean checkStatusCodeValue(StatusCodeType statusCode, String expectedValue){
         if(statusCode != null && statusCode.getValue()!=null){
             String v = statusCode.getValue().toString();
@@ -473,10 +513,10 @@ public abstract class AbstractSamlAuthenticationHandler implements SamlAuthentic
         return false;
     }
 
-    public void verifyPostBindingSignature(Document document, PublicKey publicKey) throws VerificationException {
+    public void verifyPostBindingSignature(Document document, KeyLocator keyLocator) throws VerificationException {
         SAML2Signature saml2Signature = new SAML2Signature();
         try {
-            if (!saml2Signature.validate(document, publicKey)) {
+            if (!saml2Signature.validate(document, keyLocator)) {
                 throw new VerificationException("Invalid signature on document");
             }
         } catch (ProcessingException e) {
@@ -484,7 +524,7 @@ public abstract class AbstractSamlAuthenticationHandler implements SamlAuthentic
         }
     }
 
-    public void verifyRedirectBindingSignature(PublicKey publicKey, String paramKey) throws VerificationException {
+    private void verifyRedirectBindingSignature(String paramKey, KeyLocator keyLocator, String keyId) throws VerificationException {
         String request = facade.getRequest().getQueryParamValue(paramKey);
         String algorithm = facade.getRequest().getQueryParamValue(GeneralConstants.SAML_SIG_ALG_REQUEST_KEY);
         String signature = facade.getRequest().getQueryParamValue(GeneralConstants.SAML_SIGNATURE_REQUEST_KEY);
@@ -511,16 +551,80 @@ public abstract class AbstractSamlAuthenticationHandler implements SamlAuthentic
         try {
             //byte[] decodedSignature = RedirectBindingUtil.urlBase64Decode(signature);
             byte[] decodedSignature = Base64.decode(signature);
+            byte[] rawQueryBytes = rawQuery.getBytes("UTF-8");
 
             SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.getFromXmlMethod(decodedAlgorithm);
-            Signature validator = signatureAlgorithm.createSignature(); // todo plugin signature alg
-            validator.initVerify(publicKey);
-            validator.update(rawQuery.getBytes("UTF-8"));
-            if (!validator.verify(decodedSignature)) {
+
+            if (! validateRedirectBindingSignature(signatureAlgorithm, rawQueryBytes, decodedSignature, keyLocator, keyId)) {
                 throw new VerificationException("Invalid query param signature");
             }
         } catch (Exception e) {
             throw new VerificationException(e);
         }
     }
+
+    private boolean validateRedirectBindingSignature(SignatureAlgorithm sigAlg, byte[] rawQueryBytes, byte[] decodedSignature, KeyLocator locator, String keyId)
+      throws KeyManagementException, VerificationException {
+        try {
+            Key key;
+            try {
+                key = locator.getKey(keyId);
+                boolean keyLocated = key != null;
+
+                if (validateRedirectBindingSignatureForKey(sigAlg, rawQueryBytes, decodedSignature, key)) {
+                    return true;
+                }
+
+                if (keyLocated) {
+                    return false;
+                }
+            } catch (KeyManagementException ex) {
+            }
+        } catch (SignatureException ex) {
+            log.debug("Verification failed for key %s: %s", keyId, ex);
+            log.trace(ex);
+        }
+
+        if (locator instanceof Iterable) {
+            Iterable<Key> availableKeys = (Iterable<Key>) locator;
+
+            log.trace("Trying hard to validate XML signature using all available keys.");
+
+            for (Key key : availableKeys) {
+                try {
+                    if (validateRedirectBindingSignatureForKey(sigAlg, rawQueryBytes, decodedSignature, key)) {
+                        return true;
+                    }
+                } catch (SignatureException ex) {
+                    log.debug("Verification failed: %s", ex);
+                }
+            }
+        }
+
+        return false;
+    }
+
+    private boolean validateRedirectBindingSignatureForKey(SignatureAlgorithm sigAlg, byte[] rawQueryBytes, byte[] decodedSignature, Key key)
+      throws SignatureException {
+        if (key == null) {
+            return false;
+        }
+
+        if (! (key instanceof PublicKey)) {
+            log.warnf("Unusable key for signature validation: %s", key);
+            return false;
+        }
+
+        Signature signature = sigAlg.createSignature(); // todo plugin signature alg
+        try {
+            signature.initVerify((PublicKey) key);
+        } catch (InvalidKeyException ex) {
+            log.warnf(ex, "Unusable key for signature validation: %s", key);
+            return false;
+        }
+
+        signature.update(rawQueryBytes);
+
+        return signature.verify(decodedSignature);
+    }
 }
diff --git a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/webbrowsersso/WebBrowserSsoAuthenticationHandler.java b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/webbrowsersso/WebBrowserSsoAuthenticationHandler.java
index 3581357..231c425 100755
--- a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/webbrowsersso/WebBrowserSsoAuthenticationHandler.java
+++ b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/profile/webbrowsersso/WebBrowserSsoAuthenticationHandler.java
@@ -84,6 +84,8 @@ public class WebBrowserSsoAuthenticationHandler extends AbstractSamlAuthenticati
             binding.signatureAlgorithm(deployment.getSignatureAlgorithm())
                     .signWith(null, deployment.getSigningKeyPair())
                     .signDocument();
+            // TODO: As part of KEYCLOAK-3810, add KeyID to the SAML document
+            //   <related DocumentBuilder>.addExtension(new KeycloakKeySamlExtensionGenerator(<key ID>));
         }
 
 
@@ -115,6 +117,8 @@ public class WebBrowserSsoAuthenticationHandler extends AbstractSamlAuthenticati
             binding.signatureAlgorithm(deployment.getSignatureAlgorithm());
             binding.signWith(null, deployment.getSigningKeyPair())
                     .signDocument();
+            // TODO: As part of KEYCLOAK-3810, add KeyID to the SAML document
+            //   <related DocumentBuilder>.addExtension(new KeycloakKeySamlExtensionGenerator(<key ID>));
         }
 
         binding.relayState("logout");
diff --git a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/SamlDeployment.java b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/SamlDeployment.java
index 0b82ff2..f01b6a1 100755
--- a/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/SamlDeployment.java
+++ b/adapters/saml/core/src/main/java/org/keycloak/adapters/saml/SamlDeployment.java
@@ -22,14 +22,17 @@ import org.keycloak.saml.SignatureAlgorithm;
 
 import java.security.KeyPair;
 import java.security.PrivateKey;
-import java.security.PublicKey;
 import java.util.Set;
+import org.keycloak.rotation.KeyLocator;
 
 /**
+ * Represents SAML deployment configuration.
+ * 
  * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
  * @version $Revision: 1 $
  */
 public interface SamlDeployment {
+
     enum Binding {
         POST,
         REDIRECT;
@@ -41,20 +44,62 @@ public interface SamlDeployment {
     }
 
     public interface IDP {
+        /**
+         * Returns entity identifier of this IdP.
+         * @return see description.
+         */
         String getEntityID();
 
+        /**
+         * Returns Single sign on service configuration for this IdP.
+         * @return see description.
+         */
         SingleSignOnService getSingleSignOnService();
+
+        /**
+         * Returns Single logout service configuration for this IdP.
+         * @return see description.
+         */
         SingleLogoutService getSingleLogoutService();
-        PublicKey getSignatureValidationKey();
+
+        /**
+         * Returns {@link KeyLocator} looking up public keys used for validation of IdP signatures.
+         * @return see description.
+         */
+        KeyLocator getSignatureValidationKeyLocator();
+
+        /**
+         * Returns minimum time (in seconds) between issuing requests to IdP SAML descriptor.
+         * Used e.g. by {@link KeyLocator} looking up public keys for validation of IdP signatures
+         * to prevent too frequent requests.
+         *
+         * @return see description.
+         */
+        int getMinTimeBetweenDescriptorRequests();
 
         public interface SingleSignOnService {
+            /**
+             * Returns {@code true} if the requests to IdP need to be signed by SP key.
+             * @return see dscription
+             */
             boolean signRequest();
+            /**
+             * Returns {@code true} if the complete response message from IdP should
+             * be checked for valid signature.
+             * @return see dscription
+             */
             boolean validateResponseSignature();
+            /**
+             * Returns {@code true} if individual assertions in response from IdP should
+             * be checked for valid signature.
+             * @return see dscription
+             */
             boolean validateAssertionSignature();
             Binding getRequestBinding();
             Binding getResponseBinding();
             String getRequestBindingUrl();
         }
+
         public interface SingleLogoutService {
             boolean validateRequestSignature();
             boolean validateResponseSignature();
@@ -67,10 +112,19 @@ public interface SamlDeployment {
         }
     }
 
+    /**
+     * Returns Identity Provider configuration for this SAML deployment.
+     * @return see description.
+     */
     public IDP getIDP();
 
     public boolean isConfigured();
     SslRequired getSslRequired();
+
+    /**
+     * Returns entity identifier of this SP.
+     * @return see description.
+     */
     String getEntityID();
     String getNameIDPolicyFormat();
     boolean isForceAuthentication();
diff --git a/saml-core/nbproject/project.properties b/saml-core/nbproject/project.properties
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/saml-core/nbproject/project.properties
diff --git a/saml-core/src/main/java/org/keycloak/saml/BaseSAML2BindingBuilder.java b/saml-core/src/main/java/org/keycloak/saml/BaseSAML2BindingBuilder.java
index 78a4480..f820a5e 100755
--- a/saml-core/src/main/java/org/keycloak/saml/BaseSAML2BindingBuilder.java
+++ b/saml-core/src/main/java/org/keycloak/saml/BaseSAML2BindingBuilder.java
@@ -38,11 +38,14 @@ import javax.crypto.spec.SecretKeySpec;
 import javax.xml.crypto.dsig.CanonicalizationMethod;
 import javax.xml.namespace.QName;
 import java.io.IOException;
+import java.io.UnsupportedEncodingException;
 import java.net.URI;
+import java.security.InvalidKeyException;
 import java.security.KeyPair;
 import java.security.PrivateKey;
 import java.security.PublicKey;
 import java.security.Signature;
+import java.security.SignatureException;
 import java.security.cert.X509Certificate;
 
 import static org.keycloak.common.util.HtmlUtils.escapeAttribute;
@@ -338,7 +341,7 @@ public class BaseSAML2BindingBuilder<T extends BaseSAML2BindingBuilder> {
 
     public String base64Encoded(Document document) throws ConfigurationException, ProcessingException, IOException  {
         String documentAsString = DocumentUtil.getDocumentAsString(document);
-        logger.debugv("saml docment: {0}", documentAsString);
+        logger.debugv("saml document: {0}", documentAsString);
         byte[] responseBytes = documentAsString.getBytes("UTF-8");
 
         return RedirectBindingUtil.deflateBase64URLEncode(responseBytes);
@@ -363,7 +366,7 @@ public class BaseSAML2BindingBuilder<T extends BaseSAML2BindingBuilder> {
                 signature.initSign(signingKeyPair.getPrivate());
                 signature.update(rawQuery.getBytes("UTF-8"));
                 sig = signature.sign();
-            } catch (Exception e) {
+            } catch (InvalidKeyException | UnsupportedEncodingException | SignatureException e) {
                 throw new ProcessingException(e);
             }
             String encodedSig = RedirectBindingUtil.base64URLEncode(sig);
diff --git a/saml-core/src/main/java/org/keycloak/saml/processing/api/saml/v2/sig/SAML2Signature.java b/saml-core/src/main/java/org/keycloak/saml/processing/api/saml/v2/sig/SAML2Signature.java
index 57777ab..49c8df8 100755
--- a/saml-core/src/main/java/org/keycloak/saml/processing/api/saml/v2/sig/SAML2Signature.java
+++ b/saml-core/src/main/java/org/keycloak/saml/processing/api/saml/v2/sig/SAML2Signature.java
@@ -35,8 +35,8 @@ import javax.xml.crypto.dsig.XMLSignatureException;
 import javax.xml.parsers.ParserConfigurationException;
 import java.security.GeneralSecurityException;
 import java.security.KeyPair;
-import java.security.PublicKey;
 import java.security.cert.X509Certificate;
+import org.keycloak.rotation.KeyLocator;
 
 /**
  * Class that deals with SAML2 Signature
@@ -159,7 +159,7 @@ public class SAML2Signature {
         String id = samlDocument.getDocumentElement().getAttribute(ID_ATTRIBUTE_NAME);
         try {
             sign(samlDocument, id, keyId, keypair, canonicalizationMethodType);
-        } catch (Exception e) {
+        } catch (ParserConfigurationException | GeneralSecurityException | MarshalException | XMLSignatureException e) {
             throw new ProcessingException(logger.signatureError(e));
         }
     }
@@ -168,20 +168,18 @@ public class SAML2Signature {
      * Validate the SAML2 Document
      *
      * @param signedDocument
-     * @param publicKey
+     * @param keyLocator
      *
      * @return
      *
      * @throws ProcessingException
      */
-    public boolean validate(Document signedDocument, PublicKey publicKey) throws ProcessingException {
+    public boolean validate(Document signedDocument, KeyLocator keyLocator) throws ProcessingException {
         try {
             configureIdAttribute(signedDocument);
-            return XMLSignatureUtil.validate(signedDocument, publicKey);
-        } catch (MarshalException me) {
+            return XMLSignatureUtil.validate(signedDocument, keyLocator);
+        } catch (MarshalException | XMLSignatureException me) {
             throw new ProcessingException(logger.signatureError(me));
-        } catch (XMLSignatureException xse) {
-            throw new ProcessingException(logger.signatureError(xse));
         }
     }
 
diff --git a/saml-core/src/main/java/org/keycloak/saml/processing/core/saml/v2/util/AssertionUtil.java b/saml-core/src/main/java/org/keycloak/saml/processing/core/saml/v2/util/AssertionUtil.java
index 67fb78f..ed941a0 100755
--- a/saml-core/src/main/java/org/keycloak/saml/processing/core/saml/v2/util/AssertionUtil.java
+++ b/saml-core/src/main/java/org/keycloak/saml/processing/core/saml/v2/util/AssertionUtil.java
@@ -62,6 +62,7 @@ import java.security.PublicKey;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Set;
+import org.keycloak.rotation.HardcodedKeyLocator;
 
 /**
  * Utility to deal with assertions
@@ -276,7 +277,7 @@ public class AssertionUtil {
             Node n = doc.importNode(assertionElement, true);
             doc.appendChild(n);
 
-            return new SAML2Signature().validate(doc, publicKey);
+            return new SAML2Signature().validate(doc, new HardcodedKeyLocator(publicKey));
         } catch (Exception e) {
             logger.signatureAssertionValidationError(e);
         }
diff --git a/saml-core/src/main/java/org/keycloak/saml/processing/core/util/KeycloakKeySamlExtensionGenerator.java b/saml-core/src/main/java/org/keycloak/saml/processing/core/util/KeycloakKeySamlExtensionGenerator.java
new file mode 100644
index 0000000..1bb90ea
--- /dev/null
+++ b/saml-core/src/main/java/org/keycloak/saml/processing/core/util/KeycloakKeySamlExtensionGenerator.java
@@ -0,0 +1,75 @@
+/*
+ * Copyright 2016 Red Hat, Inc. and/or its affiliates
+ * and other contributors as indicated by the @author tags.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.keycloak.saml.processing.core.util;
+
+import java.util.Objects;
+import javax.xml.stream.XMLStreamWriter;
+import org.keycloak.saml.SamlProtocolExtensionsAwareBuilder;
+import org.keycloak.saml.common.exceptions.ProcessingException;
+import org.keycloak.saml.common.util.StaxUtil;
+import org.w3c.dom.Element;
+
+/**
+ *
+ * @author hmlnarik
+ */
+public class KeycloakKeySamlExtensionGenerator implements SamlProtocolExtensionsAwareBuilder.NodeGenerator {
+
+    public static final String NS_URI = "urn:keycloak:ext:key:1.0";
+
+    public static final String NS_PREFIX = "kckey";
+
+    public static final String KC_KEY_INFO_ELEMENT_NAME = "KeyInfo";
+
+    public static final String KEY_ID_ATTRIBUTE_NAME = "MessageSigningKeyId";
+
+    private final String keyId;
+
+    public KeycloakKeySamlExtensionGenerator(String keyId) {
+        this.keyId = keyId;
+    }
+
+    @Override
+    public void write(XMLStreamWriter writer) throws ProcessingException {
+        StaxUtil.writeStartElement(writer, NS_PREFIX, KC_KEY_INFO_ELEMENT_NAME, NS_URI);
+        StaxUtil.writeNameSpace(writer, NS_PREFIX, NS_URI);
+        if (this.keyId != null) {
+            StaxUtil.writeAttribute(writer, KEY_ID_ATTRIBUTE_NAME, this.keyId);
+        }
+        StaxUtil.writeEndElement(writer);
+        StaxUtil.flush(writer);
+    }
+
+    /**
+     * Checks that the given element is indeed a Keycloak extension {@code KeyInfo} element and
+     * returns a content of {@code MessageSigningKeyId} attribute in the given element.
+     * @param element Element to obtain the key info from.
+     * @return {@code null} if the element is unknown or there is {@code MessageSigningKeyId} attribute unset,
+     *   value of the {@code MessageSigningKeyId} attribute otherwise.
+     */
+    public static String getMessageSigningKeyIdFromElement(Element element) {
+        if (Objects.equals(element.getNamespaceURI(), NS_URI) &&
+          Objects.equals(element.getLocalName(), KC_KEY_INFO_ELEMENT_NAME) &&
+          element.hasAttribute(KEY_ID_ATTRIBUTE_NAME)) {
+            return element.getAttribute(KEY_ID_ATTRIBUTE_NAME);
+        }
+
+        return null;
+    }
+
+}
diff --git a/saml-core/src/main/java/org/keycloak/saml/processing/core/util/XMLSignatureUtil.java b/saml-core/src/main/java/org/keycloak/saml/processing/core/util/XMLSignatureUtil.java
index 7000075..193af19 100755
--- a/saml-core/src/main/java/org/keycloak/saml/processing/core/util/XMLSignatureUtil.java
+++ b/saml-core/src/main/java/org/keycloak/saml/processing/core/util/XMLSignatureUtil.java
@@ -54,8 +54,6 @@ import javax.xml.crypto.dsig.dom.DOMSignContext;
 import javax.xml.crypto.dsig.dom.DOMValidateContext;
 import javax.xml.crypto.dsig.keyinfo.KeyInfo;
 import javax.xml.crypto.dsig.keyinfo.KeyInfoFactory;
-import javax.xml.crypto.dsig.keyinfo.KeyValue;
-import javax.xml.crypto.dsig.keyinfo.X509Data;
 import javax.xml.crypto.dsig.spec.C14NMethodParameterSpec;
 import javax.xml.crypto.dsig.spec.TransformParameterSpec;
 import javax.xml.namespace.QName;
@@ -69,6 +67,7 @@ import java.io.OutputStream;
 import java.security.GeneralSecurityException;
 import java.security.Key;
 import java.security.KeyException;
+import java.security.KeyManagementException;
 import java.security.KeyPair;
 import java.security.NoSuchProviderException;
 import java.security.PrivateKey;
@@ -81,7 +80,14 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.LinkedList;
 import java.util.List;
+import javax.xml.crypto.AlgorithmMethod;
+import javax.xml.crypto.KeySelector;
+import javax.xml.crypto.KeySelectorException;
+import javax.xml.crypto.KeySelectorResult;
+import javax.xml.crypto.XMLCryptoContext;
 import javax.xml.crypto.dsig.keyinfo.KeyName;
+import org.keycloak.rotation.KeyLocator;
+import org.keycloak.saml.processing.api.util.KeyInfoTools;
 
 /**
  * Utility for XML Signature <b>Note:</b> You can change the canonicalization method type by using the system property
@@ -107,15 +113,66 @@ public class XMLSignatureUtil {
 
     ;
 
-    private static String canonicalizationMethodType = CanonicalizationMethod.EXCLUSIVE;
-
-    private static XMLSignatureFactory fac = getXMLSignatureFactory();
+    private static final XMLSignatureFactory fac = getXMLSignatureFactory();
 
     /**
      * By default, we include the keyinfo in the signature
      */
     private static boolean includeKeyInfoInSignature = true;
 
+    private static class KeySelectorUtilizingKeyNameHint extends KeySelector {
+
+        private final KeyLocator locator;
+
+        private boolean keyLocated = false;
+
+        private String keyName = null;
+
+        public KeySelectorUtilizingKeyNameHint(KeyLocator locator) {
+            this.locator = locator;
+        }
+
+        @Override
+        public KeySelectorResult select(KeyInfo keyInfo, KeySelector.Purpose purpose, AlgorithmMethod method, XMLCryptoContext context) throws KeySelectorException {
+            try {
+                KeyName keyNameEl = KeyInfoTools.getKeyName(keyInfo);
+                this.keyName = keyNameEl == null ? null : keyNameEl.getName();
+                final Key key = locator.getKey(keyName);
+                this.keyLocated = key != null;
+                return new KeySelectorResult() {
+                    @Override public Key getKey() {
+                        return key;
+                    }
+                };
+            } catch (KeyManagementException ex) {
+                throw new KeySelectorException(ex);
+            }
+
+        }
+
+        private boolean wasKeyLocated() {
+            return this.keyLocated;
+        }
+    }
+
+    private static class KeySelectorPresetKey extends KeySelector {
+
+        private final Key key;
+
+        public KeySelectorPresetKey(Key key) {
+            this.key = key;
+        }
+
+        @Override
+        public KeySelectorResult select(KeyInfo keyInfo, KeySelector.Purpose purpose, AlgorithmMethod method, XMLCryptoContext context) {
+            return new KeySelectorResult() {
+                @Override public Key getKey() {
+                    return key;
+                }
+            };
+        }
+    }
+
     private static XMLSignatureFactory getXMLSignatureFactory() {
         XMLSignatureFactory xsf = null;
 
@@ -333,6 +390,7 @@ public class XMLSignatureUtil {
     public static Document sign(SignatureUtilTransferObject dto, String canonicalizationMethodType) throws GeneralSecurityException, MarshalException,
             XMLSignatureException {
         Document doc = dto.getDocumentToBeSigned();
+        String keyId = dto.getKeyId();
         KeyPair keyPair = dto.getKeyPair();
         Node nextSibling = dto.getNextSibling();
         String digestMethod = dto.getDigestMethod();
@@ -346,13 +404,14 @@ public class XMLSignatureUtil {
 
         DOMSignContext dsc = new DOMSignContext(signingKey, doc.getDocumentElement(), nextSibling);
 
-        signImpl(dsc, digestMethod, signatureMethod, referenceURI, dto.getKeyId(), publicKey, dto.getX509Certificate(), canonicalizationMethodType);
+        signImpl(dsc, digestMethod, signatureMethod, referenceURI, keyId, publicKey, dto.getX509Certificate(), canonicalizationMethodType);
 
         return doc;
     }
 
     /**
-     * Validate a signed document with the given public key
+     * Validate a signed document with the given public key. All elements that contain a Signature are checked,
+     * this way both assertions and the containing document are verified when signed.
      *
      * @param signedDoc
      * @param publicKey
@@ -363,7 +422,7 @@ public class XMLSignatureUtil {
      * @throws XMLSignatureException
      */
     @SuppressWarnings("unchecked")
-    public static boolean validate(Document signedDoc, Key publicKey) throws MarshalException, XMLSignatureException {
+    public static boolean validate(Document signedDoc, final KeyLocator locator) throws MarshalException, XMLSignatureException {
         if (signedDoc == null)
             throw logger.nullArgumentError("Signed Document");
 
@@ -376,7 +435,7 @@ public class XMLSignatureUtil {
             return false;
         }
 
-        if (publicKey == null)
+        if (locator == null)
             throw logger.nullValueError("Public Key");
 
         int signedAssertions = 0;
@@ -392,24 +451,7 @@ public class XMLSignatureUtil {
                 }
             }
 
-            DOMValidateContext valContext = new DOMValidateContext(publicKey, nl.item(i));
-            XMLSignature signature = fac.unmarshalXMLSignature(valContext);
-
-            boolean coreValidity = signature.validate(valContext);
-
-            if (!coreValidity) {
-                if (logger.isTraceEnabled()) {
-                    boolean sv = signature.getSignatureValue().validate(valContext);
-                    logger.trace("Signature validation status: " + sv);
-
-                    List<Reference> references = signature.getSignedInfo().getReferences();
-                    for (Reference ref : references) {
-                        logger.trace("[Ref id=" + ref.getId() + ":uri=" + ref.getURI() + "]validity status:" + ref.validate(valContext));
-                    }
-                }
-
-                return false;
-            }
+            if (! validateSingleNode(signatureNode, locator)) return false;
         }
 
         NodeList assertions = signedDoc.getElementsByTagNameNS(assertionNameSpaceUri, JBossSAMLConstants.ASSERTION.get());
@@ -425,6 +467,62 @@ public class XMLSignatureUtil {
         return true;
     }
 
+    private static boolean validateSingleNode(Node signatureNode, final KeyLocator locator) throws MarshalException, XMLSignatureException {
+        KeySelectorUtilizingKeyNameHint sel = new KeySelectorUtilizingKeyNameHint(locator);
+        try {
+            if (validateUsingKeySelector(signatureNode, sel)) {
+                return true;
+            }
+            if (sel.wasKeyLocated()) {
+                return false;
+            }
+        } catch (XMLSignatureException ex) { // pass through MarshalException
+            logger.debug("Verification failed for key " + sel.keyName + ": " + ex);
+            logger.trace(ex);
+        }
+
+        logger.trace("Could not validate signature using ds:KeyInfo/ds:KeyName hint.");
+
+        if (locator instanceof Iterable) {
+            Iterable<Key> availableKeys = (Iterable<Key>) locator;
+
+            logger.trace("Trying hard to validate XML signature using all available keys.");
+
+            for (Key key : availableKeys) {
+                try {
+                    if (validateUsingKeySelector(signatureNode, new KeySelectorPresetKey(key))) {
+                        return true;
+                    }
+                } catch (XMLSignatureException ex) { // pass through MarshalException
+                    logger.debug("Verification failed: " + ex);
+                    logger.trace(ex);
+                }
+            }
+        }
+
+        return false;
+    }
+
+    private static boolean validateUsingKeySelector(Node signatureNode, KeySelector validationKeySelector) throws XMLSignatureException, MarshalException {
+        DOMValidateContext valContext = new DOMValidateContext(validationKeySelector, signatureNode);
+        XMLSignature signature = fac.unmarshalXMLSignature(valContext);
+        boolean coreValidity = signature.validate(valContext);
+        
+        if (! coreValidity) {
+            if (logger.isTraceEnabled()) {
+                boolean sv = signature.getSignatureValue().validate(valContext);
+                logger.trace("Signature validation status: " + sv);
+
+                List<Reference> references = signature.getSignedInfo().getReferences();
+                for (Reference ref : references) {
+                    logger.trace("[Ref id=" + ref.getId() + ":uri=" + ref.getURI() + "]validity status:" + ref.validate(valContext));
+                }
+            }
+        }
+
+        return coreValidity;
+    }
+
     /**
      * Marshall a SignatureType to output stream
      *
@@ -605,7 +703,7 @@ public class XMLSignatureUtil {
         Transform transform1 = fac.newTransform(Transform.ENVELOPED, (TransformParameterSpec) null);
         Transform transform2 = fac.newTransform("http://www.w3.org/2001/10/xml-exc-c14n#", (TransformParameterSpec) null);
 
-        List<Transform> transformList = new ArrayList<Transform>();
+        List<Transform> transformList = new ArrayList<>();
         transformList.add(transform1);
         transformList.add(transform2);
 
@@ -618,7 +716,7 @@ public class XMLSignatureUtil {
         SignatureMethod signatureMethodObj = fac.newSignatureMethod(signatureMethod, null);
         SignedInfo si = fac.newSignedInfo(canonicalizationMethod, signatureMethodObj, referenceList);
 
-        KeyInfo ki = null;
+        KeyInfo ki;
         if (includeKeyInfoInSignature) {
             ki = createKeyInfo(keyId, publicKey, x509Certificate);
         } else {
diff --git a/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java b/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java
index d88a34e..3ee5b93 100755
--- a/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java
+++ b/services/src/main/java/org/keycloak/broker/saml/SAMLEndpoint.java
@@ -76,6 +76,9 @@ import java.io.IOException;
 import java.security.PublicKey;
 import java.security.cert.X509Certificate;
 import java.util.List;
+import org.keycloak.rotation.HardcodedKeyLocator;
+import org.keycloak.rotation.KeyLocator;
+import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
 
 /**
  * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
@@ -174,14 +177,17 @@ public class SAMLEndpoint {
         protected abstract void verifySignature(String key, SAMLDocumentHolder documentHolder) throws VerificationException;
         protected abstract SAMLDocumentHolder extractRequestDocument(String samlRequest);
         protected abstract SAMLDocumentHolder extractResponseDocument(String response);
-        protected PublicKey getIDPKey() {
+        
+        protected KeyLocator getIDPKeyLocator() {
+            // TODO !!!!!!!!!!!!!!!! Parse key from IDP's SAML descriptor
+
             X509Certificate certificate = null;
             try {
                 certificate = XMLSignatureUtil.getX509CertificateFromKeyInfoString(config.getSigningCertificate().replaceAll("\\s", ""));
             } catch (ProcessingException e) {
                 throw new RuntimeException(e);
             }
-            return certificate.getPublicKey();
+            return new HardcodedKeyLocator(certificate.getPublicKey());
         }
 
         public Response execute(String samlRequest, String samlResponse, String relayState) {
@@ -265,14 +271,18 @@ public class SAMLEndpoint {
             builder.issuer(issuerURL);
             JaxrsSAML2BindingBuilder binding = new JaxrsSAML2BindingBuilder()
                         .relayState(relayState);
+            boolean postBinding = config.isPostBindingResponse();
             if (config.isWantAuthnRequestsSigned()) {
                 KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
                 binding.signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate())
                         .signatureAlgorithm(provider.getSignatureAlgorithm())
                         .signDocument();
+                if (! postBinding) {    // Only include extension if REDIRECT binding and signing whole SAML protocol message
+                    builder.addExtension(new KeycloakKeySamlExtensionGenerator(keys.getKid()));
+                }
             }
             try {
-                if (config.isPostBindingResponse()) {
+                if (postBinding) {
                     return binding.postBinding(builder.buildDocument()).response(config.getSingleLogoutServiceUrl());
                 } else {
                     return binding.redirectBinding(builder.buildDocument()).response(config.getSingleLogoutServiceUrl());
@@ -418,7 +428,7 @@ public class SAMLEndpoint {
     protected class PostBinding extends Binding {
         @Override
         protected void verifySignature(String key, SAMLDocumentHolder documentHolder) throws VerificationException {
-            SamlProtocolUtils.verifyDocumentSignature(documentHolder.getSamlDocument(), getIDPKey());
+            SamlProtocolUtils.verifyDocumentSignature(documentHolder.getSamlDocument(), getIDPKeyLocator());
         }
 
         @Override
@@ -440,8 +450,8 @@ public class SAMLEndpoint {
     protected class RedirectBinding extends Binding {
         @Override
         protected void verifySignature(String key, SAMLDocumentHolder documentHolder) throws VerificationException {
-            PublicKey publicKey = getIDPKey();
-            SamlProtocolUtils.verifyRedirectSignature(publicKey, uriInfo, key);
+            KeyLocator locator = getIDPKeyLocator();
+            SamlProtocolUtils.verifyRedirectSignature(documentHolder, locator, uriInfo, key);
         }
 
 
diff --git a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java
index 6452c74..e1f8d16 100755
--- a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java
+++ b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProvider.java
@@ -50,8 +50,7 @@ import javax.ws.rs.core.Response;
 import javax.ws.rs.core.UriBuilder;
 import javax.ws.rs.core.UriInfo;
 import java.security.KeyPair;
-import java.security.PrivateKey;
-import java.security.PublicKey;
+import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
 
 /**
  * @author Pedro Igor
@@ -97,6 +96,7 @@ public class SAMLIdentityProvider extends AbstractIdentityProvider<SAMLIdentityP
                     .nameIdPolicy(SAML2NameIDPolicyBuilder.format(nameIDPolicyFormat));
             JaxrsSAML2BindingBuilder binding = new JaxrsSAML2BindingBuilder()
                     .relayState(request.getState());
+            boolean postBinding = getConfig().isPostBindingAuthnRequest();
 
             if (getConfig().isWantAuthnRequestsSigned()) {
                 KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
@@ -106,9 +106,12 @@ public class SAMLIdentityProvider extends AbstractIdentityProvider<SAMLIdentityP
                 binding.signWith(keys.getKid(), keypair);
                 binding.signatureAlgorithm(getSignatureAlgorithm());
                 binding.signDocument();
+                if (! postBinding) {    // Only include extension if REDIRECT binding and signing whole SAML protocol message
+                    authnRequestBuilder.addExtension(new KeycloakKeySamlExtensionGenerator(keys.getKid()));
+                }
             }
 
-            if (getConfig().isPostBindingAuthnRequest()) {
+            if (postBinding) {
                 return binding.postBinding(authnRequestBuilder.toDocument()).request(destinationUrl);
             } else {
                 return binding.redirectBinding(authnRequestBuilder.toDocument()).request(destinationUrl);
diff --git a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProviderFactory.java b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProviderFactory.java
index 714c47e..9116b92 100755
--- a/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProviderFactory.java
+++ b/services/src/main/java/org/keycloak/broker/saml/SAMLIdentityProviderFactory.java
@@ -121,6 +121,7 @@ public class SAMLIdentityProviderFactory extends AbstractIdentityProviderFactory
                             Element x509KeyInfo = DocumentUtil.getChildElement(keyInfo, new QName("dsig", "X509Certificate"));
 
                             if (KeyTypes.SIGNING.equals(keyDescriptorType.getUse())) {
+                                // TODO: CHECK
                                 samlIdentityProviderConfig.setSigningCertificate(x509KeyInfo.getTextContent());
                             } else if (KeyTypes.ENCRYPTION.equals(keyDescriptorType.getUse())) {
                                 samlIdentityProviderConfig.setEncryptionPublicKey(x509KeyInfo.getTextContent());
diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlClient.java b/services/src/main/java/org/keycloak/protocol/saml/SamlClient.java
index 0415a72..336da7b 100755
--- a/services/src/main/java/org/keycloak/protocol/saml/SamlClient.java
+++ b/services/src/main/java/org/keycloak/protocol/saml/SamlClient.java
@@ -23,6 +23,8 @@ import org.keycloak.saml.SignatureAlgorithm;
 import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
 
 /**
+ * Configuration of a SAML-enabled client.
+ *
  * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
  * @version $Revision: 1 $
  */
diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java b/services/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java
index 2b6b9f7..7acb155 100755
--- a/services/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java
+++ b/services/src/main/java/org/keycloak/protocol/saml/SamlProtocol.java
@@ -76,6 +76,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
+import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
 
 /**
  * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
@@ -373,7 +374,15 @@ public class SamlProtocol implements LoginProtocol {
         }
 
         Document samlDocument = null;
+        KeyManager keyManager = session.keys();
+        KeyManager.ActiveKey keys = keyManager.getActiveKey(realm);
+        boolean postBinding = isPostBinding(clientSession);
+
         try {
+            if ((! postBinding) && samlClient.requiresRealmSignature()) {
+                builder.addExtension(new KeycloakKeySamlExtensionGenerator(keys.getKid()));
+            }
+
             ResponseType samlModel = builder.buildModel();
             final AttributeStatementType attributeStatement = populateAttributeStatements(attributeStatementMappers, session, userSession, clientSession);
             populateRoles(roleListMapper, session, userSession, clientSession, attributeStatement);
@@ -394,9 +403,6 @@ public class SamlProtocol implements LoginProtocol {
         JaxrsSAML2BindingBuilder bindingBuilder = new JaxrsSAML2BindingBuilder();
         bindingBuilder.relayState(relayState);
 
-        KeyManager keyManager = session.keys();
-        KeyManager.ActiveKey keys = keyManager.getActiveKey(realm);
-
         if (samlClient.requiresRealmSignature()) {
             String canonicalization = samlClient.getCanonicalizationMethod();
             if (canonicalization != null) {
@@ -496,12 +502,17 @@ public class SamlProtocol implements LoginProtocol {
             if (isLogoutPostBindingForClient(clientSession)) {
                 String bindingUri = getLogoutServiceUrl(uriInfo, client, SAML_POST_BINDING);
                 SAML2LogoutRequestBuilder logoutBuilder = createLogoutRequest(bindingUri, clientSession, client);
+                // This is POST binding, hence KeyID is included in dsig:KeyInfo/dsig:KeyName, no need to add <samlp:Extensions> element
                 JaxrsSAML2BindingBuilder binding = createBindingBuilder(samlClient);
                 return binding.postBinding(logoutBuilder.buildDocument()).request(bindingUri);
             } else {
                 logger.debug("frontchannel redirect binding");
                 String bindingUri = getLogoutServiceUrl(uriInfo, client, SAML_REDIRECT_BINDING);
                 SAML2LogoutRequestBuilder logoutBuilder = createLogoutRequest(bindingUri, clientSession, client);
+                if (samlClient.requiresRealmSignature()) {
+                    KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
+                    logoutBuilder.addExtension(new KeycloakKeySamlExtensionGenerator(keys.getKid()));
+                }
                 JaxrsSAML2BindingBuilder binding = createBindingBuilder(samlClient);
                 return binding.redirectBinding(logoutBuilder.buildDocument()).request(bindingUri);
             }
@@ -534,6 +545,7 @@ public class SamlProtocol implements LoginProtocol {
         JaxrsSAML2BindingBuilder binding = new JaxrsSAML2BindingBuilder();
         binding.relayState(logoutRelayState);
         String signingAlgorithm = userSession.getNote(SAML_LOGOUT_SIGNATURE_ALGORITHM);
+        boolean postBinding = isLogoutPostBindingForInitiator(userSession);
         if (signingAlgorithm != null) {
             SignatureAlgorithm algorithm = SignatureAlgorithm.valueOf(signingAlgorithm);
             String canonicalization = userSession.getNote(SAML_LOGOUT_CANONICALIZATION);
@@ -542,6 +554,9 @@ public class SamlProtocol implements LoginProtocol {
             }
             KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
             binding.signatureAlgorithm(algorithm).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
+            if (! postBinding) {    // Only include extension if REDIRECT binding and signing whole SAML protocol message
+                builder.addExtension(new KeycloakKeySamlExtensionGenerator(keys.getKid()));
+            }
         }
 
         try {
@@ -577,6 +592,7 @@ public class SamlProtocol implements LoginProtocol {
         String logoutRequestString = null;
         try {
             JaxrsSAML2BindingBuilder binding = createBindingBuilder(samlClient);
+            // This is POST binding, hence KeyID is included in dsig:KeyInfo/dsig:KeyName, no need to add <samlp:Extensions> element
             logoutRequestString = binding.postBinding(logoutBuilder.buildDocument()).encoded();
         } catch (Exception e) {
             logger.warn("failed to send saml logout", e);
diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolUtils.java b/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolUtils.java
index e1a7c98..026a54a 100755
--- a/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolUtils.java
+++ b/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolUtils.java
@@ -17,6 +17,7 @@
 
 package org.keycloak.protocol.saml;
 
+import java.security.Key;
 import org.keycloak.common.VerificationException;
 import org.keycloak.common.util.PemUtils;
 import org.keycloak.models.ClientModel;
@@ -33,6 +34,15 @@ import javax.ws.rs.core.UriInfo;
 import java.security.PublicKey;
 import java.security.Signature;
 import java.security.cert.Certificate;
+import org.keycloak.dom.saml.v2.SAML2Object;
+import org.keycloak.dom.saml.v2.protocol.ExtensionsType;
+import org.keycloak.dom.saml.v2.protocol.RequestAbstractType;
+import org.keycloak.dom.saml.v2.protocol.StatusResponseType;
+import org.keycloak.rotation.HardcodedKeyLocator;
+import org.keycloak.rotation.KeyLocator;
+import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder;
+import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
+import org.w3c.dom.Element;
 
 /**
  * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
@@ -40,20 +50,36 @@ import java.security.cert.Certificate;
  */
 public class SamlProtocolUtils {
 
-
+    /**
+     * Verifies a signature of the given SAML document using settings for the given client.
+     * Throws an exception if the client signature is expected to be present as per the client
+     * settings and it is invalid, otherwise returns back to the caller.
+     *
+     * @param client
+     * @param document
+     * @throws VerificationException
+     */
     public static void verifyDocumentSignature(ClientModel client, Document document) throws VerificationException {
         SamlClient samlClient = new SamlClient(client);
         if (!samlClient.requiresClientSignature()) {
             return;
         }
         PublicKey publicKey = getSignatureValidationKey(client);
-        verifyDocumentSignature(document, publicKey);
+        verifyDocumentSignature(document, new HardcodedKeyLocator(publicKey));
     }
 
-    public static void verifyDocumentSignature(Document document, PublicKey publicKey) throws VerificationException {
+    /**
+     * Verifies a signature of the given SAML document using keys obtained from the given key locator.
+     * Throws an exception if the client signature is invalid, otherwise returns back to the caller.
+     *
+     * @param document
+     * @param keyLocator
+     * @throws VerificationException
+     */
+    public static void verifyDocumentSignature(Document document, KeyLocator keyLocator) throws VerificationException {
         SAML2Signature saml2Signature = new SAML2Signature();
         try {
-            if (!saml2Signature.validate(document, publicKey)) {
+            if (!saml2Signature.validate(document, keyLocator)) {
                 throw new VerificationException("Invalid signature on document");
             }
         } catch (ProcessingException e) {
@@ -61,10 +87,22 @@ public class SamlProtocolUtils {
         }
     }
 
+    /**
+     * Returns public part of SAML signing key from the client settings.
+     * @param client
+     * @return Public key for signature validation.
+     * @throws VerificationException
+     */
     public static PublicKey getSignatureValidationKey(ClientModel client) throws VerificationException {
         return getPublicKey(new SamlClient(client).getClientSigningCertificate());
     }
 
+    /**
+     * Returns public part of SAML encryption key from the client settings.
+     * @param client
+     * @return Public key for encryption.
+     * @throws VerificationException
+     */
     public static PublicKey getEncryptionValidationKey(ClientModel client) throws VerificationException {
         return getPublicKey(client, SamlConfigAttributes.SAML_ENCRYPTION_CERTIFICATE_ATTRIBUTE);
     }
@@ -85,7 +123,7 @@ public class SamlProtocolUtils {
         return cert.getPublicKey();
     }
 
-    public static void verifyRedirectSignature(PublicKey publicKey, UriInfo uriInformation, String paramKey) throws VerificationException {
+    public static void verifyRedirectSignature(SAMLDocumentHolder documentHolder, KeyLocator locator, UriInfo uriInformation, String paramKey) throws VerificationException {
         MultivaluedMap<String, String> encodedParams = uriInformation.getQueryParameters(false);
         String request = encodedParams.getFirst(paramKey);
         String algorithm = encodedParams.getFirst(GeneralConstants.SAML_SIG_ALG_REQUEST_KEY);
@@ -96,10 +134,11 @@ public class SamlProtocolUtils {
         if (algorithm == null) throw new VerificationException("SigAlg was null");
         if (signature == null) throw new VerificationException("Signature was null");
 
+        String keyId = getMessageSigningKeyId(documentHolder.getSamlObject());
+
         // Shibboleth doesn't sign the document for redirect binding.
         // todo maybe a flag?
 
-
         UriBuilder builder = UriBuilder.fromPath("/")
                 .queryParam(paramKey, request);
         if (encodedParams.containsKey(GeneralConstants.RELAY_STATE)) {
@@ -113,8 +152,13 @@ public class SamlProtocolUtils {
 
             SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.getFromXmlMethod(decodedAlgorithm);
             Signature validator = signatureAlgorithm.createSignature(); // todo plugin signature alg
-            validator.initVerify(publicKey);
-            validator.update(rawQuery.getBytes("UTF-8"));
+            Key key = locator.getKey(keyId);
+            if (key instanceof PublicKey) {
+                validator.initVerify((PublicKey) key);
+                validator.update(rawQuery.getBytes("UTF-8"));
+            } else {
+                throw new VerificationException("Invalid key locator for signature verification");
+            }
             if (!validator.verify(decodedSignature)) {
                 throw new VerificationException("Invalid query param signature");
             }
@@ -123,5 +167,32 @@ public class SamlProtocolUtils {
         }
     }
 
+    private static String getMessageSigningKeyId(SAML2Object doc) {
+        final ExtensionsType extensions;
+        if (doc instanceof RequestAbstractType) {
+            extensions = ((RequestAbstractType) doc).getExtensions();
+        } else if (doc instanceof StatusResponseType) {
+            extensions = ((StatusResponseType) doc).getExtensions();
+        } else {
+            return null;
+        }
+
+        if (extensions == null) {
+            return null;
+        }
+
+        for (Object ext : extensions.getAny()) {
+            if (! (ext instanceof Element)) {
+                continue;
+            }
+
+            String res = KeycloakKeySamlExtensionGenerator.getMessageSigningKeyIdFromElement((Element) ext);
 
+            if (res != null) {
+                return res;
+            }
+        }
+
+        return null;
+    }
 }
diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlService.java b/services/src/main/java/org/keycloak/protocol/saml/SamlService.java
index 40f615e..b3994c1 100755
--- a/services/src/main/java/org/keycloak/protocol/saml/SamlService.java
+++ b/services/src/main/java/org/keycloak/protocol/saml/SamlService.java
@@ -408,14 +408,17 @@ public class SamlService extends AuthorizationEndpointBase {
             builder.destination(logoutBindingUri);
             builder.issuer(RealmsResource.realmBaseUrl(uriInfo).build(realm.getName()).toString());
             JaxrsSAML2BindingBuilder binding = new JaxrsSAML2BindingBuilder().relayState(logoutRelayState);
+            boolean postBinding = SamlProtocol.SAML_POST_BINDING.equals(logoutBinding);
             if (samlClient.requiresRealmSignature()) {
                 SignatureAlgorithm algorithm = samlClient.getSignatureAlgorithm();
                 KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
                 binding.signatureAlgorithm(algorithm).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
-
+                if (! postBinding) {    // Only include extension if REDIRECT binding and signing whole SAML protocol message
+                    builder.addExtension(new KeycloakKeySamlExtensionGenerator(keys.getKid()));
+                }
             }
             try {
-                if (SamlProtocol.SAML_POST_BINDING.equals(logoutBinding)) {
+                if (postBinding) {
                     return binding.postBinding(builder.buildDocument()).response(logoutBindingUri);
                 } else {
                     return binding.redirectBinding(builder.buildDocument()).response(logoutBindingUri);
@@ -477,7 +480,8 @@ public class SamlService extends AuthorizationEndpointBase {
                 return;
             }
             PublicKey publicKey = SamlProtocolUtils.getSignatureValidationKey(client);
-            SamlProtocolUtils.verifyRedirectSignature(publicKey, uriInfo, GeneralConstants.SAML_REQUEST_KEY);
+            KeyLocator clientKeyLocator = new HardcodedKeyLocator(publicKey);
+            SamlProtocolUtils.verifyRedirectSignature(documentHolder, clientKeyLocator, uriInfo, GeneralConstants.SAML_REQUEST_KEY);
         }
 
         @Override