package astexpr

import (
	"errors"
	"fmt"
	"strconv"
	"strings"
)

type AST struct {
	Tokens []*Token

	source    string
	currTok   *Token
	currIndex int
	depth     int

	Err error
}

func NewAST(toks []*Token, s string) *AST {
	a := &AST{
		Tokens: toks,
		source: s,
	}
	if a.Tokens == nil || len(a.Tokens) == 0 {
		a.Err = errors.New("empty token")
	} else {
		a.currIndex = 0
		a.currTok = a.Tokens[0]
	}
	return a
}

// ParseExpression 解析表达式
func (a *AST) ParseExpression() ExprAST {
	a.depth++ // called depth
	lhs := a.parsePrimary()
	r := a.parseBinOpRHS(0, lhs)
	a.depth--
	if a.depth == 0 && a.currIndex != len(a.Tokens) && a.Err == nil {
		a.Err = errors.New(
			fmt.Sprintf("bad expression, reaching the end or missing the operator\n%s",
				ErrPos(a.source, a.currTok.Offset)))
	}
	return r
}

func (a *AST) getNextToken() *Token {
	a.currIndex++
	if a.currIndex < len(a.Tokens) {
		a.currTok = a.Tokens[a.currIndex]
		return a.currTok
	}
	return nil
}

func (a *AST) getTokPrecedence() int {
	if p, ok := precedence[a.currTok.Tok]; ok {
		return p
	}
	return -1
}

func (a *AST) parseNumber() NumberExprAST {
	f64, err := strconv.ParseFloat(a.currTok.Tok, 64)
	if err != nil {
		a.Err = errors.New(
			fmt.Sprintf("%v\nwant '(' or '0-9' but get '%s'\n%s",
				err.Error(),
				a.currTok.Tok,
				ErrPos(a.source, a.currTok.Offset)))
		return NumberExprAST{}
	}
	n := NumberExprAST{
		Val: f64,
		Str: a.currTok.Tok,
	}
	a.getNextToken()
	return n
}

func (a *AST) parseFunCallerOrConst() ExprAST {
	name := a.currTok.Tok
	a.getNextToken()
	// call func
	if a.currTok.Tok == "(" {
		f := FunCallerExprAST{}
		if _, ok := defFunc[strings.ToLower(name)]; !ok {
			a.Err = errors.New(
				fmt.Sprintf("function `%s` is undefined\n%s",
					name,
					ErrPos(a.source, a.currTok.Offset)))
			return f
		}
		a.getNextToken()
		exprs := make([]ExprAST, 0)
		if a.currTok.Tok == ")" {
			// function call without parameters
			// ignore the process of parameter resolution
		} else {
			exprs = append(exprs, a.ParseExpression())
			for a.currTok.Tok != ")" && a.getNextToken() != nil {
				if a.currTok.Type == COMMA {
					continue
				}
				exprs = append(exprs, a.ParseExpression())
			}
		}
		def := defFunc[strings.ToLower(name)]
		if def.argc >= 0 && len(exprs) != def.argc {
			a.Err = errors.New(
				fmt.Sprintf("wrong way calling function `%s`, parameters want %d but get %d\n%s",
					name,
					def.argc,
					len(exprs),
					ErrPos(a.source, a.currTok.Offset)))
		}
		a.getNextToken()
		f = NewFunCallerExprAST(name, exprs...)
		return f
	}
	// call const
	if v, ok := defConst[name]; ok {
		return NumberExprAST{
			Val: v,
			Str: strconv.FormatFloat(v, 'f', 0, 64),
		}
	} else {
		if strings.Contains(name, ".") {
			return NewFieldExprAST(name)
		}
		a.Err = errors.New(
			fmt.Sprintf("const `%s` is undefined\n%s",
				name,
				ErrPos(a.source, a.currTok.Offset)))
		return NumberExprAST{}
	}
}

func (a *AST) parsePrimary() ExprAST {
	switch a.currTok.Type {
	case Identifier:
		return a.parseFunCallerOrConst()
	case Literal:
		return a.parseNumber()
	case StringArgs:
		e := NewValueExprAST(a.currTok.Tok)
		a.getNextToken()
		return e
	case Operator:
		if a.currTok.Tok == "(" {
			t := a.getNextToken()
			if t == nil {
				a.Err = errors.New(
					fmt.Sprintf("want '(' or '0-9' but get EOF\n%s",
						ErrPos(a.source, a.currTok.Offset)))
				return nil
			}
			e := a.ParseExpression()
			if e == nil {
				return nil
			}
			if a.currTok.Tok != ")" {
				a.Err = errors.New(
					fmt.Sprintf("want ')' but get %s\n%s",
						a.currTok.Tok,
						ErrPos(a.source, a.currTok.Offset)))
				return nil
			}
			a.getNextToken()
			return e
		} else if a.currTok.Tok == "-" {
			if a.getNextToken() == nil {
				a.Err = errors.New(
					fmt.Sprintf("want '0-9' but get '-'\n%s",
						ErrPos(a.source, a.currTok.Offset)))
				return nil
			}
			bin := NewBinaryExprAST("-", NumberExprAST{}, a.parsePrimary())
			return bin
		} else {
			return a.parseNumber()
		}
	case COMMA:
		a.Err = errors.New(
			fmt.Sprintf("want '(' or '0-9' but get %s\n%s",
				a.currTok.Tok,
				ErrPos(a.source, a.currTok.Offset)))
		return nil
	default:
		return nil
	}
}

func (a *AST) parseBinOpRHS(execPrec int, lhs ExprAST) ExprAST {
	for {
		tokPrec := a.getTokPrecedence()
		if tokPrec < execPrec {
			return lhs
		}
		binOp := a.currTok.Tok
		if a.getNextToken() == nil {
			a.Err = errors.New(
				fmt.Sprintf("want '(' or '0-9' but get EOF\n%s",
					ErrPos(a.source, a.currTok.Offset)))
			return nil
		}
		rhs := a.parsePrimary()
		if rhs == nil {
			return nil
		}
		nextPrec := a.getTokPrecedence()
		if tokPrec < nextPrec {
			rhs = a.parseBinOpRHS(tokPrec+1, rhs)
			if rhs == nil {
				return nil
			}
		}
		lhs = NewBinaryExprAST(binOp, lhs, rhs)
	}
}