作者 yangfu

token验证修改

@@ -152,12 +152,12 @@ func (svr AuthService) AuthLoginQrcodeBinding(bindingCmd *command.QrcodeBindingC @@ -152,12 +152,12 @@ func (svr AuthService) AuthLoginQrcodeBinding(bindingCmd *command.QrcodeBindingC
152 qrmsg := domain.QrcodeMessage{} 152 qrmsg := domain.QrcodeMessage{}
153 err := qrmsg.ParseToken(bindingCmd.Key) 153 err := qrmsg.ParseToken(bindingCmd.Key)
154 if err != nil { 154 if err != nil {
155 - return nil, application.ThrowError(application.TRANSACTION_ERROR, "二维码已失效,请重试") 155 + return nil, application.ThrowError(application.TRANSACTION_ERROR, "您扫描的二维码无效,请确认后重新扫描")
156 } 156 }
157 qrCache := cache.LoginQrcodeCache{} 157 qrCache := cache.LoginQrcodeCache{}
158 qrmsgCache, err := qrCache.Get(qrmsg.Id) 158 qrmsgCache, err := qrCache.Get(qrmsg.Id)
159 if err != nil { 159 if err != nil {
160 - return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error()) 160 + return nil, application.ThrowError(application.TRANSACTION_ERROR, "您扫描的二维码无效,请确认后重新扫描")
161 } 161 }
162 if err := qrmsgCache.BindUser(bindingCmd.Operator); err != nil { 162 if err := qrmsgCache.BindUser(bindingCmd.Operator); err != nil {
163 return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error()) 163 return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
@@ -240,14 +240,28 @@ func (svr AuthService) GetAuthAccessToken(accessTokenCommand *command.AccessToke @@ -240,14 +240,28 @@ func (svr AuthService) GetAuthAccessToken(accessTokenCommand *command.AccessToke
240 240
241 func (svr AuthService) RefreshAuthAccessToken(refreshTokenCommand *command.RefreshTokenCommand) (interface{}, error) { 241 func (svr AuthService) RefreshAuthAccessToken(refreshTokenCommand *command.RefreshTokenCommand) (interface{}, error) {
242 if err := refreshTokenCommand.ValidateCommand(); err != nil { 242 if err := refreshTokenCommand.ValidateCommand(); err != nil {
243 - return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error()) 243 + return nil, domain.NewApplicationError(domain.InvalidRefreshToken)
244 } 244 }
245 loginToken := domain.LoginToken{} 245 loginToken := domain.LoginToken{}
246 err := loginToken.ParseToken(refreshTokenCommand.RefreshToken) 246 err := loginToken.ParseToken(refreshTokenCommand.RefreshToken)
247 if err != nil { 247 if err != nil {
248 - return nil, application.ThrowError(application.TRANSACTION_ERROR, "refreshToken 不可用,"+err.Error()) 248 + return nil, domain.NewApplicationError(domain.InvalidRefreshToken)
  249 + }
  250 + //redis缓存
  251 + tokenCache := cache.LoginTokenCache{}
  252 + refreshToken, err := tokenCache.GetRefreshToken(loginToken.Account, loginToken.Platform)
  253 + if err != nil {
  254 + log.Logger.Debug(err.Error())
  255 + return nil, domain.NewApplicationError(domain.InvalidRefreshToken)
  256 + }
  257 + if refreshToken != refreshTokenCommand.RefreshToken {
  258 + return nil, domain.NewApplicationError(domain.InvalidRefreshToken)
249 } 259 }
  260 +
250 token, err := svr.getToken(loginToken) 261 token, err := svr.getToken(loginToken)
  262 + if err != nil {
  263 + return nil, domain.NewApplicationError(domain.InvalidRefreshToken)
  264 + }
251 return map[string]interface{}{ 265 return map[string]interface{}{
252 "access": token["token"], 266 "access": token["token"],
253 }, err 267 }, err
@@ -540,14 +554,18 @@ loopUser1: @@ -540,14 +554,18 @@ loopUser1:
540 554
541 //redis缓存 555 //redis缓存
542 tokenCache := cache.LoginTokenCache{} 556 tokenCache := cache.LoginTokenCache{}
543 - tokenCache.RemoveAccessToken(currentAccess.Account, loginToken.Platform)  
544 - tokenCache.RemoveRefreshToken(currentAccess.Account, loginToken.Platform)  
545 - tokenCache.SaveAccessToken(currentAccess)  
546 - tokenCache.SaveRefreshToken(currentAccess) 557 +
  558 + //todo:error handler
  559 + if err = tokenCache.SaveAccessToken(currentAccess); err != nil {
  560 + return nil, err
  561 + }
  562 + if err = tokenCache.SaveRefreshToken(currentAccess); err != nil {
  563 + return nil, err
  564 + }
547 nowTime := time.Now().Unix() 565 nowTime := time.Now().Unix()
548 token := map[string]interface{}{ 566 token := map[string]interface{}{
549 - "refreshToken": accessTokenStr,  
550 - "accessToken": refreshTokenStr, 567 + "refreshToken": refreshTokenStr,
  568 + "accessToken": accessTokenStr,
551 "expiresIn": currentAccess.AccessExpired - nowTime, 569 "expiresIn": currentAccess.AccessExpired - nowTime,
552 } 570 }
553 return map[string]interface{}{ 571 return map[string]interface{}{
1 package domain 1 package domain
2 2
3 -import "time" 3 +import (
  4 + "github.com/linmadan/egglib-go/core/application"
  5 + "time"
  6 +)
4 7
5 //登录的平台 8 //登录的平台
6 const ( 9 const (
@@ -19,6 +22,22 @@ const ( @@ -19,6 +22,22 @@ const (
19 DeviceTypeWeb = "4" 22 DeviceTypeWeb = "4"
20 ) 23 )
21 24
  25 +const (
  26 + InvalidAccessToken = 901
  27 + InvalidRefreshToken = 902
  28 + InvalidSign = 903
  29 + InvalidClientId = 904
  30 + InvalidUUid = 905
  31 +)
  32 +
  33 +var codeMsg = map[int]string{
  34 + InvalidAccessToken: "access token 过期或无效,需刷新令牌",
  35 + InvalidRefreshToken: "refresh token 过期或失效,需重新进行登录认证操作",
  36 + InvalidSign: "sign 签名无效,需重新登录手机 APP",
  37 + InvalidClientId: "client id 或 client secret 无效,需强制更新手机 APP",
  38 + InvalidUUid: "uuid 无效",
  39 +}
  40 +
22 // 登录凭证存储 41 // 登录凭证存储
23 type LoginAccess struct { 42 type LoginAccess struct {
24 LoginAccessId int64 `json:"loginAccessId"` 43 LoginAccessId int64 `json:"loginAccessId"`
@@ -93,3 +112,24 @@ func (loginAccess *LoginAccess) Update(data map[string]interface{}) error { @@ -93,3 +112,24 @@ func (loginAccess *LoginAccess) Update(data map[string]interface{}) error {
93 } 112 }
94 return nil 113 return nil
95 } 114 }
  115 +
  116 +func ParsePlatform(deviceType string) string {
  117 + if deviceType == DeviceTypeWeb {
  118 + return LoginPlatformWeb
  119 + }
  120 + return LoginPlatformApp
  121 +}
  122 +
  123 +func ParseCodeMsg(code int) string {
  124 + if v, ok := codeMsg[code]; ok {
  125 + return v
  126 + }
  127 + return ""
  128 +}
  129 +
  130 +func NewApplicationError(code int) error {
  131 + return &application.ServiceError{
  132 + Code: code,
  133 + Message: ParseCodeMsg(code),
  134 + }
  135 +}
@@ -2,6 +2,7 @@ package domain @@ -2,6 +2,7 @@ package domain
2 2
3 import ( 3 import (
4 "fmt" 4 "fmt"
  5 + "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/util"
5 "time" 6 "time"
6 7
7 jwt "github.com/dgrijalva/jwt-go" 8 jwt "github.com/dgrijalva/jwt-go"
@@ -12,6 +13,9 @@ const ( @@ -12,6 +13,9 @@ const (
12 qrcodeCodeExpire int64 = 60 * 30 //15分钟过期 13 qrcodeCodeExpire int64 = 60 * 30 //15分钟过期
13 ) 14 )
14 15
  16 +var aecSecret = []byte("mmm.qrcode.(%^&)")
  17 +var loginHost = "https://api.fjmaimaimai.com/app/auth/login/qrcode?key="
  18 +
15 type QrcodeMessage struct { 19 type QrcodeMessage struct {
16 jwt.StandardClaims 20 jwt.StandardClaims
17 Id string `json:"id"` 21 Id string `json:"id"`
@@ -42,10 +46,18 @@ func (qrmsg *QrcodeMessage) GenerateImageBase64() ([]byte, error) { @@ -42,10 +46,18 @@ func (qrmsg *QrcodeMessage) GenerateImageBase64() ([]byte, error) {
42 if err != nil { 46 if err != nil {
43 return nil, err 47 return nil, err
44 } 48 }
  49 + key := loginHost + str
  50 + encryptedData, err := util.AesEncrypt([]byte(key), aecSecret)
45 //初始化数据 51 //初始化数据
46 - qrmsg.Token = str 52 + qrmsg.Token = string(encryptedData)
47 qrmsg.IsLogin = false 53 qrmsg.IsLogin = false
48 54
  55 + // 输入日志
  56 + //decrypted,_:= util.AesDecrypt(encryptedData,aecSecret)
  57 + //if string(decrypted)==key{
  58 + // log.Println("token:",str,"\n encrypt:",key,"\n decrypt:",string(decrypted))
  59 + //}
  60 +
49 //qrCode, err := qr.Encode(str, qr.M, qr.Auto) 61 //qrCode, err := qr.Encode(str, qr.M, qr.Auto)
50 //if err != nil { 62 //if err != nil {
51 // return nil, err 63 // return nil, err
@@ -61,7 +73,7 @@ func (qrmsg *QrcodeMessage) GenerateImageBase64() ([]byte, error) { @@ -61,7 +73,7 @@ func (qrmsg *QrcodeMessage) GenerateImageBase64() ([]byte, error) {
61 //} 73 //}
62 //var result []byte 74 //var result []byte
63 //base64.StdEncoding.Encode(result, buf.Bytes()) 75 //base64.StdEncoding.Encode(result, buf.Bytes())
64 - return []byte(str), err 76 + return encryptedData, err
65 } 77 }
66 78
67 func (qrmsg *QrcodeMessage) ParseToken(str string) error { 79 func (qrmsg *QrcodeMessage) ParseToken(str string) error {
@@ -11,12 +11,12 @@ type LoginTokenCache struct { @@ -11,12 +11,12 @@ type LoginTokenCache struct {
11 } 11 }
12 12
13 func (ca LoginTokenCache) keyAccessToken(account string, platform string) string { 13 func (ca LoginTokenCache) keyAccessToken(account string, platform string) string {
14 - str := KEY_PREFIX + "accesstoken:" + account + ":" + platform 14 + str := KEY_PREFIX + "access-token:" + account + ":" + platform
15 return str 15 return str
16 } 16 }
17 17
18 func (ca LoginTokenCache) keyRefreshToken(account string, platform string) string { 18 func (ca LoginTokenCache) keyRefreshToken(account string, platform string) string {
19 - str := KEY_PREFIX + "refreshtoken" + account + ":" + platform 19 + str := KEY_PREFIX + "refresh-token:" + account + ":" + platform
20 return str 20 return str
21 } 21 }
22 22
@@ -27,7 +27,7 @@ func (ca LoginTokenCache) SaveAccessToken(access *domain.LoginAccess) error { @@ -27,7 +27,7 @@ func (ca LoginTokenCache) SaveAccessToken(access *domain.LoginAccess) error {
27 exp = 60 * 60 * 2 27 exp = 60 * 60 * 2
28 } 28 }
29 key := ca.keyAccessToken(access.Account, access.Platform) 29 key := ca.keyAccessToken(access.Account, access.Platform)
30 - result := clientRedis.Set(key, access.AccessToken, time.Duration(exp)) 30 + result := clientRedis.Set(key, access.AccessToken, time.Duration(exp)*time.Second)
31 return result.Err() 31 return result.Err()
32 } 32 }
33 33
@@ -49,8 +49,8 @@ func (ca LoginTokenCache) SaveRefreshToken(access *domain.LoginAccess) error { @@ -49,8 +49,8 @@ func (ca LoginTokenCache) SaveRefreshToken(access *domain.LoginAccess) error {
49 if exp <= 0 { 49 if exp <= 0 {
50 exp = 60 * 60 * 2 50 exp = 60 * 60 * 2
51 } 51 }
52 - key := ca.keyAccessToken(access.Account, access.Platform)  
53 - result := clientRedis.Set(key, access.RefreshToken, time.Duration(exp)) 52 + key := ca.keyRefreshToken(access.Account, access.Platform)
  53 + result := clientRedis.Set(key, access.RefreshToken, time.Duration(exp)*time.Second)
54 return result.Err() 54 return result.Err()
55 } 55 }
56 56
@@ -4,6 +4,7 @@ import ( @@ -4,6 +4,7 @@ import (
4 "encoding/json" 4 "encoding/json"
5 "github.com/beego/beego/v2/server/web/context" 5 "github.com/beego/beego/v2/server/web/context"
6 "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/constant" 6 "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/constant"
  7 + "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/port/beego/middleware"
7 "os" 8 "os"
8 "strconv" 9 "strconv"
9 10
@@ -31,6 +32,8 @@ func init() { @@ -31,6 +32,8 @@ func init() {
31 } 32 }
32 } 33 }
33 filters.SecureKeyMap["token"] = "x-mmm-accesstoken" 34 filters.SecureKeyMap["token"] = "x-mmm-accesstoken"
  35 + //TODO:token验证改为 /v1
  36 + web.InsertFilterChain("/v1/app/*", middleware.CheckAccessToken)
34 web.InsertFilter("/*", web.BeforeRouter, filters.AllowCors()) 37 web.InsertFilter("/*", web.BeforeRouter, filters.AllowCors())
35 web.InsertFilter("/*", web.BeforeRouter, filters.CreateRequstLogFilter(log.Logger)) 38 web.InsertFilter("/*", web.BeforeRouter, filters.CreateRequstLogFilter(log.Logger))
36 web.InsertFilter("/*", web.AfterExec, filters.CreateResponseLogFilter(log.Logger), web.WithReturnOnOutput(false)) 39 web.InsertFilter("/*", web.AfterExec, filters.CreateResponseLogFilter(log.Logger), web.WithReturnOnOutput(false))
@@ -7,6 +7,7 @@ import ( @@ -7,6 +7,7 @@ import (
7 "github.com/linmadan/egglib-go/web/beego/utils" 7 "github.com/linmadan/egglib-go/web/beego/utils"
8 "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/domain" 8 "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/domain"
9 "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/log" 9 "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/log"
  10 + "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/port/beego/middleware"
10 ) 11 )
11 12
12 type BaseController struct { 13 type BaseController struct {
@@ -77,6 +78,10 @@ func (controller *BaseController) GetOperator() domain.Operator { @@ -77,6 +78,10 @@ func (controller *BaseController) GetOperator() domain.Operator {
77 err := loginToken.ParseToken(token) 78 err := loginToken.ParseToken(token)
78 if err != nil { 79 if err != nil {
79 log.Logger.Error(err.Error()) 80 log.Logger.Error(err.Error())
  81 + *loginToken, _ = middleware.FormCtxLoginToken(controller.Ctx)
  82 + }
  83 + if tmpToken, ok := middleware.FormCtxLoginToken(controller.Ctx); ok {
  84 + log.Logger.Debug(json.MarshalToString(tmpToken))
80 } 85 }
81 op := domain.Operator{ 86 op := domain.Operator{
82 UserId: loginToken.UserId, 87 UserId: loginToken.UserId,
@@ -94,7 +99,7 @@ func (controller *BaseController) GetOperator() domain.Operator { @@ -94,7 +99,7 @@ func (controller *BaseController) GetOperator() domain.Operator {
94 op.UserBaseId = 1 99 op.UserBaseId = 1
95 } 100 }
96 // TODO:打印测试日志 101 // TODO:打印测试日志
97 - log.Logger.Debug("operator " + json.MarshalToString(op)) 102 + //log.Logger.Debug("operator " + json.MarshalToString(op))
98 return op 103 return op
99 } 104 }
100 105
1 package middleware 1 package middleware
2 2
3 import ( 3 import (
  4 + "fmt"
  5 + "github.com/beego/beego/v2/server/web"
4 "github.com/beego/beego/v2/server/web/context" 6 "github.com/beego/beego/v2/server/web/context"
5 "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/domain" 7 "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/domain"
  8 + "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/infrastructure/cache"
  9 + "gitlab.fjmaimaimai.com/allied-creation/allied-creation-gateway/pkg/log"
  10 + log1 "log"
  11 + "net/url"
6 ) 12 )
7 13
8 type CtxKeyLoginToken struct{} 14 type CtxKeyLoginToken struct{}
9 15
10 func JWTAuth(ctx *context.Context) { 16 func JWTAuth(ctx *context.Context) {
11 - tokenStr := ctx.Input.Header("xxxx")  
12 - tk := domain.LoginToken{}  
13 - err := tk.ParseToken(tokenStr) 17 +
  18 +}
  19 +
  20 +func CheckAccessToken(next web.FilterFunc) web.FilterFunc {
  21 + return func(ctx *context.Context) {
  22 + tokenStr := ctx.Input.Header("x-mmm-accesstoken")
  23 + filterMap := map[string]string{
  24 + "/v1/auth/login/pwd": "",
  25 + "/v1/auth/login/sms": "",
  26 + "/v1/auth/login/qrcode": "",
  27 + "/v1/auth/captcha-init": "",
  28 + "/v1/auth/qrcode-init": "",
  29 + "/v1/auth/sms-code": "",
  30 + "/v1/auth/check-sms-code": "",
  31 + "/v1/auth/company-sign-up": "",
  32 + "/v1/auth/reset-password": "",
  33 + "/v1/auth/refresh-token": "",
  34 + }
  35 + var err error
  36 + if filterUrl, err := url.Parse(ctx.Request.RequestURI); err == nil {
  37 + // 不需要验证的接口
  38 + if _, ok := filterMap[filterUrl.Path]; ok {
  39 + next(ctx)
  40 + return
  41 + }
  42 + } else {
  43 + log.Logger.Error("parse url error:" + err.Error())
  44 + }
  45 + defer func() {
14 if err != nil { 46 if err != nil {
15 - // 47 + ctx.Output.JSON(map[string]interface{}{
  48 + "msg": domain.ParseCodeMsg(domain.InvalidAccessToken),
  49 + "code": domain.InvalidAccessToken,
  50 + "data": struct{}{},
  51 + }, false, false)
  52 + }
  53 + }()
  54 +
  55 + tk := &domain.LoginToken{}
  56 + err = tk.ParseToken(tokenStr)
  57 + if err != nil {
  58 + log.Logger.Error(err.Error())
16 return 59 return
17 } 60 }
18 - ctx.Input.SetData(CtxKeyLoginToken{}, domain.LoginToken{}) 61 + platform := domain.ParsePlatform(ctx.Input.Header("x-mmm-devicetype"))
  62 + //redis缓存
  63 + tokenCache := cache.LoginTokenCache{}
  64 + token, err := tokenCache.GetAccessToken(tk.Account, platform)
  65 + if err != nil {
  66 + log.Logger.Error(err.Error())
  67 + return
  68 + }
  69 + if token != tokenStr {
  70 + log1.Println("token not equal \n" + tk.Account + "\n" + tokenStr + "\n" + token)
  71 + err = fmt.Errorf("access token not exists")
  72 + return
  73 + }
  74 + ctx.Input.SetData(CtxKeyLoginToken{}, tk)
  75 + next(ctx)
  76 + }
19 } 77 }
20 78
21 func NewCtxLoginToken(ctx *context.Context, tk domain.LoginToken) { 79 func NewCtxLoginToken(ctx *context.Context, tk domain.LoginToken) {
  1 +package util
  2 +
  3 +import (
  4 + "bytes"
  5 + "crypto/aes"
  6 + "crypto/cipher"
  7 +)
  8 +
  9 +func PKCS5Padding(plaintext []byte, blockSize int) []byte {
  10 + padding := blockSize - len(plaintext)%blockSize
  11 + padtext := bytes.Repeat([]byte{byte(padding)}, padding)
  12 + return append(plaintext, padtext...)
  13 +}
  14 +
  15 +//@brief:去除填充数据
  16 +func PKCS5UnPadding(origData []byte) []byte {
  17 + length := len(origData)
  18 + unpadding := int(origData[length-1])
  19 + return origData[:(length - unpadding)]
  20 +}
  21 +
  22 +//@brief:AES加密
  23 +func AesEncrypt(origData, key []byte) ([]byte, error) {
  24 + block, err := aes.NewCipher(key)
  25 + if err != nil {
  26 + return nil, err
  27 + }
  28 +
  29 + //AES分组长度为128位,所以blockSize=16,单位字节
  30 + blockSize := block.BlockSize()
  31 + origData = PKCS5Padding(origData, blockSize)
  32 + blockMode := cipher.NewCBCEncrypter(block, key[:blockSize]) //初始向量的长度必须等于块block的长度16字节
  33 + crypted := make([]byte, len(origData))
  34 + blockMode.CryptBlocks(crypted, origData)
  35 + return crypted, nil
  36 +}
  37 +
  38 +//@brief:AES解密
  39 +func AesDecrypt(crypted, key []byte) ([]byte, error) {
  40 + block, err := aes.NewCipher(key)
  41 + if err != nil {
  42 + return nil, err
  43 + }
  44 +
  45 + //AES分组长度为128位,所以blockSize=16,单位字节
  46 + blockSize := block.BlockSize()
  47 + blockMode := cipher.NewCBCDecrypter(block, key[:blockSize]) //初始向量的长度必须等于块block的长度16字节
  48 + origData := make([]byte, len(crypted))
  49 + blockMode.CryptBlocks(origData, crypted)
  50 + origData = PKCS5UnPadding(origData)
  51 + return origData, nil
  52 +}
  1 +package util
  2 +
  3 +import (
  4 + "encoding/base64"
  5 + "fmt"
  6 + "testing"
  7 +)
  8 +
  9 +func Test_Aes(t *testing.T) {
  10 + //key的长度必须是16、24或者32字节,分别用于选择AES-128, AES-192, or AES-256
  11 + var aeskey = []byte("12345678abcdefgh")
  12 + pass := []byte("vdncloud123456")
  13 + xpass, err := AesEncrypt(pass, aeskey)
  14 + if err != nil {
  15 + fmt.Println(err)
  16 + return
  17 + }
  18 +
  19 + pass64 := base64.StdEncoding.EncodeToString(xpass)
  20 + fmt.Printf("加密后:%v\n", pass64)
  21 +
  22 + bytesPass, err := base64.StdEncoding.DecodeString(pass64)
  23 + if err != nil {
  24 + fmt.Println(err)
  25 + return
  26 + }
  27 +
  28 + tpass, err := AesDecrypt(bytesPass, aeskey)
  29 + if err != nil {
  30 + fmt.Println(err)
  31 + return
  32 + }
  33 + fmt.Printf("解密后:%s\n", tpass)
  34 +}