Blob Blame History Raw
package xmlutil

import (
	"bytes"
	"encoding/base64"
	"encoding/xml"
	"fmt"
	"io"
	"reflect"
	"strconv"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/private/protocol"
)

// UnmarshalXMLError unmarshals the XML error from the stream into the value
// type specified. The value must be a pointer. If the message fails to
// unmarshal, the message content will be included in the returned error as a
// awserr.UnmarshalError.
func UnmarshalXMLError(v interface{}, stream io.Reader) error {
	var errBuf bytes.Buffer
	body := io.TeeReader(stream, &errBuf)

	err := xml.NewDecoder(body).Decode(v)
	if err != nil && err != io.EOF {
		return awserr.NewUnmarshalError(err,
			"failed to unmarshal error message", errBuf.Bytes())
	}

	return nil
}

// UnmarshalXML deserializes an xml.Decoder into the container v. V
// needs to match the shape of the XML expected to be decoded.
// If the shape doesn't match unmarshaling will fail.
func UnmarshalXML(v interface{}, d *xml.Decoder, wrapper string) error {
	n, err := XMLToStruct(d, nil)
	if err != nil {
		return err
	}
	if n.Children != nil {
		for _, root := range n.Children {
			for _, c := range root {
				if wrappedChild, ok := c.Children[wrapper]; ok {
					c = wrappedChild[0] // pull out wrapped element
				}

				err = parse(reflect.ValueOf(v), c, "")
				if err != nil {
					if err == io.EOF {
						return nil
					}
					return err
				}
			}
		}
		return nil
	}
	return nil
}

// parse deserializes any value from the XMLNode. The type tag is used to infer the type, or reflect
// will be used to determine the type from r.
func parse(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
	rtype := r.Type()
	if rtype.Kind() == reflect.Ptr {
		rtype = rtype.Elem() // check kind of actual element type
	}

	t := tag.Get("type")
	if t == "" {
		switch rtype.Kind() {
		case reflect.Struct:
			// also it can't be a time object
			if _, ok := r.Interface().(*time.Time); !ok {
				t = "structure"
			}
		case reflect.Slice:
			// also it can't be a byte slice
			if _, ok := r.Interface().([]byte); !ok {
				t = "list"
			}
		case reflect.Map:
			t = "map"
		}
	}

	switch t {
	case "structure":
		if field, ok := rtype.FieldByName("_"); ok {
			tag = field.Tag
		}
		return parseStruct(r, node, tag)
	case "list":
		return parseList(r, node, tag)
	case "map":
		return parseMap(r, node, tag)
	default:
		return parseScalar(r, node, tag)
	}
}

// parseStruct deserializes a structure and its fields from an XMLNode. Any nested
// types in the structure will also be deserialized.
func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
	t := r.Type()
	if r.Kind() == reflect.Ptr {
		if r.IsNil() { // create the structure if it's nil
			s := reflect.New(r.Type().Elem())
			r.Set(s)
			r = s
		}

		r = r.Elem()
		t = t.Elem()
	}

	// unwrap any payloads
	if payload := tag.Get("payload"); payload != "" {
		field, _ := t.FieldByName(payload)
		return parseStruct(r.FieldByName(payload), node, field.Tag)
	}

	for i := 0; i < t.NumField(); i++ {
		field := t.Field(i)
		if c := field.Name[0:1]; strings.ToLower(c) == c {
			continue // ignore unexported fields
		}

		// figure out what this field is called
		name := field.Name
		if field.Tag.Get("flattened") != "" && field.Tag.Get("locationNameList") != "" {
			name = field.Tag.Get("locationNameList")
		} else if locName := field.Tag.Get("locationName"); locName != "" {
			name = locName
		}

		// try to find the field by name in elements
		elems := node.Children[name]

		if elems == nil { // try to find the field in attributes
			if val, ok := node.findElem(name); ok {
				elems = []*XMLNode{{Text: val}}
			}
		}

		member := r.FieldByName(field.Name)
		for _, elem := range elems {
			err := parse(member, elem, field.Tag)
			if err != nil {
				return err
			}
		}
	}
	return nil
}

// parseList deserializes a list of values from an XML node. Each list entry
// will also be deserialized.
func parseList(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
	t := r.Type()

	if tag.Get("flattened") == "" { // look at all item entries
		mname := "member"
		if name := tag.Get("locationNameList"); name != "" {
			mname = name
		}

		if Children, ok := node.Children[mname]; ok {
			if r.IsNil() {
				r.Set(reflect.MakeSlice(t, len(Children), len(Children)))
			}

			for i, c := range Children {
				err := parse(r.Index(i), c, "")
				if err != nil {
					return err
				}
			}
		}
	} else { // flattened list means this is a single element
		if r.IsNil() {
			r.Set(reflect.MakeSlice(t, 0, 0))
		}

		childR := reflect.Zero(t.Elem())
		r.Set(reflect.Append(r, childR))
		err := parse(r.Index(r.Len()-1), node, "")
		if err != nil {
			return err
		}
	}

	return nil
}

// parseMap deserializes a map from an XMLNode. The direct children of the XMLNode
// will also be deserialized as map entries.
func parseMap(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
	if r.IsNil() {
		r.Set(reflect.MakeMap(r.Type()))
	}

	if tag.Get("flattened") == "" { // look at all child entries
		for _, entry := range node.Children["entry"] {
			parseMapEntry(r, entry, tag)
		}
	} else { // this element is itself an entry
		parseMapEntry(r, node, tag)
	}

	return nil
}

// parseMapEntry deserializes a map entry from a XML node.
func parseMapEntry(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
	kname, vname := "key", "value"
	if n := tag.Get("locationNameKey"); n != "" {
		kname = n
	}
	if n := tag.Get("locationNameValue"); n != "" {
		vname = n
	}

	keys, ok := node.Children[kname]
	values := node.Children[vname]
	if ok {
		for i, key := range keys {
			keyR := reflect.ValueOf(key.Text)
			value := values[i]
			valueR := reflect.New(r.Type().Elem()).Elem()

			parse(valueR, value, "")
			r.SetMapIndex(keyR, valueR)
		}
	}
	return nil
}

// parseScaller deserializes an XMLNode value into a concrete type based on the
// interface type of r.
//
// Error is returned if the deserialization fails due to invalid type conversion,
// or unsupported interface type.
func parseScalar(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
	switch r.Interface().(type) {
	case *string:
		r.Set(reflect.ValueOf(&node.Text))
		return nil
	case []byte:
		b, err := base64.StdEncoding.DecodeString(node.Text)
		if err != nil {
			return err
		}
		r.Set(reflect.ValueOf(b))
	case *bool:
		v, err := strconv.ParseBool(node.Text)
		if err != nil {
			return err
		}
		r.Set(reflect.ValueOf(&v))
	case *int64:
		v, err := strconv.ParseInt(node.Text, 10, 64)
		if err != nil {
			return err
		}
		r.Set(reflect.ValueOf(&v))
	case *float64:
		v, err := strconv.ParseFloat(node.Text, 64)
		if err != nil {
			return err
		}
		r.Set(reflect.ValueOf(&v))
	case *time.Time:
		format := tag.Get("timestampFormat")
		if len(format) == 0 {
			format = protocol.ISO8601TimeFormatName
		}

		t, err := protocol.ParseTime(format, node.Text)
		if err != nil {
			return err
		}
		r.Set(reflect.ValueOf(&t))
	default:
		return fmt.Errorf("unsupported value: %v (%s)", r.Interface(), r.Type())
	}
	return nil
}