package starrocks

import (
	"database/sql"
	"fmt"
	"gitlab.fjmaimaimai.com/allied-creation/character-library-metadata-bastion/pkg/domain"
	"gitlab.fjmaimaimai.com/allied-creation/character-library-metadata-bastion/pkg/infrastructure/utils"
	"gitlab.fjmaimaimai.com/allied-creation/character-library-metadata-bastion/pkg/log"
	"gorm.io/gorm"
	"reflect"
	"strings"
)

var AssertString = utils.AssertString

func Query(params QueryOptions, queryFunc func(params QueryOptions) (*sql.Rows, error)) (*domain.DataTable, error) {
	rows, err := queryFunc(params)
	if err != nil {
		return nil, err
	}
	defer rows.Close()
	dataTable := &domain.DataTable{}
	dataTable.Data, err = ScanRows(rows)

	//rows.Columns()
	return dataTable, err
}

type QueryOptions struct {
	Table     *domain.Table
	TableName string
	Select    []*domain.Field
	Where     []Condition
	Offset    int
	Limit     int
	Context   *domain.Context
}

func (o *QueryOptions) SetOffsetLimit(pageNumber, pageSize int) {
	if pageNumber == 0 {
		pageNumber = 1
	}
	if pageSize == 0 {
		pageSize = 20
	}
	o.Offset = (pageNumber - 1) * pageSize
	o.Limit = pageSize
}

func (o *QueryOptions) SetCondition(conditions []domain.Condition) *QueryOptions {
	for _, c := range conditions {
		o.Where = append(o.Where, Condition{
			Condition: c,
		})
	}
	return o
}

func (o *QueryOptions) SetDefaultOrder() *QueryOptions {
	hasOrder := false
	for _, c := range o.Where {
		if len(c.Order) > 0 {
			hasOrder = true
		}
	}
	// 没有排序的加一个排序,才能分页
	if !hasOrder {
		if o.Table != nil {
			o.Where = append(o.Where, Condition{
				Condition: domain.Condition{
					Field: o.Table.PK,
					Order: "ASC",
				},
			})
			return o
		}
		o.Where = append(o.Where, Condition{
			Condition: domain.Condition{
				Field: domain.PK(),
				Order: "ASC",
			},
		})
	}
	return o
}

type Condition struct {
	domain.Condition
	Distinct bool
}

func (c Condition) SetWhere(q *gorm.DB) {
	if len(c.Like) > 0 {
		q.Where(fmt.Sprintf("%v like '%%%v%%'", FormatIfNull(c.Field), c.Like))
	}
	if len(c.In) > 0 {
		q.Where(fmt.Sprintf("%v in %v", FormatIfNull(c.Field), c.InArgs(c.In)))
	}
	if len(c.Ex) > 0 {
		in := c.InArgs(c.Ex)
		q.Where(fmt.Sprintf("%v not in %v", FormatIfNull(c.Field), in))
	}
	if len(c.Range) > 0 {
		for _, item := range c.Range {
			if item.Op == "" {
				continue
			}
			opVal, ok := opMap[item.Op]
			if !ok {
				continue
			}
			val, err := domain.ValueToType(AssertString(item.Val), c.Field.SQLType)
			if err != nil {
				log.Logger.Error(err.Error())
				continue
			}
			q.Where(fmt.Sprintf("%s %s %s",
				FormatIfNull(c.Field),
				opVal,
				c.formatByOp(item.Op, val),
			))
		}
	}
	if c.Distinct {
		q.Distinct(c.Field.SQLName)
	}
	if len(c.Order) > 0 {
		q.Order(fmt.Sprintf("%v %v", c.Field.SQLName, c.Order))
	}
}

func FormatIfNull(f *domain.Field) string {
	if domain.SQLType(f.SQLType).IsString() {
		return fmt.Sprintf("ifnull(%s,'')", f.SQLName)
	}
	return f.SQLName
}

var opMap = map[string]string{
	"=":        "=",
	">":        ">",
	"<":        "<",
	">=":       ">=",
	"<=":       "<=",
	"<>":       "<>",
	"like":     "like",
	"not like": "not like",
}

