keycloak-aplcache

Details

diff --git a/services/src/main/java/org/keycloak/protocol/saml/profile/ecp/SamlEcpProfileService.java b/services/src/main/java/org/keycloak/protocol/saml/profile/ecp/SamlEcpProfileService.java
index b90a165..c0be2ba 100755
--- a/services/src/main/java/org/keycloak/protocol/saml/profile/ecp/SamlEcpProfileService.java
+++ b/services/src/main/java/org/keycloak/protocol/saml/profile/ecp/SamlEcpProfileService.java
@@ -43,6 +43,7 @@ import javax.xml.soap.SOAPException;
 import javax.xml.soap.SOAPHeaderElement;
 import java.io.IOException;
 import java.io.InputStream;
+import java.util.Map;
 
 /**
  * @author <a href="mailto:psilva@redhat.com">Pedro Igor</a>
@@ -53,8 +54,8 @@ public class SamlEcpProfileService extends SamlService {
     private static final String NS_PREFIX_SAML_PROTOCOL = "samlp";
     private static final String NS_PREFIX_SAML_ASSERTION = "saml";
 
-    public SamlEcpProfileService(RealmModel realm, EventBuilder event) {
-        super(realm, event);
+    public SamlEcpProfileService(RealmModel realm, EventBuilder event, Map<String, Integer> knownPorts, Map<Integer, String> knownProtocols) {
+        super(realm, event, knownPorts, knownProtocols);
     }
 
     public Response authenticate(InputStream inputStream) {
diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolFactory.java b/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolFactory.java
index 21ccc81..87d6615 100755
--- a/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolFactory.java
+++ b/services/src/main/java/org/keycloak/protocol/saml/SamlProtocolFactory.java
@@ -39,7 +39,11 @@ import org.keycloak.saml.processing.core.saml.v2.constants.X500SAMLProfileConsta
 
 import javax.xml.crypto.dsig.CanonicalizationMethod;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 
 /**
  * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
@@ -47,9 +51,26 @@ import java.util.List;
  */
 public class SamlProtocolFactory extends AbstractLoginProtocolFactory {
 
+    private static final Pattern PROTOCOL_MAP_PATTERN = Pattern.compile("\\s*([a-zA-Z][a-zA-Z\\d+-.]*)\\s*=\\s*(\\d+)\\s*");
+    private static final String[] DEFAULT_PROTOCOL_TO_PORT_MAP = new String[] { "http=80", "https=443" };
+
+    private final Map<Integer, String> knownPorts = new HashMap<>();
+    private final Map<String, Integer> knownProtocols = new HashMap<>();
+
+    private void addToProtocolPortMaps(String protocolMapping) {
+        Matcher m = PROTOCOL_MAP_PATTERN.matcher(protocolMapping);
+        if (m.matches()) {
+            Integer port = Integer.valueOf(m.group(2));
+            String proto = m.group(1);
+
+            knownPorts.put(port, proto);
+            knownProtocols.put(proto, port);
+        }
+    }
+
     @Override
     public Object createProtocolEndpoint(RealmModel realm, EventBuilder event) {
-        return new SamlService(realm, event);
+        return new SamlService(realm, event, knownProtocols, knownPorts);
     }
 
     @Override
@@ -61,6 +82,15 @@ public class SamlProtocolFactory extends AbstractLoginProtocolFactory {
     public void init(Config.Scope config) {
         //PicketLinkCoreSTS sts = PicketLinkCoreSTS.instance();
         //sts.installDefaultConfiguration();
+
+        String[] protocolMappings = config.getArray("knownProtocols");
+        if (protocolMappings == null) {
+            protocolMappings = DEFAULT_PROTOCOL_TO_PORT_MAP;
+        }
+
+        for (String protocolMapping : protocolMappings) {
+            addToProtocolPortMaps(protocolMapping);
+        }
     }
 
     @Override
diff --git a/services/src/main/java/org/keycloak/protocol/saml/SamlService.java b/services/src/main/java/org/keycloak/protocol/saml/SamlService.java
index 9a6790b..e0ac524 100755
--- a/services/src/main/java/org/keycloak/protocol/saml/SamlService.java
+++ b/services/src/main/java/org/keycloak/protocol/saml/SamlService.java
@@ -67,7 +67,6 @@ import javax.ws.rs.Path;
 import javax.ws.rs.PathParam;
 import javax.ws.rs.Produces;
 import javax.ws.rs.QueryParam;
-import javax.ws.rs.core.Context;
 import javax.ws.rs.core.MediaType;
 import javax.ws.rs.core.Response;
 import javax.ws.rs.core.UriInfo;
@@ -87,6 +86,7 @@ import org.keycloak.rotation.KeyLocator;
 import org.keycloak.saml.SPMetadataDescriptor;
 import org.keycloak.saml.processing.core.util.KeycloakKeySamlExtensionGenerator;
 import org.keycloak.sessions.AuthenticationSessionModel;
+import java.util.Map;
 
 /**
  * Resource class for the saml connect token service
@@ -98,8 +98,13 @@ public class SamlService extends AuthorizationEndpointBase {
 
     protected static final Logger logger = Logger.getLogger(SamlService.class);
 
-    public SamlService(RealmModel realm, EventBuilder event) {
+    private final Map<String, Integer> knownPorts;
+    private final Map<Integer, String> knownProtocols;
+
+    public SamlService(RealmModel realm, EventBuilder event, Map<String, Integer> knownPorts, Map<Integer, String> knownProtocols) {
         super(realm, event);
+        this.knownPorts = knownPorts;
+        this.knownProtocols = knownProtocols;
     }
 
     public abstract class BindingProtocol {
@@ -239,7 +244,7 @@ public class SamlService extends AuthorizationEndpointBase {
         protected Response loginRequest(String relayState, AuthnRequestType requestAbstractType, ClientModel client) {
             SamlClient samlClient = new SamlClient(client);
             // validate destination
-            if (requestAbstractType.getDestination() != null && !uriInfo.getAbsolutePath().equals(requestAbstractType.getDestination())) {
+            if (! isValidDestination(requestAbstractType.getDestination())) {
                 event.detail(Details.REASON, "invalid_destination");
                 event.error(Errors.INVALID_SAML_AUTHN_REQUEST);
                 return ErrorPage.error(session, Messages.INVALID_REQUEST);
@@ -341,7 +346,7 @@ public class SamlService extends AuthorizationEndpointBase {
         protected Response logoutRequest(LogoutRequestType logoutRequest, ClientModel client, String relayState) {
             SamlClient samlClient = new SamlClient(client);
             // validate destination
-            if (logoutRequest.getDestination() != null && !uriInfo.getAbsolutePath().equals(logoutRequest.getDestination())) {
+            if (! isValidDestination(logoutRequest.getDestination())) {
                 event.detail(Details.REASON, "invalid_destination");
                 event.error(Errors.INVALID_SAML_LOGOUT_REQUEST);
                 return ErrorPage.error(session, Messages.INVALID_REQUEST);
@@ -683,11 +688,35 @@ public class SamlService extends AuthorizationEndpointBase {
     @NoCache
     @Consumes({"application/soap+xml",MediaType.TEXT_XML})
     public Response soapBinding(InputStream inputStream) {
-        SamlEcpProfileService bindingService = new SamlEcpProfileService(realm, event);
+        SamlEcpProfileService bindingService = new SamlEcpProfileService(realm, event, knownPorts, knownProtocols);
 
         ResteasyProviderFactory.getInstance().injectProperties(bindingService);
 
         return bindingService.authenticate(inputStream);
     }
 
+    private boolean isValidDestination(URI destination) {
+        if (destination == null) {
+            return false;
+        }
+
+        URI expected = uriInfo.getAbsolutePath();
+
+        if (Objects.equals(expected, destination)) {
+            return true;
+        }
+
+        Integer portByScheme = knownPorts.get(expected.getScheme());
+        if (expected.getPort() < 0 && portByScheme != null) {
+            return Objects.equals(uriInfo.getRequestUriBuilder().port(portByScheme).build(), destination);
+        }
+
+        String protocolByPort = knownProtocols.get(expected.getPort());
+        if (expected.getPort() >= 0 && Objects.equals(protocolByPort, expected.getScheme())) {
+            return Objects.equals(uriInfo.getRequestUriBuilder().port(-1).build(), destination);
+        }
+
+        return false;
+    }
+
 }
diff --git a/services/src/main/java/org/keycloak/services/resources/IdentityBrokerService.java b/services/src/main/java/org/keycloak/services/resources/IdentityBrokerService.java
index 530fce2..eed2858 100755
--- a/services/src/main/java/org/keycloak/services/resources/IdentityBrokerService.java
+++ b/services/src/main/java/org/keycloak/services/resources/IdentityBrokerService.java
@@ -57,6 +57,8 @@ import org.keycloak.models.RoleModel;
 import org.keycloak.models.UserModel;
 import org.keycloak.models.UserSessionModel;
 import org.keycloak.models.utils.FormMessage;
+import org.keycloak.protocol.LoginProtocol;
+import org.keycloak.protocol.LoginProtocolFactory;
 import org.keycloak.protocol.oidc.OIDCLoginProtocol;
 import org.keycloak.protocol.oidc.TokenManager;
 import org.keycloak.protocol.oidc.utils.RedirectUtils;
@@ -1027,7 +1029,8 @@ public class IdentityBrokerService implements IdentityProvider.AuthenticationCal
             return ParsedCodeContext.response(redirectToErrorPage(Messages.CLIENT_NOT_FOUND));
         }
 
-        SamlService samlService = new SamlService(realmModel, event);
+        LoginProtocolFactory factory = (LoginProtocolFactory) session.getKeycloakSessionFactory().getProviderFactory(LoginProtocol.class, SamlProtocol.LOGIN_PROTOCOL);
+        SamlService samlService = (SamlService) factory.createProtocolEndpoint(realmModel, event);
         ResteasyProviderFactory.getInstance().injectProperties(samlService);
         AuthenticationSessionModel authSession = samlService.getOrCreateLoginSessionForIdpInitiatedSso(session, realmModel, oClient.get(), null);
 
diff --git a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/saml/BasicSamlTest.java b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/saml/BasicSamlTest.java
index 0bd0793..78cf93d 100644
--- a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/saml/BasicSamlTest.java
+++ b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/saml/BasicSamlTest.java
@@ -2,19 +2,32 @@ package org.keycloak.testsuite.saml;
 
 import org.junit.Test;
 import org.keycloak.dom.saml.v2.protocol.AuthnRequestType;
+import org.keycloak.protocol.saml.SamlProtocol;
 import org.keycloak.saml.common.exceptions.ConfigurationException;
 import org.keycloak.saml.common.exceptions.ParsingException;
 import org.keycloak.saml.common.exceptions.ProcessingException;
 import org.keycloak.saml.processing.api.saml.v2.request.SAML2Request;
 import org.keycloak.saml.processing.core.saml.v2.common.SAMLDocumentHolder;
+import org.keycloak.services.resources.RealmsResource;
 import org.keycloak.testsuite.util.SamlClient;
+import org.keycloak.testsuite.util.SamlClient.Binding;
+import org.keycloak.testsuite.util.SamlClient.RedirectStrategyWithSwitchableFollowRedirect;
+import javax.ws.rs.core.Response;
+import javax.ws.rs.core.UriBuilder;
+import org.apache.http.client.methods.CloseableHttpResponse;
+import org.apache.http.client.methods.HttpUriRequest;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.HttpClientBuilder;
+import org.apache.http.util.EntityUtils;
+import org.hamcrest.Matcher;
 import org.w3c.dom.Document;
 
-import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.Matchers.containsString;
 import static org.junit.Assert.assertThat;
 import static org.keycloak.testsuite.util.IOUtil.documentToString;
 import static org.keycloak.testsuite.util.IOUtil.setDocElementAttributeValue;
+import static org.keycloak.testsuite.util.Matchers.statusCodeIsHC;
 import static org.keycloak.testsuite.util.SamlClient.login;
 
 /**
@@ -35,4 +48,34 @@ public class BasicSamlTest extends AbstractSamlTest {
 
         assertThat(documentToString(document.getSamlDocument()), not(containsString("InResponseTo=\"" + System.getProperty("java.version") + "\"")));
     }
+
+    @Test
+    public void testNoPortInDestination() throws Exception {
+        // note that this test relies on settings of the login-protocol.saml.knownProtocols configuration option
+        testWithOverriddenPort(-1, Response.Status.OK, containsString("login"));
+    }
+
+    @Test
+    public void testExplicitPortInDestination() throws Exception {
+        testWithOverriddenPort(Integer.valueOf(System.getProperty("auth.server.http.port")), Response.Status.OK, containsString("login"));
+    }
+
+    @Test
+    public void testWrongPortInDestination() throws Exception {
+        testWithOverriddenPort(123, Response.Status.INTERNAL_SERVER_ERROR, containsString("Invalid Request"));
+    }
+
+    private void testWithOverriddenPort(int port, Response.Status expectedHttpCode, Matcher<String> pageTextMatcher) throws Exception {
+        AuthnRequestType loginRep = SamlClient.createLoginRequestDocument(SAML_CLIENT_ID_SALES_POST, SAML_ASSERTION_CONSUMER_URL_SALES_POST,
+          RealmsResource.protocolUrl(UriBuilder.fromUri(getAuthServerRoot()).port(port)).build(REALM_NAME, SamlProtocol.LOGIN_PROTOCOL));
+
+        Document doc = SAML2Request.convert(loginRep);
+        HttpUriRequest post = Binding.POST.createSamlUnsignedRequest(getAuthServerSamlEndpoint(REALM_NAME), null, doc);
+
+        try (CloseableHttpClient client = HttpClientBuilder.create().setRedirectStrategy(new RedirectStrategyWithSwitchableFollowRedirect()).build();
+          CloseableHttpResponse response = client.execute(post)) {
+            assertThat(response, statusCodeIsHC(expectedHttpCode));
+            assertThat(EntityUtils.toString(response.getEntity(), "UTF-8"), pageTextMatcher);
+        }
+    }
 }
diff --git a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/util/SamlClient.java b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/util/SamlClient.java
index 5d5675f..8c20b26 100644
--- a/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/util/SamlClient.java
+++ b/testsuite/integration-arquillian/tests/base/src/test/java/org/keycloak/testsuite/util/SamlClient.java
@@ -398,6 +398,14 @@ public class SamlClient {
         this.samlEndpoint = samlEndpoint;
     }
 
+    public HttpClientContext getContext() {
+        return context;
+    }
+
+    public URI getSamlEndpoint() {
+        return samlEndpoint;
+    }
+
     /**
      * Send request for login form and then login using user param. Check whether client requires consent and handle consent page.
      *
@@ -415,21 +423,22 @@ public class SamlClient {
                                     Document samlRequest, String relayState, Binding requestBinding, Binding expectedResponseBinding, boolean consentRequired, boolean consent) {
         return getSamlResponse(expectedResponseBinding, (client, context, strategy) -> {
             HttpUriRequest post = requestBinding.createSamlUnsignedRequest(samlEndpoint, relayState, samlRequest);
-            CloseableHttpResponse response = client.execute(post, context);
-
-            assertThat(response, statusCodeIsHC(Response.Status.OK));
-            String loginPageText = EntityUtils.toString(response.getEntity(), "UTF-8");
-            response.close();
+            String loginPageText;
 
-            assertThat(loginPageText, containsString("login"));
+            try (CloseableHttpResponse response = client.execute(post, context)) {
+                assertThat(response, statusCodeIsHC(Response.Status.OK));
+                loginPageText = EntityUtils.toString(response.getEntity(), "UTF-8");
+                assertThat(loginPageText, containsString("login"));
+            }
 
             HttpUriRequest loginRequest = handleLoginPage(user, loginPageText);
 
             if (consentRequired) {
                 // Client requires consent
-                response = client.execute(loginRequest, context);
-                String consentPageText = EntityUtils.toString(response.getEntity(), "UTF-8");
-                loginRequest = handleConsentPage(consentPageText, consent);
+                try (CloseableHttpResponse response = client.execute(loginRequest, context)) {
+                    String consentPageText = EntityUtils.toString(response.getEntity(), "UTF-8");
+                    loginRequest = handleConsentPage(consentPageText, consent);
+                }
             }
 
             strategy.setRedirectable(false);
diff --git a/testsuite/integration-arquillian/tests/base/src/test/resources/META-INF/keycloak-server.json b/testsuite/integration-arquillian/tests/base/src/test/resources/META-INF/keycloak-server.json
index d038877..9d801e5 100755
--- a/testsuite/integration-arquillian/tests/base/src/test/resources/META-INF/keycloak-server.json
+++ b/testsuite/integration-arquillian/tests/base/src/test/resources/META-INF/keycloak-server.json
@@ -133,5 +133,14 @@
             "enabled": true
         }
 
+    },
+
+    "login-protocol": {
+        "saml": {
+            "knownProtocols": [
+                "http=${auth.server.http.port}",
+                "https=${auth.server.https.port}"
+            ]
+        }
     }
 }