ast_expr_calculator.go 6.9 KB
package astexpr

import (
	"fmt"
	"github.com/go-gota/gota/dataframe"
	"github.com/go-gota/gota/series"
	"github.com/shopspring/decimal"
	"gitlab.fjmaimaimai.com/allied-creation/character-library-metadata-bastion/pkg/domain"
	"gitlab.fjmaimaimai.com/allied-creation/character-library-metadata-bastion/pkg/infrastructure/utils"
	"strings"
)

type Calculator struct {
	ExprAST   ExprAST
	DataTable *domain.DataTable
	Result    []string
}

func NewCalculator(expr string) (*Calculator, error) {
	ar, err := NewExprAST(expr)
	if err != nil {
		return nil, err
	}
	cal := &Calculator{
		ExprAST: ar,
	}
	return cal, nil
}

func NewExprAST(expr string) (ExprAST, error) {
	toks, err := ParseToken(expr)
	if err != nil {
		return nil, err
	}
	ast := NewAST(toks, expr)
	if ast.Err != nil {
		return nil, ast.Err
	}
	ar := ast.ParseExpression()
	if ast.Err != nil {
		return nil, ast.Err
	}
	return ar, nil
}

func (cal *Calculator) SetDataTable(t *domain.DataTable) *Calculator {
	cal.DataTable = t
	return cal
}

func (cal *Calculator) Exec() error {
	return nil
}

func (cal *Calculator) ExprASTResult(ast ExprAST) (*param, error) {
	switch ast.(type) {
	case BinaryExprAST:
		var l, r *param
		var err error
		ast := ast.(BinaryExprAST)
		l, err = cal.ExprASTResult(ast.Lhs)
		if err != nil {
			return nil, err
		}
		r, err = cal.ExprASTResult(ast.Rhs)
		if err != nil {
			return nil, err
		}
		switch ast.Op {
		case "+", "-", "*", "/", "%":
			return cal.OpCalc(ast.Op, l, r), nil
		default:

		}
	case NumberExprAST:
		f := ast.(NumberExprAST)
		return NewResult([]string{f.Str}), nil
	case ValueExprAST:
		f := ast.(ValueExprAST)
		return NewResult([]string{f.Val}), nil
	case FieldExprAST:
		f := ast.(FieldExprAST)
		values := cal.DataTable.Values(&domain.Field{SQLName: f.Field.FieldSqlName})
		return NewResult(values), nil
	case FunCallerExprAST:
		f := ast.(FunCallerExprAST)
		//def := defFunc[f.Name]
		//def.fun(f.Args...)

		args := make([]*param, 0)
		for i := range f.Args {
			argValue, err := cal.ExprASTResult(f.Args[i])
			if err != nil {
				return nil, err
			}
			args = append(args, argValue)
		}
		return cal.callDef(f.Name, args), nil
	}
	return nil, nil
}

func (cal *Calculator) callDef(name string, args []*param) *param {
	switch strings.ToLower(name) {
	case "sum":
		return cal.sum(args...)
	case "sumifs":
		return cal.sumifs(args...)
	case "countifs":
		return cal.countifs(args...)
	}
	return cal.sum(args...)
}

func (cal *Calculator) sum(params ...*param) *param {
	var res = make([]string, 0)
	var total = decimal.NewFromFloat(0)
	for _, p := range params {
		for _, v := range p.data {
			dv, _ := decimal.NewFromString(v)
			total = total.Add(dv)
		}
	}
	res = append(res, total.String())
	return NewResult(res)
}

func (cal *Calculator) sumifs(params ...*param) *param {
	var list = make([]series.Series, 0)
	var filters = make([]dataframe.F, 0)
	var groupBy = make([]string, 0)
	for i := 0; i < len(params)-1; i++ {
		col := colName(i)
		if i == 0 {
			list = append(list, series.New(params[i].Data(), series.Float, col))
			continue
		}
		if i%2 == 1 {
			list = append(list, series.New(params[i].Data(), series.String, col))
			// TODO 类型是行字段判断为按行分组
			if params[i+1].Len() > 1 {
				groupBy = append(groupBy, col)
			} else {
				if f, ok := cal.resolverFilter(col, params[i+1]); ok {
					filters = append(filters, f)
				}
			}
			i++
		}
	}
	df := dataframe.New(list...)
	df = df.FilterAggregation(dataframe.And, filters...)
	if len(groupBy) > 0 {
		groups := df.GroupBy(groupBy...)
		df = groups.Aggregation([]dataframe.AggregationType{dataframe.Aggregation_SUM}, []string{"A0"})
		s := df.Col("A0_SUM")
		return NewResult(toArrayFloat(s.Records())) //4000.00 需要格式化掉后缀 .00
	}

	s := df.Col("A0")
	return NewResult(s.Records())
}