func (c Condition) formatByOp(op string, val interface{}) string {
	if op == "like" || op == "not like" {
		return fmt.Sprintf("'%%%s%%'", AssertString(val))
	}
	return c.Arg(val)
}

func (c Condition) InArgs(args interface{}) string {
	bytes := make([]byte, 0)
	bytes = appendIn(bytes, reflect.ValueOf(args))
	return string(bytes)
}

func (c Condition) Arg(args interface{}) string {
	bytes := make([]byte, 0)
	v := reflect.ValueOf(args)
	bytes = appendValue(bytes, v)
	return string(bytes)
}

func appendIn(b []byte, slice reflect.Value) []byte {
	sliceLen := slice.Len()
	b = append(b, '(')
	for i := 0; i < sliceLen; i++ {
		if i > 0 {
			b = append(b, ',')
		}

		elem := slice.Index(i)
		if elem.Kind() == reflect.Interface {
			elem = elem.Elem()
		}
		if elem.Kind() == reflect.Slice {
			//b = appendIn(b, elem)
		} else {
			b = appendValue(b, elem)
		}
	}
	b = append(b, ')')
	return b
}

func appendValue(b []byte, v reflect.Value) []byte {
	if v.Kind() == reflect.Ptr && v.IsNil() {

		return append(b, "NULL"...)
	}
	if v.Kind() == reflect.Int || v.Kind() == reflect.Int64 || v.Kind() == reflect.Float64 {
		return append(b, []byte(AssertString(v.Interface()))...)
	}
	b = append(b, []byte("'")...)
	b = append(b, []byte(AssertString(v.Interface()))...)
	b = append(b, []byte("'")...)
	return b
}

func DefaultQueryFunc(params QueryOptions) (*sql.Rows, error) {
	query := DB.Table(params.TableName)
	rows, err := query.Rows()
	if err != nil {
		return nil, err
	}
	return rows, nil
}

func WrapQueryFuncWithDB(db *gorm.DB) func(QueryOptions) (*sql.Rows, error) {
	return func(params QueryOptions) (*sql.Rows, error) {
		query := db.Table(params.TableName)
		queryWithoutLimitOffset(query, params)
		if params.Offset > 0 {
			query.Offset(params.Offset)
		}
		if params.Limit > 0 {
			query.Limit(params.Limit)
		}
		if params.Context != nil {
			query.Where(fmt.Sprintf("context->>'companyId'='%v'", params.Context.CompanyId))
		}
		rows, err := query.Rows()
		if err != nil {
			return nil, err
		}
		return rows, nil
	}
}

func SetTable(query *gorm.DB, tableName string) {
	query.Statement.Table = tableName
}

func queryWithoutLimitOffset(query *gorm.DB, params QueryOptions) {
	if len(params.Select) > 0 {
		fields := make([]string, 0)
		for _, f := range params.Select {
			if f.Flag == domain.ManualField {
				fields = append(fields, "'' "+f.SQLName)
				continue
			}
			fields = append(fields, f.SQLName)
		}
		query.Select(strings.Join(fields, ","))
	}
	if len(params.Where) > 0 {
		for _, w := range params.Where {
			w.SetWhere(query)
		}
	}
}

func QueryCount(params QueryOptions) (int64, error) {
	var total int64
	query := DB.Table(params.TableName)
	queryWithoutLimitOffset(query, params)
	query.Count(&total)
	return total, query.Error
}

func WrapQueryCountWithDB(params QueryOptions, db *gorm.DB) func() (int64, error) {
	return func() (int64, error) {
		var total int64
		query := db.Table(params.TableName)
		queryWithoutLimitOffset(query, params)
		if params.Context != nil {
			query.Where(fmt.Sprintf("context->>'companyId'='%v'", params.Context.CompanyId))
		}
		query.Count(&total)
		return total, query.Error
	}
}

func ArrayInterfaceToString(args []interface{}) []string {
	result := make([]string, 0)
	for _, arg := range args {
		result = append(result, AssertString(arg))
	}
	return result
}