作者 yangfu

fix auth test case

  1 +version: v1
  2 +kind: Schema
  3 +metadata:
  4 + name: credentialAuth
  5 + description: 凭证认证
  6 + attributes:
  7 + - name: deviceType
  8 + description: Attribute描述
  9 + type:
  10 + primitive: string
  11 + - name: credential
  12 + description: Attribute描述
  13 + type:
  14 + primitive: string
  15 + - name: expire
  16 + description: Attribute描述
  17 + type:
  18 + primitive: string
@@ -7,6 +7,8 @@ import ( @@ -7,6 +7,8 @@ import (
7 ) 7 )
8 8
9 type LoginByCompanyCommand struct { 9 type LoginByCompanyCommand struct {
  10 + // 设备类型
  11 + DeviceType int `json:"deviceType,omitempty"`
10 // 1.高管 2.合伙人 4:游客 12 // 1.高管 2.合伙人 4:游客
11 UserType int `json:"userType" valid:"Required"` 13 UserType int `json:"userType" valid:"Required"`
12 // 凭证 14 // 凭证
@@ -7,6 +7,8 @@ import ( @@ -7,6 +7,8 @@ import (
7 ) 7 )
8 8
9 type LoginQuery struct { 9 type LoginQuery struct {
  10 + // 设备类型
  11 + DeviceType int `json:"deviceType,omitempty"`
10 // 手机号 12 // 手机号
11 Phone string `json:"phone,omitempty"` 13 Phone string `json:"phone,omitempty"`
12 // signInPassword(密码登录) 或 signInCaptcha(验证码登录)或signInCredentials(凭证登录) 14 // signInPassword(密码登录) 或 signInCaptcha(验证码登录)或signInCredentials(凭证登录)
@@ -6,8 +6,6 @@ import ( @@ -6,8 +6,6 @@ import (
6 "gitlab.fjmaimaimai.com/mmm-go-pp/partner01/pkg/application/auth/query" 6 "gitlab.fjmaimaimai.com/mmm-go-pp/partner01/pkg/application/auth/query"
7 "gitlab.fjmaimaimai.com/mmm-go-pp/partner01/pkg/application/factory" 7 "gitlab.fjmaimaimai.com/mmm-go-pp/partner01/pkg/application/factory"
8 "gitlab.fjmaimaimai.com/mmm-go-pp/partner01/pkg/domain" 8 "gitlab.fjmaimaimai.com/mmm-go-pp/partner01/pkg/domain"
9 - "math/rand"  
10 - "strconv"  
11 "strings" 9 "strings"
12 ) 10 )
13 11
@@ -109,8 +107,8 @@ func (authService *AuthService) Login(loginQuery *query.LoginQuery) (interface{} @@ -109,8 +107,8 @@ func (authService *AuthService) Login(loginQuery *query.LoginQuery) (interface{}
109 ) 107 )
110 UserAuthRepository, _ := factory.CreateUserAuthRepository(map[string]interface{}{"transactionContext": transactionContext}) 108 UserAuthRepository, _ := factory.CreateUserAuthRepository(map[string]interface{}{"transactionContext": transactionContext})
111 UserRepository, _ := factory.CreateUserRepository(map[string]interface{}{"transactionContext": transactionContext}) 109 UserRepository, _ := factory.CreateUserRepository(map[string]interface{}{"transactionContext": transactionContext})
112 - switch loginQuery.GrantType {  
113 - case "signInPassword": 110 + switch domain.LoginType(loginQuery.GrantType) {
  111 + case domain.SignInPassword:
114 if len(loginQuery.Password) == 0 { 112 if len(loginQuery.Password) == 0 {
115 return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "密码不能为空!") 113 return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "密码不能为空!")
116 } 114 }
@@ -121,18 +119,20 @@ func (authService *AuthService) Login(loginQuery *query.LoginQuery) (interface{} @@ -121,18 +119,20 @@ func (authService *AuthService) Login(loginQuery *query.LoginQuery) (interface{}
121 if !strings.EqualFold(userAuth.PhoneAuth.Password, loginQuery.Password) { 119 if !strings.EqualFold(userAuth.PhoneAuth.Password, loginQuery.Password) {
122 return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "密码有误!") 120 return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "密码有误!")
123 } 121 }
124 - case "signInCaptcha":  
125 - case "signInCredentials": 122 + case domain.SignInCaptcha:
  123 + case domain.SignInCredentials:
126 default: 124 default:
127 return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "undefined grantType:"+loginQuery.GrantType) 125 return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "undefined grantType:"+loginQuery.GrantType)
128 } 126 }
129 _, users, err = UserRepository.Find(map[string]interface{}{"inUserIds": userAuth.Users, "status": domain.StatusEnable}) 127 _, users, err = UserRepository.Find(map[string]interface{}{"inUserIds": userAuth.Users, "status": domain.StatusEnable})
130 128
131 - if err := transactionContext.CommitTransaction(); err != nil {  
132 - return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error()) 129 + cred := domain.NewCredentialAuth(loginQuery.DeviceType, loginQuery.Credentials)
  130 + userAuth.BindCredentialAuth(cred)
  131 + if _, err = UserAuthRepository.Save(userAuth); err != nil {
  132 + return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, err.Error())
