PairwiseSubMapperUtils.java

166 lines | 6.832 kB Blame History Raw Download
package org.keycloak.protocol.oidc.utils;

import org.jboss.logging.Logger;
import org.keycloak.protocol.oidc.mappers.AbstractPairwiseSubMapper;
import org.keycloak.representations.idm.ClientRepresentation;
import org.keycloak.representations.idm.ProtocolMapperRepresentation;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class PairwiseSubMapperUtils {
    private static final Logger logger = Logger.getLogger(PairwiseSubMapperUtils.class);

    /**
     * Returns a set of valid redirect URIs from the root url and redirect URIs registered on a client.
     *
     * @param clientRootUrl
     * @param clientRedirectUris
     * @return
     */
    public static Set<String> resolveValidRedirectUris(String clientRootUrl, Collection<String> clientRedirectUris) {
        if (clientRedirectUris == null) {
            return Collections.emptySet();
        }

        Set<String> validRedirects = new HashSet<String>();
        for (String redirectUri : clientRedirectUris) {
            if (redirectUri.startsWith("/")) {
                redirectUri = relativeToAbsoluteURI(clientRootUrl, redirectUri);
                logger.debugv("replacing relative valid redirect with: {0}", redirectUri);
            }
            if (redirectUri != null) {
                validRedirects.add(redirectUri);
            }
        }
        return validRedirects.stream()
                .filter(r -> r != null && !r.trim().isEmpty())
                .collect(Collectors.toSet());
    }

    /**
     * Tries to resolve a valid sector identifier from a sector identifier URI.
     *
     * @param sectorIdentifierUri
     * @return a sector identifier iff. the sector identifier URI is a valid URI, contains a valid scheme and contains a valid host component.
     */
    public static String resolveValidSectorIdentifier(String sectorIdentifierUri) {
        URI uri;
        try {
            uri = new URI(sectorIdentifierUri);
        } catch (URISyntaxException e) {
            logger.debug("Invalid sector identifier URI", e);
            return null;
        }

        if (uri.getScheme() == null) {
            logger.debugv("Invalid sector identifier URI: {0}", sectorIdentifierUri);
            return null;
        }

        /*if (!uri.getScheme().equalsIgnoreCase("https")) {
            logger.debugv("The sector identifier URI scheme must be HTTPS. Was '{0}'", uri.getScheme());
            return null;
        }*/

        if (uri.getHost() == null) {
            logger.debug("The sector identifier URI must specify a host");
            return null;
        }

        return uri.getHost();
    }

    /**
     * Tries to resolve a valid sector identifier from the redirect URIs registered on a client.
     *
     * @param clientRootUrl      Root url registered on the client.
     * @param clientRedirectUris Redirect URIs registered on the client.
     * @return a sector identifier iff. all the registered redirect URIs are located at the same host, otherwise {@code null}.
     */
    public static String resolveValidSectorIdentifier(String clientRootUrl, Set<String> clientRedirectUris) {
        Set<String> hosts = new HashSet<>();
        for (String redirectUri : resolveValidRedirectUris(clientRootUrl, clientRedirectUris)) {
            try {
                URI uri = new URI(redirectUri);
                hosts.add(uri.getHost());
            } catch (URISyntaxException e) {
                logger.debugv("client redirect uris contained an invalid uri: {0}", redirectUri);
            }
        }
        if (hosts.isEmpty()) {
            logger.debug("could not infer any valid sector_identifiers from client redirect uris");
            return null;
        }
        if (hosts.size() > 1) {
            logger.debug("the client redirect uris contained multiple hosts");
            return null;
        }
        return hosts.iterator().next();
    }

    /**
     * Checks if the the registered client redirect URIs matches the set of redirect URIs from the sector identifier URI.
     *
     * @param clientRootUrl      root url registered on the client.
     * @param clientRedirectUris redirect URIs registered on the client.
     * @param sectorRedirects    value of the sector identifier URI.
     * @return {@code true} iff. the all the redirect URIs can be described by the {@code sectorRedirects}, i.e if the registered redirect URIs is a subset of the {@code sectorRedirects}, otherwise {@code false}.
     */
    public static boolean matchesRedirects(String clientRootUrl, Set<String> clientRedirectUris, Set<String> sectorRedirects) {
        Set<String> validRedirects = resolveValidRedirectUris(clientRootUrl, clientRedirectUris);
        for (String redirect : validRedirects) {
            if (!matchesRedirect(sectorRedirects, redirect)) return false;
        }
        return true;
    }

    private static boolean matchesRedirect(Set<String> validRedirects, String redirect) {
        for (String validRedirect : validRedirects) {
            if (validRedirect.endsWith("*") && !validRedirect.contains("?")) {
                // strip off the query component - we don't check them when wildcards are effective
                String r = redirect.contains("?") ? redirect.substring(0, redirect.indexOf("?")) : redirect;
                // strip off *
                int length = validRedirect.length() - 1;
                validRedirect = validRedirect.substring(0, length);
                if (r.startsWith(validRedirect)) return true;
                // strip off trailing '/'
                if (length - 1 > 0 && validRedirect.charAt(length - 1) == '/') length--;
                validRedirect = validRedirect.substring(0, length);
                if (validRedirect.equals(r)) return true;
            } else if (validRedirect.equals(redirect)) return true;
        }
        return false;
    }

    private static String relativeToAbsoluteURI(String rootUrl, String relative) {
        if (rootUrl == null || rootUrl.isEmpty()) {
            return null;
        }
        relative = rootUrl + relative;
        return relative;
    }

    public static List<ProtocolMapperRepresentation> getPairwiseSubMappers(ClientRepresentation client) {
        List<ProtocolMapperRepresentation> pairwiseMappers = new LinkedList<>();
        List<ProtocolMapperRepresentation> mappers = client.getProtocolMappers();

        if (mappers != null) {
            client.getProtocolMappers().stream().filter((ProtocolMapperRepresentation mapping) -> {
                return mapping.getProtocolMapper().endsWith(AbstractPairwiseSubMapper.PROVIDER_ID_SUFFIX);
            }).forEach((ProtocolMapperRepresentation mapping) -> {
                pairwiseMappers.add(mapping);
            });
        }

        return pairwiseMappers;
    }
}