package msgpack

import (
	"bufio"
	"bytes"
	"errors"
	"fmt"
	"io"
	"reflect"
	"sync"
	"time"

	"github.com/vmihailenco/msgpack/v5/codes"
)

const (
	looseIfaceFlag uint32 = 1 << iota
	disallowUnknownFieldsFlag
)

const (
	bytesAllocLimit = 1e6 // 1mb
	sliceAllocLimit = 1e4
	maxMapSize      = 1e6
)

type bufReader interface {
	io.Reader
	io.ByteScanner
}

//------------------------------------------------------------------------------

var decPool = sync.Pool{
	New: func() interface{} {
		return NewDecoder(nil)
	},
}

func GetDecoder() *Decoder {
	return decPool.Get().(*Decoder)
}

func PutDecoder(dec *Decoder) {
	dec.r = nil
	dec.s = nil
	decPool.Put(dec)
}

//------------------------------------------------------------------------------

// Unmarshal decodes the MessagePack-encoded data and stores the result
// in the value pointed to by v.
func Unmarshal(data []byte, v interface{}) error {
	dec := GetDecoder()

	dec.ResetBytes(data)
	err := dec.Decode(v)

	PutDecoder(dec)

	return err
}

// A Decoder reads and decodes MessagePack values from an input stream.
type Decoder struct {
	r   io.Reader
	s   io.ByteScanner
	buf []byte

	extLen int
	rec    []byte // accumulates read data if not nil

	intern        []string
	flags         uint32
	structTag     string
	decodeMapFunc func(*Decoder) (interface{}, error)
}

// NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may read data from r
// beyond the MessagePack values requested. Buffering can be disabled
// by passing a reader that implements io.ByteScanner interface.
func NewDecoder(r io.Reader) *Decoder {
	d := new(Decoder)
	d.Reset(r)
	return d
}

type resetter interface {
	Reset([]byte)
}

func (d *Decoder) ResetBytes(data []byte) {
	if r, ok := d.r.(resetter); ok {
		r.Reset(data)
	} else {
		d.Reset(bytes.NewReader(data))
	}
}

// Reset discards any buffered data, resets all state, and switches the buffered
// reader to read from r.
func (d *Decoder) Reset(r io.Reader) {
	if br, ok := r.(bufReader); ok {
		d.r = br
		d.s = br
	} else if br, ok := d.r.(*bufio.Reader); ok {
		br.Reset(r)
	} else {
		br := bufio.NewReader(r)
		d.r = br
		d.s = br
	}

	if d.intern != nil {
		d.intern = d.intern[:0]
	}

	d.flags = 0
	d.decodeMapFunc = nil
}

func (d *Decoder) SetDecodeMapFunc(fn func(*Decoder) (interface{}, error)) {
	d.decodeMapFunc = fn
}

// UseDecodeInterfaceLoose causes decoder to use DecodeInterfaceLoose
// to decode msgpack value into Go interface{}.
func (d *Decoder) UseDecodeInterfaceLoose(on bool) {
	if on {
		d.flags |= looseIfaceFlag
	} else {
		d.flags &= ^looseIfaceFlag
	}
}

// UseJSONTag causes the Decoder to use json struct tag as fallback option
// if there is no msgpack tag.
func (d *Decoder) UseJSONTag(on bool) {
	if on {
		d.UseCustomStructTag("json")
	} else {
		d.UseCustomStructTag("")
	}
}

// UseCustomStructTag causes the decoder to use the supplied tag as a fallback option
// if there is no msgpack tag.
func (d *Decoder) UseCustomStructTag(tag string) {
	d.structTag = tag
}

// DisallowUnknownFields causes the Decoder to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
func (d *Decoder) DisallowUnknownFields(on bool) {
	if on {
		d.flags |= disallowUnknownFieldsFlag
	} else {
		d.flags &= ^disallowUnknownFieldsFlag
	}
}

// Buffered returns a reader of the data remaining in the Decoder's buffer.
// The reader is valid until the next call to Decode.
func (d *Decoder) Buffered() io.Reader {
	return d.r
}

