Blob Blame History Raw
package runtime

import (
	"encoding/json"
	"fmt"
	"net/url"
	"reflect"
	"sort"
	"strconv"
	"strings"
	"time"

	"github.com/pkg/errors"

	"github.com/deepmap/oapi-codegen/pkg/types"
)

func marshalDeepObject(in interface{}, path []string) ([]string, error) {
	var result []string

	switch t := in.(type) {
	case []interface{}:
		// For the array, we will use numerical subscripts of the form [x],
		// in the same order as the array.
		for i, iface := range t {
			newPath := append(path, strconv.Itoa(i))
			fields, err := marshalDeepObject(iface, newPath)
			if err != nil {
				return nil, errors.Wrap(err, "error traversing array")
			}
			result = append(result, fields...)
		}
	case map[string]interface{}:
		// For a map, each key (field name) becomes a member of the path, and
		// we recurse. First, sort the keys.
		keys := make([]string, len(t))
		i := 0
		for k := range t {
			keys[i] = k
			i++
		}
		sort.Strings(keys)

		// Now, for each key, we recursively marshal it.
		for _, k := range keys {
			newPath := append(path, k)
			fields, err := marshalDeepObject(t[k], newPath)
			if err != nil {
				return nil, errors.Wrap(err, "error traversing map")
			}
			result = append(result, fields...)
		}
	default:
		// Now, for a concrete value, we will turn the path elements
		// into a deepObject style set of subscripts. [a, b, c] turns into
		// [a][b][c]
		prefix := "[" + strings.Join(path, "][") + "]"
		result = []string{
			prefix + fmt.Sprintf("=%v", t),
		}
	}
	return result, nil
}

func MarshalDeepObject(i interface{}, paramName string) (string, error) {
	// We're going to marshal to JSON and unmarshal into an interface{},
	// which will use the json pkg to deal with all the field annotations. We
	// can then walk the generic object structure to produce a deepObject. This
	// isn't efficient and it would be more efficient to reflect on our own,
	// but it's complicated, error-prone code.
	buf, err := json.Marshal(i)
	if err != nil {
		return "", errors.Wrap(err, "failed to marshal input to JSON")
	}
	var i2 interface{}
	err = json.Unmarshal(buf, &i2)
	if err != nil {
		return "", errors.Wrap(err, "failed to unmarshal JSON")
	}
	fields, err := marshalDeepObject(i2, nil)
	if err != nil {
		return "", errors.Wrap(err, "error traversing JSON structure")
	}

	// Prefix the param name to each subscripted field.
	for i := range fields {
		fields[i] = paramName + fields[i]
	}
	return strings.Join(fields, "&"), nil
}

type fieldOrValue struct {
	fields map[string]fieldOrValue
	value  string
}

func (f *fieldOrValue) appendPathValue(path []string, value string) {
	fieldName := path[0]
	if len(path) == 1 {
		f.fields[fieldName] = fieldOrValue{value: value}
		return
	}

	pv, found := f.fields[fieldName]
	if !found {
		pv = fieldOrValue{
			fields: make(map[string]fieldOrValue),
		}
		f.fields[fieldName] = pv
	}
	pv.appendPathValue(path[1:], value)
}

func makeFieldOrValue(paths [][]string, values []string) fieldOrValue {

	f := fieldOrValue{
		fields: make(map[string]fieldOrValue),
	}
	for i := range paths {
		path := paths[i]
		value := values[i]
		f.appendPathValue(path, value)
	}
	return f
}

func UnmarshalDeepObject(dst interface{}, paramName string, params url.Values) error {
	// Params are all the query args, so we need those that look like
	// "paramName["...
	var fieldNames []string
	var fieldValues []string
	searchStr := paramName + "["
	for pName, pValues := range params {
		if strings.HasPrefix(pName, searchStr) {
			// trim the parameter name from the full name.
			pName = pName[len(paramName):]
			fieldNames = append(fieldNames, pName)
			if len(pValues) != 1 {
				return fmt.Errorf("%s has multiple values", pName)
			}
			fieldValues = append(fieldValues, pValues[0])
		}
	}

	// Now, for each field, reconstruct its subscript path and value
	paths := make([][]string, len(fieldNames))
	for i, path := range fieldNames {
		path = strings.TrimLeft(path, "[")
		path = strings.TrimRight(path, "]")
		paths[i] = strings.Split(path, "][")
	}

	fieldPaths := makeFieldOrValue(paths, fieldValues)
	err := assignPathValues(dst, fieldPaths)
	if err != nil {
		return errors.Wrap(err, "error assigning value to destination")
	}

	return nil
}

// This returns a field name, either using the variable name, or the json
// annotation if that exists.
func getFieldName(f reflect.StructField) string {
	n := f.Name
	tag, found := f.Tag.Lookup("json")
	if found {
		// If we have a json field, and the first part of it before the
		// first comma is non-empty, that's our field name.
		parts := strings.Split(tag, ",")
		if parts[0] != "" {
			n = parts[0]
		}
	}
	return n
}

