decode_map.go 6.1 KB
package msgpack

import (
	"errors"
	"fmt"
	"reflect"

	"github.com/vmihailenco/msgpack/v5/codes"
)

var (
	mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil))
	mapStringStringType    = mapStringStringPtrType.Elem()
)

var (
	mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil))
	mapStringInterfaceType    = mapStringInterfacePtrType.Elem()
)

func decodeMapValue(d *Decoder, v reflect.Value) error {
	size, err := d.DecodeMapLen()
	if err != nil {
		return err
	}

	typ := v.Type()
	if size == -1 {
		v.Set(reflect.Zero(typ))
		return nil
	}

	if v.IsNil() {
		v.Set(reflect.MakeMap(typ))
	}
	if size == 0 {
		return nil
	}

	return decodeMapValueSize(d, v, size)
}

func decodeMapValueSize(d *Decoder, v reflect.Value, size int) error {
	typ := v.Type()
	keyType := typ.Key()
	valueType := typ.Elem()

	for i := 0; i < size; i++ {
		mk := reflect.New(keyType).Elem()
		if err := d.DecodeValue(mk); err != nil {
			return err
		}

		mv := reflect.New(valueType).Elem()
		if err := d.DecodeValue(mv); err != nil {
			return err
		}

		v.SetMapIndex(mk, mv)
	}

	return nil
}

// DecodeMapLen decodes map length. Length is -1 when map is nil.
func (d *Decoder) DecodeMapLen() (int, error) {
	c, err := d.readCode()
	if err != nil {
		return 0, err
	}

	if codes.IsExt(c) {
		if err = d.skipExtHeader(c); err != nil {
			return 0, err
		}

		c, err = d.readCode()
		if err != nil {
			return 0, err
		}
	}
	return d.mapLen(c)
}

func (d *Decoder) mapLen(c codes.Code) (int, error) {
	size, err := d._mapLen(c)
	err = expandInvalidCodeMapLenError(c, err)
	return size, err
}

func (d *Decoder) _mapLen(c codes.Code) (int, error) {
	if c == codes.Nil {
		return -1, nil
	}
	if c >= codes.FixedMapLow && c <= codes.FixedMapHigh {
		return int(c & codes.FixedMapMask), nil
	}
	if c == codes.Map16 {
		size, err := d.uint16()
		return int(size), err
	}
	if c == codes.Map32 {
		size, err := d.uint32()
		return int(size), err
	}
	return 0, errInvalidCode
}

var errInvalidCode = errors.New("invalid code")

func expandInvalidCodeMapLenError(c codes.Code, err error) error {
	if err == errInvalidCode {
		return fmt.Errorf("msgpack: invalid code=%x decoding map length", c)
	}
	return err
}

func decodeMapStringStringValue(d *Decoder, v reflect.Value) error {
	mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string)
	return d.decodeMapStringStringPtr(mptr)
}

func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error {
	size, err := d.DecodeMapLen()
	if err != nil {
		return err
	}
	if size == -1 {
		*ptr = nil
		return nil
	}

	m := *ptr
	if m == nil {
		*ptr = make(map[string]string, min(size, maxMapSize))
		m = *ptr
	}

	for i := 0; i < size; i++ {
		mk, err := d.DecodeString()
		if err != nil {
			return err
		}
		mv, err := d.DecodeString()
		if err != nil {
			return err
		}
		m[mk] = mv
	}

	return nil
}

func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error {
	ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{})
	return d.decodeMapStringInterfacePtr(ptr)
}

func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error {
	n, err := d.DecodeMapLen()
	if err != nil {
		return err
	}
	if n == -1 {
		*ptr = nil
		return nil
	}

	m := *ptr
	if m == nil {
		*ptr = make(map[string]interface{}, min(n, maxMapSize))
		m = *ptr
	}

	for i := 0; i < n; i++ {
		mk, err := d.DecodeString()
		if err != nil {
			return err
		}
		mv, err := d.decodeInterfaceCond()
		if err != nil {
			return err
		}
		m[mk] = mv
	}

	return nil
}

func (d *Decoder) DecodeMap() (interface{}, error) {
	if d.decodeMapFunc != nil {
		return d.decodeMapFunc(d)
	}

	size, err := d.DecodeMapLen()
	if err != nil {
		return nil, err
	}
	if size == -1 {
		return nil, nil
	}
	if size == 0 {
		return make(map[string]interface{}), nil
	}

	code, err := d.PeekCode()
	if err != nil {
		return nil, err
	}

	if codes.IsString(code) || codes.IsBin(code) {
		return d.decodeMapStringInterfaceSize(size)
	}

	key, err := d.decodeInterfaceCond()
	if err != nil {
		return nil, err
	}

	value, err := d.decodeInterfaceCond()
	if err != nil {
		return nil, err
	}

	keyType := reflect.TypeOf(key)
	valueType := reflect.TypeOf(value)

	mapType := reflect.MapOf(keyType, valueType)
	mapValue := reflect.MakeMap(mapType)

	mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value))
	size--

	err = decodeMapValueSize(d, mapValue, size)
	if err != nil {
		return nil, err
	}

	return mapValue.Interface(), nil
}

func (d *Decoder) decodeMapStringInterfaceSize(size int) (map[string]interface{}, error) {
	m := make(map[string]interface{}, min(size, maxMapSize))
	for i := 0; i < size; i++ {
		mk, err := d.DecodeString()
		if err != nil {
			return nil, err
		}
		mv, err := d.decodeInterfaceCond()
		if err != nil {
			return nil, err
		}
		m[mk] = mv
	}
	return m, nil
}

func (d *Decoder) skipMap(c codes.Code) error {
	n, err := d.mapLen(c)
	if err != nil {
		return err
	}
	for i := 0; i < n; i++ {
		if err := d.Skip(); err != nil {
			return err
		}
		if err := d.Skip(); err != nil {
			return err
		}
	}
	return nil
}

func decodeStructValue(d *Decoder, v reflect.Value) error {
	c, err := d.readCode()
	if err != nil {
		return err
	}

	var isArray bool

	n, err := d._mapLen(c)
	if err != nil {
		var err2 error
		n, err2 = d.arrayLen(c)
		if err2 != nil {
			return expandInvalidCodeMapLenError(c, err)
		}
		isArray = true
	}
	if n == -1 {
		v.Set(reflect.Zero(v.Type()))
		return nil
	}

	fields := structs.Fields(v.Type(), d.structTag)
	if isArray {
		for i, f := range fields.List {
			if i >= n {
				break
			}
			if err := f.DecodeValue(d, v); err != nil {
				return err
			}
		}

		// Skip extra values.
		for i := len(fields.List); i < n; i++ {
			if err := d.Skip(); err != nil {
				return err
			}
		}

		return nil
	}

	for i := 0; i < n; i++ {
		name, err := d.bytesTemp()
		if err != nil {
			return err
		}

		if f := fields.Map[string(name)]; f != nil {
			if err := f.DecodeValue(d, v); err != nil {
				return err
			}
		} else if d.flags&disallowUnknownFieldsFlag != 0 {
			return fmt.Errorf("msgpack: unknown field %q", name)
		} else if err := d.Skip(); err != nil {
			return err
		}
	}

	return nil
}