match_xml_matcher.go 3.6 KB
package matchers

import (
	"bytes"
	"encoding/xml"
	"errors"
	"fmt"
	"io"
	"reflect"
	"sort"
	"strings"

	"github.com/onsi/gomega/format"
	"golang.org/x/net/html/charset"
)

type MatchXMLMatcher struct {
	XMLToMatch interface{}
}

func (matcher *MatchXMLMatcher) Match(actual interface{}) (success bool, err error) {
	actualString, expectedString, err := matcher.formattedPrint(actual)
	if err != nil {
		return false, err
	}

	aval, err := parseXmlContent(actualString)
	if err != nil {
		return false, fmt.Errorf("Actual '%s' should be valid XML, but it is not.\nUnderlying error:%s", actualString, err)
	}

	eval, err := parseXmlContent(expectedString)
	if err != nil {
		return false, fmt.Errorf("Expected '%s' should be valid XML, but it is not.\nUnderlying error:%s", expectedString, err)
	}

	return reflect.DeepEqual(aval, eval), nil
}

func (matcher *MatchXMLMatcher) FailureMessage(actual interface{}) (message string) {
	actualString, expectedString, _ := matcher.formattedPrint(actual)
	return fmt.Sprintf("Expected\n%s\nto match XML of\n%s", actualString, expectedString)
}

func (matcher *MatchXMLMatcher) NegatedFailureMessage(actual interface{}) (message string) {
	actualString, expectedString, _ := matcher.formattedPrint(actual)
	return fmt.Sprintf("Expected\n%s\nnot to match XML of\n%s", actualString, expectedString)
}

func (matcher *MatchXMLMatcher) formattedPrint(actual interface{}) (actualString, expectedString string, err error) {
	var ok bool
	actualString, ok = toString(actual)
	if !ok {
		return "", "", fmt.Errorf("MatchXMLMatcher matcher requires a string, stringer, or []byte.  Got actual:\n%s", format.Object(actual, 1))
	}
	expectedString, ok = toString(matcher.XMLToMatch)
	if !ok {
		return "", "", fmt.Errorf("MatchXMLMatcher matcher requires a string, stringer, or []byte.  Got expected:\n%s", format.Object(matcher.XMLToMatch, 1))
	}
	return actualString, expectedString, nil
}

func parseXmlContent(content string) (*xmlNode, error) {
	allNodes := []*xmlNode{}

	dec := newXmlDecoder(strings.NewReader(content))
	for {
		tok, err := dec.Token()
		if err != nil {
			if err == io.EOF {
				break
			}
			return nil, fmt.Errorf("failed to decode next token: %v", err) // untested section
		}

		lastNodeIndex := len(allNodes) - 1
		var lastNode *xmlNode
		if len(allNodes) > 0 {
			lastNode = allNodes[lastNodeIndex]
		} else {
			lastNode = &xmlNode{}
		}

		switch tok := tok.(type) {
		case xml.StartElement:
			attrs := attributesSlice(tok.Attr)
			sort.Sort(attrs)
			allNodes = append(allNodes, &xmlNode{XMLName: tok.Name, XMLAttr: tok.Attr})
		case xml.EndElement:
			if len(allNodes) > 1 {
				allNodes[lastNodeIndex-1].Nodes = append(allNodes[lastNodeIndex-1].Nodes, lastNode)
				allNodes = allNodes[:lastNodeIndex]
			}
		case xml.CharData:
			lastNode.Content = append(lastNode.Content, tok.Copy()...)
		case xml.Comment:
			lastNode.Comments = append(lastNode.Comments, tok.Copy()) // untested section
		case xml.ProcInst:
			lastNode.ProcInsts = append(lastNode.ProcInsts, tok.Copy())
		}
	}

	if len(allNodes) == 0 {
		return nil, errors.New("found no nodes")
	}
	firstNode := allNodes[0]
	trimParentNodesContentSpaces(firstNode)

	return firstNode, nil
}

func newXmlDecoder(reader io.Reader) *xml.Decoder {
	dec := xml.NewDecoder(reader)
	dec.CharsetReader = charset.NewReaderLabel
	return dec
}

func trimParentNodesContentSpaces(node *xmlNode) {
	if len(node.Nodes) > 0 {
		node.Content = bytes.TrimSpace(node.Content)
		for _, childNode := range node.Nodes {
			trimParentNodesContentSpaces(childNode)
		}
	}
}

type xmlNode struct {
	XMLName   xml.Name
	Comments  []xml.Comment
	ProcInsts []xml.ProcInst
	XMLAttr   []xml.Attr
	Content   []byte
	Nodes     []*xmlNode
}