作者 yangfu

登录优化

@@ -516,6 +516,73 @@ func (svr AuthService) getUserInfo(operator domain.Operator) (interface{}, error @@ -516,6 +516,73 @@ func (svr AuthService) getUserInfo(operator domain.Operator) (interface{}, error
516 } 516 }
517 517
518 func (svr AuthService) getToken(loginToken domain.LoginToken) (map[string]interface{}, error) { 518 func (svr AuthService) getToken(loginToken domain.LoginToken) (map[string]interface{}, error) {
  519 + // 1.匹配账号对应的用户
  520 + currentUser, err := svr.matchUser(&loginToken)
  521 + if err != nil {
  522 + return nil, application.ThrowError(application.BUSINESS_ERROR, err.Error())
  523 + }
  524 +
  525 + // 2. 更新LoginAccess
  526 + transactionContext, err := factory.CreateTransactionContext(nil)
  527 + if err != nil {
  528 + return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
  529 + }
  530 + if err := transactionContext.StartTransaction(); err != nil {
  531 + return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
  532 + }
  533 + defer func() {
  534 + transactionContext.RollbackTransaction()
  535 + }()
  536 + var loginAccessRepository domain.LoginAccessRepository
  537 + if loginAccessRepository, err = factory.CreateLoginAccessRepository(map[string]interface{}{
  538 + "transactionContext": transactionContext,
  539 + }); err != nil {
  540 + return nil, application.ThrowError(application.BUSINESS_ERROR, err.Error())
  541 + }
  542 + _, lAccess, err := loginAccessRepository.Find(map[string]interface{}{
  543 + "account": loginToken.Account,
  544 + "platform": loginToken.Platform,
  545 + })
  546 + if err != nil {
  547 + return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
  548 + }
  549 + var currentAccess = &domain.LoginAccess{CreatedTime: time.Now()}
  550 + if len(lAccess) > 0 {
  551 + currentAccess = lAccess[0]
  552 + }
  553 +
  554 + if _, err = currentAccess.ResetLoginAccess(loginToken); err != nil {
  555 + return nil, application.ThrowError(application.BUSINESS_ERROR, err.Error())
  556 + }
  557 + //存数据库
  558 + if _, err = loginAccessRepository.Save(currentAccess); err != nil {
  559 + return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())
  560 + }
  561 + //redis缓存
  562 + tokenCache := cache.LoginTokenCache{}
  563 + if err = tokenCache.SaveAccessToken(currentAccess); err != nil {
  564 + return nil, err
  565 + }
  566 + if err = tokenCache.SaveRefreshToken(currentAccess); err != nil {
  567 + return nil, err
  568 + }
  569 +
  570 + if err := transactionContext.CommitTransaction(); err != nil {
  571 + return nil, application.ThrowError(application.BUSINESS_ERROR, err.Error())
  572 + }
  573 + nowTime := time.Now().Unix()
  574 + token := map[string]interface{}{
  575 + "refreshToken": currentAccess.RefreshToken,
  576 + "accessToken": currentAccess.AccessToken,
  577 + "expiresIn": currentAccess.AccessExpired - nowTime,
  578 + }
  579 + return map[string]interface{}{
  580 + "token": token,
  581 + "userId": currentUser.UserId,
  582 + }, nil
  583 +}
  584 +
  585 +func (svr AuthService) matchUser(loginToken *domain.LoginToken) (*allied_creation_user.UserDetail, error) {
519 creationUserGateway := allied_creation_user.NewHttplibAlliedCreationUser(domain.Operator{}) 586 creationUserGateway := allied_creation_user.NewHttplibAlliedCreationUser(domain.Operator{})
520 userSearchResult, err := creationUserGateway.UserSearch(allied_creation_user.ReqUserSearch{ 587 userSearchResult, err := creationUserGateway.UserSearch(allied_creation_user.ReqUserSearch{
521 Phone: loginToken.Account, 588 Phone: loginToken.Account,
@@ -572,87 +639,21 @@ loopUser1: @@ -572,87 +639,21 @@ loopUser1:
572 return nil, application.ThrowError(application.TRANSACTION_ERROR, "账号不存在") 639 return nil, application.ThrowError(application.TRANSACTION_ERROR, "账号不存在")
573 } 640 }
574 loginToken.UserBaseId = int64(userBase.UserBaseID) 641 loginToken.UserBaseId = int64(userBase.UserBaseID)
575 - }  
576 -  
577 - // 2. 更新currentAccess信息  
578 - transactionContext, err := factory.CreateTransactionContext(nil)  
579 - if err != nil {  
580 - return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())  
581 - }  
582 - if err := transactionContext.StartTransaction(); err != nil {  
583 - return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())  
584 - }  
585 - defer func() {  
586 - transactionContext.RollbackTransaction()  
587 - }()  
588 - var loginAccessRepository domain.LoginAccessRepository  
589 - if loginAccessRepository, err = factory.CreateLoginAccessRepository(map[string]interface{}{  
590 - "transactionContext": transactionContext,  
591 - }); err != nil {  
592 - return nil, application.ThrowError(application.BUSINESS_ERROR, err.Error())  
593 - }  
594 - _, lAccess, err := loginAccessRepository.Find(map[string]interface{}{  
595 - "account": loginToken.Account,  
596 - "platform": loginToken.Platform, 642 + if userBase.UserBaseID > 0 {
  643 + cooperationUsers, _ := creationUserGateway.UserSearch(allied_creation_user.ReqUserSearch{
  644 + UserBaseId: int64(userBase.UserBaseID),
  645 + UserType: domain.UserTypeCooperation,
  646 + EnableStatus: domain.UserStatusEnable,
597 }) 647 })
598 - if err != nil {  
599 - return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error()) 648 + if len(cooperationUsers.Users) > 0 {
  649 + loginToken.CompanyId = int64(cooperationUsers.Users[0].Company.CompanyId)
  650 + loginToken.UserId = int64(cooperationUsers.Users[0].UserId)
  651 + loginToken.OrgId = int64(cooperationUsers.Users[0].Org.OrgId)
  652 + currentUser = cooperationUsers.Users[0]
600 } 653 }
601 - var currentAccess = &domain.LoginAccess{CreatedTime: time.Now()}  
602 - if len(lAccess) > 0 {  
603 - currentAccess = lAccess[0]  
604 } 654 }
605 - currentAccess.UserId = int64(loginToken.UserId)  
606 - currentAccess.UserBaseId = int64(loginToken.UserBaseId)  
607 - currentAccess.Account = loginToken.Account  
608 - currentAccess.Platform = loginToken.Platform  
609 - currentAccess.CompanyId = int64(loginToken.CompanyId)  
610 - currentAccess.OrganizationId = loginToken.OrgId  
611 - currentAccess.OrgIds = loginToken.OrgIds  
612 - currentAccess.UpdatedTime = time.Now()  
613 -  
614 - accessTokenStr, err := loginToken.GenerateAccessToken()  
615 - if err != nil {  
616 - return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())  
617 } 655 }
618 - currentAccess.AccessToken = accessTokenStr  
619 - currentAccess.AccessExpired = loginToken.ExpiresAt  
620 - refreshTokenStr, err := loginToken.GenerateRefreshToken()  
621 - if err != nil {  
622 - return nil, application.ThrowError(application.BUSINESS_ERROR, err.Error())  
623 - }  
624 - currentAccess.RefreshToken = refreshTokenStr  
625 - currentAccess.RefreshExpired = loginToken.ExpiresAt  
626 -  
627 - //存数据库  
628 - _, err = loginAccessRepository.Save(currentAccess)  
629 - if err != nil {  
630 - return nil, application.ThrowError(application.TRANSACTION_ERROR, err.Error())  
631 - }  
632 - if err := transactionContext.CommitTransaction(); err != nil {  
633 - return nil, application.ThrowError(application.BUSINESS_ERROR, err.Error())  
634 - }  
635 -  
636 - //redis缓存  
637 - tokenCache := cache.LoginTokenCache{}  
638 -  
639 - //todo:error handler  
640 - if err = tokenCache.SaveAccessToken(currentAccess); err != nil {  
641 - return nil, err  
642 - }  
643 - if err = tokenCache.SaveRefreshToken(currentAccess); err != nil {  
644 - return nil, err  
645 - }  
646 - nowTime := time.Now().Unix()  
647 - token := map[string]interface{}{  
648 - "refreshToken": refreshTokenStr,  
649 - "accessToken": accessTokenStr,  
650 - "expiresIn": currentAccess.AccessExpired - nowTime,  
651 - }  
652 - return map[string]interface{}{  
653 - "token": token,  
654 - "userId": currentUser.UserId,  
655 - }, nil 656 + return &currentUser, nil
656 } 657 }
657 658
658 //GetCompanyOrgsByUser 获取登录用户的公司组织列表 659 //GetCompanyOrgsByUser 获取登录用户的公司组织列表
@@ -128,6 +128,31 @@ func ParseCodeMsg(code int) string { @@ -128,6 +128,31 @@ func ParseCodeMsg(code int) string {
128 return "" 128 return ""
129 } 129 }
130 130
  131 +func (loginAccess *LoginAccess) ResetLoginAccess(loginToken LoginToken) (interface{}, error) {
  132 + loginAccess.UserId = int64(loginToken.UserId)
  133 + loginAccess.UserBaseId = int64(loginToken.UserBaseId)
  134 + loginAccess.Account = loginToken.Account
  135 + loginAccess.Platform = loginToken.Platform
  136 + loginAccess.CompanyId = int64(loginToken.CompanyId)
  137 + loginAccess.OrganizationId = loginToken.OrgId
  138 + loginAccess.OrgIds = loginToken.OrgIds
  139 + loginAccess.UpdatedTime = time.Now()
  140 +
  141 + accessTokenStr, err := loginToken.GenerateAccessToken()
  142 + if err != nil {
  143 + return nil, err
  144 + }
  145 + loginAccess.AccessToken = accessTokenStr
  146 + loginAccess.AccessExpired = loginToken.ExpiresAt
  147 + refreshTokenStr, err := loginToken.GenerateRefreshToken()
  148 + if err != nil {
  149 + return nil, err
  150 + }
  151 + loginAccess.RefreshToken = refreshTokenStr
  152 + loginAccess.RefreshExpired = loginToken.ExpiresAt
  153 + return nil, nil
  154 +}
  155 +
131 func NewApplicationError(code int) error { 156 func NewApplicationError(code int) error {
132 return &application.ServiceError{ 157 return &application.ServiceError{
133 Code: code, 158 Code: code,