133 } 133 }
134 - credentials := "cred:" + strconv.Itoa(rand.Int())  
135 - rspMapData["credentials"] = credentials 134 +
  135 + rspMapData["credentials"] = cred.Credential
136 CompanyRepository, _ := factory.CreateCompanyRepository(map[string]interface{}{"transactionContext": transactionContext}) 136 CompanyRepository, _ := factory.CreateCompanyRepository(map[string]interface{}{"transactionContext": transactionContext})
137 for i := range users { 137 for i := range users {
138 company, _ := CompanyRepository.FindOne(map[string]interface{}{"companyId": users[i].CompanyId, "status": domain.StatusEnable}) 138 company, _ := CompanyRepository.FindOne(map[string]interface{}{"companyId": users[i].CompanyId, "status": domain.StatusEnable})
@@ -145,7 +145,12 @@ func (authService *AuthService) Login(loginQuery *query.LoginQuery) (interface{} @@ -145,7 +145,12 @@ func (authService *AuthService) Login(loginQuery *query.LoginQuery) (interface{}
145 } 145 }
146 userCompanies = append(userCompanies, item) 146 userCompanies = append(userCompanies, item)
147 } 147 }
148 - return map[string]interface{}{"userCompanies": userCompanies, "credentials": credentials}, nil 148 + rspMapData["userCompanies"] = userCompanies
  149 +
  150 + if err := transactionContext.CommitTransaction(); err != nil {
  151 + return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
  152 + }
  153 + return rspMapData, nil
149 } 154 }
150 155
151 // 用户按公司登录 156 // 用户按公司登录
@@ -164,8 +169,9 @@ func (authService *AuthService) LoginByCompany(loginByCompanyCommand *command.Lo @@ -164,8 +169,9 @@ func (authService *AuthService) LoginByCompany(loginByCompanyCommand *command.Lo
164 transactionContext.RollbackTransaction() 169 transactionContext.RollbackTransaction()
165 }() 170 }()
166 171
167 - //todo:valid/refresh credentials  
168 - if len(loginByCompanyCommand.Credentials) == 0 { 172 + UserAuthRepository, _ := factory.CreateUserAuthRepository(map[string]interface{}{"transactionContext": transactionContext})
  173 + userAuth, _ := UserAuthRepository.FindOne(map[string]interface{}{"credential": loginByCompanyCommand.Credentials})
  174 + if userAuth == nil || !userAuth.CheckCredentialAuth(loginByCompanyCommand.DeviceType, loginByCompanyCommand.Credentials) {
169 return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "credentials expire") 175 return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "credentials expire")
170 } 176 }
171 177
@@ -180,10 +186,11 @@ func (authService *AuthService) LoginByCompany(loginByCompanyCommand *command.Lo @@ -180,10 +186,11 @@ func (authService *AuthService) LoginByCompany(loginByCompanyCommand *command.Lo
180 if err != nil { 186 if err != nil {
181 return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error()) 187 return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
182 } 188 }
  189 + authCode := domain.SignToken(user.UserId, company.CompanyId)
  190 +
183 if err := transactionContext.CommitTransaction(); err != nil { 191 if err := transactionContext.CommitTransaction(); err != nil {
184 return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error()) 192 return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
185 } 193 }
186 - authCode := domain.SignToken(user.UserId, company.CompanyId)  
187 return map[string]interface{}{"user": user, "company": company, "authCode": authCode}, nil 194 return map[string]interface{}{"user": user, "company": company, "authCode": authCode}, nil
188 } 195 }
189 196
@@ -14,11 +14,20 @@ const ( @@ -14,11 +14,20 @@ const (
14 AccessTokenExpire = 3600 14 AccessTokenExpire = 3600
15 ) 15 )
16 16
17 -type UserTokenClaim struct { 17 +var (
  18 + SignInPassword LoginType = "signInPassword"
  19 + SignInCaptcha LoginType = "signInCaptcha"
  20 + SignInCredentials LoginType = "signInCredentials"
  21 +)
  22 +
  23 +type (
  24 + LoginType string // 登录类型
  25 + UserTokenClaim struct {
18 jwt.StandardClaims 26 jwt.StandardClaims
19 UserId int64 `json:"userId"` 27 UserId int64 `json:"userId"`
20 CompanyId int64 `json:"companyId"` 28 CompanyId int64 `json:"companyId"`
21 -} 29 + }
  30 +)
