StaxWriterUtil.java

167 lines | 7.807 kB Blame History Raw Download
/*
 * Copyright 2016 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.keycloak.saml.processing.core.saml.v2.util;


import org.keycloak.dom.xmlsec.w3.xmldsig.DSAKeyValueType;
import org.keycloak.dom.xmlsec.w3.xmldsig.KeyInfoType;
import org.keycloak.dom.xmlsec.w3.xmldsig.KeyValueType;
import org.keycloak.dom.xmlsec.w3.xmldsig.RSAKeyValueType;
import org.keycloak.dom.xmlsec.w3.xmldsig.X509CertificateType;
import org.keycloak.dom.xmlsec.w3.xmldsig.X509DataType;
import org.keycloak.saml.common.ErrorCodes;
import org.keycloak.saml.common.PicketLinkLogger;
import org.keycloak.saml.common.PicketLinkLoggerFactory;
import org.keycloak.saml.common.constants.GeneralConstants;
import org.keycloak.saml.common.constants.WSTrustConstants;
import org.keycloak.saml.common.exceptions.ProcessingException;
import org.keycloak.saml.common.util.StaxUtil;

import org.w3c.dom.Element;

import javax.xml.stream.XMLStreamWriter;

/**
 * Utility methods for stax writing
 *
 * @author anil saldhana
 * @since Jan 28, 2013
 */
public class StaxWriterUtil {

    private static final PicketLinkLogger logger = PicketLinkLoggerFactory.getLogger();

    /**
     * Write the {@link org.keycloak.dom.xmlsec.w3.xmldsig.KeyInfoType}
     *
     * @param writer
     * @param keyInfo
     *
     * @throws org.keycloak.saml.common.exceptions.ProcessingException
     */
    public static void writeKeyInfo(XMLStreamWriter writer, KeyInfoType keyInfo) throws ProcessingException {
        if (keyInfo.getContent() == null || keyInfo.getContent().size() == 0)
            throw logger.writerInvalidKeyInfoNullContentError();
        StaxUtil.writeStartElement(writer, WSTrustConstants.XMLDSig.DSIG_PREFIX, WSTrustConstants.XMLDSig.KEYINFO,
                WSTrustConstants.XMLDSig.DSIG_NS);
        StaxUtil.writeNameSpace(writer, WSTrustConstants.XMLDSig.DSIG_PREFIX, WSTrustConstants.XMLDSig.DSIG_NS);
        // write the keyInfo content.
        Object content = keyInfo.getContent().get(0);
        if (content instanceof Element) {
            Element element = (Element) keyInfo.getContent().get(0);
            StaxUtil.writeDOMNode(writer, element);
        } else if (content instanceof X509DataType) {
            X509DataType type = (X509DataType) content;
            if (type.getDataObjects().size() == 0)
                throw logger.writerNullValueError("X509Data");
            StaxUtil.writeStartElement(writer, WSTrustConstants.XMLDSig.DSIG_PREFIX, WSTrustConstants.XMLDSig.X509DATA,
                    WSTrustConstants.XMLDSig.DSIG_NS);
            Object obj = type.getDataObjects().get(0);
            if (obj instanceof Element) {
                Element element = (Element) obj;
                StaxUtil.writeDOMElement(writer, element);
            } else if (obj instanceof X509CertificateType) {
                X509CertificateType cert = (X509CertificateType) obj;
                StaxUtil.writeStartElement(writer, WSTrustConstants.XMLDSig.DSIG_PREFIX, WSTrustConstants.XMLDSig.X509CERT,
                        WSTrustConstants.XMLDSig.DSIG_NS);
                StaxUtil.writeCharacters(writer, new String(cert.getEncodedCertificate(), GeneralConstants.SAML_CHARSET));
                StaxUtil.writeEndElement(writer);
            }
            StaxUtil.writeEndElement(writer);
        } else if (content instanceof KeyValueType) {
            KeyValueType keyvalueType = (KeyValueType) content;
            StaxUtil.writeStartElement(writer, WSTrustConstants.XMLDSig.DSIG_PREFIX, WSTrustConstants.XMLDSig.KEYVALUE,
                    WSTrustConstants.XMLDSig.DSIG_NS);
            if (keyvalueType instanceof DSAKeyValueType) {
                writeDSAKeyValueType(writer, (DSAKeyValueType) keyvalueType);
            }
            if (keyvalueType instanceof RSAKeyValueType) {
                writeRSAKeyValueType(writer, (RSAKeyValueType) keyvalueType);
            }
            StaxUtil.writeEndElement(writer);
        } else
            throw new ProcessingException(ErrorCodes.UNSUPPORTED_TYPE + content);

        StaxUtil.writeEndElement(writer);
    }