//nolint:gocyclo
func (d *Decoder) Decode(v interface{}) error {
	var err error
	switch v := v.(type) {
	case *string:
		if v != nil {
			*v, err = d.DecodeString()
			return err
		}
	case *[]byte:
		if v != nil {
			return d.decodeBytesPtr(v)
		}
	case *int:
		if v != nil {
			*v, err = d.DecodeInt()
			return err
		}
	case *int8:
		if v != nil {
			*v, err = d.DecodeInt8()
			return err
		}
	case *int16:
		if v != nil {
			*v, err = d.DecodeInt16()
			return err
		}
	case *int32:
		if v != nil {
			*v, err = d.DecodeInt32()
			return err
		}
	case *int64:
		if v != nil {
			*v, err = d.DecodeInt64()
			return err
		}
	case *uint:
		if v != nil {
			*v, err = d.DecodeUint()
			return err
		}
	case *uint8:
		if v != nil {
			*v, err = d.DecodeUint8()
			return err
		}
	case *uint16:
		if v != nil {
			*v, err = d.DecodeUint16()
			return err
		}
	case *uint32:
		if v != nil {
			*v, err = d.DecodeUint32()
			return err
		}
	case *uint64:
		if v != nil {
			*v, err = d.DecodeUint64()
			return err
		}
	case *bool:
		if v != nil {
			*v, err = d.DecodeBool()
			return err
		}
	case *float32:
		if v != nil {
			*v, err = d.DecodeFloat32()
			return err
		}
	case *float64:
		if v != nil {
			*v, err = d.DecodeFloat64()
			return err
		}
	case *[]string:
		return d.decodeStringSlicePtr(v)
	case *map[string]string:
		return d.decodeMapStringStringPtr(v)
	case *map[string]interface{}:
		return d.decodeMapStringInterfacePtr(v)
	case *time.Duration:
		if v != nil {
			vv, err := d.DecodeInt64()
			*v = time.Duration(vv)
			return err
		}
	case *time.Time:
		if v != nil {
			*v, err = d.DecodeTime()
			return err
		}
	}

	vv := reflect.ValueOf(v)
	if !vv.IsValid() {
		return errors.New("msgpack: Decode(nil)")
	}
	if vv.Kind() != reflect.Ptr {
		return fmt.Errorf("msgpack: Decode(non-pointer %T)", v)
	}
	if vv.IsNil() {
		return fmt.Errorf("msgpack: Decode(non-settable %T)", v)
	}

	vv = vv.Elem()
	if vv.Kind() == reflect.Interface {
		if !vv.IsNil() {
			vv = vv.Elem()
			if vv.Kind() != reflect.Ptr {
				return fmt.Errorf("msgpack: Decode(non-pointer %s)", vv.Type().String())
			}
		}
	}

	return d.DecodeValue(vv)
}

func (d *Decoder) DecodeMulti(v ...interface{}) error {
	for _, vv := range v {
		if err := d.Decode(vv); err != nil {
			return err
		}
	}
	return nil
}

func (d *Decoder) decodeInterfaceCond() (interface{}, error) {
	if d.flags&looseIfaceFlag != 0 {
		return d.DecodeInterfaceLoose()
	}
	return d.DecodeInterface()
}

func (d *Decoder) DecodeValue(v reflect.Value) error {
	decode := getDecoder(v.Type())
	return decode(d, v)
}

func (d *Decoder) DecodeNil() error {
	c, err := d.readCode()
	if err != nil {
		return err
	}
	if c != codes.Nil {
		return fmt.Errorf("msgpack: invalid code=%x decoding nil", c)
	}
	return nil
}

func (d *Decoder) decodeNilValue(v reflect.Value) error {
	err := d.DecodeNil()
	if v.IsNil() {
		return err
	}
	if v.Kind() == reflect.Ptr {
		v = v.Elem()
	}
	v.Set(reflect.Zero(v.Type()))
	return err
}

func (d *Decoder) DecodeBool() (bool, error) {
	c, err := d.readCode()
	if err != nil {
		return false, err
	}
	return d.bool(c)
}

func (d *Decoder) bool(c codes.Code) (bool, error) {
	if c == codes.False {
		return false, nil
	}
	if c == codes.True {
		return true, nil
	}
	return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c)
}