22 31
23 func NewUserTokenClaim(userId, companyId int64, expire int64) UserTokenClaim { 32 func NewUserTokenClaim(userId, companyId int64, expire int64) UserTokenClaim {
24 return UserTokenClaim{ 33 return UserTokenClaim{
  1 +package domain
  2 +
  3 +import (
  4 + "fmt"
  5 + "math/rand"
  6 + "strconv"
  7 + "time"
  8 +)
  9 +
  10 +var (
  11 + errorCredentialIsEmpty = fmt.Errorf("credential is empty")
  12 + errorCredentialExpired = fmt.Errorf("credential has expired")
  13 +)
  14 +
  15 +const (
  16 + defaultCredExpireDuration = 3600 * time.Second
  17 +)
  18 +
  19 +// 凭证认证
  20 +type CredentialAuth struct {
  21 + // 设备类型 0:ios 1:android 2:web
  22 + DeviceType int `json:"deviceType"`
  23 + // 凭证信息
  24 + Credential string `json:"credential"`
  25 + // 过期时间
  26 + Expire int64 `json:"expire"`
  27 +}
  28 +
  29 +func (auth *CredentialAuth) Check() error {
  30 + if len(auth.Credential) == 0 {
  31 + return errorCredentialIsEmpty
  32 + }
  33 + if auth.Expire == 0 || auth.Expire < time.Now().Unix() {
  34 + return errorCredentialExpired
  35 + }
  36 + return nil
  37 +}
  38 +
  39 +func NewCredentialAuth(deviceType int, cred string) *CredentialAuth {
  40 + if cred == "" {
  41 + cred = "cred:" + strconv.Itoa(rand.Int())
  42 + }
  43 + return &CredentialAuth{
  44 + DeviceType: deviceType,
  45 + Credential: cred,
  46 + Expire: time.Now().Add(defaultCredExpireDuration).Unix(),
  47 + }
  48 +}
1 package domain 1 package domain
2 2
3 -import "time" 3 +import (
  4 + "strings"
  5 + "time"
  6 +)
4 7
5 // 用户认证实体 8 // 用户认证实体
6 type UserAuth struct { 9 type UserAuth struct {
@@ -10,6 +13,8 @@ type UserAuth struct { @@ -10,6 +13,8 @@ type UserAuth struct {
10 Users []int64 `json:"users"` 13 Users []int64 `json:"users"`
11 // 手机认证 14 // 手机认证
12 PhoneAuth *PhoneAuth `json:"phoneAuth"` 15 PhoneAuth *PhoneAuth `json:"phoneAuth"`
  16 + // 凭证认证
  17 + CredentialAuths []*CredentialAuth `credentialAuths`
13 // 创建时间 18 // 创建时间
14 CreateAt time.Time `json:"createAt"` 19 CreateAt time.Time `json:"createAt"`
15 // 更新时间 20 // 更新时间
@@ -51,3 +56,51 @@ func (userAuth *UserAuth) Update(data map[string]interface{}) error { @@ -51,3 +56,51 @@ func (userAuth *UserAuth) Update(data map[string]interface{}) error {
51 } 56 }
52 return nil 57 return nil
53 } 58 }
  59 +
  60 +// 绑定凭证
  61 +func (userAuth *UserAuth) BindCredentialAuth(cred *CredentialAuth) error {
  62 + var exists bool
  63 + if err := cred.Check(); err != nil {
  64 + return err
  65 + }
  66 + for i, v := range userAuth.CredentialAuths {
  67 + if v.DeviceType == cred.DeviceType {
  68 + userAuth.CredentialAuths[i] = cred
  69 + exists = true
  70 + }
  71 + }
  72 + if !exists {
  73 + userAuth.CredentialAuths = append(userAuth.CredentialAuths, cred)
  74 + }
  75 + return nil
  76 +}
  77 +
  78 +// 解除凭证
  79 +func (userAuth *UserAuth) UnbindCredentialAuth(deviceType int) error {
  80 + length := len(userAuth.CredentialAuths)
  81 + for i, v := range userAuth.CredentialAuths {
  82 + if v.DeviceType == deviceType {
  83 + if length <= 1 {
  84 + userAuth.CredentialAuths = []*CredentialAuth{}
  85 + } else {
  86 + userAuth.CredentialAuths[i] = userAuth.CredentialAuths[length-1]
  87 + userAuth.CredentialAuths = userAuth.CredentialAuths[:length-1]
  88 + }
  89 + return nil
  90 + }
  91 + }
  92 + return nil
  93 +}
  94 +
  95 +// 检查凭证
  96 +func (userAuth *UserAuth) CheckCredentialAuth(deviceType int, cred string) bool {
  97 + for _, v := range userAuth.CredentialAuths {
  98 + if v.DeviceType == deviceType && strings.EqualFold(v.Credential, cred) {
  99 + if err := v.Check(); err != nil {
  100 + return false
  101 + }
  102 + return true
  103 + }
  104 + }
  105 + return false
  106 +}
@@ -13,6 +13,8 @@ type UserAuth struct { @@ -13,6 +13,8 @@ type UserAuth struct {
13 Users []int64 `pg:",array"` 13 Users []int64 `pg:",array"`
14 // 手机认证 14 // 手机认证
15 PhoneAuth *domain.PhoneAuth 15 PhoneAuth *domain.PhoneAuth
  16 + // 凭证认证
  17 + CredentialAuths []*domain.CredentialAuth
16 // 创建时间 18 // 创建时间
17 CreateAt time.Time 19 CreateAt time.Time
18 // 更新时间 20 // 更新时间
@@ -10,6 +10,7 @@ func TransformToUserAuthDomainModelFromPgModels(userAuthModel *models.UserAuth) @@ -10,6 +10,7 @@ func TransformToUserAuthDomainModelFromPgModels(userAuthModel *models.UserAuth)
10 UserAuthId: userAuthModel.UserAuthId, 10 UserAuthId: userAuthModel.UserAuthId,
11 Users: userAuthModel.Users, 11 Users: userAuthModel.Users,
12 PhoneAuth: userAuthModel.PhoneAuth, 12 PhoneAuth: userAuthModel.PhoneAuth,
  13 + CredentialAuths: userAuthModel.CredentialAuths,
13 CreateAt: userAuthModel.CreateAt, 14 CreateAt: userAuthModel.CreateAt,
14 UpdateAt: userAuthModel.UpdateAt, 15 UpdateAt: userAuthModel.UpdateAt,
15 }, nil 16 }, nil
@@ -101,6 +101,9 @@ func (repository *UserAuthRepository) FindOne(queryOptions map[string]interface{ @@ -101,6 +101,9 @@ func (repository *UserAuthRepository) FindOne(queryOptions map[string]interface{
101 if v, ok := queryOptions["phone"]; ok { 101 if v, ok := queryOptions["phone"]; ok {
102 query.Where(fmt.Sprintf(`user_auth.phone_auth @>'{"phone":"%v"}'`, v)) 102 query.Where(fmt.Sprintf(`user_auth.phone_auth @>'{"phone":"%v"}'`, v))
103 } 103 }
  104 + if v, ok := queryOptions["credential"]; ok {
  105 + query.Where(fmt.Sprintf(`user_auth.credential_auths @>'[{"credential":"%v"}]'`, v))
  106 + }
104 if err := query.First(); err != nil { 107 if err := query.First(); err != nil {
105 if err.Error() == "pg: no rows in result set" { 108 if err.Error() == "pg: no rows in result set" {
106 return nil, fmt.Errorf("没有此资源") 109 return nil, fmt.Errorf("没有此资源")
@@ -19,11 +19,11 @@ var _ = Describe("用户按公司登录", func() { @@ -19,11 +19,11 @@ var _ = Describe("用户按公司登录", func() {
19 ) 19 )
20 _, err = pG.DB.QueryOne( 20 _, err = pG.DB.QueryOne(
21 pg.Scan(&Id), 21 pg.Scan(&Id),
22 - `INSERT INTO companys (company_id,company_info) VALUES (1000,'{"company_id":1000,"name":"company","status":1}') RETURNING company_id`, 22 + `INSERT INTO companies (company_id,company_info) VALUES (1000,'{"company_id":1000,"name":"company","status":1}') RETURNING company_id`,
23 ) 23 )
24 _, err = pG.DB.QueryOne( 24 _, err = pG.DB.QueryOne(
25 pg.Scan(&Id), 25 pg.Scan(&Id),
26 - `INSERT INTO user_auth (user_auth_id,users,phone_auth) VALUES (1,ARRAY [1],'{"phone":"18800000001","password":"password"}') RETURNING user_auth_id`, 26 + `INSERT INTO user_auth (user_auth_id,users,phone_auth,credential_auths) VALUES (1,ARRAY [1],'{"phone":"18800000001","password":"password"}','[{"credential":"string","expire":999999999999999}]') RETURNING user_auth_id`,
27 ) 27 )
28 Expect(err).NotTo(HaveOccurred()) 28 Expect(err).NotTo(HaveOccurred())
29 }) 29 })
@@ -50,7 +50,7 @@ var _ = Describe("用户按公司登录", func() { @@ -50,7 +50,7 @@ var _ = Describe("用户按公司登录", func() {
50 }) 50 })
51 AfterEach(func() { 51 AfterEach(func() {
52 _, err := pG.DB.Exec("DELETE FROM users WHERE true") 52 _, err := pG.DB.Exec("DELETE FROM users WHERE true")
53 - _, err = pG.DB.Exec("DELETE FROM companys WHERE true") 53 + _, err = pG.DB.Exec("DELETE FROM companies WHERE true")
54 _, err = pG.DB.Exec("DELETE FROM user_auth WHERE true") 54 _, err = pG.DB.Exec("DELETE FROM user_auth WHERE true")
55 Expect(err).NotTo(HaveOccurred()) 55 Expect(err).NotTo(HaveOccurred())
56 }) 56 })
@@ -19,7 +19,7 @@ var _ = Describe("用户登录 返回有权限的公司列表", func() { @@ -19,7 +19,7 @@ var _ = Describe("用户登录 返回有权限的公司列表", func() {
19 ) 19 )
20 _, err = pG.DB.QueryOne( 20 _, err = pG.DB.QueryOne(
21 pg.Scan(&Id), 21 pg.Scan(&Id),
22 - `INSERT INTO companys (company_id,company_info) VALUES (1000,'{"company_id":1000,"name":"company"}') RETURNING company_id`, 22 + `INSERT INTO companies (company_id,company_info) VALUES (1000,'{"company_id":1000,"name":"company"}') RETURNING company_id`,
23 ) 23 )
24 _, err = pG.DB.QueryOne( 24 _, err = pG.DB.QueryOne(
25 pg.Scan(&Id), 25 pg.Scan(&Id),
@@ -52,7 +52,7 @@ var _ = Describe("用户登录 返回有权限的公司列表", func() { @@ -52,7 +52,7 @@ var _ = Describe("用户登录 返回有权限的公司列表", func() {
52 }) 52 })
53 AfterEach(func() { 53 AfterEach(func() {
54 _, err := pG.DB.Exec("DELETE FROM users WHERE true") 54 _, err := pG.DB.Exec("DELETE FROM users WHERE true")
55 - _, err = pG.DB.Exec("DELETE FROM companys WHERE true") 55 + _, err = pG.DB.Exec("DELETE FROM companies WHERE true")
56 _, err = pG.DB.Exec("DELETE FROM user_auth WHERE true") 56 _, err = pG.DB.Exec("DELETE FROM user_auth WHERE true")
57 Expect(err).NotTo(HaveOccurred()) 57 Expect(err).NotTo(HaveOccurred())
58 }) 58 })