package pg

import (
	"context"
	"fmt"
	"time"

	"github.com/go-pg/pg/v10/orm"
)

type BeforeScanHook = orm.BeforeScanHook
type AfterScanHook = orm.AfterScanHook
type AfterSelectHook = orm.AfterSelectHook
type BeforeInsertHook = orm.BeforeInsertHook
type AfterInsertHook = orm.AfterInsertHook
type BeforeUpdateHook = orm.BeforeUpdateHook
type AfterUpdateHook = orm.AfterUpdateHook
type BeforeDeleteHook = orm.BeforeDeleteHook
type AfterDeleteHook = orm.AfterDeleteHook

//------------------------------------------------------------------------------

type dummyFormatter struct{}

func (dummyFormatter) FormatQuery(b []byte, query string, params ...interface{}) []byte {
	return append(b, query...)
}

// QueryEvent ...
type QueryEvent struct {
	StartTime  time.Time
	DB         orm.DB
	Model      interface{}
	Query      interface{}
	Params     []interface{}
	fmtedQuery []byte
	Result     Result
	Err        error

	Stash map[interface{}]interface{}
}

// QueryHook ...
type QueryHook interface {
	BeforeQuery(context.Context, *QueryEvent) (context.Context, error)
	AfterQuery(context.Context, *QueryEvent) error
}

// UnformattedQuery returns the unformatted query of a query event.
// The query is only valid until the query Result is returned to the user.
func (e *QueryEvent) UnformattedQuery() ([]byte, error) {
	return queryString(e.Query)
}

func queryString(query interface{}) ([]byte, error) {
	switch query := query.(type) {
	case orm.TemplateAppender:
		return query.AppendTemplate(nil)
	case string:
		return dummyFormatter{}.FormatQuery(nil, query), nil
	default:
		return nil, fmt.Errorf("pg: can't append %T", query)
	}
}

// FormattedQuery returns the formatted query of a query event.
// The query is only valid until the query Result is returned to the user.
func (e *QueryEvent) FormattedQuery() ([]byte, error) {
	return e.fmtedQuery, nil
}

// AddQueryHook adds a hook into query processing.
func (db *baseDB) AddQueryHook(hook QueryHook) {
	db.queryHooks = append(db.queryHooks, hook)
}

func (db *baseDB) beforeQuery(
	ctx context.Context,
	ormDB orm.DB,
	model, query interface{},
	params []interface{},
	fmtedQuery []byte,
) (context.Context, *QueryEvent, error) {
	if len(db.queryHooks) == 0 {
		return ctx, nil, nil
	}

	event := &QueryEvent{
		StartTime:  time.Now(),
		DB:         ormDB,
		Model:      model,
		Query:      query,
		Params:     params,
		fmtedQuery: fmtedQuery,
	}

	for _, hook := range db.queryHooks {
		var err error
		ctx, err = hook.BeforeQuery(ctx, event)
		if err != nil {
			return nil, nil, err
		}
	}

	return ctx, event, nil
}

func (db *baseDB) afterQuery(
	ctx context.Context,
	event *QueryEvent,
	res Result,
	err error,
) error {
	if event == nil {
		return nil
	}

	event.Err = err
	event.Result = res

	for _, hook := range db.queryHooks {
		err := hook.AfterQuery(ctx, event)
		if err != nil {
			return err
		}
	}

	return nil
}

func copyQueryHooks(s []QueryHook) []QueryHook {
	return s[:len(s):len(s)]
}