func (d *Decoder) DecodeDuration() (time.Duration, error) {
	n, err := d.DecodeInt64()
	if err != nil {
		return 0, err
	}
	return time.Duration(n), nil
}

// DecodeInterface decodes value into interface. It returns following types:
//   - nil,
//   - bool,
//   - int8, int16, int32, int64,
//   - uint8, uint16, uint32, uint64,
//   - float32 and float64,
//   - string,
//   - []byte,
//   - slices of any of the above,
//   - maps of any of the above.
//
// DecodeInterface should be used only when you don't know the type of value
// you are decoding. For example, if you are decoding number it is better to use
// DecodeInt64 for negative numbers and DecodeUint64 for positive numbers.
func (d *Decoder) DecodeInterface() (interface{}, error) {
	c, err := d.readCode()
	if err != nil {
		return nil, err
	}

	if codes.IsFixedNum(c) {
		return int8(c), nil
	}
	if codes.IsFixedMap(c) {
		err = d.s.UnreadByte()
		if err != nil {
			return nil, err
		}
		return d.DecodeMap()
	}
	if codes.IsFixedArray(c) {
		return d.decodeSlice(c)
	}
	if codes.IsFixedString(c) {
		return d.string(c)
	}

	switch c {
	case codes.Nil:
		return nil, nil
	case codes.False, codes.True:
		return d.bool(c)
	case codes.Float:
		return d.float32(c)
	case codes.Double:
		return d.float64(c)
	case codes.Uint8:
		return d.uint8()
	case codes.Uint16:
		return d.uint16()
	case codes.Uint32:
		return d.uint32()
	case codes.Uint64:
		return d.uint64()
	case codes.Int8:
		return d.int8()
	case codes.Int16:
		return d.int16()
	case codes.Int32:
		return d.int32()
	case codes.Int64:
		return d.int64()
	case codes.Bin8, codes.Bin16, codes.Bin32:
		return d.bytes(c, nil)
	case codes.Str8, codes.Str16, codes.Str32:
		return d.string(c)
	case codes.Array16, codes.Array32:
		return d.decodeSlice(c)
	case codes.Map16, codes.Map32:
		err = d.s.UnreadByte()
		if err != nil {
			return nil, err
		}
		return d.DecodeMap()
	case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
		codes.Ext8, codes.Ext16, codes.Ext32:
		return d.extInterface(c)
	}

	return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
}

// DecodeInterfaceLoose is like DecodeInterface except that:
//   - int8, int16, and int32 are converted to int64,
//   - uint8, uint16, and uint32 are converted to uint64,
//   - float32 is converted to float64.
func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) {
	c, err := d.readCode()
	if err != nil {
		return nil, err
	}

	if codes.IsFixedNum(c) {
		return int64(int8(c)), nil
	}
	if codes.IsFixedMap(c) {
		err = d.s.UnreadByte()
		if err != nil {
			return nil, err
		}
		return d.DecodeMap()
	}
	if codes.IsFixedArray(c) {
		return d.decodeSlice(c)
	}
	if codes.IsFixedString(c) {
		return d.string(c)
	}

	switch c {
	case codes.Nil:
		return nil, nil
	case codes.False, codes.True:
		return d.bool(c)
	case codes.Float, codes.Double:
		return d.float64(c)
	case codes.Uint8, codes.Uint16, codes.Uint32, codes.Uint64:
		return d.uint(c)
	case codes.Int8, codes.Int16, codes.Int32, codes.Int64:
		return d.int(c)
	case codes.Bin8, codes.Bin16, codes.Bin32:
		return d.bytes(c, nil)
	case codes.Str8, codes.Str16, codes.Str32:
		return d.string(c)
	case codes.Array16, codes.Array32:
		return d.decodeSlice(c)
	case codes.Map16, codes.Map32:
		err = d.s.UnreadByte()
		if err != nil {
			return nil, err
		}
		return d.DecodeMap()
	case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
		codes.Ext8, codes.Ext16, codes.Ext32:
		return d.extInterface(c)
	}

	return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
}

