pg_evaluation_project_repository.go 6.4 KB
package repository

import (
	"errors"
	"fmt"
	"time"

	"github.com/go-pg/pg/v10"
	"github.com/linmadan/egglib-go/persistent/pg/sqlbuilder"
	pgTransaction "github.com/linmadan/egglib-go/transaction/pg"
	"gitlab.fjmaimaimai.com/allied-creation/performance/pkg/domain"
	"gitlab.fjmaimaimai.com/allied-creation/performance/pkg/infrastructure/pg/models"
	"gitlab.fjmaimaimai.com/allied-creation/performance/pkg/utils"
)

type EvaluationProjectRepository struct {
	transactionContext *pgTransaction.TransactionContext
}

func NewEvaluationProjectRepository(transactionContext *pgTransaction.TransactionContext) *EvaluationProjectRepository {
	return &EvaluationProjectRepository{transactionContext: transactionContext}
}

func (repo *EvaluationProjectRepository) TransformToDomain(m *models.EvaluationProject) domain.EvaluationProject {
	// 历史旧数据,默认输出转成1
	if m.Template != nil {
		for i := range m.Template.LinkNodes {
			node := m.Template.LinkNodes[i]
			for j := range node.NodeContents {
				if node.NodeContents[j].Required == 0 {
					node.NodeContents[j].Required = domain.NodeRequiredYes
				}
			}
		}
	}

	return domain.EvaluationProject{
		Id:           m.Id,
		Name:         m.Name,
		Describe:     m.Describe,
		CompanyId:    m.CompanyId,
		CycleId:      m.CycleId,
		CreatorId:    m.CreatorId,
		State:        m.State,
		SummaryState: domain.ProjectSummaryState(m.SummaryState),
		HrBp:         m.HrBp,
		Pmp:          m.Pmp,
		PmpIds:       m.PmpIds,
		Recipients:   m.Recipients,
		PrincipalId:  m.PrincipalId,
		Template:     m.Template,
		BeginTime:    m.BeginTime,
		EndTime:      m.EndTime,
		CreatedAt:    m.CreatedAt.Local(),
		UpdatedAt:    m.UpdatedAt.Local(),
		DeletedAt:    m.DeletedAt,
	}
}

func (repo *EvaluationProjectRepository) TransformToModel(d *domain.EvaluationProject) models.EvaluationProject {
	return models.EvaluationProject{
		Id:           d.Id,
		Name:         d.Name,
		Describe:     d.Describe,
		CompanyId:    d.CompanyId,
		CycleId:      d.CycleId,
		CreatorId:    d.CreatorId,
		State:        d.State,
		SummaryState: int(d.SummaryState),
		HrBp:         d.HrBp,
		Pmp:          d.Pmp,
		PmpIds:       d.PmpIds,
		Recipients:   d.Recipients,
		PrincipalId:  d.PrincipalId,
		Template:     d.Template,
		BeginTime:    d.BeginTime,
		EndTime:      d.EndTime,
		CreatedAt:    d.CreatedAt,
		UpdatedAt:    d.UpdatedAt,
		DeletedAt:    d.DeletedAt,
	}
}

func (repo *EvaluationProjectRepository) nextIdentify() (int64, error) {
	return utils.NewSnowflakeId()
}

func (repo *EvaluationProjectRepository) Insert(d *domain.EvaluationProject) (*domain.EvaluationProject, error) {
	var isCreate = d.Id == 0
	if isCreate {
		id, err := repo.nextIdentify()
		if err != nil {
			return d, err
		}
		d.Id = id
		d.CreatedAt = time.Now()
		d.UpdatedAt = d.CreatedAt
	} else {
		d.UpdatedAt = time.Now()
	}
	m := repo.TransformToModel(d)
	tx := repo.transactionContext.PgTx
	var err error
	if isCreate {
		_, err = tx.Model(&m).Returning("id").Insert()
	} else {
		_, err = tx.Model(&m).Returning("id").WherePK().Update() // 更新和删除必须增加条件
	}
	if err != nil {
		return nil, err
	}
	d.Id = m.Id
	return d, nil
}

