作者 yangfu

fix auth test case

version: v1
kind: Schema
metadata:
name: credentialAuth
description: 凭证认证
attributes:
- name: deviceType
description: Attribute描述
type:
primitive: string
- name: credential
description: Attribute描述
type:
primitive: string
- name: expire
description: Attribute描述
type:
primitive: string
... ...
... ... @@ -7,6 +7,8 @@ import (
)
type LoginByCompanyCommand struct {
// 设备类型
DeviceType int `json:"deviceType,omitempty"`
// 1.高管 2.合伙人 4:游客
UserType int `json:"userType" valid:"Required"`
// 凭证
... ...
... ... @@ -7,6 +7,8 @@ import (
)
type LoginQuery struct {
// 设备类型
DeviceType int `json:"deviceType,omitempty"`
// 手机号
Phone string `json:"phone,omitempty"`
// signInPassword(密码登录) 或 signInCaptcha(验证码登录)或signInCredentials(凭证登录)
... ...
... ... @@ -6,8 +6,6 @@ import (
"gitlab.fjmaimaimai.com/mmm-go-pp/partner01/pkg/application/auth/query"
"gitlab.fjmaimaimai.com/mmm-go-pp/partner01/pkg/application/factory"
"gitlab.fjmaimaimai.com/mmm-go-pp/partner01/pkg/domain"
"math/rand"
"strconv"
"strings"
)
... ... @@ -109,8 +107,8 @@ func (authService *AuthService) Login(loginQuery *query.LoginQuery) (interface{}
)
UserAuthRepository, _ := factory.CreateUserAuthRepository(map[string]interface{}{"transactionContext": transactionContext})
UserRepository, _ := factory.CreateUserRepository(map[string]interface{}{"transactionContext": transactionContext})
switch loginQuery.GrantType {
case "signInPassword":
switch domain.LoginType(loginQuery.GrantType) {
case domain.SignInPassword:
if len(loginQuery.Password) == 0 {
return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "密码不能为空!")
}
... ... @@ -121,18 +119,20 @@ func (authService *AuthService) Login(loginQuery *query.LoginQuery) (interface{}
if !strings.EqualFold(userAuth.PhoneAuth.Password, loginQuery.Password) {
return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "密码有误!")
}
case "signInCaptcha":
case "signInCredentials":
case domain.SignInCaptcha:
case domain.SignInCredentials:
default:
return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "undefined grantType:"+loginQuery.GrantType)
}
_, users, err = UserRepository.Find(map[string]interface{}{"inUserIds": userAuth.Users, "status": domain.StatusEnable})
if err := transactionContext.CommitTransaction(); err != nil {
return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
cred := domain.NewCredentialAuth(loginQuery.DeviceType, loginQuery.Credentials)
userAuth.BindCredentialAuth(cred)
if _, err = UserAuthRepository.Save(userAuth); err != nil {
return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, err.Error())
}
credentials := "cred:" + strconv.Itoa(rand.Int())
rspMapData["credentials"] = credentials
rspMapData["credentials"] = cred.Credential
CompanyRepository, _ := factory.CreateCompanyRepository(map[string]interface{}{"transactionContext": transactionContext})
for i := range users {
company, _ := CompanyRepository.FindOne(map[string]interface{}{"companyId": users[i].CompanyId, "status": domain.StatusEnable})
... ... @@ -145,7 +145,12 @@ func (authService *AuthService) Login(loginQuery *query.LoginQuery) (interface{}
}
userCompanies = append(userCompanies, item)
}
return map[string]interface{}{"userCompanies": userCompanies, "credentials": credentials}, nil
rspMapData["userCompanies"] = userCompanies
if err := transactionContext.CommitTransaction(); err != nil {
return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
}
return rspMapData, nil
}
// 用户按公司登录
... ... @@ -164,8 +169,9 @@ func (authService *AuthService) LoginByCompany(loginByCompanyCommand *command.Lo
transactionContext.RollbackTransaction()
}()
//todo:valid/refresh credentials
if len(loginByCompanyCommand.Credentials) == 0 {
UserAuthRepository, _ := factory.CreateUserAuthRepository(map[string]interface{}{"transactionContext": transactionContext})
userAuth, _ := UserAuthRepository.FindOne(map[string]interface{}{"credential": loginByCompanyCommand.Credentials})
if userAuth == nil || !userAuth.CheckCredentialAuth(loginByCompanyCommand.DeviceType, loginByCompanyCommand.Credentials) {
return nil, application.ThrowError(application.INTERNAL_SERVER_ERROR, "credentials expire")
}
... ... @@ -180,10 +186,11 @@ func (authService *AuthService) LoginByCompany(loginByCompanyCommand *command.Lo
if err != nil {
return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
}
authCode := domain.SignToken(user.UserId, company.CompanyId)
if err := transactionContext.CommitTransaction(); err != nil {
return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
}
authCode := domain.SignToken(user.UserId, company.CompanyId)
return map[string]interface{}{"user": user, "company": company, "authCode": authCode}, nil
}
... ...
... ... @@ -14,11 +14,20 @@ const (
AccessTokenExpire = 3600
)
type UserTokenClaim struct {
jwt.StandardClaims
UserId int64 `json:"userId"`
CompanyId int64 `json:"companyId"`
}
var (
SignInPassword LoginType = "signInPassword"
SignInCaptcha LoginType = "signInCaptcha"
SignInCredentials LoginType = "signInCredentials"
)
type (
LoginType string // 登录类型
UserTokenClaim struct {
jwt.StandardClaims
UserId int64 `json:"userId"`
CompanyId int64 `json:"companyId"`
}
)
func NewUserTokenClaim(userId, companyId int64, expire int64) UserTokenClaim {
return UserTokenClaim{
... ...
package domain
import (
"fmt"
"math/rand"
"strconv"
"time"
)
var (
errorCredentialIsEmpty = fmt.Errorf("credential is empty")
errorCredentialExpired = fmt.Errorf("credential has expired")
)
const (
defaultCredExpireDuration = 3600 * time.Second
)
// 凭证认证
type CredentialAuth struct {
// 设备类型 0:ios 1:android 2:web
DeviceType int `json:"deviceType"`
// 凭证信息
Credential string `json:"credential"`
// 过期时间
Expire int64 `json:"expire"`
}
func (auth *CredentialAuth) Check() error {
if len(auth.Credential) == 0 {
return errorCredentialIsEmpty
}
if auth.Expire == 0 || auth.Expire < time.Now().Unix() {
return errorCredentialExpired
}
return nil
}
func NewCredentialAuth(deviceType int, cred string) *CredentialAuth {
if cred == "" {
cred = "cred:" + strconv.Itoa(rand.Int())
}
return &CredentialAuth{
DeviceType: deviceType,
Credential: cred,
Expire: time.Now().Add(defaultCredExpireDuration).Unix(),
}
}
... ...
package domain
import "time"
import (
"strings"
"time"
)
// 用户认证实体
type UserAuth struct {
... ... @@ -10,6 +13,8 @@ type UserAuth struct {
Users []int64 `json:"users"`
// 手机认证
PhoneAuth *PhoneAuth `json:"phoneAuth"`
// 凭证认证
CredentialAuths []*CredentialAuth `credentialAuths`
// 创建时间
CreateAt time.Time `json:"createAt"`
// 更新时间
... ... @@ -51,3 +56,51 @@ func (userAuth *UserAuth) Update(data map[string]interface{}) error {
}
return nil
}
// 绑定凭证
func (userAuth *UserAuth) BindCredentialAuth(cred *CredentialAuth) error {
var exists bool
if err := cred.Check(); err != nil {
return err
}
for i, v := range userAuth.CredentialAuths {
if v.DeviceType == cred.DeviceType {
userAuth.CredentialAuths[i] = cred
exists = true
}
}
if !exists {
userAuth.CredentialAuths = append(userAuth.CredentialAuths, cred)
}
return nil
}
// 解除凭证
func (userAuth *UserAuth) UnbindCredentialAuth(deviceType int) error {
length := len(userAuth.CredentialAuths)
for i, v := range userAuth.CredentialAuths {
if v.DeviceType == deviceType {
if length <= 1 {
userAuth.CredentialAuths = []*CredentialAuth{}
} else {
userAuth.CredentialAuths[i] = userAuth.CredentialAuths[length-1]
userAuth.CredentialAuths = userAuth.CredentialAuths[:length-1]
}
return nil
}
}
return nil
}
// 检查凭证
func (userAuth *UserAuth) CheckCredentialAuth(deviceType int, cred string) bool {
for _, v := range userAuth.CredentialAuths {
if v.DeviceType == deviceType && strings.EqualFold(v.Credential, cred) {
if err := v.Check(); err != nil {
return false
}
return true
}
}
return false
}
... ...
... ... @@ -13,6 +13,8 @@ type UserAuth struct {
Users []int64 `pg:",array"`
// 手机认证
PhoneAuth *domain.PhoneAuth
// 凭证认证
CredentialAuths []*domain.CredentialAuth
// 创建时间
CreateAt time.Time
// 更新时间
... ...
... ... @@ -7,10 +7,11 @@ import (
func TransformToUserAuthDomainModelFromPgModels(userAuthModel *models.UserAuth) (*domain.UserAuth, error) {
return &domain.UserAuth{
UserAuthId: userAuthModel.UserAuthId,
Users: userAuthModel.Users,
PhoneAuth: userAuthModel.PhoneAuth,
CreateAt: userAuthModel.CreateAt,
UpdateAt: userAuthModel.UpdateAt,
UserAuthId: userAuthModel.UserAuthId,
Users: userAuthModel.Users,
PhoneAuth: userAuthModel.PhoneAuth,
CredentialAuths: userAuthModel.CredentialAuths,
CreateAt: userAuthModel.CreateAt,
UpdateAt: userAuthModel.UpdateAt,
}, nil
}
... ...
... ... @@ -101,6 +101,9 @@ func (repository *UserAuthRepository) FindOne(queryOptions map[string]interface{
if v, ok := queryOptions["phone"]; ok {
query.Where(fmt.Sprintf(`user_auth.phone_auth @>'{"phone":"%v"}'`, v))
}
if v, ok := queryOptions["credential"]; ok {
query.Where(fmt.Sprintf(`user_auth.credential_auths @>'[{"credential":"%v"}]'`, v))
}
if err := query.First(); err != nil {
if err.Error() == "pg: no rows in result set" {
return nil, fmt.Errorf("没有此资源")
... ...
... ... @@ -19,11 +19,11 @@ var _ = Describe("用户按公司登录", func() {
)
_, err = pG.DB.QueryOne(
pg.Scan(&Id),
`INSERT INTO companys (company_id,company_info) VALUES (1000,'{"company_id":1000,"name":"company","status":1}') RETURNING company_id`,
`INSERT INTO companies (company_id,company_info) VALUES (1000,'{"company_id":1000,"name":"company","status":1}') RETURNING company_id`,
)
_, err = pG.DB.QueryOne(
pg.Scan(&Id),
`INSERT INTO user_auth (user_auth_id,users,phone_auth) VALUES (1,ARRAY [1],'{"phone":"18800000001","password":"password"}') RETURNING user_auth_id`,
`INSERT INTO user_auth (user_auth_id,users,phone_auth,credential_auths) VALUES (1,ARRAY [1],'{"phone":"18800000001","password":"password"}','[{"credential":"string","expire":999999999999999}]') RETURNING user_auth_id`,
)
Expect(err).NotTo(HaveOccurred())
})
... ... @@ -50,7 +50,7 @@ var _ = Describe("用户按公司登录", func() {
})
AfterEach(func() {
_, err := pG.DB.Exec("DELETE FROM users WHERE true")
_, err = pG.DB.Exec("DELETE FROM companys WHERE true")
_, err = pG.DB.Exec("DELETE FROM companies WHERE true")
_, err = pG.DB.Exec("DELETE FROM user_auth WHERE true")
Expect(err).NotTo(HaveOccurred())
})
... ...
... ... @@ -19,7 +19,7 @@ var _ = Describe("用户登录 返回有权限的公司列表", func() {
)
_, err = pG.DB.QueryOne(
pg.Scan(&Id),
`INSERT INTO companys (company_id,company_info) VALUES (1000,'{"company_id":1000,"name":"company"}') RETURNING company_id`,
`INSERT INTO companies (company_id,company_info) VALUES (1000,'{"company_id":1000,"name":"company"}') RETURNING company_id`,
)
_, err = pG.DB.QueryOne(
pg.Scan(&Id),
... ... @@ -52,7 +52,7 @@ var _ = Describe("用户登录 返回有权限的公司列表", func() {
})
AfterEach(func() {
_, err := pG.DB.Exec("DELETE FROM users WHERE true")
_, err = pG.DB.Exec("DELETE FROM companys WHERE true")
_, err = pG.DB.Exec("DELETE FROM companies WHERE true")
_, err = pG.DB.Exec("DELETE FROM user_auth WHERE true")
Expect(err).NotTo(HaveOccurred())
})
... ...