    public static void writeRSAKeyValueType(XMLStreamWriter writer, RSAKeyValueType type) throws ProcessingException {
        String prefix = WSTrustConstants.XMLDSig.DSIG_PREFIX;

        StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.RSA_KEYVALUE, WSTrustConstants.DSIG_NS);
        // write the rsa key modulus.
        byte[] modulus = type.getModulus();
        StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.MODULUS, WSTrustConstants.DSIG_NS);
        StaxUtil.writeCharacters(writer, new String(modulus, GeneralConstants.SAML_CHARSET));
        StaxUtil.writeEndElement(writer);

        // write the rsa key exponent.
        byte[] exponent = type.getExponent();
        StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.EXPONENT, WSTrustConstants.DSIG_NS);
        StaxUtil.writeCharacters(writer, new String(exponent, GeneralConstants.SAML_CHARSET));
        StaxUtil.writeEndElement(writer);

        StaxUtil.writeEndElement(writer);
    }

    public static void writeDSAKeyValueType(XMLStreamWriter writer, DSAKeyValueType type) throws ProcessingException {

        String prefix = WSTrustConstants.XMLDSig.DSIG_PREFIX;

        StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.DSA_KEYVALUE, WSTrustConstants.DSIG_NS);

        byte[] p = type.getP();
        if (p != null) {
            StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.P, WSTrustConstants.DSIG_NS);
            StaxUtil.writeCharacters(writer, new String(p, GeneralConstants.SAML_CHARSET));
            StaxUtil.writeEndElement(writer);
        }
        byte[] q = type.getQ();
        if (q != null) {
            StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.Q, WSTrustConstants.DSIG_NS);
            StaxUtil.writeCharacters(writer, new String(q, GeneralConstants.SAML_CHARSET));
            StaxUtil.writeEndElement(writer);
        }
        byte[] g = type.getG();
        if (g != null) {
            StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.G, WSTrustConstants.DSIG_NS);
            StaxUtil.writeCharacters(writer, new String(g, GeneralConstants.SAML_CHARSET));
            StaxUtil.writeEndElement(writer);
        }
        byte[] y = type.getY();
        if (y != null) {
            StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.Y, WSTrustConstants.DSIG_NS);
            StaxUtil.writeCharacters(writer, new String(y, GeneralConstants.SAML_CHARSET));
            StaxUtil.writeEndElement(writer);
        }
        byte[] seed = type.getSeed();
        if (seed != null) {
            StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.SEED, WSTrustConstants.DSIG_NS);
            StaxUtil.writeCharacters(writer, new String(seed, GeneralConstants.SAML_CHARSET));
            StaxUtil.writeEndElement(writer);
        }
        byte[] pgen = type.getPgenCounter();
        if (pgen != null) {
            StaxUtil.writeStartElement(writer, prefix, WSTrustConstants.XMLDSig.PGEN_COUNTER, WSTrustConstants.DSIG_NS);
            StaxUtil.writeCharacters(writer, new String(pgen, GeneralConstants.SAML_CHARSET));
            StaxUtil.writeEndElement(writer);
        }

        StaxUtil.writeEndElement(writer);
    }
}