作者 yangfu

token验证修改

... ... @@ -152,12 +152,12 @@ func (svr AuthService) AuthLoginQrcodeBinding(bindingCmd *command.QrcodeBindingC
qrmsg := domain.QrcodeMessage{}
err := qrmsg.ParseToken(bindingCmd.Key)
if err != nil {
return nil, application.ThrowError(application.TRANSACTION_ERROR, "二维码已失效,请重试")
return nil, application.ThrowError(application.TRANSACTION_ERROR, "您扫描的二维码无效,请确认后重新扫描")
}
qrCache := cache.LoginQrcodeCache{}
qrmsgCache, err := qrCache.Get(qrmsg.Id)
if err != nil {
return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
return nil, application.ThrowError(application.TRANSACTION_ERROR, "您扫描的二维码无效,请确认后重新扫描")
}
if err := qrmsgCache.BindUser(bindingCmd.Operator); err != nil {
return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
... ... @@ -240,14 +240,28 @@ func (svr AuthService) GetAuthAccessToken(accessTokenCommand *command.AccessToke
func (svr AuthService) RefreshAuthAccessToken(refreshTokenCommand *command.RefreshTokenCommand) (interface{}, error) {
if err := refreshTokenCommand.ValidateCommand(); err != nil {
return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
return nil, domain.NewApplicationError(domain.InvalidRefreshToken)
}
loginToken := domain.LoginToken{}
err := loginToken.ParseToken(refreshTokenCommand.RefreshToken)
if err != nil {
return nil, application.ThrowError(application.TRANSACTION_ERROR, "refreshToken 不可用,"+err.Error())
return nil, domain.NewApplicationError(domain.InvalidRefreshToken)
}
//redis缓存
tokenCache := cache.LoginTokenCache{}
refreshToken, err := tokenCache.GetRefreshToken(loginToken.Account, loginToken.Platform)
if err != nil {
log.Logger.Debug(err.Error())
return nil, domain.NewApplicationError(domain.InvalidRefreshToken)
}
if refreshToken != refreshTokenCommand.RefreshToken {
return nil, domain.NewApplicationError(domain.InvalidRefreshToken)
}
token, err := svr.getToken(loginToken)
if err != nil {
return nil, domain.NewApplicationError(domain.InvalidRefreshToken)
}
return map[string]interface{}{
"access": token["token"],
}, err
... ... @@ -540,14 +554,18 @@ loopUser1:
//redis缓存
tokenCache := cache.LoginTokenCache{}
tokenCache.RemoveAccessToken(currentAccess.Account, loginToken.Platform)
tokenCache.RemoveRefreshToken(currentAccess.Account, loginToken.Platform)
tokenCache.SaveAccessToken(currentAccess)
tokenCache.SaveRefreshToken(currentAccess)
//todo:error handler
if err = tokenCache.SaveAccessToken(currentAccess); err != nil {
return nil, err
}
if err = tokenCache.SaveRefreshToken(currentAccess); err != nil {
return nil, err
}
nowTime := time.Now().Unix()
token := map[string]interface{}{
"refreshToken": accessTokenStr,
"accessToken": refreshTokenStr,
"refreshToken": refreshTokenStr,
"accessToken": accessTokenStr,
"expiresIn": currentAccess.AccessExpired - nowTime,
}
return map[string]interface{}{
... ...
package domain
import "time"
import (
"github.com/linmadan/egglib-go/core/application"
"time"
)
//登录的平台
const (
... ... @@ -19,6 +22,22 @@ const (
DeviceTypeWeb = "4"
)
const (
InvalidAccessToken = 901
InvalidRefreshToken = 902
InvalidSign = 903
InvalidClientId = 904
InvalidUUid = 905
)
var codeMsg = map[int]string{
InvalidAccessToken: "access token 过期或无效,需刷新令牌",
InvalidRefreshToken: "refresh token 过期或失效,需重新进行登录认证操作",
InvalidSign: "sign 签名无效,需重新登录手机 APP",
InvalidClientId: "client id 或 client secret 无效,需强制更新手机 APP",
InvalidUUid: "uuid 无效",
}
// 登录凭证存储
type LoginAccess struct {
LoginAccessId int64 `json:"loginAccessId"`
... ... @@ -93,3 +112,24 @@ func (loginAccess *LoginAccess) Update(data map[string]interface{}) error {
}
return nil
}
func ParsePlatform(deviceType string) string {
if deviceType == DeviceTypeWeb {
return LoginPlatformWeb
}
return LoginPlatformApp
}
func ParseCodeMsg(code int) string {
if v, ok := codeMsg[code]; ok {
return v
}
return ""
}
func NewApplicationError(code int) error {
return &application.ServiceError{
Code: code,
Message: ParseCodeMsg(code),
}
}
... ...
... ... @@ -2,6 +2,7 @@ package domain
import (
"fmt"
"gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/util"
"time"
jwt "github.com/dgrijalva/jwt-go"
... ... @@ -12,6 +13,9 @@ const (
qrcodeCodeExpire int64 = 60 * 30 //15分钟过期
)
var aecSecret = []byte("mmm.qrcode.(%^&)")
var loginHost = "https://api.fjmaimaimai.com/app/auth/login/qrcode?key="
type QrcodeMessage struct {
jwt.StandardClaims
Id string `json:"id"`
... ... @@ -42,10 +46,18 @@ func (qrmsg *QrcodeMessage) GenerateImageBase64() ([]byte, error) {
if err != nil {
return nil, err
}
key := loginHost + str
encryptedData, err := util.AesEncrypt([]byte(key), aecSecret)
//初始化数据
qrmsg.Token = str
qrmsg.Token = string(encryptedData)
qrmsg.IsLogin = false
// 输入日志
//decrypted,_:= util.AesDecrypt(encryptedData,aecSecret)
//if string(decrypted)==key{
// log.Println("token:",str,"\n encrypt:",key,"\n decrypt:",string(decrypted))
//}
//qrCode, err := qr.Encode(str, qr.M, qr.Auto)
//if err != nil {
// return nil, err
... ... @@ -61,7 +73,7 @@ func (qrmsg *QrcodeMessage) GenerateImageBase64() ([]byte, error) {
//}
//var result []byte
//base64.StdEncoding.Encode(result, buf.Bytes())
return []byte(str), err
return encryptedData, err
}
func (qrmsg *QrcodeMessage) ParseToken(str string) error {
... ...
... ... @@ -11,12 +11,12 @@ type LoginTokenCache struct {
}
func (ca LoginTokenCache) keyAccessToken(account string, platform string) string {
str := KEY_PREFIX + "accesstoken:" + account + ":" + platform
str := KEY_PREFIX + "access-token:" + account + ":" + platform
return str
}
func (ca LoginTokenCache) keyRefreshToken(account string, platform string) string {
str := KEY_PREFIX + "refreshtoken" + account + ":" + platform
str := KEY_PREFIX + "refresh-token:" + account + ":" + platform
return str
}
... ... @@ -27,7 +27,7 @@ func (ca LoginTokenCache) SaveAccessToken(access *domain.LoginAccess) error {
exp = 60 * 60 * 2
}
key := ca.keyAccessToken(access.Account, access.Platform)
result := clientRedis.Set(key, access.AccessToken, time.Duration(exp))
result := clientRedis.Set(key, access.AccessToken, time.Duration(exp)*time.Second)
return result.Err()
}
... ... @@ -49,8 +49,8 @@ func (ca LoginTokenCache) SaveRefreshToken(access *domain.LoginAccess) error {
if exp <= 0 {
exp = 60 * 60 * 2
}
key := ca.keyAccessToken(access.Account, access.Platform)
result := clientRedis.Set(key, access.RefreshToken, time.Duration(exp))
key := ca.keyRefreshToken(access.Account, access.Platform)
result := clientRedis.Set(key, access.RefreshToken, time.Duration(exp)*time.Second)
return result.Err()
}
... ...
... ... @@ -4,6 +4,7 @@ import (
"encoding/json"
"github.com/beego/beego/v2/server/web/context"
"gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/constant"
"gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/port/beego/middleware"
"os"
"strconv"
... ... @@ -31,6 +32,8 @@ func init() {
}
}
filters.SecureKeyMap["token"] = "x-mmm-accesstoken"
//TODO:token验证改为 /v1
web.InsertFilterChain("/v1/app/*", middleware.CheckAccessToken)
web.InsertFilter("/*", web.BeforeRouter, filters.AllowCors())
web.InsertFilter("/*", web.BeforeRouter, filters.CreateRequstLogFilter(log.Logger))
web.InsertFilter("/*", web.AfterExec, filters.CreateResponseLogFilter(log.Logger), web.WithReturnOnOutput(false))
... ...
... ... @@ -7,6 +7,7 @@ import (
"github.com/linmadan/egglib-go/web/beego/utils"
"gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/domain"
"gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/log"
"gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/port/beego/middleware"
)
type BaseController struct {
... ... @@ -77,6 +78,10 @@ func (controller *BaseController) GetOperator() domain.Operator {
err := loginToken.ParseToken(token)
if err != nil {
log.Logger.Error(err.Error())
*loginToken, _ = middleware.FormCtxLoginToken(controller.Ctx)
}
if tmpToken, ok := middleware.FormCtxLoginToken(controller.Ctx); ok {
log.Logger.Debug(json.MarshalToString(tmpToken))
}
op := domain.Operator{
UserId: loginToken.UserId,
... ... @@ -94,7 +99,7 @@ func (controller *BaseController) GetOperator() domain.Operator {
op.UserBaseId = 1
}
// TODO:打印测试日志
log.Logger.Debug("operator " + json.MarshalToString(op))
//log.Logger.Debug("operator " + json.MarshalToString(op))
return op
}
... ...
package middleware
import (
"fmt"
"github.com/beego/beego/v2/server/web"
"github.com/beego/beego/v2/server/web/context"
"gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/domain"
"gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/infrastructure/cache"
"gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/log"
log1 "log"
"net/url"
)
type CtxKeyLoginToken struct{}
func JWTAuth(ctx *context.Context) {
tokenStr := ctx.Input.Header("xxxx")
tk := domain.LoginToken{}
err := tk.ParseToken(tokenStr)
if err != nil {
//
return
}
func CheckAccessToken(next web.FilterFunc) web.FilterFunc {
return func(ctx *context.Context) {
tokenStr := ctx.Input.Header("x-mmm-accesstoken")
filterMap := map[string]string{
"/v1/auth/login/pwd": "",
"/v1/auth/login/sms": "",
"/v1/auth/login/qrcode": "",
"/v1/auth/captcha-init": "",
"/v1/auth/qrcode-init": "",
"/v1/auth/sms-code": "",
"/v1/auth/check-sms-code": "",
"/v1/auth/company-sign-up": "",
"/v1/auth/reset-password": "",
"/v1/auth/refresh-token": "",
}
var err error
if filterUrl, err := url.Parse(ctx.Request.RequestURI); err == nil {
// 不需要验证的接口
if _, ok := filterMap[filterUrl.Path]; ok {
next(ctx)
return
}
} else {
log.Logger.Error("parse url error:" + err.Error())
}
defer func() {
if err != nil {
ctx.Output.JSON(map[string]interface{}{
"msg": domain.ParseCodeMsg(domain.InvalidAccessToken),
"code": domain.InvalidAccessToken,
"data": struct{}{},
}, false, false)
}
}()
tk := &domain.LoginToken{}
err = tk.ParseToken(tokenStr)
if err != nil {
log.Logger.Error(err.Error())
return
}
platform := domain.ParsePlatform(ctx.Input.Header("x-mmm-devicetype"))
//redis缓存
tokenCache := cache.LoginTokenCache{}
token, err := tokenCache.GetAccessToken(tk.Account, platform)
if err != nil {
log.Logger.Error(err.Error())
return
}
if token != tokenStr {
log1.Println("token not equal \n" + tk.Account + "\n" + tokenStr + "\n" + token)
err = fmt.Errorf("access token not exists")
return
}
ctx.Input.SetData(CtxKeyLoginToken{}, tk)
next(ctx)
}
ctx.Input.SetData(CtxKeyLoginToken{}, domain.LoginToken{})
}
func NewCtxLoginToken(ctx *context.Context, tk domain.LoginToken) {
... ...
package util
import (
"bytes"
"crypto/aes"
"crypto/cipher"
)
func PKCS5Padding(plaintext []byte, blockSize int) []byte {
padding := blockSize - len(plaintext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(plaintext, padtext...)
}
//@brief:去除填充数据
func PKCS5UnPadding(origData []byte) []byte {
length := len(origData)
unpadding := int(origData[length-1])
return origData[:(length - unpadding)]
}
//@brief:AES加密
func AesEncrypt(origData, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
//AES分组长度为128位,所以blockSize=16,单位字节
blockSize := block.BlockSize()
origData = PKCS5Padding(origData, blockSize)
blockMode := cipher.NewCBCEncrypter(block, key[:blockSize]) //初始向量的长度必须等于块block的长度16字节
crypted := make([]byte, len(origData))
blockMode.CryptBlocks(crypted, origData)
return crypted, nil
}
//@brief:AES解密
func AesDecrypt(crypted, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
//AES分组长度为128位,所以blockSize=16,单位字节
blockSize := block.BlockSize()
blockMode := cipher.NewCBCDecrypter(block, key[:blockSize]) //初始向量的长度必须等于块block的长度16字节
origData := make([]byte, len(crypted))
blockMode.CryptBlocks(origData, crypted)
origData = PKCS5UnPadding(origData)
return origData, nil
}
... ...
package util
import (
"encoding/base64"
"fmt"
"testing"
)
func Test_Aes(t *testing.T) {
//key的长度必须是16、24或者32字节,分别用于选择AES-128, AES-192, or AES-256
var aeskey = []byte("12345678abcdefgh")
pass := []byte("vdncloud123456")
xpass, err := AesEncrypt(pass, aeskey)
if err != nil {
fmt.Println(err)
return
}
pass64 := base64.StdEncoding.EncodeToString(xpass)
fmt.Printf("加密后:%v\n", pass64)
bytesPass, err := base64.StdEncoding.DecodeString(pass64)
if err != nil {
fmt.Println(err)
return
}
tpass, err := AesDecrypt(bytesPass, aeskey)
if err != nil {
fmt.Println(err)
return
}
fmt.Printf("解密后:%s\n", tpass)
}
... ...