作者 yangfu

修改:用户登录修改

... ... @@ -6,6 +6,7 @@ require (
github.com/astaxie/beego v1.12.1
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/gin-gonic/gin v1.5.0
github.com/go-pg/pg v8.0.6+incompatible
github.com/go-pg/pg/v10 v10.0.0-beta.2
github.com/linmadan/egglib-go v0.0.0-20191217144343-ca4539f95bf9
github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 // indirect
... ...
... ... @@ -108,8 +108,14 @@ func AccessToken(request *protocol.AccessTokenRequest) (rsp *protocol.AccessToke
err = protocol.NewErrWithMessage(1, fmt.Errorf("jwt authCode (%v) valid", request.AuthCode))
return
}
rsp.AccessToken, _ = utils.GenerateToken(claim.UserId, claim.Phone, protocol.TokenExpire*time.Second)
rsp.RefreshToken, _ = utils.GenerateToken(claim.UserId, claim.Phone, protocol.RefreshTokenExipre*time.Second)
userClaims := utils.UserTokenClaims{
UserId: claim.UserId,
CompanyId: claim.CompanyId,
AdminType: claim.AdminType,
Phone: claim.Phone,
}
rsp.AccessToken, _ = utils.GenerateTokenWithClaim(userClaims, protocol.TokenExpire*time.Second)
rsp.RefreshToken, _ = utils.GenerateTokenWithClaim(userClaims, protocol.RefreshTokenExipre*time.Second)
rsp.ExpiresIn = protocol.TokenExpire
//auth := userAuth.NewRedisUserAuth(userAuth.WithUserId(claim.UserId),
... ... @@ -124,13 +130,16 @@ func AccessToken(request *protocol.AccessTokenRequest) (rsp *protocol.AccessToke
func RefreshToken(request *protocol.RefreshTokenRequest) (rsp *protocol.RefreshTokenResponse, err error) {
var (
claim *utils.UserTokenClaims
transactionContext, _ = factory.CreateTransactionContext(nil)
PartnerInfoService, _ = factory.CreatePartnerInfoRepositoryIn(transactionContext)
PartnerSubAccountRepository, _ = factory.CreatePartnerSubAccountRepository(transactionContext)
claim *utils.UserTokenClaims
transactionContext, _ = factory.CreateTransactionContext(nil)
PartnerInfoService, _ = factory.CreatePartnerInfoRepositoryIn(transactionContext)
//PartnerSubAccountRepository, _ = factory.CreatePartnerSubAccountRepository(transactionContext)
UsersRepository, _ = factory.CreateUsersRepository(transactionContext)
partnerInfo *domain.PartnerInfo
partnerSubAccount *domain.PartnerSubAccount
partnerInfo *domain.PartnerInfo
//partnerSubAccount *domain.PartnerSubAccount
user *domain.Users
userId int64
)
if err = transactionContext.StartTransaction(); err != nil {
... ... @@ -151,14 +160,29 @@ func RefreshToken(request *protocol.RefreshTokenRequest) (rsp *protocol.RefreshT
}
//验证用户有效
var e error
if partnerSubAccount, e = PartnerSubAccountRepository.FindOne(map[string]interface{}{"account": claim.Phone}); e == nil {
partnerInfo, e = PartnerInfoService.FindOne(map[string]interface{}{"id": partnerSubAccount.PartnerId})
} else {
partnerInfo, e = PartnerInfoService.FindOne(map[string]interface{}{"account": claim.Phone})
}
if e != nil || partnerInfo == nil || !partnerInfo.IsEnable() || partnerInfo.Id != claim.UserId {
err = protocol.NewErrWithMessage(4140) //账号禁用
//var e error
//if partnerSubAccount, e = PartnerSubAccountRepository.FindOne(map[string]interface{}{"account": claim.Phone}); e == nil {
// partnerInfo, e = PartnerInfoService.FindOne(map[string]interface{}{"id": partnerSubAccount.PartnerId})
//} else {
// partnerInfo, e = PartnerInfoService.FindOne(map[string]interface{}{"account": claim.Phone})
//}
switch claim.AdminType {
case 1:
if user, err = UsersRepository.FindOne(map[string]interface{}{"phone": claim.Phone, "companyId": claim.CompanyId, "status": 1}); err != nil || user != nil {
err = protocol.NewErrWithMessage(4140, err)
return
}
userId = user.Id
break
case 2:
if partnerInfo, err = PartnerInfoService.FindOne(map[string]interface{}{"account": claim.Id, "companyId": claim.CompanyId, "status": 1}); err != nil || partnerInfo == nil {
err = protocol.NewErrWithMessage(4140, err)
return
}
userId = partnerInfo.Id
break
default:
err = protocol.NewErrWithMessage(4140, err)
return
}
... ... @@ -168,8 +192,8 @@ func RefreshToken(request *protocol.RefreshTokenRequest) (rsp *protocol.RefreshT
// err = protocol.NewErrWithMessage(4140, err)
// return
//}
rsp.AccessToken, _ = utils.GenerateToken(claim.UserId, claim.Phone, protocol.TokenExpire*time.Second)
rsp.RefreshToken, _ = utils.GenerateToken(claim.UserId, claim.Phone, protocol.RefreshTokenExipre*time.Second)
rsp.AccessToken, _ = utils.GenerateTokenWithAdminType(userId, claim.Phone, claim.AdminType, protocol.TokenExpire*time.Second)
rsp.RefreshToken, _ = utils.GenerateTokenWithAdminType(userId, claim.Phone, claim.AdminType, protocol.RefreshTokenExipre*time.Second)
rsp.ExpiresIn = protocol.TokenExpire
//newAuth := userAuth.NewRedisUserAuth(userAuth.WithUserId(claim.UserId),
... ... @@ -251,6 +275,10 @@ func CenterCompanys(header *protocol.RequestHeader, request *protocolx.CenterCom
}
switch request.GrantType {
case protocol.LoginByPassword:
if len(request.Password) == 0 {
err = protocol.NewCustomMessage(1, "密码不能为空!")
return
}
if loginSvr.ManagerLogin(request.Phone, request.Password) != nil && loginSvr.PartnerLogin(request.Phone, request.Password) != nil {
err = protocol.NewCustomMessage(1, "密码输入有误!")
return
... ... @@ -342,12 +370,12 @@ func LoginV2(header *protocol.RequestHeader, request *protocol.LoginRequestV2) (
}
switch request.IdType {
case int(protocolx.AdminTypePartner):
if p, e := PartnerInfoRepository.FindOne(map[string]interface{}{"account": claim.Phone, "company_id": request.Cid, "status": 1}); e == nil {
if p, e := PartnerInfoRepository.FindOne(map[string]interface{}{"account": claim.Phone, "companyId": request.Cid, "status": 1}); e == nil {
userId = p.Id
}
break
case int(protocolx.AdminTypeManager):
if p, e := UsersRepository.FindOne(map[string]interface{}{"phone": claim.Phone, "company_id": request.Cid, "status": 1}); e == nil {
if p, e := UsersRepository.FindOne(map[string]interface{}{"phone": claim.Phone, "companyId": request.Cid, "status": 1}); e == nil {
userId = p.Id
}
break
... ... @@ -360,7 +388,13 @@ func LoginV2(header *protocol.RequestHeader, request *protocol.LoginRequestV2) (
return
}
//根据simnum + cid
rsp.AuthCode, _ = utils.GenerateTokenWithAdminType(userId, claim.Phone, request.IdType, protocol.AuthCodeExpire*time.Second)
userClaims := utils.UserTokenClaims{
UserId: userId,
CompanyId: claim.CompanyId,
AdminType: claim.AdminType,
Phone: claim.Phone,
}
rsp.AuthCode, _ = utils.GenerateTokenWithClaim(userClaims, protocol.AuthCodeExpire*time.Second)
err = transactionContext.CommitTransaction()
return
... ...
... ... @@ -9,7 +9,7 @@ type BusinessBonus struct {
// 公司编号
CompanyId int64 `json:"companyId"`
// 合伙人信息Id
PartnerInfoId string `json:"partnerInfoId"`
PartnerInfoId int64 `json:"partnerInfoId"`
// 应收分红
Bonus float64 `json:"bonus"`
// 未收分红
... ...
... ... @@ -2,6 +2,7 @@ package dao
import (
"fmt"
"github.com/go-pg/pg/v10"
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/domain"
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/infrastructure/pg/models"
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/infrastructure/pg/transaction"
... ... @@ -46,14 +47,14 @@ func (dao *OrderBaseDao) OrderStatics(option *domain.OrderStaticQuery) (count in
//订单分红统计
func (dao *OrderBaseDao) OrderBonusStatics(option domain.OrderBonusQuery) (rsp domain.OrderBonusResponse, err error) {
rsp = domain.OrderBonusResponse{}
if option.PartnerId == 0 && option.CompanyId == 0 {
if option.PartnerId == 0 && option.CompanyId == 0 && len(option.InPartnerIds) == 0 {
return
}
tx := dao.transactionContext.PgTx
q := tx.Model(new(models.OrderBase))
q.ExcludeColumn("count(*) count")
q.ExcludeColumn("sum(plan_partner_bonus) bonus")
q.ExcludeColumn("sum(bonus_expense) bonus_expense")
q.ColumnExpr("count(*) count")
q.ColumnExpr("sum(plan_partner_bonus) bonus")
q.ColumnExpr("sum(partner_bonus_expense) bonus_expense")
if option.PartnerId > 0 {
q.Where(`"order_base".partner_id =?`, option.PartnerId)
}
... ... @@ -61,7 +62,7 @@ func (dao *OrderBaseDao) OrderBonusStatics(option domain.OrderBonusQuery) (rsp d
q.Where(`"order_base".company_id =?`, option.CompanyId)
}
if len(option.InPartnerIds) > 0 {
q.Where(`"order_base".partner_id in (?)`, option.InPartnerIds)
q.Where(`"order_base".partner_id in (?)`, pg.In(option.InPartnerIds))
}
err = q.Select(&rsp.Total, &rsp.Bonus, &rsp.BonusExpense)
return
... ...
... ... @@ -7,6 +7,7 @@ import (
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/infrastructure/pg/transaction"
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/infrastructure/repository"
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/infrastructure/utils"
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/log"
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/protocol"
protocolx "gitlab.fjmaimaimai.com/mmm-go/partner/pkg/protocol/auth"
"strings"
... ... @@ -26,8 +27,8 @@ func (svr *PgLoginService) Init(phone string) (err error) {
UsersRepository, _ = repository.NewUsersRepository(svr.transactionContext)
)
svr.Phone = phone
_, svr.PartnerInfo, _ = PartnerInfoService.Find(map[string]interface{}{"account": phone, "status": 1, "sortByCreateTime": "ASC"})
_, svr.Users, _ = UsersRepository.Find(map[string]interface{}{"phone": phone, "status": 1, "sortByCreateTime": "ASC"})
_, svr.PartnerInfo, err = PartnerInfoService.Find(map[string]interface{}{"account": phone, "status": 1, "sortByCreateTime": "ASC"})
_, svr.Users, err = UsersRepository.Find(map[string]interface{}{"phone": phone, "status": 1, "sortByCreateTime": "ASC"})
return nil
}
... ... @@ -64,7 +65,7 @@ func (svr *PgLoginService) PartnerStaticInfo() (interface{}, error) {
OrderDao, _ = dao.NewOrderBaseDao(svr.transactionContext)
PartnerCategoryInfoRepository, _ = repository.NewPartnerCategoryInfoRepository(svr.transactionContext)
companyList []*domain.Company
partnerCategory []*domain.PartnerCategoryInfo
allPartnerCategory []*domain.PartnerCategoryInfo
)
doGetCompanyIds := func() []int64 {
var companies []int64
... ... @@ -81,12 +82,18 @@ func (svr *PgLoginService) PartnerStaticInfo() (interface{}, error) {
return array
}
companyList = svr.GetCompanyList(doGetCompanyIds)
if len(companyList) == 0 {
return nil, nil
}
totalBonus, e := OrderDao.OrderBonusStatics(domain.OrderBonusQuery{InPartnerIds: doGetPartnerIds()})
if e != nil {
return nil, e
}
_, partnerCategory, _ = PartnerCategoryInfoRepository.Find(map[string]interface{}{"sortById": domain.ASC})
_, allPartnerCategory, e = PartnerCategoryInfoRepository.Find(map[string]interface{}{"sortById": domain.ASC})
if e != nil {
log.Error(e)
return nil, e
}
var companys = make([]*Company, 0)
for i := range companyList {
c := companyList[i]
... ... @@ -103,9 +110,15 @@ func (svr *PgLoginService) PartnerStaticInfo() (interface{}, error) {
bonus, _ := OrderDao.OrderBonusStatics(domain.OrderBonusQuery{PartnerId: partner.Id})
item := &Company{
CompanyBase: newCompanyBase(c),
IncomePercent: computeBonusPercent(totalBonus.Bonus, bonus.Bonus),
IncomePercent: computeBonusPercent(totalBonus.Bonus*100, bonus.Bonus),
DividendMoney: bonus.Bonus,
JoinWays: svr.GetJoinWays(partnerCategory, partner, bonus.Bonus),
JoinWays: svr.GetJoinWays(allPartnerCategory, partner, bonus.Bonus),
}
//当所有公司的总收入都为0时,
//初始总收入=公司数*1
//否则计算的比例始终都为0
if totalBonus.Bonus == 0 {
item.IncomePercent = computeBonusPercent(totalBonus.Bonus+float64(len(companyList)), bonus.Bonus+1) * 100
}
companys = append(companys, item)
}
... ... @@ -125,8 +138,8 @@ func (svr *PgLoginService) ManagerStaticInfo() (interface{}, error) {
)
doGetCompanyIds := func() []int64 {
var companies []int64
for i := range svr.PartnerInfo {
companies = append(companies, svr.PartnerInfo[i].CompanyId)
for i := range svr.Users {
companies = append(companies, svr.Users[i].CompanyId)
}
return companies
}
... ... @@ -146,7 +159,7 @@ func (svr *PgLoginService) ManagerStaticInfo() (interface{}, error) {
}
response := make(map[string]interface{})
response["id"] = protocolx.AdminTypePartner
response["id"] = protocolx.AdminTypeManager
response["name"] = svr.PartnerInfo[0].PartnerName
response["companys"] = companys
return response, nil
... ... @@ -160,7 +173,8 @@ func (svr *PgLoginService) GetCompanyList(funcGetCompanyIds func() []int64) []*d
if len(companies) == 0 {
return companyList
}
if _, v, e := CompanyRepository.Find(map[string]interface{}{"companies": companies, "status": 1, "sortByCreateTime": domain.ASC}); e != nil {
if _, v, e := CompanyRepository.Find(map[string]interface{}{"inCompanyIds": companies, "status": 1, "sortByCreateTime": domain.ASC}); e != nil {
log.Error(e)
return companyList
} else {
companyList = v
... ... @@ -176,23 +190,23 @@ func (svr *PgLoginService) GetJoinWays(partnerCategory []*domain.PartnerCategory
}
return nil
}
var (
totalBonus float64
businessBonus float64
BusinessBonusRepository, _ = repository.NewBusinessBonusRepository(svr.transactionContext)
)
for i := range partnerInfo.PartnerCategoryInfos {
c := partnerInfo.PartnerCategoryInfos[i]
switch c.Id {
case 1:
totalBonus += bonus
case 2:
if one, e := BusinessBonusRepository.FindOne(map[string]interface{}{"partner_id": partnerInfo.Id}); e == nil {
businessBonus = one.Bonus
totalBonus += businessBonus
}
}
}
//var (
// totalBonus float64
// businessBonus float64
// BusinessBonusRepository, _ = repository.NewBusinessBonusRepository(svr.transactionContext)
//)
//for i := range partnerInfo.PartnerCategoryInfos {
// c := partnerInfo.PartnerCategoryInfos[i]
// switch c.Id {
// case 1:
// totalBonus += bonus
// case 2:
// if one, e := BusinessBonusRepository.FindOne(map[string]interface{}{"partner_id": partnerInfo.Id}); e == nil {
// businessBonus = one.Bonus
// totalBonus += businessBonus
// }
// }
//}
var joinWays []joinWay
for i := range partnerCategory {
c := partnerCategory[i]
... ... @@ -202,14 +216,18 @@ func (svr *PgLoginService) GetJoinWays(partnerCategory []*domain.PartnerCategory
Type: int(c.Id),
Name: c.Name,
}
if c.Id == 1 {
item.Percent = computeBonusPercent(totalBonus, bonus) * 100
} else if c.Id == 2 {
item.Percent = computeBonusPercent(totalBonus, businessBonus) * 100
}
//if c.Id == 1 {
// item.Percent = computeBonusPercent(totalBonus, bonus) * 100
//} else if c.Id == 2 {
// item.Percent = computeBonusPercent(totalBonus, businessBonus) * 100
//}
joinWays = append(joinWays, item)
}
}
for i := range joinWays {
joinWays[i].Percent = computeBonusPercent(float64(len(joinWays)), 1) * 100
}
return joinWays
}
func newCompanyBase(company *domain.Company) protocol.CompanyBase {
... ...
... ... @@ -10,7 +10,7 @@ type BusinessBonus struct {
// 公司编号
CompanyId int64
// 合伙人信息Id
PartnerInfoId string
PartnerInfoId int64
// 应收分红
Bonus float64
// 未收分红
... ...
... ... @@ -49,7 +49,7 @@ func (repository *BusinessBonusRepository) FindOne(queryOptions map[string]inter
BusinessBonusModel := new(models.BusinessBonus)
query := NewQuery(tx.Model(BusinessBonusModel), queryOptions)
query.SetWhere("id = ?", "id")
query.SetWhere("partner_id = ?", "partner_id")
query.SetWhere("partner_info_id = ?", "partner_id")
if err := query.First(); err != nil {
return nil, fmt.Errorf("query row not found")
}
... ...
package repository
import (
"github.com/go-pg/pg/v10"
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/domain"
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/infrastructure/pg/models"
"gitlab.fjmaimaimai.com/mmm-go/partner/pkg/infrastructure/pg/transaction"
... ... @@ -62,8 +63,9 @@ func (repository *CompanyRepository) Find(queryOptions map[string]interface{}) (
var CompanyModels []*models.Company
Companys := make([]*domain.Company, 0)
query := NewQuery(tx.Model(&CompanyModels), queryOptions)
if companies, ok := queryOptions["companies"]; ok {
query.WhereIn("id in (?)", companies)
if companies, ok := queryOptions["inCompanyIds"]; ok {
companyIds, _ := companies.([]int64)
query.Where("id in (?)", pg.In(companyIds))
}
query.SetWhere("status = ?", "status")
query.SetOrder(`create_at`, "sortByCreateTime")
... ...
... ... @@ -53,6 +53,8 @@ func (repository *PartnerInfoRepository) FindOne(queryOptions map[string]interfa
query := NewQuery(tx.Model(PartnerInfoModel), queryOptions)
query.SetWhere("partner_info.id = ?", "id")
query.SetWhere("partner_info.account = ?", "account")
query.SetWhere("partner_info.status = ?", "status")
query.SetWhere("partner_info.company_id = ?", "companyId")
if err := query.First(); err != nil {
return nil, query.HandleError(err, "没有此合伙人")
}
... ...
... ... @@ -48,7 +48,10 @@ func (repository *UsersRepository) FindOne(queryOptions map[string]interface{})
tx := repository.transactionContext.PgTx
UsersModel := new(models.Users)
query := NewQuery(tx.Model(UsersModel), queryOptions)
query.SetWhere("id = ?", "id")
query.SetWhere("status = ?", "status")
query.SetWhere("phone = ?", "phone")
query.SetWhere("company_id = ?", "companyId")
if err := query.First(); err != nil {
return nil, fmt.Errorf("query row not found")
}
... ...
... ... @@ -72,3 +72,22 @@ func GenerateTokenWithAdminType(uid int64, phone string, adminType int, expire t
token, err := tokenClaims.SignedString(jwtSecret)
return token, err
}
func GenerateTokenWithClaim(claim UserTokenClaims, expire time.Duration) (string, error) {
now := time.Now()
expireTime := now.Add(expire)
claims := UserTokenClaims{
UserId: claim.UserId,
Phone: claim.Phone,
CompanyId: claim.CompanyId,
AdminType: claim.AdminType,
StandardClaims: jwt.StandardClaims{
ExpiresAt: expireTime.Unix(),
Issuer: "jwt",
},
}
tokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token, err := tokenClaims.SignedString(jwtSecret)
return token, err
}
... ...
... ... @@ -181,10 +181,6 @@ func (this *AuthController) CenterCompanys() {
msg = m
return
}
if len(request.Password) == 0 {
msg = protocol.NewResponseMessage(1, "密码不能为空!")
return
}
header := this.GetRequestHeader(this.Ctx)
data, err := auth.CenterCompanys(header, request)
if err != nil {
... ... @@ -209,6 +205,10 @@ func (this *AuthController) Companys() {
msg = m
return
}
if request.ClientId != clientId {
msg = protocol.NewResponseMessage(101, "clientId无效")
return
}
header := this.GetRequestHeader(this.Ctx)
data, err := auth.Companys(header, request)
if err != nil {
... ...
... ... @@ -16,7 +16,9 @@ func CheckJWTToken(ctx *context.Context) {
if strings.HasSuffix(ctx.Request.RequestURI, "login") ||
strings.HasSuffix(ctx.Request.RequestURI, "accessToken") ||
strings.HasSuffix(ctx.Request.RequestURI, "refreshToken") ||
strings.HasSuffix(ctx.Request.RequestURI, "smsCode") {
strings.HasSuffix(ctx.Request.RequestURI, "smsCode") ||
strings.HasSuffix(ctx.Request.RequestURI, "centerCompanys") ||
strings.HasSuffix(ctx.Request.RequestURI, "centerCompanys") {
return
}
defer func() {
... ...
... ... @@ -32,10 +32,12 @@ func init() {
nsV1.Router("/dividend/statistics", &controllers.DividendController{}, "Post:DividendStatistics")
nsV1.Router("/dividend/orders", &controllers.DividendController{}, "Post:DividendOrders")
beego.AddNamespace(nsV1)
InitV2()
}
func InitV2() {
nsV2 := beego.NewNamespace("v2", beego.NSBefore(middleware.CheckJWTToken))
nsV2 := beego.NewNamespace("v2") // , beego.NSBefore(middleware.CheckJWTToken)
nsV2.Router("/auth/login", &controllers.AuthController{}, "Post:LoginV2")
beego.AddNamespace(nsV2)
}
... ...