package pg

import (
	"context"
	"fmt"
	"time"

	"github.com/go-pg/pg/v10/orm"
)

type (
	BeforeScanHook   = orm.BeforeScanHook
	AfterScanHook    = orm.AfterScanHook
	AfterSelectHook  = orm.AfterSelectHook
	BeforeInsertHook = orm.BeforeInsertHook
	AfterInsertHook  = orm.AfterInsertHook
	BeforeUpdateHook = orm.BeforeUpdateHook
	AfterUpdateHook  = orm.AfterUpdateHook
	BeforeDeleteHook = orm.BeforeDeleteHook
	AfterDeleteHook  = orm.AfterDeleteHook
)

//------------------------------------------------------------------------------

type dummyFormatter struct{}

func (dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte {
	return append(b, query...)
}

// QueryEvent ...
type QueryEvent struct {
	StartTime  time.Time
	DB         orm.DB
	Model      interface{}
	Query      interface{}
	Params     []interface{}
	fmtedQuery []byte
	Result     Result
	Err        error

	Stash map[interface{}]interface{}
}

// QueryHook ...
type QueryHook interface {
	BeforeQuery(context.Context, *QueryEvent) (context.Context, error)
	AfterQuery(context.Context, *QueryEvent) error
}

// UnformattedQuery returns the unformatted query of a query event.
// The query is only valid until the query Result is returned to the user.
func (e *QueryEvent) UnformattedQuery() ([]byte, error) {
	return queryString(e.Query)
}

func queryString(query interface{}) ([]byte, error) {
	switch query := query.(type) {
	case orm.TemplateAppender:
		return query.AppendTemplate(nil)
	case string:
		return dummyFormatter{}.FormatQuery(nil, query), nil
	default:
		return nil, fmt.Errorf("pg: can't append %T", query)
	}
}

// FormattedQuery returns the formatted query of a query event.
// The query is only valid until the query Result is returned to the user.
func (e *QueryEvent) FormattedQuery() ([]byte, error) {
	return e.fmtedQuery, nil
}

// AddQueryHook adds a hook into query processing.
func (db *baseDB) AddQueryHook(hook QueryHook) {
	db.queryHooks = append(db.queryHooks, hook)
}

func (db *baseDB) beforeQuery(
	ctx context.Context,
	ormDB orm.DB,
	model, query interface{},
	params []interface{},
	fmtedQuery []byte,
) (context.Context, *QueryEvent, error) {
	if len(db.queryHooks) == 0 {
		return ctx, nil, nil
	}

	event := &QueryEvent{
		StartTime:  time.Now(),
		DB:         ormDB,
		Model:      model,
		Query:      query,
		Params:     params,
		fmtedQuery: fmtedQuery,
	}

	for i, hook := range db.queryHooks {
		var err error
		ctx, err = hook.BeforeQuery(ctx, event)
		if err != nil {
			if err := db.afterQueryFromIndex(ctx, event, i); err != nil {
				return ctx, nil, err
			}
			return ctx, nil, err
		}
	}

	return ctx, event, nil
}

func (db *baseDB) afterQuery(
	ctx context.Context,
	event *QueryEvent,
	res Result,
	err error,
) error {
	if event == nil {
		return nil
	}

	event.Err = err
	event.Result = res
	return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1)
}

func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error {
	for ; hookIndex >= 0; hookIndex-- {
		if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil {
			return err
		}
	}
	return nil
}

func copyQueryHooks(s []QueryHook) []QueryHook {
	return s[:len(s):len(s)]
}