func (cal *Calculator) countifs(params ...*param) *param {
	var list = make([]series.Series, 0)
	var filters = make([]dataframe.F, 0)
	for i := 0; i < len(params)-1; i++ {
		col := colName(i)
		if i%2 == 0 {
			list = append(list, series.New(params[i].Data(), series.String, col))
			if f, ok := cal.resolverFilter(col, params[i+1]); ok {
				filters = append(filters, f)
			}
			i++
		}
	}
	df := dataframe.New(list...)
	df = df.FilterAggregation(dataframe.And, filters...)
	count := df.Col("A0").Len()
	return NewResult([]string{fmt.Sprintf("%d", count)})
}

func (cal *Calculator) resolverFilter(key string, param *param) (dataframe.F, bool) {
	if len(param.Data()) == 1 {
		condition := param.Data()[0]
		tokens, _ := ParseToken(formatTok(condition))

		switch tokens[0].Type {
		case Operator, CompareOperator:
			if tokens[0].Tok == "*" {
				return dataframe.F{Colname: key, Comparator: series.CompFunc, Comparando: func(el series.Element) bool {
					return strings.Contains(el.String(), strings.Trim(formatTok(condition), "*"))
				}}, true
			}
			return dataframe.F{Colname: key, Comparator: series.Comparator(tokens[0].Tok), Comparando: formatTok(tokens[1].Tok)}, true
		case Identifier, Literal, StringArgs:
			if tokens[len(tokens)-1].Tok == "*" || tokens[0].Tok == "*" {
				return dataframe.F{Colname: key, Comparator: series.CompFunc, Comparando: func(el series.Element) bool {
					return strings.Contains(el.String(), strings.Trim(formatTok(condition), "*"))
				}}, true
			}
			return dataframe.F{Colname: key, Comparator: series.Eq, Comparando: formatTok(condition)}, true
		}
	}
	return dataframe.F{}, false
}

func colName(i int) string {
	return fmt.Sprintf("A%v", i)
}

func formatTok(tok string) string {
	return strings.Trim(tok, `"`)
}

func (cal *Calculator) OpCalc(op string, lp *param, rp *param) *param {
	var res = make([]string, 0)
	temp := make([]string, 0)
	temp = lp.Data()
	l := lp.Data()
	r := rp.Data()
	if lp.Len() < rp.Len() {
		l = r
		r = temp
	}
	rIsSingleValue := len(r) == 1
	var rValue string
	if rIsSingleValue {
		rValue = r[0]
	}
	for i, lValue := range l {
		if rIsSingleValue {
			res = append(res, opCalc(op, lValue, rValue))
			continue
		}
		if i >= len(r) {
			break
		}
		res = append(res, opCalc(op, lValue, r[i]))
	}
	return NewResult(res)
}

func opCalc(op, v1, v2 string) string {
	//fv1 := utils.NumberString(v1).MustFloat64()
	//fv2 := utils.NumberString(v2).MustFloat64()

	fv1, _ := decimal.NewFromString(v1)
	fv2, _ := decimal.NewFromString(v2)
	switch op {
	case "+":

		return utils.AssertString(fv1.Add(fv2).String())
	case "-":
		return utils.AssertString(fv1.Sub(fv2).String()) // utils.Round(fv1-fv2, 15)
	case "*":
		return utils.AssertString(fv1.Mul(fv2).String())
	case "/":
		return utils.AssertString(fv1.Div(fv2).String())
	}
	return ""
}

type param struct {
	data []string
}

func (p *param) Len() int {
	return len(p.data)
}

func (p *param) Data() []string {
	return p.data
}

func NewResult(data []string) *param {
	return &param{
		data: data,
	}
}

func toArrayFloat(list []string) []string {
	for i := range list {
		list[i] = utils.AssertString(utils.NewNumberString(list[i]).MustFloat64())
	}
	return list
}