// Skip skips next value.
func (d *Decoder) Skip() error {
	c, err := d.readCode()
	if err != nil {
		return err
	}

	if codes.IsFixedNum(c) {
		return nil
	}
	if codes.IsFixedMap(c) {
		return d.skipMap(c)
	}
	if codes.IsFixedArray(c) {
		return d.skipSlice(c)
	}
	if codes.IsFixedString(c) {
		return d.skipBytes(c)
	}

	switch c {
	case codes.Nil, codes.False, codes.True:
		return nil
	case codes.Uint8, codes.Int8:
		return d.skipN(1)
	case codes.Uint16, codes.Int16:
		return d.skipN(2)
	case codes.Uint32, codes.Int32, codes.Float:
		return d.skipN(4)
	case codes.Uint64, codes.Int64, codes.Double:
		return d.skipN(8)
	case codes.Bin8, codes.Bin16, codes.Bin32:
		return d.skipBytes(c)
	case codes.Str8, codes.Str16, codes.Str32:
		return d.skipBytes(c)
	case codes.Array16, codes.Array32:
		return d.skipSlice(c)
	case codes.Map16, codes.Map32:
		return d.skipMap(c)
	case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
		codes.Ext8, codes.Ext16, codes.Ext32:
		return d.skipExt(c)
	}

	return fmt.Errorf("msgpack: unknown code %x", c)
}

func (d *Decoder) DecodeRaw() (RawMessage, error) {
	d.rec = make([]byte, 0)
	if err := d.Skip(); err != nil {
		return nil, err
	}
	msg := RawMessage(d.rec)
	d.rec = nil
	return msg, nil
}

// PeekCode returns the next MessagePack code without advancing the reader.
// Subpackage msgpack/codes defines list of available codes.
func (d *Decoder) PeekCode() (codes.Code, error) {
	c, err := d.s.ReadByte()
	if err != nil {
		return 0, err
	}
	return codes.Code(c), d.s.UnreadByte()
}

// ReadFull reads exactly len(buf) bytes into the buf.
func (d *Decoder) ReadFull(buf []byte) error {
	_, err := readN(d.r, buf, len(buf))
	return err
}

func (d *Decoder) hasNilCode() bool {
	code, err := d.PeekCode()
	return err == nil && code == codes.Nil
}

func (d *Decoder) readCode() (codes.Code, error) {
	d.extLen = 0
	c, err := d.s.ReadByte()
	if err != nil {
		return 0, err
	}
	if d.rec != nil {
		d.rec = append(d.rec, c)
	}
	return codes.Code(c), nil
}

func (d *Decoder) readFull(b []byte) error {
	_, err := io.ReadFull(d.r, b)
	if err != nil {
		return err
	}
	if d.rec != nil {
		d.rec = append(d.rec, b...)
	}
	return nil
}

func (d *Decoder) readN(n int) ([]byte, error) {
	var err error
	d.buf, err = readN(d.r, d.buf, n)
	if err != nil {
		return nil, err
	}
	if d.rec != nil {
		//TODO: read directly into d.rec?
		d.rec = append(d.rec, d.buf...)
	}
	return d.buf, nil
}

func readN(r io.Reader, b []byte, n int) ([]byte, error) {
	if b == nil {
		if n == 0 {
			return make([]byte, 0), nil
		}
		switch {
		case n < 64:
			b = make([]byte, 0, 64)
		case n <= bytesAllocLimit:
			b = make([]byte, 0, n)
		default:
			b = make([]byte, 0, bytesAllocLimit)
		}
	}

	if n <= cap(b) {
		b = b[:n]
		_, err := io.ReadFull(r, b)
		return b, err
	}
	b = b[:cap(b)]

	var pos int
	for {
		alloc := min(n-len(b), bytesAllocLimit)
		b = append(b, make([]byte, alloc)...)

		_, err := io.ReadFull(r, b[pos:])
		if err != nil {
			return b, err
		}

		if len(b) == n {
			break
		}
		pos = len(b)
	}

	return b, nil
}

func min(a, b int) int { //nolint:unparam
	if a <= b {
		return a
	}
	return b
}