作者 唐旭辉

bug 修复

@@ -69,6 +69,8 @@ var AuthToken = func(ctx *context.Context) { @@ -69,6 +69,8 @@ var AuthToken = func(ctx *context.Context) {
69 redisdata.RefreshLoginTokenExpires(mtoken.UID, mtoken.CompanyID) 69 redisdata.RefreshLoginTokenExpires(mtoken.UID, mtoken.CompanyID)
70 ctx.Input.SetData(protocol.HeaderCompanyid, mtoken.CompanyID) 70 ctx.Input.SetData(protocol.HeaderCompanyid, mtoken.CompanyID)
71 ctx.Input.SetData(protocol.HeaderUserid, mtoken.UID) 71 ctx.Input.SetData(protocol.HeaderUserid, mtoken.UID)
  72 + ctx.Input.SetData(protocol.HeaderUCompanyid, mtoken.UserCompanyId)
  73 + log.Info("c=%d,u=%d,cu=%d", mtoken.CompanyID, mtoken.UID, mtoken.UserCompanyId)
72 return 74 return
73 } 75 }
74 if ok := serveauth.IsJwtErrorExpired(err); ok { 76 if ok := serveauth.IsJwtErrorExpired(err); ok {
@@ -91,6 +91,7 @@ func UpdateCompanyById(m *Company, col []string, om ...orm.Ormer) (err error) { @@ -91,6 +91,7 @@ func UpdateCompanyById(m *Company, col []string, om ...orm.Ormer) (err error) {
91 } 91 }
92 var num int64 92 var num int64
93 m.UpdateAt = time.Now() 93 m.UpdateAt = time.Now()
  94 + col = append(col, "UpdateAt")
94 if num, err = o.Update(m, col...); err == nil { 95 if num, err = o.Update(m, col...); err == nil {
95 fmt.Println("Number of records updated in database:", num) 96 fmt.Println("Number of records updated in database:", num)
96 } 97 }
1 package models 1 package models
2 2
3 import ( 3 import (
4 - "errors"  
5 "fmt" 4 "fmt"
6 "oppmg/common/log" 5 "oppmg/common/log"
7 "time" 6 "time"
@@ -94,19 +93,16 @@ func UpdateUserCompanyById(m *UserCompany, col []string, om ...orm.Ormer) (err e @@ -94,19 +93,16 @@ func UpdateUserCompanyById(m *UserCompany, col []string, om ...orm.Ormer) (err e
94 93
95 func GetUserCompanyBy(userid int64, companyId int64) (*UserCompany, error) { 94 func GetUserCompanyBy(userid int64, companyId int64) (*UserCompany, error) {
96 o := orm.NewOrm() 95 o := orm.NewOrm()
97 - var data []*UserCompany  
98 - _, err := o.QueryTable(&UserCompany{}). 96 + var data UserCompany
  97 + err := o.QueryTable(&UserCompany{}).
99 Filter("user_id", userid). 98 Filter("user_id", userid).
100 Filter("company_id", companyId). 99 Filter("company_id", companyId).
101 Filter("delete_at", 0). 100 Filter("delete_at", 0).
102 - All(&data) 101 + One(&data)
103 if err != nil { 102 if err != nil {
104 return nil, err 103 return nil, err
105 } 104 }
106 - if len(data) == 0 {  
107 - return nil, errors.New("UserCompany not found")  
108 - }  
109 - return data[0], nil 105 + return &data, nil
110 } 106 }
111 107
112 func ExistUserCompany(userid int64, companyId int64) bool { 108 func ExistUserCompany(userid int64, companyId int64) bool {
@@ -9,8 +9,9 @@ const ( @@ -9,8 +9,9 @@ const (
9 9
10 //用来存储从token中解析出来的内容对应的键名 10 //用来存储从token中解析出来的内容对应的键名
11 const ( 11 const (
12 - HeaderCompanyid string = "header_companyid"  
13 - HeaderUserid string = "header_userid" 12 + HeaderCompanyid string = "header_companyid"
  13 + HeaderUserid string = "header_userid"
  14 + HeaderUCompanyid string = "header_ucompanyid"
14 ) 15 )
15 16
16 //BaseHeader 请求的header数据 17 //BaseHeader 请求的header数据
@@ -436,6 +436,7 @@ func TemplateOperateCategory(uid, companyId int64, request *protocol.TemplateOpe @@ -436,6 +436,7 @@ func TemplateOperateCategory(uid, companyId int64, request *protocol.TemplateOpe
436 if chanceType.CompanyId != int(companyId) { 436 if chanceType.CompanyId != int(companyId) {
437 err = protocol.NewErrWithMessage("10027") 437 err = protocol.NewErrWithMessage("10027")
438 log.Error("template_id:%v companyId:%v want:%v not equal.", request.Id, companyId, chanceType.CompanyId) 438 log.Error("template_id:%v companyId:%v want:%v not equal.", request.Id, companyId, chanceType.CompanyId)
  439 + return
439 } 440 }
440 if err = utils.UpdateTableByMap(chanceType, map[string]interface{}{"Name": request.Name, "Icon": request.Icon, "SortNum": chanceType.SortNum, "UpdateAt": time.Now()}); err != nil { 441 if err = utils.UpdateTableByMap(chanceType, map[string]interface{}{"Name": request.Name, "Icon": request.Icon, "SortNum": chanceType.SortNum, "UpdateAt": time.Now()}); err != nil {
441 log.Error(err.Error()) 442 log.Error(err.Error())
@@ -134,7 +134,7 @@ func ChangeLoginToken(userid, companyid int64) (protocol.LoginAuthToken, error) @@ -134,7 +134,7 @@ func ChangeLoginToken(userid, companyid int64) (protocol.LoginAuthToken, error)
134 log.Debug("无效公司") 134 log.Debug("无效公司")
135 return logintoken, protocol.NewErrWithMessage("10027") 135 return logintoken, protocol.NewErrWithMessage("10027")
136 } 136 }
137 - logintoken, err = GenerateAuthToken(userid, companydata.Id) 137 + logintoken, err = GenerateAuthToken(userid, companydata.Id, usercompany.Id)
138 if err != nil { 138 if err != nil {
139 log.Error("GenerateAuthToken err:%s", err) 139 log.Error("GenerateAuthToken err:%s", err)
140 return logintoken, protocol.NewErrWithMessage("1") 140 return logintoken, protocol.NewErrWithMessage("1")
@@ -142,38 +142,38 @@ func ChangeLoginToken(userid, companyid int64) (protocol.LoginAuthToken, error) @@ -142,38 +142,38 @@ func ChangeLoginToken(userid, companyid int64) (protocol.LoginAuthToken, error)
142 return logintoken, nil 142 return logintoken, nil
143 } 143 }
144 144
145 -func RefreshLoginToken(refreshtoken string) (protocol.LoginAuthToken, error) {  
146 - var (  
147 - logintoken protocol.LoginAuthToken  
148 - mtoken *MyToken  
149 - err error  
150 - storetoken redisdata.RedisLoginToken  
151 - )  
152 - mtoken, err = ValidJWTToken(refreshtoken)  
153 - if err != nil {  
154 - log.Debug("token失效 err:%s", err)  
155 - return logintoken, protocol.NewErrWithMessage("10024")  
156 - }  
157 - storetoken, err = redisdata.GetLoginToken(mtoken.UID, mtoken.CompanyID)  
158 - if err != nil {  
159 - log.Error("redis err:%s", err)  
160 - return logintoken, protocol.NewErrWithMessage("10024")  
161 - }  
162 - if storetoken.RefreshToken != refreshtoken {  
163 - return logintoken, protocol.NewErrWithMessage("10024")  
164 - }  
165 - logintoken, _ = GenerateAuthToken(mtoken.UID, mtoken.CompanyID)  
166 - return logintoken, nil  
167 -} 145 +// func RefreshLoginToken(refreshtoken string) (protocol.LoginAuthToken, error) {
  146 +// var (
  147 +// logintoken protocol.LoginAuthToken
  148 +// mtoken *MyToken
  149 +// err error
  150 +// storetoken redisdata.RedisLoginToken
  151 +// )
  152 +// mtoken, err = ValidJWTToken(refreshtoken)
  153 +// if err != nil {
  154 +// log.Debug("token失效 err:%s", err)
  155 +// return logintoken, protocol.NewErrWithMessage("10024")
  156 +// }
  157 +// storetoken, err = redisdata.GetLoginToken(mtoken.UID, mtoken.CompanyID)
  158 +// if err != nil {
  159 +// log.Error("redis err:%s", err)
  160 +// return logintoken, protocol.NewErrWithMessage("10024")
  161 +// }
  162 +// if storetoken.RefreshToken != refreshtoken {
  163 +// return logintoken, protocol.NewErrWithMessage("10024")
  164 +// }
  165 +// logintoken, _ = GenerateAuthToken(mtoken.UID, mtoken.CompanyID)
  166 +// return logintoken, nil
  167 +// }
168 168
169 func LoginAuthByUCenter(account, password string) (protocol.LoginAuthToken, error) { 169 func LoginAuthByUCenter(account, password string) (protocol.LoginAuthToken, error) {
170 var ( 170 var (
171 - err error  
172 - logintoken protocol.LoginAuthToken  
173 -  
174 - companys []companybase  
175 - companyid int64  
176 - userdata *models.User 171 + err error
  172 + logintoken protocol.LoginAuthToken
  173 + usercompanyid int64
  174 + companys []companybase
  175 + companyid int64
  176 + userdata *models.User
177 ) 177 )
178 var uclientReturn *ucenter.ResponseLogin 178 var uclientReturn *ucenter.ResponseLogin
179 uclientReturn, err = ucenter.RequestUCenterLogin(account, password) 179 uclientReturn, err = ucenter.RequestUCenterLogin(account, password)
@@ -204,11 +204,28 @@ func LoginAuthByUCenter(account, password string) (protocol.LoginAuthToken, erro @@ -204,11 +204,28 @@ func LoginAuthByUCenter(account, password string) (protocol.LoginAuthToken, erro
204 //获取上一次登录的公司 204 //获取上一次登录的公司
205 uAuth, err := models.GetUserAuthByUser(userdata.Id) 205 uAuth, err := models.GetUserAuthByUser(userdata.Id)
206 if err == nil { 206 if err == nil {
207 - companyid = uAuth.CurrentCompanyId 207 + has := false
  208 + for _, v := range companys {
  209 + if v.Id == uAuth.CurrentCompanyId {
  210 + companyid = uAuth.CurrentCompanyId
  211 + has = true
  212 + break
  213 + }
  214 + }
  215 + if !has {
  216 + companyid = companys[0].Id
  217 + }
  218 +
208 } else { 219 } else {
209 companyid = companys[0].Id 220 companyid = companys[0].Id
210 } 221 }
211 - 222 + ucompany, err := models.GetUserCompanyBy(userdata.Id, companyid)
  223 + if err != nil {
  224 + log.Error("获取user_company失败;%s", err)
  225 + return logintoken, protocol.NewErrWithMessage("1")
  226 + }
  227 + usercompanyid = ucompany.Id
  228 + logintoken, _ = GenerateAuthToken(userdata.Id, companyid, usercompanyid)
212 //更新用户数据 229 //更新用户数据
213 userdata.Accid = uclientReturn.Data.Accid 230 userdata.Accid = uclientReturn.Data.Accid
214 userdata.Icon = uclientReturn.Data.Avatar 231 userdata.Icon = uclientReturn.Data.Avatar
@@ -220,7 +237,6 @@ func LoginAuthByUCenter(account, password string) (protocol.LoginAuthToken, erro @@ -220,7 +237,6 @@ func LoginAuthByUCenter(account, password string) (protocol.LoginAuthToken, erro
220 if err != nil { 237 if err != nil {
221 log.Error("更新用户数据失败:%s", err) 238 log.Error("更新用户数据失败:%s", err)
222 } 239 }
223 - logintoken, _ = GenerateAuthToken(userdata.Id, companyid)  
224 return logintoken, err 240 return logintoken, err
225 } 241 }
226 242
@@ -329,11 +345,12 @@ func GetUserHasMenu(userid, companyid int64) ([]protocol.PermissionItem, error) @@ -329,11 +345,12 @@ func GetUserHasMenu(userid, companyid int64) ([]protocol.PermissionItem, error)
329 345
330 func LoginAuthBySmsCode(phone string, code string) (protocol.LoginAuthToken, error) { 346 func LoginAuthBySmsCode(phone string, code string) (protocol.LoginAuthToken, error) {
331 var ( 347 var (
332 - err error  
333 - logintoken protocol.LoginAuthToken  
334 - companys []companybase  
335 - companyid int64  
336 - userdata *models.User 348 + err error
  349 + logintoken protocol.LoginAuthToken
  350 + companys []companybase
  351 + companyid int64
  352 + usercompanyid int64
  353 + userdata *models.User
337 ) 354 )
338 var uclientReturn *ucenter.ResponseLoginSms 355 var uclientReturn *ucenter.ResponseLoginSms
339 uclientReturn, err = ucenter.RequestUCenterLoginSms(phone, code) 356 uclientReturn, err = ucenter.RequestUCenterLoginSms(phone, code)
@@ -364,11 +381,27 @@ func LoginAuthBySmsCode(phone string, code string) (protocol.LoginAuthToken, err @@ -364,11 +381,27 @@ func LoginAuthBySmsCode(phone string, code string) (protocol.LoginAuthToken, err
364 //获取上一次登录的公司 381 //获取上一次登录的公司
365 uAuth, err := models.GetUserAuthByUser(userdata.Id) 382 uAuth, err := models.GetUserAuthByUser(userdata.Id)
366 if err == nil { 383 if err == nil {
367 - companyid = uAuth.CurrentCompanyId 384 + has := false
  385 + for _, v := range companys {
  386 + if v.Id == uAuth.CurrentCompanyId {
  387 + companyid = uAuth.CurrentCompanyId
  388 + has = true
  389 + break
  390 + }
  391 + }
  392 + if !has {
  393 + companyid = companys[0].Id
  394 + }
368 } else { 395 } else {
369 companyid = companys[0].Id 396 companyid = companys[0].Id
370 } 397 }
371 - 398 + ucompany, err := models.GetUserCompanyBy(userdata.Id, companyid)
  399 + if err != nil {
  400 + log.Error("获取user_company失败;%s", err)
  401 + return logintoken, protocol.NewErrWithMessage("1")
  402 + }
  403 + usercompanyid = ucompany.Id
  404 + logintoken, _ = GenerateAuthToken(userdata.Id, companyid, usercompanyid)
372 //更新用户数据 405 //更新用户数据
373 userdata.Accid = uclientReturn.Data.CsAccountID 406 userdata.Accid = uclientReturn.Data.CsAccountID
374 userdata.Icon = uclientReturn.Data.Image.Path 407 userdata.Icon = uclientReturn.Data.Image.Path
@@ -379,7 +412,6 @@ func LoginAuthBySmsCode(phone string, code string) (protocol.LoginAuthToken, err @@ -379,7 +412,6 @@ func LoginAuthBySmsCode(phone string, code string) (protocol.LoginAuthToken, err
379 if err != nil { 412 if err != nil {
380 log.Error("更新用户数据失败:%s", err) 413 log.Error("更新用户数据失败:%s", err)
381 } 414 }
382 - logintoken, _ = GenerateAuthToken(userdata.Id, companyid)  
383 return logintoken, err 415 return logintoken, err
384 } 416 }
385 417
@@ -15,12 +15,13 @@ var ( @@ -15,12 +15,13 @@ var (
15 //MyToken ... 15 //MyToken ...
16 type MyToken struct { 16 type MyToken struct {
17 jwt.StandardClaims 17 jwt.StandardClaims
18 - UID int64 `json:"uid"`  
19 - CompanyID int64 `json:"company_id"` 18 + UID int64 `json:"uid"`
  19 + CompanyID int64 `json:"company_id"`
  20 + UserCompanyId int64 `json:"user_company_id"`
20 } 21 }
21 22
22 //CreateJWTToken ... 23 //CreateJWTToken ...
23 -func CreateJWTToken(uid int64, companyid int64, expires int64) (string, error) { 24 +func CreateJWTToken(uid int64, companyid int64, userCompanyId int64, expires int64) (string, error) {
24 nowTime := time.Now().Unix() 25 nowTime := time.Now().Unix()
25 claims := MyToken{ 26 claims := MyToken{
26 StandardClaims: jwt.StandardClaims{ 27 StandardClaims: jwt.StandardClaims{
@@ -29,8 +30,9 @@ func CreateJWTToken(uid int64, companyid int64, expires int64) (string, error) { @@ -29,8 +30,9 @@ func CreateJWTToken(uid int64, companyid int64, expires int64) (string, error) {
29 ExpiresAt: expires, //过期时间 30 ExpiresAt: expires, //过期时间
30 Issuer: "mmm_oppmg", 31 Issuer: "mmm_oppmg",
31 }, 32 },
32 - UID: uid,  
33 - CompanyID: companyid, 33 + UID: uid,
  34 + CompanyID: companyid,
  35 + UserCompanyId: userCompanyId,
34 } 36 }
35 37
36 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 38 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
@@ -64,27 +66,19 @@ func IsJwtErrorExpired(err error) bool { @@ -64,27 +66,19 @@ func IsJwtErrorExpired(err error) bool {
64 return false 66 return false
65 } 67 }
66 68
67 -func GenerateAuthToken(uid int64, companyid int64) (protocol.LoginAuthToken, error) { 69 +func GenerateAuthToken(uid int64, companyid int64, usercompanyid int64) (protocol.LoginAuthToken, error) {
68 var ( 70 var (
69 authToken protocol.LoginAuthToken 71 authToken protocol.LoginAuthToken
70 accesstoken string //主token,请求用 72 accesstoken string //主token,请求用
71 expiresIn int64 = 60 * 60 * 6 //主token过期时间,6小时 73 expiresIn int64 = 60 * 60 * 6 //主token过期时间,6小时
72 - // refreshtoken string //副token,刷新主token用  
73 - // refreshExpires int64 = 60 * 60 * 2 //副token 过期时间 ,60分钟  
74 - err error  
75 - nowtime = time.Now() 74 + err error
  75 + nowtime = time.Now()
76 ) 76 )
77 - accesstoken, err = CreateJWTToken(uid, companyid, nowtime.Unix()+expiresIn+2) 77 + accesstoken, err = CreateJWTToken(uid, companyid, usercompanyid, nowtime.Unix()+expiresIn+1)
78 if err != nil { 78 if err != nil {
79 return authToken, err 79 return authToken, err
80 } 80 }
81 - // refreshtoken, err = CreateJWTToken(uid, companyid, nowtime.Unix()+refreshExpires+2)  
82 - // if err != nil {  
83 - // return authToken, err  
84 - // }  
85 authToken.AccessToken = accesstoken 81 authToken.AccessToken = accesstoken
86 authToken.ExpiresIn = nowtime.Unix() + expiresIn 82 authToken.ExpiresIn = nowtime.Unix() + expiresIn
87 - // authToken.RefreshToken = refreshtoken  
88 - // authToken.RefreshExpires = nowtime.Unix() + refreshExpires  
89 return authToken, err 83 return authToken, err
90 } 84 }
@@ -197,7 +197,22 @@ func addNewUser(name string, phone string, ucenterId int64, avatar string, accid @@ -197,7 +197,22 @@ func addNewUser(name string, phone string, ucenterId int64, avatar string, accid
197 return usrData.Id, nil 197 return usrData.Id, nil
198 } 198 }
199 if err == nil { 199 if err == nil {
200 - // 用户存在, 200 + usr := &models.User{
  201 + Id: usrData.Id,
  202 + Phone: phone,
  203 + NickName: name,
  204 + Icon: avatar,
  205 + Accid: accid,
  206 + CsAccount: customerAccout,
  207 + UserCenterId: ucenterId,
  208 + }
  209 + // 用户存在,更新用户
  210 + err = models.UpdateUserById(usr,
  211 + []string{"Phone", "NickName", "Icon", "Accid", "CsAccount", "UserCenterId"},
  212 + o)
  213 + if err != nil {
  214 + log.Error("更新用户失败;%s", err)
  215 + }
201 return usrData.Id, nil 216 return usrData.Id, nil
202 } 217 }
203 return 0, err 218 return 0, err