作者 yangfu

登录修改

@@ -117,3 +117,14 @@ func GetUserCompanyIdAll(companyId int) (v []int64, err error) { @@ -117,3 +117,14 @@ func GetUserCompanyIdAll(companyId int) (v []int64, err error) {
117 } 117 }
118 return nil, err 118 return nil, err
119 } 119 }
  120 +
  121 +//获取用户所有的公司列表
  122 +//@uid 表user.id
  123 +func GetUserAllCompany(uid int64) (v []*UserCompany, err error) {
  124 + o := orm.NewOrm()
  125 + sql := "select * from user_company where user_id=? and enable=1" //and enable=1
  126 + if _, err = o.Raw(sql, uid).QueryRows(&v); err == nil {
  127 + return v, nil
  128 + }
  129 + return nil, err
  130 +}
@@ -275,25 +275,23 @@ func SwitchCompany(header *protocol.RequestHeader, request *protocol.SwitchCompa @@ -275,25 +275,23 @@ func SwitchCompany(header *protocol.RequestHeader, request *protocol.SwitchCompa
275 //用户信息 275 //用户信息
276 func UserInfo(header *protocol.RequestHeader, request *protocol.UserInfoRequest) (rsp *protocol.UserInfoResponse, err error) { 276 func UserInfo(header *protocol.RequestHeader, request *protocol.UserInfoRequest) (rsp *protocol.UserInfoResponse, err error) {
277 var ( 277 var (
278 - companyId int64  
279 - userCompany *models.UserCompany  
280 - userAuth *models.UserAuth  
281 - userBaseAgg *protocol.UserBaseInfoAggregation  
282 - companys []*models.Company 278 + companyId int64
  279 + userCompany *models.UserCompany
  280 + userAuth *models.UserAuth
  281 + userBaseAgg *protocol.UserBaseInfoAggregation
  282 + companys []*models.Company
  283 + userCompanys []*models.UserCompany
283 ) 284 )
284 if companys, err = models.GetCompanyByPermission(header.Uid); err != nil { 285 if companys, err = models.GetCompanyByPermission(header.Uid); err != nil {
285 log.Error(err) 286 log.Error(err)
286 return 287 return
287 } 288 }
288 -  
289 - //保证用户登录期间公司有权限  
290 - //if len(companys) == 0 {  
291 - // //无权限  
292 - // err = protocol.NewErrWithMessage(2002)  
293 - // return  
294 - //}  
295 - for i := range companys {  
296 - if companys[i].Id == header.CompanyId { 289 + if userCompanys, err = models.GetUserAllCompany(header.Uid); err != nil {
  290 + log.Error(err)
  291 + return
  292 + }
  293 + for i := range userCompanys {
  294 + if userCompanys[i].Id == header.UserId && userCompanys[i].CompanyId == header.CompanyId {
297 companyId = header.CompanyId 295 companyId = header.CompanyId
298 break 296 break
299 } 297 }