loginstatuscheck_middleware.go 1.7 KB
package middleware

import (
	"github.com/zeromicro/go-zero/core/mathx"
	"github.com/zeromicro/go-zero/rest/httpx"
	"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/contextdata"
	"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/xerr"
	"net/http"
)

type LoginStatusCheckMiddleware struct {
	compareFunc func(int64, string) error
	secret      string
	proba       *mathx.Proba
	kProba      float64
}

// NewLoginStatusCheckMiddleware 登录状态验证,kProba代表验证token的概率,需要小于1,越小校验的概率越低(不需要每次接口调用都验证)
func NewLoginStatusCheckMiddleware(fn func(int64, string) error, secret string, kProba float64) *LoginStatusCheckMiddleware {
	return &LoginStatusCheckMiddleware{
		compareFunc: fn,
		secret:      secret,
		proba:       mathx.NewProba(),
		kProba:      kProba,
	}
}

func (m *LoginStatusCheckMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		if m.compareFunc == nil {
			return
		}
		token := r.Header.Get("Authorization")
		if len(token) < 7 {
			//httpx.ErrorCtx(r.Context(), w, xerr.NewCodeErr(xerr.TokenExpireError, nil))
			httpx.WriteJson(w, http.StatusUnauthorized, xerr.Error(xerr.TokenExpireError, ""))
			return
		}
		token = token[7:]
		if tmpCtx, err := contextdata.ParseToken(r.Context(), m.secret, token); err != nil {
			httpx.WriteJson(w, http.StatusUnauthorized, xerr.Error(xerr.TokenExpireError, ""))
			return
		} else {
			if m.proba.TrueOnProba(m.kProba) {
				userToken := contextdata.GetUserTokenFromCtx(tmpCtx)
				if err = m.compareFunc(userToken.UserId, token); err != nil {
					httpx.WriteJson(w, http.StatusUnauthorized, xerr.Error(xerr.TokenExpireError, err.Error()))
					return
				}
			}
		}
		next(w, r)
	}
}