user_auth.go 3.6 KB
package userAuth

import (
	"encoding/json"
	"fmt"
	"github.com/tiptok/gocomm/pkg/redis"
	"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/infrastructure/utils"
	"strconv"
)

var (
	errDataType      = fmt.Errorf("auth:data type assert error")
	errNoMatch       = fmt.Errorf("auth:user auth no match")
	errNotFound      = func(field string) error { return fmt.Errorf("auth:hset field (%v) not found", field) }
	errTokenNotEqual = func(target, compare string) error {
		return fmt.Errorf("auth:token not equal (%v != %v)", target, compare)
	}
)

type UserAuthManager interface {
	//添加用户权限
	AddAuth() error
	//获取用户权限
	GetAuth() (interface{}, error)
	//移除用户权限
	RemoveAuth() error
	//检验权限
	//检查 refreshToken assessToken 是否一致
	Check(*Options) error
	//用户权限是否存在
	//true:存在 false:不存在
	Exist() bool
}

var _ UserAuthManager = (*RedisUserAuth)(nil)

type Options struct {
	UserId       int64
	RefreshToken string
	AccessToken  string
}

func NewOptions(options ...Option) *Options {
	Options := &Options{}
	for i := range options {
		options[i](Options)
	}
	return Options
}

type Option func(options *Options)

//option 参数
func WithRefreshToken(token string) Option {
	return func(options *Options) {
		options.RefreshToken = token
	}
}
func WithAccessToken(token string) Option {
	return func(options *Options) {
		options.AccessToken = token
	}
}
func WithUserId(uid int64) Option {
	return func(options *Options) {
		options.UserId = uid
	}
}

//Redis用户权限
type RedisUserAuth struct {
	Options *Options
}

func NewRedisUserAuth(options ...Option) *RedisUserAuth {
	rua := &RedisUserAuth{
		Options: NewOptions(options...),
	}
	return rua
}
func (auth RedisUserAuth) AddAuth() error {
	err := redis.Hset(
		auth.redisKey(),
		auth.field(),
		utils.JsonAssertString(NewRedisUserAuthData(auth.Options)), 0,
	)
	return err
}
func (auth RedisUserAuth) RemoveAuth() error {
	if !auth.Exist() {
		return nil
	}
	return redis.Hdel(auth.redisKey(), auth.field())
}
func (auth RedisUserAuth) GetAuth() (interface{}, error) {
	if !auth.Exist() {
		return nil, errNotFound(auth.field())
	}
	data, err := redis.Hget(auth.redisKey(), auth.field())
	if err != nil {
		return nil, err
	}
	var authData *RedisUserAuthData
	if err = json.Unmarshal([]byte(data), &authData); err != nil {
		return nil, err
	}
	return authData, nil
}
func (auth RedisUserAuth) Check(options *Options) error {
	data, err := auth.GetAuth()
	if err != nil {
		return err
	}
	authData, ok := data.(*RedisUserAuthData)
	if !ok {
		return errDataType
	}
	if options.AccessToken != "" {
		if authData.AccessToken != options.AccessToken {
			return errTokenNotEqual(authData.AccessToken, options.AccessToken)
		}
		return nil
	}
	if options.RefreshToken != "" {
		if authData.RefreshToken != options.RefreshToken {
			return errTokenNotEqual(authData.RefreshToken, options.RefreshToken)
		}
		return nil
	}
	return errNoMatch
}
func (auth RedisUserAuth) Exist() bool {
	return redis.Hexists(auth.redisKey(), auth.field())
}
func (auth RedisUserAuth) redisKey() string {
	if auth.Options.UserId == 0 {
		return ""
	}
	return utils.RedisKey("user_auth")
}
func (auth RedisUserAuth) field() string {
	return strconv.Itoa(int(auth.Options.UserId))
}

//存储到redis的数据结构
type RedisUserAuthData struct {
	UserId       int64  `json:"userId"`
	RefreshToken string `json:"refreshToken"`
	AccessToken  string `json:"accessToken"`
}

func NewRedisUserAuthData(options *Options) RedisUserAuthData {
	return RedisUserAuthData{
		UserId:       options.UserId,
		RefreshToken: options.RefreshToken,
		AccessToken:  options.AccessToken,
	}
}