// Create a map of field names that we'll see in the deepObject to reflect
// field indices on the given type.
func fieldIndicesByJsonTag(i interface{}) (map[string]int, error) {
	t := reflect.TypeOf(i)
	if t.Kind() != reflect.Struct {
		return nil, errors.New("expected a struct as input")
	}

	n := t.NumField()
	fieldMap := make(map[string]int)
	for i := 0; i < n; i++ {
		field := t.Field(i)
		fieldName := getFieldName(field)
		fieldMap[fieldName] = i
	}
	return fieldMap, nil
}

func assignPathValues(dst interface{}, pathValues fieldOrValue) error {
	//t := reflect.TypeOf(dst)
	v := reflect.ValueOf(dst)

	iv := reflect.Indirect(v)
	it := iv.Type()

	switch it.Kind() {
	case reflect.Slice:
		sliceLength := len(pathValues.fields)
		dstSlice := reflect.MakeSlice(it, sliceLength, sliceLength)
		err := assignSlice(dstSlice, pathValues)
		if err != nil {
			return errors.Wrap(err, "error assigning slice")
		}
		iv.Set(dstSlice)
		return nil
	case reflect.Struct:
		// Some special types we care about are structs. Handle them
		// here.
		if _, isDate := iv.Interface().(types.Date); isDate {
			var date types.Date
			var err error
			date.Time, err = time.Parse(types.DateFormat, pathValues.value)
			if err != nil {
				return errors.Wrap(err, "invalid date format")
			}
			iv.Set(reflect.ValueOf(date))
		}
		if _, isTime := iv.Interface().(time.Time); isTime {
			var tm time.Time
			var err error
			tm, err = time.Parse(types.DateFormat, pathValues.value)
			if err != nil {
				return errors.Wrap(err, "invalid date format")
			}
			iv.Set(reflect.ValueOf(tm))
		}

		fieldMap, err := fieldIndicesByJsonTag(iv.Interface())
		if err != nil {
			return errors.Wrap(err, "failed enumerating fields")
		}
		for _, fieldName := range sortedFieldOrValueKeys(pathValues.fields) {
			fieldValue := pathValues.fields[fieldName]
			fieldIndex, found := fieldMap[fieldName]
			if !found {
				return fmt.Errorf("field [%s] is not present in destination object", fieldName)
			}
			field := iv.Field(fieldIndex)
			err = assignPathValues(field.Addr().Interface(), fieldValue)
			if err != nil {
				return errors.Wrapf(err, "error assigning field [%s]", fieldName)
			}
		}
		return nil
	case reflect.Ptr:
		// If we have a pointer after redirecting, it means we're dealing with
		// an optional field, such as *string, which was passed in as &foo. We
		// will allocate it if necessary, and call ourselves with a different
		// interface.
		dstVal := reflect.New(it.Elem())
		dstPtr := dstVal.Interface()
		err := assignPathValues(dstPtr, pathValues)
		iv.Set(dstVal)
		return err
	case reflect.Bool:
		val, err := strconv.ParseBool(pathValues.value)
		if err != nil {
			return fmt.Errorf("expected a valid bool, got %s", pathValues.value)
		}
		iv.SetBool(val)
		return nil
	case reflect.Float32:
		val, err := strconv.ParseFloat(pathValues.value, 32)
		if err != nil {
			return fmt.Errorf("expected a valid float, got %s", pathValues.value)
		}
		iv.SetFloat(val)
		return nil
	case reflect.Float64:
		val, err := strconv.ParseFloat(pathValues.value, 64)
		if err != nil {
			return fmt.Errorf("expected a valid float, got %s", pathValues.value)
		}
		iv.SetFloat(val)
		return nil
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		val, err := strconv.ParseInt(pathValues.value, 10, 64)
		if err != nil {
			return fmt.Errorf("expected a valid int, got %s", pathValues.value)
		}
		iv.SetInt(val)
		return nil
	case reflect.String:
		iv.SetString(pathValues.value)
		return nil
	default:
		return errors.New("unhandled type: " + it.String())
	}
}

func assignSlice(dst reflect.Value, pathValues fieldOrValue) error {
	// Gather up the values
	nValues := len(pathValues.fields)
	values := make([]string, nValues)
	// We expect to have consecutive array indices in the map
	for i := 0; i < nValues; i++ {
		indexStr := strconv.Itoa(i)
		fv, found := pathValues.fields[indexStr]
		if !found {
			return errors.New("array deepObjects must have consecutive indices")
		}
		values[i] = fv.value
	}

	// This could be cleaner, but we can call into assignPathValues to
	// avoid recreating this logic.
	for i:=0; i < nValues; i++ {
		dstElem := dst.Index(i).Addr()
		err := assignPathValues(dstElem.Interface(), fieldOrValue{value:values[i]})
		if err != nil {
			return errors.Wrap(err, "error binding array")
		}
	}

	return nil
}

func sortedFieldOrValueKeys(m map[string]fieldOrValue) []string {
	keys := make([]string, 0, len(m))
	for k := range m {
		keys = append(keys, k)
	}
	sort.Strings(keys)
	return keys
}