diff --git a/saml-core/src/main/java/org/keycloak/saml/common/util/TransformerUtil.java b/saml-core/src/main/java/org/keycloak/saml/common/util/TransformerUtil.java
index 14c3fe0..b90db1b 100755
--- a/saml-core/src/main/java/org/keycloak/saml/common/util/TransformerUtil.java
+++ b/saml-core/src/main/java/org/keycloak/saml/common/util/TransformerUtil.java
@@ -209,8 +209,6 @@ public class TransformerUtil {
if (!(outputTarget instanceof DOMResult))
throw logger.wrongTypeError("outputTarget should be a dom result");
- String rootTag = null;
-
StAXSource staxSource = (StAXSource) xmlSource;
XMLEventReader xmlEventReader = staxSource.getXMLEventReader();
if (xmlEventReader == null)
@@ -227,7 +225,6 @@ public class TransformerUtil {
throw new TransformerException(ErrorCodes.WRITER_SHOULD_START_ELEMENT);
StartElement rootElement = (StartElement) xmlEvent;
- rootTag = StaxParserUtil.getElementName(rootElement);
CustomHolder holder = new CustomHolder(doc, false);
Element docRoot = handleStartElement(xmlEventReader, rootElement, holder);
Node parent = doc.importNode(docRoot, true);
@@ -243,6 +240,8 @@ public class TransformerUtil {
while (xmlEventReader.hasNext()) {
xmlEvent = StaxParserUtil.getNextEvent(xmlEventReader);
int type = xmlEvent.getEventType();
+ Node top = null;
+
switch (type) {
case XMLEvent.START_ELEMENT:
StartElement startElement = (StartElement) xmlEvent;
@@ -250,13 +249,11 @@ public class TransformerUtil {
Element docStartElement = handleStartElement(xmlEventReader, startElement, holder);
Node el = doc.importNode(docStartElement, true);
- Node top = null;
-
- if (!stack.isEmpty()) {
+ if (! stack.isEmpty()) {
top = stack.peek();
}
- if (!holder.encounteredTextNode) {
+ if (! holder.encounteredTextNode) {
stack.push(el);
}
@@ -265,15 +262,15 @@ public class TransformerUtil {
else
top.appendChild(el);
break;
+
case XMLEvent.END_ELEMENT:
- EndElement endElement = (EndElement) xmlEvent;
- String endTag = StaxParserUtil.getElementName(endElement);
- if (rootTag.equals(endTag))
- return; // We are done with the dom parsing
- else {
- if (!stack.isEmpty())
- stack.pop();
+ top = stack.pop();
+
+ if (! (top instanceof Element)) {
+ throw new TransformerException(ErrorCodes.UNKNOWN_END_ELEMENT);
}
+ if (stack.isEmpty())
+ return; // We are done with the dom parsing
break;
}
}
diff --git a/saml-core/src/test/java/org/keycloak/saml/common/util/StaxParserUtilTest.java b/saml-core/src/test/java/org/keycloak/saml/common/util/StaxParserUtilTest.java
index 77bf1a8..14438c4 100644
--- a/saml-core/src/test/java/org/keycloak/saml/common/util/StaxParserUtilTest.java
+++ b/saml-core/src/test/java/org/keycloak/saml/common/util/StaxParserUtilTest.java
@@ -18,18 +18,23 @@ package org.keycloak.saml.common.util;
import org.keycloak.saml.common.exceptions.ParsingException;
import java.nio.charset.Charset;
+import java.util.NoSuchElementException;
import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.events.Characters;
+import javax.xml.stream.events.EndDocument;
import javax.xml.stream.events.EndElement;
import javax.xml.stream.events.StartDocument;
import javax.xml.stream.events.StartElement;
import javax.xml.stream.events.XMLEvent;
import org.apache.commons.io.IOUtils;
import org.hamcrest.Matcher;
+import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
+import org.w3c.dom.Element;
+import org.w3c.dom.Text;
import static org.junit.Assert.assertThat;
import static org.hamcrest.CoreMatchers.*;
@@ -175,4 +180,38 @@ public class StaxParserUtilTest {
reader.nextEvent();
}
+ @Test
+ public void testGetDOMElementSameElements() throws XMLStreamException, ParsingException {
+ String xml = "<root><test><test><a>b</a></test></test></root>";
+ XMLEventReader reader = StaxParserUtil.getXMLEventReader(IOUtils.toInputStream(xml, Charset.defaultCharset()));
+
+ assertThat(reader.nextEvent(), instanceOf(StartDocument.class));
+
+ assertStartTag(reader.nextEvent(), "root");
+
+ Element element = StaxParserUtil.getDOMElement(reader);
+
+ assertThat(element.getNodeName(), is("test"));
+ assertThat(element.getChildNodes().getLength(), is(1));
+
+ assertThat(element.getChildNodes().item(0), instanceOf(Element.class));
+ Element e = (Element) element.getChildNodes().item(0);
+ assertThat(e.getNodeName(), is("test"));
+
+ assertThat(e.getChildNodes().getLength(), is(1));
+ assertThat(e.getChildNodes().item(0), instanceOf(Element.class));
+ Element e1 = (Element) e.getChildNodes().item(0);
+ assertThat(e1.getNodeName(), is("a"));
+
+ assertThat(e1.getChildNodes().getLength(), is(1));
+ assertThat(e1.getChildNodes().item(0), instanceOf(Text.class));
+ assertThat(((Text) e1.getChildNodes().item(0)).getWholeText(), is("b"));
+
+ assertEndTag(reader.nextEvent(), "root");
+ assertThat(reader.nextEvent(), instanceOf(EndDocument.class));
+
+ expectedException.expect(NoSuchElementException.class);
+ Assert.fail(String.valueOf(reader.nextEvent()));
+ }
+
}