func (repo *EvaluationProjectRepository) Remove(d *domain.EvaluationProject) (*domain.EvaluationProject, error) {
	tx := repo.transactionContext.PgTx
	nowTime := time.Now()
	m := repo.TransformToModel(d)
	m.DeletedAt = &nowTime
	if _, err := tx.Model(&m).WherePK().Update(); err != nil {
		return d, err
	}
	return d, nil
}

func (repo *EvaluationProjectRepository) FindOne(queryOptions map[string]interface{}) (*domain.EvaluationProject, error) {
	tx := repo.transactionContext.PgTx
	m := new(models.EvaluationProject)
	query := tx.Model(m)
	query.Where("deleted_at isnull")
	if id, ok := queryOptions["id"]; ok {
		query.Where("id=?", id)
	}
	if err := query.First(); err != nil {
		if errors.Is(err, pg.ErrNoRows) {
			return nil, fmt.Errorf("没有此资源")
		} else {
			return nil, err
		}
	}
	u := repo.TransformToDomain(m)
	return &u, nil
}

func (repo *EvaluationProjectRepository) Find(queryOptions map[string]interface{}, excludeColumns ...string) (int64, []*domain.EvaluationProject, error) {
	tx := repo.transactionContext.PgTx
	var m []*models.EvaluationProject
	query := tx.Model(&m).Where("deleted_at isnull")

	if len(excludeColumns) > 0 {
		query.ExcludeColumn(excludeColumns...)
	}

	if v, ok := queryOptions["ids"]; ok {
		query.Where("id in (?)", pg.In(v))
	}

	if v, ok := queryOptions["name"].(string); ok && len(v) > 0 {
		query.Where("name LIKE ?", v)
	}

	if v, ok := queryOptions["companyId"]; ok {
		query.Where("company_id = ?", v)
	}

	if v, ok := queryOptions["cycleId"]; ok {
		query.Where("cycle_id = ?", v)
	}

	if v, ok := queryOptions["state"]; ok && v.(int) >= 0 {
		query.Where("state = ?", v)
	}

	if v, ok := queryOptions["summaryState"]; ok {
		query.Where("summary_state=?", v)
	}

	if v, ok := queryOptions["beginTime"]; ok {
		t := v.(time.Time)
		query.Where("begin_time>=?", t)
	}

	if v, ok := queryOptions["endTime"]; ok {
		t := v.(time.Time)
		query.Where("end_time<=?", t)
	}

	if v, ok := queryOptions["pmpIds"].([]string); ok && len(v) > 0 {
		query.WhereGroup(func(query *pg.Query) (*pg.Query, error) {
			for i := range v {
				query.WhereOr("pmp_ids @> ?", `"`+v[i]+`"`)
			}
			return query, nil
		})
	}

	if v, ok := queryOptions["limit"].(int64); ok {
		query.Limit(int(v))
	}
	if v, ok := queryOptions["offset"].(int64); ok {
		query.Offset(int(v))
	}
	// 按创建时间降序
	query.Order("created_at DESC")

	count, err := query.SelectAndCount()
	if err != nil {
		return 0, nil, err
	}
	var arrays []*domain.EvaluationProject
	for _, v := range m {
		d := repo.TransformToDomain(v)
		arrays = append(arrays, &d)
	}
	return int64(count), arrays, nil
}

func (repo *EvaluationProjectRepository) Count(queryOptions map[string]interface{}) (int64, error) {
	tx := repo.transactionContext.PgTx
	m := new(models.EvaluationProject)
	query := sqlbuilder.BuildQuery(tx.Model(m), queryOptions)
	query.Where("deleted_at isnull")

	if v, ok := queryOptions["id"]; ok {
		query.Where("id = ?", v)
	}

	if v, ok := queryOptions["notId"]; ok {
		query.Where("id != ?", v)
	}

	if v, ok := queryOptions["name"].(string); ok && len(v) > 0 {
		query.Where("name = ?", v)
	}

	if v, ok := queryOptions["companyId"]; ok {
		query.Where("company_id = ?", v)
	}

	if v, ok := queryOptions["cycleId"]; ok {
		query.Where("cycle_id = ?", v)
	}

	count, err := query.Count()
	if err != nil {
		return 0, err
	}
	return int64(count), nil
}