作者 陈志颖

合并分支 'dev' 到 'test'

Dev



查看合并请求 !11
正在显示 61 个修改的文件 包含 2543 行增加1085 行删除

要显示太多修改。

为保证性能只显示 61 of 61+ 个文件。

... ... @@ -25,3 +25,4 @@
/*.exe~
/logs
download
\ No newline at end of file
... ...
... ... @@ -79,9 +79,13 @@ spec:
- name: BUSINESS_ADMIN_HOST
value: "http://suplus-business-admin-dev.fjmaimaimai.com"
- name: KAFKA_HOST
value: "192.168.0.250:9092;192.168.0.251:9092;192.168.0.252:9092"
value: ""
- name: KAFKA_CONSUMER_ID
value: "partnermg_dev"
- name: RUN_MODE
value: "dev"
- name: Log_PREFIX
value: "[partnermg_dev]"
volumes:
- name: accesslogs
emptyDir: {}
... ...
... ... @@ -38,9 +38,6 @@ spec:
- key: kubernetes.io/hostname
operator: In
values:
- cn-hangzhou.i-bp1djh1xn7taumbue1ze
- cn-hangzhou.i-bp1djh1xn7taumbue1zd
- cn-hangzhou.i-bp1euf5u1ph9kbhtndhb
- cn-hangzhou.i-bp1hyp5oips9cdwxxgxy
containers:
- name: mmm-partnermg
... ... @@ -57,11 +54,11 @@ spec:
- name: POSTGRESQL_USER
value: "postgres"
- name: POSTGRESQL_PASSWORD
value: "postgres_55_online"
value: "chJVQkg1sys"
- name: POSTGRESQL_HOST
value: "112.124.115.55"
value: "114.55.200.59"
- name: POSTGRESQL_PORT
value: "15432"
value: "31544"
- name: LOG_LEVEL
value: "info"
- name: ERROR_BASE_CODE
... ... @@ -82,6 +79,10 @@ spec:
value: "192.168.0.250:9092;192.168.0.251:9092;192.168.0.252:9092"
- name: KAFKA_CONSUMER_ID
value: "partnermg_prd"
- name: RUN_MODE
value: "dev"
- name: Log_PREFIX
value: "[partnermg_prd]"
volumes:
- name: accesslogs
emptyDir: {}
... ...
... ... @@ -79,6 +79,10 @@ spec:
value: "192.168.0.250:9092;192.168.0.251:9092;192.168.0.252:9092"
- name: KAFKA_CONSUMER_ID
value: "partnermg_test"
- name: RUN_MODE
value: "dev"
- name: Log_PREFIX
value: "[partnermg_test]"
volumes:
- name: accesslogs
emptyDir: {}
... ...
... ... @@ -7,12 +7,13 @@ require (
github.com/Shopify/sarama v1.23.1
github.com/ajg/form v1.5.1 // indirect
github.com/astaxie/beego v1.12.2
github.com/beego/beego/v2 v2.0.1
github.com/bsm/sarama-cluster v2.1.15+incompatible
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072 // indirect
github.com/fatih/structs v1.1.0 // indirect
github.com/gavv/httpexpect v2.0.0+incompatible
github.com/go-pg/pg/v10 v10.0.0-beta.2
github.com/go-pg/pg/v10 v10.7.3
github.com/google/go-querystring v1.0.0 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/imkira/go-interpol v1.1.0 // indirect
... ... @@ -20,8 +21,9 @@ require (
github.com/linmadan/egglib-go v0.0.0-20191217144343-ca4539f95bf9
github.com/mattn/go-colorable v0.1.6 // indirect
github.com/moul/http2curl v1.0.0 // indirect
github.com/onsi/ginkgo v1.13.0
github.com/onsi/gomega v1.10.1
github.com/onsi/ginkgo v1.14.2
github.com/onsi/gomega v1.10.3
github.com/sclevine/agouti v3.0.0+incompatible // indirect
github.com/sergi/go-diff v1.1.0 // indirect
github.com/shopspring/decimal v1.2.0
github.com/smartystreets/goconvey v1.6.4 // indirect
... ...
package command
import "errors"
//创建订单
type CreateOrderCommand struct {
//订单类型
... ... @@ -18,7 +20,43 @@ type CreateOrderCommand struct {
SalesmanBonusPercent float64 `json:"salesmanBonusPercent"`
//货品
Goods []OrderGoodData `json:"goods"`
//公司id
CompanyId int64 `json:"companyId"`
//合伙人类型
PartnerCategory int64 `json:"partner_category"`
//行号-错误信息返回
LineNumbers []int `json:"lineNumber"`
//合伙人姓名
PartnerName string `json:"partnerName"`
//编号-错误信息返回
Code string `json:"code"`
//合伙人类型名称-错误信息返回
PartnerCategoryName string `json:"partnerCategoryName"`
}
func (postData *CreateOrderCommand) Valid() error {
if len(postData.OrderCode) == 0 {
return errors.New("订单编号必填")
}
if len(postData.BuyerName) == 0 {
return errors.New("买家信息必填")
}
if postData.PartnerId == 0 {
return errors.New("合伙人信息必填")
}
if len(postData.OrderRegion) == 0 {
return errors.New("订单区域必填")
}
if len(postData.Goods) == 0 {
return errors.New("货品列表必填")
}
if len(postData.Goods) > 50 {
return errors.New("货品列表最多50项")
}
for i := range postData.Goods {
if err := postData.Goods[i].Valid(); err != nil {
return err
}
}
return nil
}
... ...
package command
import (
"errors"
"fmt"
"regexp"
"unicode/utf8"
)
type OrderGoodData struct {
//货品id
Id int64 `json:"id"`
... ... @@ -13,4 +20,33 @@ type OrderGoodData struct {
PartnerBonusPercent float64 `json:"partnerBonusPercent"`
//备注信息
Remark string `json:"remark"`
//行号-错误信息返回
LineNumber int `json:"lineNumber"`
}
func (postData OrderGoodData) Valid() error {
lenProductName := utf8.RuneCountInString(postData.GoodName)
if lenProductName == 0 {
return errors.New("商品名称必填")
}
if lenProductName > 50 {
return errors.New("商品名称最多50位")
}
if postData.PlanGoodNumber >= 1e16 {
return errors.New("商品数量最多16位")
}
if postData.Price >= 1e16 {
return errors.New("商品单价最多16位")
}
if postData.PartnerBonusPercent > 100 {
return errors.New("合伙人分红比例超额")
}
partnerRatio := fmt.Sprint(postData.PartnerBonusPercent)
regexpStr := `^(100|[1-9]\d|\d)(.\d{1,2})?$`
ok := regexp.MustCompile(regexpStr).MatchString(partnerRatio)
if !ok {
return errors.New("合伙人分红比例精确到小数点2位")
}
return nil
}
... ...
... ... @@ -19,7 +19,8 @@ type UpdateOrderCommand struct {
OrderType int `json:"orderType"`
//货品
Goods []OrderGoodData `json:"goods"`
//公司id
CompanyId int64 `json:"companyId"`
// 合伙人类型
PartnerCategory int64 `json:"partner_category"`
}
... ...
/**
@author: stevechan
@date: 2021/1/6
@note:
**/
package query
/**
* @Author SteveChan
* @Description //TODO 查询合伙人id
* @Date 23:18 2021/1/6
**/
type GetPartnerIdQuery struct {
Code string `json:"code"`
PartnerCategory int `json:"partnerCategory"`
CompanyId int64 `json:"companyId"`
}
... ...
/**
@author: stevechan
@date: 2021/1/6
@note:
**/
package query
/**
* @Author SteveChan
* @Description //TODO 查询产品id
* @Date 23:18 2021/1/6
**/
type GetProductIdQuery struct {
ProductName int64 `json:"productName"`
}
... ...
... ... @@ -7,14 +7,25 @@ type ListOrderBaseQuery struct {
// 查询限制
Limit int `json:"limit"`
//发货单号
PartnerOrCode string `json:"partnerOrCode"`
//PartnerOrCode string `json:"partnerOrCode"`
//合伙人姓名
PartnerName string `json:"partnerName"`
//订单号
OrderCode string `json:"orderCode"`
//发货单号
DeliveryCode string `json:"deliveryCode"`
//公司id
CompanyId int64 `json:"companyId"`
//订单类型
OrderType int `json:"orderType"`
//合伙人分类
PartnerCategory int `json:"partnerCategory"`
//更新时间开始
UpdateTimeBegin string `json:"updateTimeBegin"`
//更新时间截止
UpdateTimeEnd string `json:"updateTimeEnd"`
//创建时间开始
CreateTimeBegin string `json:"createTimeBegin"`
//创建时间截止
CreateTimeEnd string `json:"createTimeEnd"`
}
... ...
... ... @@ -26,7 +26,13 @@ func NewOrderInfoService(option map[string]interface{}) *OrderInfoService {
return newAdminUserService
}
// PageListOrderBase 获取订单列表
/**
* @Author SteveChan
* @Description // 获取订单列表
* @Date 22:05 2021/1/10
* @Param
* @return
**/
func (service OrderInfoService) PageListOrderBase(listOrderQuery query.ListOrderBaseQuery) ([]map[string]interface{}, int, error) {
var err error
transactionContext, err := factory.CreateTransactionContext(nil)
... ... @@ -53,7 +59,9 @@ func (service OrderInfoService) PageListOrderBase(listOrderQuery query.ListOrder
orders, cnt, err = orderDao.OrderListByCondition(
listOrderQuery.CompanyId,
listOrderQuery.OrderType,
listOrderQuery.PartnerOrCode,
listOrderQuery.PartnerName, // 合伙人姓名
listOrderQuery.OrderCode, // 订单号
listOrderQuery.DeliveryCode, // 发货单号
[2]string{listOrderQuery.UpdateTimeBegin, listOrderQuery.UpdateTimeEnd},
[2]string{listOrderQuery.CreateTimeBegin, listOrderQuery.CreateTimeEnd},
listOrderQuery.PartnerCategory,
... ... @@ -186,7 +194,9 @@ func (service OrderInfoService) CreateNewOrder(cmd command.CreateOrderCommand) (
transactionContext, _ = factory.CreateTransactionContext(nil)
err error
)
if err = cmd.Valid(); err != nil {
return nil, lib.ThrowError(lib.BUSINESS_ERROR, err.Error())
}
if err = transactionContext.StartTransaction(); err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, err.Error())
}
... ... @@ -225,19 +235,21 @@ func (service OrderInfoService) CreateNewOrder(cmd command.CreateOrderCommand) (
}); err != nil {
return nil, lib.ThrowError(lib.TRANSACTION_ERROR, err.Error())
}
//检查order_code是否重复
// if ok, err := orderBaseDao.OrderCodeExist(cmd.OrderCode, cmd.PartnerCategory, cmd.PartnerId); err != nil {
// return nil, lib.ThrowError(lib.TRANSACTION_ERROR, err.Error())
// } else if ok {
// return nil, lib.ThrowError(lib.BUSINESS_ERROR, "订单号已存在")
// }
//检查delivery_code是否重复
if len(cmd.DeliveryCode) > 0 {
if ok, err := orderBaseDao.DeliveryCodeExist(cmd.DeliveryCode, cmd.CompanyId); err != nil {
if ok, err := orderBaseDao.CheckOrderExist(cmd.CompanyId, cmd.OrderCode, cmd.DeliveryCode,
cmd.PartnerCategory, cmd.PartnerId, 0); err != nil {
return nil, lib.ThrowError(lib.TRANSACTION_ERROR, err.Error())
} else if ok {
return nil, lib.ThrowError(lib.BUSINESS_ERROR, "发货号已存在")
return nil, lib.ThrowError(lib.BUSINESS_ERROR, "订单已存在")
}
//检查货品数据
var goodMap = map[string]int{}
for i := range cmd.Goods {
goodname := cmd.Goods[i].GoodName
if _, ok := goodMap[goodname]; ok {
return nil, lib.ThrowError(lib.BUSINESS_ERROR, "订单中货品重复已存在")
}
goodMap[goodname] = 1
}
newOrder := &domain.OrderBase{
OrderType: cmd.OrderType, OrderCode: cmd.OrderCode,
... ... @@ -436,12 +448,19 @@ func (service OrderInfoService) UpdateOrderData(cmd command.UpdateOrderCommand)
// }
// }
//检查delivery_code是否重复
if cmd.DeliveryCode != oldOrderData.DeliveryCode {
if ok, err := orderBaseDao.DeliveryCodeExist(cmd.DeliveryCode, cmd.CompanyId, cmd.Id); err != nil {
if ok, err := orderBaseDao.CheckOrderExist(cmd.CompanyId, cmd.OrderCode, cmd.DeliveryCode, cmd.PartnerCategory, cmd.PartnerId, cmd.Id); err != nil {
return nil, lib.ThrowError(lib.TRANSACTION_ERROR, err.Error())
} else if ok {
return nil, lib.ThrowError(lib.BUSINESS_ERROR, "发货号已存在")
return nil, lib.ThrowError(lib.BUSINESS_ERROR, "订单已存在")
}
//检查货品数据
var goodMap = map[string]int{}
for i := range cmd.Goods {
goodname := cmd.Goods[i].GoodName
if _, ok := goodMap[goodname]; ok {
return nil, lib.ThrowError(lib.BUSINESS_ERROR, "订单中货品重复已存在")
}
goodMap[goodname] = 1
}
//获取旧的订单中的商品
oldOrderGoods, _, err = orderGoodRepository.Find(domain.OrderGoodFindQuery{
... ... @@ -854,6 +873,13 @@ func (service OrderInfoService) ListOrderBonusForExcel(listOrderQuery query.List
return resultMaps, column, nil
}
/**
* @Author SteveChan
* @Description // 导出订单数据
* @Date 22:05 2021/1/10
* @Param
* @return
**/
func (service OrderInfoService) ListOrderForExcel(listOrderQuery query.ListOrderBaseQuery) ([]map[string]string, [][2]string, error) {
transactionContext, err := factory.CreateTransactionContext(nil)
if err != nil {
... ... @@ -865,6 +891,7 @@ func (service OrderInfoService) ListOrderForExcel(listOrderQuery query.ListOrder
defer func() {
transactionContext.RollbackTransaction()
}()
var (
orderBaseDao *dao.OrderBaseDao
)
... ... @@ -876,7 +903,9 @@ func (service OrderInfoService) ListOrderForExcel(listOrderQuery query.ListOrder
}
ordersData, err := orderBaseDao.OrderListForExcel(
listOrderQuery.CompanyId,
listOrderQuery.PartnerOrCode,
listOrderQuery.PartnerName, // 合伙人姓名
listOrderQuery.OrderCode, // 订单号
listOrderQuery.DeliveryCode, // 发货单号
[2]string{listOrderQuery.UpdateTimeBegin, listOrderQuery.UpdateTimeEnd},
[2]string{listOrderQuery.CreateTimeBegin, listOrderQuery.CreateTimeEnd},
listOrderQuery.PartnerCategory,
... ... @@ -930,3 +959,330 @@ func (service OrderInfoService) ListOrderForExcel(listOrderQuery query.ListOrder
}
return resultMaps, column, nil
}
/**
* @Author SteveChan
* @Description //TODO 批量导入创建订单
* @Date 11:00 2021/1/7
* @Param
* @return
**/
func (service OrderInfoService) CreateNewOrderByImport(createOrderCommands []*command.CreateOrderCommand) ([]*domain.ImportInfo, error) {
// 事务初始化
var (
transactionContext, _ = factory.CreateTransactionContext(nil)
err error
errorDataList []*domain.ImportInfo // 错误数据返回
)
// 循环校验命令
for _, cmd := range createOrderCommands {
if err = cmd.Valid(); err != nil {
// 返回信息 0: 订单号, 1: 发货单号, 2: 客户名称, 3: 订单区域, 4: 编号, 5: 合伙人, 6: 类型, 7: 业务抽成比例, 8: 产品名称, 9: 数量, 10: 单价, 11: 合伙人分红比例
row := &domain.ImportInfo{
Error: lib.ThrowError(lib.BUSINESS_ERROR, err.Error()), // 错误信息
LineNumbers: cmd.LineNumbers, // 错误影响的行
GoodLine: map[int]interface{}{},
}
errorDataList = append(errorDataList, row)
continue
}
}
// 开始事务
if err = transactionContext.StartTransaction(); err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, err.Error())
}
defer func() {
transactionContext.RollbackTransaction()
}()
// 仓储、数据访问对象初始化
var (
orderBaseRepository domain.OrderBaseRepository
orderGoodRepository domain.OrderGoodRepository
PartnerInfoRepository domain.PartnerInfoRepository
categoryRepository domain.PartnerCategoryRepository
orderBaseDao *dao.OrderBaseDao
)
// 合伙人信息仓储初始化
if PartnerInfoRepository, err = factory.CreatePartnerInfoRepository(map[string]interface{}{
"transactionContext": transactionContext,
}); err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, err.Error())
}
// 订单仓储初始化
if orderBaseRepository, err = factory.CreateOrderBaseRepository(map[string]interface{}{
"transactionContext": transactionContext,
}); err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, err.Error())
}
// 订单产品仓储初始化
if orderGoodRepository, err = factory.CreateOrderGoodRepository(map[string]interface{}{
"transactionContext": transactionContext,
}); err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, err.Error())
}
// 合伙人类型仓储初始化
if categoryRepository, err = factory.CreatePartnerCategoryRepository(map[string]interface{}{
"transactionContext": transactionContext,
}); err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, err.Error())
}
// 订单数据访问对象初始化
if orderBaseDao, err = factory.CreateOrderBaseDao(map[string]interface{}{
"transactionContext": transactionContext,
}); err != nil {
return nil, lib.ThrowError(lib.TRANSACTION_ERROR, err.Error())
}
// 批量创建订单
for _, cmd := range createOrderCommands {
// 批量校验合伙人信息
var partnerData *domain.PartnerInfo
partnerData, err = PartnerInfoRepository.FindOne(domain.PartnerFindOneQuery{UserId: cmd.PartnerId})
if err != nil {
row := &domain.ImportInfo{
Error: lib.ThrowError(lib.INTERNAL_SERVER_ERROR, fmt.Sprintf("检索合伙人数据失败")),
LineNumbers: cmd.LineNumbers, // 错误影响的行
GoodLine: map[int]interface{}{},
}
errorDataList = append(errorDataList, row)
continue
}
// 批量校验订单
if ok, err := orderBaseDao.CheckOrderExist(cmd.CompanyId, cmd.OrderCode, cmd.DeliveryCode,
cmd.PartnerCategory, cmd.PartnerId, 0); err != nil {
row := &domain.ImportInfo{
Error: lib.ThrowError(lib.TRANSACTION_ERROR, err.Error()),
LineNumbers: cmd.LineNumbers, // 错误影响的行
GoodLine: map[int]interface{}{},
}
errorDataList = append(errorDataList, row)
continue
} else if ok {
row := &domain.ImportInfo{
Error: lib.ThrowError(lib.BUSINESS_ERROR, "订单已存在"),
LineNumbers: cmd.LineNumbers, // 错误影响的行
GoodLine: map[int]interface{}{},
}
errorDataList = append(errorDataList, row)
continue
}
// 批量校验产品
var goodMap = map[string]int{}
goodErrMap := make(map[int]interface{}, 0)
for i := range cmd.Goods {
goodName := cmd.Goods[i].GoodName
if _, ok := goodMap[goodName]; ok {
goodErrMap[cmd.Goods[i].LineNumber] = lib.ThrowError(lib.BUSINESS_ERROR, "订单中货品重复已存在")
continue
}
goodMap[goodName] = 1
}
if len(goodErrMap) > 0 {
row := &domain.ImportInfo{
Error: lib.ThrowError(lib.BUSINESS_ERROR, "订单中货品重复已存在"),
LineNumbers: cmd.LineNumbers, // 错误影响的行
GoodLine: goodErrMap, // 错误产品行号记录
}
errorDataList = append(errorDataList, row)
continue
}
newOrder := &domain.OrderBase{
OrderType: cmd.OrderType, OrderCode: cmd.OrderCode,
DeliveryCode: cmd.DeliveryCode,
Buyer: domain.Buyer{
BuyerName: cmd.BuyerName,
},
RegionInfo: domain.RegionInfo{
RegionName: cmd.OrderRegion,
},
PartnerId: cmd.PartnerId,
PartnerInfo: partnerData.Partner,
SalesmanBonusPercent: cmd.SalesmanBonusPercent,
CompanyId: cmd.CompanyId,
}
// 批量校验合伙人分类数据
var cmdPartnerCategoryOk bool
for _, v := range partnerData.PartnerCategoryInfos {
if v.Id == cmd.PartnerCategory {
_, categories, err := categoryRepository.Find(domain.PartnerCategoryFindQuery{
Ids: []int64{v.Id},
})
if err != nil {
e := fmt.Sprintf("获取合伙人分类数据失败:%s", err)
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, e)
}
if len(categories) > 0 {
newOrder.PartnerCategory = categories[0]
cmdPartnerCategoryOk = true
}
break
}
}
if !cmdPartnerCategoryOk {
row := &domain.ImportInfo{
Error: lib.ThrowError(lib.BUSINESS_ERROR, "合伙人类型选择错误"),
LineNumbers: cmd.LineNumbers, // 错误影响的行
GoodLine: map[int]interface{}{},
}
errorDataList = append(errorDataList, row)
continue
}
// 订单产品分红核算
var orderGoods []domain.OrderGood
orderGoodErrMap := make(map[int]interface{}, 0)
for i, good := range cmd.Goods {
m := domain.NewOrderGood()
m.OrderId = 0
m.GoodName = good.GoodName
m.PlanGoodNumber = good.PlanGoodNumber
m.Price = good.Price
m.PartnerBonusPercent = good.PartnerBonusPercent
m.Remark = good.Remark
m.CompanyId = cmd.CompanyId
err = m.Compute()
if err != nil {
orderGoodErrMap[cmd.Goods[i].LineNumber] = lib.ThrowError(lib.INTERNAL_SERVER_ERROR, fmt.Sprintf("核算订单中商品的数值失败:%s", err))
continue
}
err = m.CurrentBonusStatus.WartPayPartnerBonus(&m)
if err != nil {
orderGoodErrMap[cmd.Goods[i].LineNumber] = lib.ThrowError(lib.INTERNAL_SERVER_ERROR, fmt.Sprintf("核算订单中商品的分红数值失败:%s", err))
continue
}
orderGoods = append(orderGoods, m)
}
if len(orderGoodErrMap) > 0 {
row := &domain.ImportInfo{
Error: lib.ThrowError(lib.BUSINESS_ERROR, "核算订单中商品错误"),
LineNumbers: cmd.LineNumbers, // 错误影响的行
GoodLine: orderGoodErrMap, // 错误产品行号记录
}
errorDataList = append(errorDataList, row)
continue
}
newOrder.Goods = orderGoods
err = newOrder.Compute()
if err != nil {
row := &domain.ImportInfo{
Error: lib.ThrowError(lib.INTERNAL_SERVER_ERROR, fmt.Sprintf("核算订单中合计的数值失败:%s", err)),
LineNumbers: cmd.LineNumbers, // 错误影响的行
GoodLine: map[int]interface{}{},
}
errorDataList = append(errorDataList, row)
continue
}
// 保存订单数据
err = orderBaseRepository.Save(newOrder)
if err != nil {
row := &domain.ImportInfo{
Error: lib.ThrowError(lib.INTERNAL_SERVER_ERROR, fmt.Sprintf("保存订单数据失败:%s", err)),
LineNumbers: cmd.LineNumbers, // 错误影响的行
GoodLine: map[int]interface{}{},
}
errorDataList = append(errorDataList, row)
continue
}
for i := range newOrder.Goods {
newOrder.Goods[i].OrderId = newOrder.Id
}
// 保存订单产品
err = orderGoodRepository.Save(orderGoods)
if err != nil {
row := &domain.ImportInfo{
Error: lib.ThrowError(lib.INTERNAL_SERVER_ERROR, fmt.Sprintf("保存订单中的商品数据失败:%s", err)),
LineNumbers: cmd.LineNumbers, // 错误影响的行
GoodLine: map[int]interface{}{},
}
errorDataList = append(errorDataList, row)
continue
}
newOrder.Goods = orderGoods
}
if len(errorDataList) == 0 {
// 完成事务
err = transactionContext.CommitTransaction()
if err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, err.Error())
}
return errorDataList, nil
}
return errorDataList, nil
}
/**
* @Author SteveChan
* @Description // 根据合伙人编号和合伙人类型获取合伙人id
* @Date 23:15 2021/1/6
* @Param
* @return
**/
func (service OrderInfoService) GetPartnerIdByCodeAndCategory(getPartnerIdQuery query.GetPartnerIdQuery) (*domain.PartnerInfo, error) {
// 事务初始化
var (
transactionContext, _ = factory.CreateTransactionContext(nil)
err error
partnerData *domain.PartnerInfo
)
// 开始事务
if err = transactionContext.StartTransaction(); err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, err.Error())
}
// 收尾
defer func() {
transactionContext.RollbackTransaction()
}()
// 仓储、数据访问对象初始化
var (
PartnerInfoRepository domain.PartnerInfoRepository
)
// 合伙人信息仓储初始化
if PartnerInfoRepository, err = factory.CreatePartnerInfoRepository(map[string]interface{}{
"transactionContext": transactionContext,
}); err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, err.Error())
}
//var partnerData *domain.PartnerInfo
partnerData, err = PartnerInfoRepository.FindOne(domain.PartnerFindOneQuery{
CompanyId: getPartnerIdQuery.CompanyId,
Code: getPartnerIdQuery.Code,
PartnerCategory: getPartnerIdQuery.PartnerCategory,
})
if err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, fmt.Sprintf("检索合伙人数据失败"))
}
// 完成事务
err = transactionContext.CommitTransaction()
if err != nil {
return nil, lib.ThrowError(lib.INTERNAL_SERVER_ERROR, err.Error())
}
return partnerData, nil
}
... ...
package command
import (
"errors"
"gitlab.fjmaimaimai.com/mmm-go/partnermg/pkg/domain"
)
type EditUserPermissionCommand struct {
Id int64 `json:"id"`
CompanyId int64 `json:"-"`
PermissionType []int64 `json:"permissionType"` //权限数据
CheckedPartner []int64 `json:"checkedPartner"` //可查看合伙人列表合伙人
IsSenior int8 `json:"isSenior"`
}
func (cmd EditUserPermissionCommand) Validate() error {
if cmd.IsSenior <= 0 {
return errors.New("是否是高管必填")
}
if !(cmd.IsSenior == domain.UserIsSeniorNo || cmd.IsSenior == domain.UserIsSeniorYes) {
return errors.New("是否是高管必填")
}
return nil
}
... ...
... ... @@ -280,7 +280,13 @@ func (service UsersService) GetUserList(queryOption query.UserListQuery) (int, [
return cnt, result, nil
}
//buildGetUserList 组装构建前端需要的用户列表数据
/**
* @Author SteveChan
* @Description // 组装构建前端需要的用户列表数据
* @Date 00:22 2021/1/8
* @Param
* @return
**/
func (service UsersService) buildGetUserList(usersData []domain.Users, permissionData []domain.AdminPermission) []map[string]interface{} {
result := make([]map[string]interface{}, 0, len(usersData))
permissionMap := map[int64]domain.AdminPermission{}
... ... @@ -313,9 +319,11 @@ func (service UsersService) buildGetUserList(usersData []domain.Users, permissio
"permission": permissionTypes,
"isAdmin": 0,
"partnership": len(usersData[i].AccessPartners),
"isSenior": usersData[i].IsSenior,
}
if usersData[i].IsSuperAdmin() {
m["isAdmin"] = 1
m["name"] = m["name"].(string) + "(管理员)"
}
result = append(result, m)
}
... ... @@ -383,6 +391,7 @@ func (service UsersService) buildGetUserData(userData *domain.Users, partnerList
"isAdmin": 0,
"status": 0,
"checkedPartner": []map[string]interface{}{},
"isSenior": userData.IsSenior,
}
if userData.IsSuperAdmin() {
result["isAdmin"] = 1
... ... @@ -433,6 +442,9 @@ func (service UsersService) EditUserPermission(cmd command.EditUserPermissionCom
transactionContext, _ = factory.CreateTransactionContext(nil)
err error
)
if err = cmd.Validate(); err != nil {
return lib.ThrowError(lib.BUSINESS_ERROR, err.Error())
}
if err = transactionContext.StartTransaction(); err != nil {
return lib.ThrowError(lib.TRANSACTION_ERROR, err.Error())
}
... ... @@ -501,9 +513,6 @@ func (service UsersService) EditUserPermission(cmd command.EditUserPermissionCom
partners = append(partners, p)
}
for i := range permissionList {
// if permissionList[i].Code == domain.PERMINSSION_ADMIN_USER && !usersData.IsSuperAdmin() {
// return lib.ThrowError(lib.BUSINESS_ERROR, "操作异常")
// }
p := domain.AdminPermissionBase{
Id: permissionList[i].Id,
Code: permissionList[i].Code,
... ... @@ -512,6 +521,7 @@ func (service UsersService) EditUserPermission(cmd command.EditUserPermissionCom
}
updateMap := map[string]interface{}{
"AccessPartners": partners,
"IsSenior": cmd.IsSenior,
}
if !usersData.IsSuperAdmin() {
updateMap["Permission"] = permissionsBase
... ...
... ... @@ -6,7 +6,8 @@ const SERVICE_NAME = "partnermg"
var LOG_LEVEL = "debug"
var LOG_File = "./logs/partnermg.log"
var IMPORT_EXCEL = "./download/订单数据模板.xlsx"
var Log_PREFIX = "[partnermg_dev]"
var (
UCENTER_HOST = "https://suplus-ucenter-test.fjmaimaimai.com" //统一用户中心地址
UCENTER_SECRET = "cykbjnfqgctn"
... ... @@ -18,6 +19,8 @@ var (
BUSINESS_ADMIN_HOST = "http://suplus-business-admin-test.fjmaimaimai.com" //企业平台的地址
)
var EXCEL_COLUMN = 12
func init() {
if os.Getenv("LOG_LEVEL") != "" {
LOG_LEVEL = os.Getenv("LOG_LEVEL")
... ... @@ -37,4 +40,7 @@ func init() {
if os.Getenv("BUSINESS_ADMIN_HOST") != "" {
BUSINESS_ADMIN_HOST = os.Getenv("BUSINESS_ADMIN_HOST")
}
if os.Getenv("Log_PREFIX") != "" {
Log_PREFIX = os.Getenv("Log_PREFIX")
}
}
... ...
... ... @@ -15,7 +15,7 @@ var KafkaCfg KafkaConfig
func init() {
KafkaCfg = KafkaConfig{
Servers: []string{"127.0.0.1:9092"},
ConsumerId: "partnermg_local",
ConsumerId: "partnermg_dev",
}
if os.Getenv("KAFKA_HOST") != "" {
kafkaHost := os.Getenv("KAFKA_HOST")
... ...
... ... @@ -35,7 +35,7 @@ type AdminUserFindOneQuery struct {
type AdminUserRepository interface {
Save(AdminUser) (*AdminUser, error)
FindOne(qureyOptions AdminUserFindOneQuery) (*AdminUser, error)
FindOne(queryOptions AdminUserFindOneQuery) (*AdminUser, error)
Find(queryOptions AdminUserFindQuery) ([]AdminUser, error)
CountAll(queryOption AdminUserFindQuery) (int, error)
}
... ...
... ... @@ -273,6 +273,7 @@ func (order *OrderBase) Compute() error {
if hasUsePartnerBonus {
order.OrderCompute.UsePartnerBonus, _ = usePartnerBonus.Round(2).BigFloat().Float64()
} else {
//订单中的货品列表中合伙人分成没有调整值的情况下,对订单的调整值设置为负值用以标识
order.OrderCompute.UsePartnerBonus = -1
}
if hasUseOrderAmount {
... ... @@ -283,6 +284,7 @@ func (order *OrderBase) Compute() error {
Div(decimal.NewFromInt(100)).
Round(2).BigFloat().Float64()
} else {
//订单中的货品列表中货品总金额没有调整值的情况下,对订单的调整值设置为负值用以标识
order.OrderCompute.UseOrderAmount = -1
order.OrderCompute.SalesmanBonus, _ = planOrderAmount.
Mul(decimal.NewFromFloat(order.SalesmanBonusPercent)).
... ... @@ -312,6 +314,17 @@ type OrderBaseFindQuery struct {
CompanyId int64
}
// 导入错误信息
type ImportInfo struct {
Error error
LineNumbers []int
GoodLine map[int]interface{}
}
// 导入产品错误信息
type GoodErrInfo struct {
}
type OrderBaseRepository interface {
Save(order *OrderBase) error
FindOne(qureyOptions OrderBaseFindOneQuery) (*OrderBase, error)
... ...
... ... @@ -84,6 +84,7 @@ type OrderGood struct {
CompanyId int64 `json:"companyId"`
//原因备注
RemarkReason OrderGoodRemarkReason `json:"remarkReason"`
//数据来源
DataFrom OrderDataFrom `json:"data_from"`
}
... ... @@ -322,7 +323,7 @@ func (good *OrderGood) Compute() error {
good.GoodCompute.PlanAmount, _ = planamount.Round(2).BigFloat().Float64()
good.GoodCompute.PlanPartnerBonus, _ = planPartnerBonus.Round(2).BigFloat().Float64()
if good.UseGoodNumber < 0 {
//没有出现数量调整
//没有出现数量调整,使用负值进行标记
good.GoodCompute.UsePartnerBonus = -1
good.GoodCompute.UseAmount = -1
} else {
... ...
... ... @@ -64,6 +64,8 @@ type PartnerFindOneQuery struct {
UserId int64
AccountEqual string
CompanyId int64
Code string // 合伙人编码
PartnerCategory int // 合伙人类型
}
type PartnerFindQuery struct {
... ...
... ... @@ -2,24 +2,30 @@ package domain
import "time"
//用户是否可用状态:【1:正常】【 2:禁用】
//Users.Status用户是否可用状态:【1:正常】【 2:禁用】
const (
userStatusUsable int8 = 1
userStatusUnusable int8 = 2
)
//用户是否是主管 :【1:是主管】【 2:不是主管】
//Users.ChargeStatus用户是否是主管 :【1:是主管】【 2:不是主管】
const (
UserIsCompanyCharge int8 = 1
UserIsNotCompanyCharge int8 = 2
)
//用户类型 1普通用户 2主管理员
//Users.AdminType 用户类型 1普通用户 2主管理员
const (
UserIsNotAdmin int8 = 1
UserIsAdmin int8 = 2
)
//Users.IsSenior 用户是否是公司高管【1:是】【2:否】
const (
UserIsSeniorYes int8 = 1
UserIsSeniorNo int8 = 2
)
//Users 企业平台的用户
type Users struct {
Id int64 //用户id
... ... @@ -38,11 +44,12 @@ type Users struct {
Avatar string ///头像
Remarks string //备注
ChargeStatus int8 //是否为当前公司主管 【1:是】【2:否】
CreateAt time.Time
UpdateAt time.Time
CreateAt time.Time //
UpdateAt time.Time //
Permission []AdminPermissionBase //权限
AccessPartners []Partner
AccessPartners []Partner //
AdminType int8 //是否是公司负责人,即超级管理员 1普通用户 2主管理员
IsSenior int8 //是否是公司高管【1:是】【2:否】;用于确定是否可以拥有“可查看的合伙人”
}
//IsUsable 用户是否可用
... ... @@ -71,6 +78,17 @@ func (u Users) HasPermissionByCode(code string) bool {
return false
}
func (u *Users) SetIsSenior(senior int8) {
switch senior {
case UserIsSeniorYes:
u.IsSenior = senior
case UserIsSeniorNo:
u.IsSenior = senior
u.AccessPartners = make([]Partner, 0)
default:
}
}
func (u *Users) Update(m map[string]interface{}) error {
if v, ok := m["CompanyId"]; ok {
u.CompanyId = v.(int64)
... ... @@ -126,6 +144,10 @@ func (u *Users) Update(m map[string]interface{}) error {
if v, ok := m["AdminType"]; ok {
u.AdminType = v.(int8)
}
if v, ok := m["IsSenior"]; ok {
senior := v.(int8)
u.SetIsSenior(senior)
}
return nil
}
... ...
... ... @@ -23,26 +23,20 @@ func NewOrderBaseDao(transactionContext *transaction.TransactionContext) (*Order
}
}
//OrderCodeExist 检查order_code是否重复
//
func (dao OrderBaseDao) OrderCodeExist(code string, partnerCategory int64, partnerId int64) (bool, error) {
//CheckOrderUnique 检查订单的是否已存在
//@companyId 公司id
//@orderCode 订单号
//@deliveryCode 发货单号
//@partnerCategoryCode 合伙人类型编号
func (dao OrderBaseDao) CheckOrderExist(companyId int64, orderCode string,
deliveryCode string, partnerCategory int64, partnerId int64, notId int64) (bool, error) {
tx := dao.transactionContext.GetDB()
m := &models.OrderBase{}
query := tx.Model(m).
Where("order_code=?", code).
query := tx.Model(&models.OrderBase{}).
Where("company_id=?", companyId).
Where("order_code=?", orderCode).
Where("partner_id=?", partnerId).
Where(`partner_category @>'{"id":?}'`, partnerCategory)
ok, err := query.Exists()
return ok, err
}
func (dao OrderBaseDao) DeliveryCodeExist(code string, companyId int64, notId ...int64) (bool, error) {
tx := dao.transactionContext.GetDB()
m := &models.OrderBase{}
query := tx.Model(m).Where("delivery_code=?", code).Where("company_id=?", companyId)
if len(notId) > 0 {
query = query.WhereIn("id not in(?)", notId)
}
Where(`partner_category @>'{"id":?}'`, partnerCategory).
Where("id<>?", notId)
ok, err := query.Exists()
return ok, err
}
... ... @@ -192,7 +186,7 @@ func (dao OrderBaseDao) OrderBonusListForExcel(companyId int64, orderType int, p
//@param partnerCategory 合伙人类型id
//@param updateTime 订单更新时间范围"[开始时间,结束时间]",时间格式"2006-01-02 15:04:05+07"
//@param createTime 订单的创建时间范围"[开始时间,结束时间]" 时间格式"2006-01-02 15:04:05+07"
func (dao OrderBaseDao) OrderListByCondition(companyId int64, orderType int, partnerOrCode string,
func (dao OrderBaseDao) OrderListByCondition(companyId int64, orderType int, partnerName string, orderCode string, deliveryCode string,
updateTime [2]string, createTime [2]string, partnerCategory int, limit, offset int) ([]models.OrderBase, int, error) {
tx := dao.transactionContext.GetDB()
var orders []models.OrderBase
... ... @@ -217,16 +211,25 @@ func (dao OrderBaseDao) OrderListByCondition(companyId int64, orderType int, par
if len(createTime[1]) > 0 {
query = query.Where(`order_base.create_time<=?`, createTime[1])
}
if len(partnerOrCode) > 0 {
if len(partnerName) > 0 {
query = query.Join("LEFT JOIN partner_info as p ON order_base.partner_id=p.id").
WhereGroup(func(q *orm.Query) (*orm.Query, error) {
q = q.WhereOr("order_base.order_code like ? ", "%"+partnerOrCode+"%").
WhereOr("order_base.delivery_code like ? ", "%"+partnerOrCode+"%").
WhereOr("p.partner_name like ? ", "%"+partnerOrCode+"%")
return q, nil
})
Where("p.partner_name like ? ", "%"+partnerName+"%")
}
if len(orderCode) > 0 {
query = query.Where("order_base.order_code like ? ", "%"+orderCode+"%")
}
if len(deliveryCode) > 0 {
query = query.Where("order_base.delivery_code like ? ", "%"+deliveryCode+"%")
}
//if len(partnerOrCode) > 0 {
// query = query.Join("LEFT JOIN partner_info as p ON order_base.partner_id=p.id").
// WhereGroup(func(q *orm.Query) (*orm.Query, error) {
// q = q.WhereOr("order_base.order_code like ? ", "%"+partnerOrCode+"%").
// WhereOr("order_base.delivery_code like ? ", "%"+partnerOrCode+"%").
// WhereOr("p.partner_name like ? ", "%"+partnerOrCode+"%")
// return q, nil
// })
//}
query = query.Order("order_base.create_time DESC").
Offset(offset).
Limit(limit)
... ... @@ -259,7 +262,7 @@ type CustomOrderListForExcel struct {
//@param partnerCategory 合伙人类型id
//@param updateTime 订单更新时间范围"[开始时间,结束时间]",时间格式"2006-01-02 15:04:05+07"
//@param createTime 订单的创建时间范围"[开始时间,结束时间]" 时间格式"2006-01-02 15:04:05+07"
func (dao OrderBaseDao) OrderListForExcel(companyId int64, partnerOrCode string,
func (dao OrderBaseDao) OrderListForExcel(companyId int64, partnerName string, orderCode string, deliveryCode string,
updateTime [2]string, createTime [2]string, partnerCategory int) (
result []CustomOrderListForExcel, err error) {
sqlstr := `
... ... @@ -275,12 +278,26 @@ func (dao OrderBaseDao) OrderListForExcel(companyId int64, partnerOrCode string,
WHERE 1=1 AND t1.order_type = 1 AND t1.company_id=?
`
params := []interface{}{companyId}
if len(partnerOrCode) > 0 {
like := "%" + partnerOrCode + "%"
params = append(params, like, like, like)
sqlstr += " AND (t1.order_code like ? OR t1.delivery_code like ? OR t2.partner_name like ? ) "
//if len(partnerOrCode) > 0 {
// like := "%" + partnerOrCode + "%"
// params = append(params, like, like, like)
// sqlstr += " AND (t1.order_code like ? OR t1.delivery_code like ? OR t2.partner_name like ? ) "
//}
if len(partnerName) > 0 {
like := "%" + partnerName + "%"
params = append(params, like)
sqlstr += ` AND t2.partner_name like ? `
}
if len(orderCode) > 0 {
like := "%" + orderCode + "%"
params = append(params, like)
sqlstr += ` AND t1.order_code like ? `
}
if len(deliveryCode) > 0 {
like := "%" + deliveryCode + "%"
params = append(params, like)
sqlstr += ` AND t1.delivery_code like ? `
}
if partnerCategory > 0 {
params = append(params, partnerCategory)
sqlstr += ` AND t1.partner_category@>'{"id":?}' `
... ...
... ... @@ -43,6 +43,6 @@ type OrderGood struct {
CompanyId int64
//原因备注
RemarkReason domain.OrderGoodRemarkReason ``
//数据来源
DataFrom domain.OrderDataFrom ``
}
... ...
... ... @@ -29,6 +29,7 @@ type Users struct {
ChargeStatus int8 //是否为当前公司主管 【1:是】【2:否】
Permission []domain.AdminPermissionBase //权限
AccessPartners []domain.Partner //可查看的合伙人
IsSenior int8 //是否是公司高管【1:是】【2:否】;用于确定是否可以拥有“可查看的合伙人”
CreateAt time.Time
UpdateAt time.Time
DeleteAt time.Time
... ...
... ... @@ -3,6 +3,7 @@ package repository
import (
"errors"
"fmt"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"gitlab.fjmaimaimai.com/mmm-go/partnermg/pkg/domain"
... ... @@ -47,7 +48,7 @@ func (repository *PartnerInfoRepository) Save(dm *domain.PartnerInfo) error {
Remark: dm.Remark,
}
if m.Id == 0 {
err = tx.Insert(m)
_, err = tx.Model(m).Insert()
dm.Partner.Id = m.Id
if err != nil {
return err
... ... @@ -81,6 +82,10 @@ func (repository *PartnerInfoRepository) FindOne(queryOptions domain.PartnerFind
hasCondition = true
query = query.Where("company_id=?", queryOptions.CompanyId)
}
if queryOptions.PartnerCategory > 0 && queryOptions.Code != "" { // 合伙人类型和编码判断
hasCondition = true
query = query.Where(`partner_category_infos@> '[{"id":?,"code":?}]'`, queryOptions.PartnerCategory, pg.Ident(queryOptions.Code))
}
if !hasCondition {
return nil, errors.New("FindOne 必须要有查询条件")
}
... ...
... ... @@ -48,6 +48,7 @@ func (repository UsersRepository) transformPgModelToDomainModel(m *models.Users)
Permission: m.Permission,
AccessPartners: m.AccessPartners,
AdminType: m.AdminType,
IsSenior: m.IsSenior,
}, nil
}
... ... @@ -78,6 +79,7 @@ func (reponsitory UsersRepository) Add(u *domain.Users) error {
Permission: u.Permission,
AccessPartners: u.AccessPartners,
AdminType: u.AdminType,
IsSenior: u.IsSenior,
}
_, err = tx.Model(m).Insert()
return err
... ... @@ -110,6 +112,7 @@ func (reponsitory UsersRepository) Edit(u *domain.Users) error {
Permission: u.Permission,
AccessPartners: u.AccessPartners,
AdminType: u.AdminType,
IsSenior: u.IsSenior,
}
_, err = tx.Model(m).WherePK().Update()
return err
... ... @@ -180,7 +183,8 @@ func (reponsitory UsersRepository) Find(queryOption domain.UsersFindQuery) (int,
usersReturn = make([]domain.Users, 0)
cnt int
)
query = query.Order("id DESC")
//query = query.Order("id DESC")
query = query.Order("admin_type DESC")
cnt, err = query.SelectAndCount()
if err != nil {
return 0, usersReturn, err
... ...
... ... @@ -31,3 +31,19 @@ func GenerateRangeNum(min, max int) int {
randNum := rand.Intn(max-min) + min
return randNum
}
/**
* @Author SteveChan
* @Description // 判断数组是否包含
* @Date 14:30 2021/1/6
* @Param
* @return
**/
func IsContain(items []string, item string) bool {
for _, eachItem := range items {
if eachItem == item {
return true
}
}
return false
}
... ...
package exceltool
import (
"io"
excelize "github.com/360EntSecGroup-Skylar/excelize/v2"
)
// ExcelListReader 读取基础excel表格,
// 指定读取的列表区域的第一行作为表头字段处理,表头字段唯一
type ExcelListReader struct {
RowStart int //从第几行开始,零值做为起始
RowEnd func(index int, rowsData []string) bool //第几行结束,
ColStart int //第几列开始,零值做为起始
ColEnd int //第几列结束,
Sheet string //获取的表格
}
func NewExcelListReader() *ExcelListReader {
rowEnd := func(index int, rowsData []string) bool {
var allEmpty bool = true
for _, v := range rowsData {
if allEmpty && len(v) > 0 {
allEmpty = false
break
}
}
return allEmpty
}
return &ExcelListReader{
RowEnd: rowEnd,
}
}
func (eRead ExcelListReader) OpenReader(r io.Reader) ([]map[string]string, error) {
xlsxFile, err := excelize.OpenReader(r)
if err != nil {
return nil, err
}
rows, err := xlsxFile.Rows(eRead.Sheet)
if err != nil {
return nil, err
}
var (
datas = make([]map[string]string, 0) //数据列表
listHead = make(map[int]string) //map[索引数字]列表头字符串
rowIndex int = 0
)
for rows.Next() {
cols, err := rows.Columns()
if err != nil {
return nil, err
}
if readEnd := eRead.RowEnd(rowIndex, cols); readEnd {
break
}
if rowIndex < eRead.RowStart {
rowIndex++
continue
}
listRowData := make(map[string]string)
for colK, colV := range cols {
if eRead.ColEnd != 0 && colK > eRead.ColEnd {
break
}
if colK < eRead.ColStart {
continue
}
if rowIndex == eRead.RowStart {
//指定的数据列表第一行作为列表头处理
listHead[colK] = colV
}
if rowIndex > eRead.RowStart {
//指定的数据列表第二行开始作为列表数据内容处理
headK := listHead[colK]
listRowData[headK] = colV
}
}
if rowIndex > eRead.RowStart {
//指定的数据列表第二行开始作为列表数据内容处理
datas = append(datas, listRowData)
}
rowIndex++
}
return datas, nil
}
... ...
... ... @@ -8,10 +8,10 @@ import (
)
func init() {
logs.SetLevel(logLevel(constant.LOG_LEVEL))
logs.SetLogFuncCall(false)
logs.SetLogger("file", getlogFileConfig())
logs.SetPrefix(constant.Log_PREFIX)
logs.Async()
logs.Async(2 * 1e3)
}
... ...
package controllers
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"github.com/beego/beego/v2/client/httplib"
"path"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/360EntSecGroup-Skylar/excelize/v2"
"gitlab.fjmaimaimai.com/mmm-go/partnermg/pkg/constant"
"github.com/astaxie/beego/logs"
orderCmd "gitlab.fjmaimaimai.com/mmm-go/partnermg/pkg/application/orderinfo/command"
orderQuery "gitlab.fjmaimaimai.com/mmm-go/partnermg/pkg/application/orderinfo/query"
orderService "gitlab.fjmaimaimai.com/mmm-go/partnermg/pkg/application/orderinfo/service"
"gitlab.fjmaimaimai.com/mmm-go/partnermg/pkg/domain"
"gitlab.fjmaimaimai.com/mmm-go/partnermg/pkg/infrastructure/utils"
"gitlab.fjmaimaimai.com/mmm-go/partnermg/pkg/lib"
"gitlab.fjmaimaimai.com/mmm-go/partnermg/pkg/lib/exceltool"
)
... ... @@ -62,6 +70,7 @@ func (postData *postPurposeOrderDetail) Valid() error {
}
if postData.PartnerId == 0 {
return lib.ThrowError(lib.ARG_ERROR, "合伙人信息必填")
}
if len(postData.OrderDist) == 0 {
return lib.ThrowError(lib.ARG_ERROR, "订单区域必填")
... ... @@ -141,10 +150,19 @@ func (postData *postOrderPurposeDelivery) Valid() error {
return nil
}
//PageListOrderReal 获取实发订单列表
/**
* @Author SteveChan
* @Description // 获取实发订单列表,修改搜索条件
* @Date 20:23 2021/1/10
* @Param
* @return
**/
func (c *OrderInfoController) PageListOrderReal() {
type Parameter struct {
SearchText string `json:"searchText"`
//SearchText string `json:"searchText"`
PartnerName string `json:"partnerName"` // 合伙人姓名
OrderCode string `json:"orderCode"` // 订单号
DeliveryCode string `json:"deliveryCode"` // 发货单号
PartnerCategory int `json:"PartnerCategory"`
PageSize int `json:"pageSize"`
PageNumber int `json:"pageNumber"`
... ... @@ -222,7 +240,10 @@ func (c *OrderInfoController) PageListOrderReal() {
companyId := c.GetUserCompany()
orderSrv := orderService.NewOrderInfoService(nil)
orderinfos, cnt, err := orderSrv.PageListOrderBase(orderQuery.ListOrderBaseQuery{
PartnerOrCode: param.SearchText,
//PartnerOrCode: param.SearchText,
PartnerName: param.PartnerName,
OrderCode: param.OrderCode,
DeliveryCode: param.DeliveryCode,
OrderType: domain.OrderReal,
Limit: param.PageSize,
Offset: (param.PageNumber - 1) * param.PageSize,
... ... @@ -506,7 +527,10 @@ func (c *OrderInfoController) RemoveOrderReal() {
//ListOrderForExcel excel 导出实际订单的列表
func (c *OrderInfoController) ListOrderForExcel() {
type Parameter struct {
SearchText string `json:"searchText"`
//SearchText string `json:"searchText"`
PartnerName string `json:"partnerName"` // 合伙人姓名
OrderCode string `json:"orderCode"` // 订单号
DeliveryCode string `json:"deliveryCode"` // 发货单号
PartnerCategory int `json:"PartnerCategory"`
UpdateTime []string `json:"updateTime"`
CreateTime []string `json:"createTime"`
... ... @@ -576,7 +600,10 @@ func (c *OrderInfoController) ListOrderForExcel() {
companyId := c.GetUserCompany()
orderSrv := orderService.NewOrderInfoService(nil)
orderinfos, columns, err := orderSrv.ListOrderForExcel(orderQuery.ListOrderBaseQuery{
PartnerOrCode: param.SearchText,
//PartnerOrCode: param.SearchText,
PartnerName: param.PartnerName,
OrderCode: param.OrderCode,
DeliveryCode: param.DeliveryCode,
OrderType: domain.OrderReal,
CompanyId: companyId,
PartnerCategory: param.PartnerCategory,
... ... @@ -603,3 +630,502 @@ func (c *OrderInfoController) ListOrderForExcel() {
c.ResponseExcelByFile(c.Ctx, excelMaker)
return
}
/**
* @Author SteveChan
* @Description // 下载导入模板
* @Date 16:48 2021/1/8
* @Param
* @return
**/
func (c *OrderInfoController) DownloadTemplate() {
type Parameter struct {
TYPE string `json:"type"`
}
var (
param Parameter
err error
)
if err = c.BindJsonData(&param); err != nil {
logs.Error(err)
c.ResponseError(errors.New("json数据解析失败"))
return
}
// 校验类型编码
if param.TYPE != "PARTNER_ORDER_FILE" {
c.ResponseError(errors.New("类型编码错误"))
}
// 获取导入模板
req := httplib.Get("http://suplus-file-dev.fjmaimaimai.com/upload/file/2021010803305336443.xlsx")
err = req.ToFile(constant.IMPORT_EXCEL)
if err != nil {
logs.Error("could not save to file: ", err)
}
// 返回字段定义
ret := map[string]interface{}{}
resp, err := req.Response()
if err != nil {
logs.Error("could not get response: ", err)
} else {
logs.Info(resp)
ret = map[string]interface{}{
"url": "http://" + c.Ctx.Request.Host + "/download/订单数据模板.xlsx",
}
c.ResponseData(ret)
}
}
/**
* @Author SteveChan
* @Description //TODO 导入excel订单
* @Date 10:52 2021/1/6
* @Param
* @return
**/
func (c *OrderInfoController) ImportOrderFromExcel() {
// 获取参数
typeCode := c.GetString("type")
file, h, _ := c.GetFile("file")
companyId := c.GetUserCompany()
// Json数据解析
//jsonMap := make(map[string]interface{})
//err := json.Unmarshal([]byte(where), &jsonMap)
//if err != nil {
// logs.Error(err)
// c.ResponseError(errors.New("json数据解析失败"))
//}
if typeCode != "PARTNER_ORDER_IMPORT" {
c.ResponseError(errors.New("类型编码错误"))
}
// 返回字段定义
ret := map[string]interface{}{}
// 返回信息表头定义 0: 订单号, 1: 发货单号, 2: 客户名称, 3: 订单区域, 4: 编号, 5: 合伙人, 6: 类型, 7: 业务抽成比例, 8: 产品名称, 9: 数量, 10: 单价, 11: 合伙人分红比例
var tableHeader = []string{"错误详情", "行号", "订单号", "发货单号", "客户名称", "订单区域", "编号", "合伙人", "类型", "业务抽成比例", "产品名称", "数量", "单价", "合伙人分红比例"}
// 文件后缀名校验
ext := path.Ext(h.Filename)
AllowExtMap := map[string]bool{
".xlsx": true,
}
if _, ok := AllowExtMap[ext]; !ok {
c.ResponseError(errors.New("文件后缀名不符合上传要求,请上传正确格式的文件"))
return
}
// 打开文件
xlsx, err := excelize.OpenReader(file)
if err != nil {
c.ResponseError(errors.New("文件打开失败,请确定文件能够正常打开"))
return
}
// 文件行数校验
rows, _ := xlsx.GetRows("工作表1")
if len(rows) > 303 {
c.ResponseError(errors.New("导入文件的行数超过300行,请调整行数后重新导入"))
return
}
// 数据行计数
rowCnt := 0
// 空文件校验
if len(rows) < 3 {
c.ResponseError(errors.New("导入的excel文件为空文件,请上传正确的文件"))
}
// 必填项校验
nullLine := make([]interface{}, 0)
nullFlag := false
for i, row := range rows {
if i > 2 && row != nil {
rowCnt++
if len(row) == constant.EXCEL_COLUMN { // 中间空字符校验
var tmpRow = row
var myRow []string
for j, cell := range row {
if j != 8 { // 业务员抽成比例非必填
if cell == "" || cell == " " { // 空字符串填充
tmpRow[j] = "null"
nullFlag = true
}
}
}
if nullFlag {
myRow = append(myRow, "必填项不能为空") // 错误信息
s := strconv.Itoa(i + 1)
myRow = append(myRow, s) // 行号
myRow = append(myRow, tmpRow...) // 错误行数据
nullLine = append(nullLine, myRow)
nullFlag = false
}
} else if len(row) > 0 && len(row) < constant.EXCEL_COLUMN { // 尾部空字符校验
var myRow []string
for i := 0; i < constant.EXCEL_COLUMN-len(row); i++ { // null补位
myRow = append(myRow, "null")
}
myRow = append(myRow, "必填项不能为空") // 错误信息
s := strconv.Itoa(i + 1)
myRow = append(myRow, s) // 行号
myRow = append(myRow, row...) // 错误行数据
nullLine = append(nullLine, myRow)
}
}
}
// 空单元格返回
if len(nullLine) > 0 {
ret = map[string]interface{}{
"successCount": 0,
"fail": map[string]interface{}{
"tableHeader": tableHeader,
"tableData": nullLine,
},
}
c.ResponseData(ret)
return
}
// 内容校验
errorLine := make([]interface{}, 0)
var partnerType = []string{"事业合伙", "业务合伙", "研发合伙", "业务-产品应用合伙"}
for i, row := range rows {
if i > 2 && row != nil && len(row) == constant.EXCEL_COLUMN { // 数据行
var myRow []string
for j, cell := range row {
switch j {
case 0, 1, 2, 3, 4, 5, 8: // 订单号、发货单号、客户名称、订单区域、编号、合伙人、产品名称长度校验
{
cellStr := strings.TrimSpace(cell)
lenCellStr := utf8.RuneCountInString(cellStr)
if lenCellStr > 50 {
var tmpRow []string
tmpRow = append(tmpRow, tableHeader[j+2]+"长度超过50位,请重新输入") // 错误信息
s := strconv.Itoa(i + 1)
tmpRow = append(tmpRow, s) // 行号
tmpRow = append(tmpRow, row...) // 错误行数据
myRow = tmpRow
}
}
case 6: // 合伙人类型校验(事业合伙、业务合伙、研发合伙、业务-产品应用合伙)
{
if !utils.IsContain(partnerType, cell) {
var tmpRow []string
tmpRow = append(tmpRow, "合伙人类型须为以下类型:事业合伙、业务合伙、研发合伙、业务-产品应用合伙") // 错误信息
s := strconv.Itoa(i + 1)
tmpRow = append(tmpRow, s) // 行号
tmpRow = append(tmpRow, row...) // 错误行数据
myRow = tmpRow
}
}
case 7: // 业务员抽成比例,非必填,精确到小数点后两位
{
var (
typeErrFlag bool
lenErrFlag bool
ratioErrFlag bool
)
if len(cell) > 0 {
// 参数类型转换
shareRatio, err := strconv.ParseFloat(cell, 64)
if err != nil {
typeErrFlag = true
}
// 比例不能超过100%
if shareRatio > 100 {
ratioErrFlag = true
}
// 长度校验
regexpStr := `^(100|[1-9]\d|\d)(.\d{1,2})?$`
ok := regexp.MustCompile(regexpStr).MatchString(cell)
if !ok {
lenErrFlag = true
}
if typeErrFlag || lenErrFlag || ratioErrFlag {
var tmpRow []string
tmpRow = append(tmpRow, "业务员抽成比例格式错误,请输入正确的业务员抽成比例比例,保留两位小数") // 错误信息
s := strconv.Itoa(i + 1)
tmpRow = append(tmpRow, s) // 行号
tmpRow = append(tmpRow, row...) // 错误行数据
myRow = tmpRow
typeErrFlag = false
lenErrFlag = false
ratioErrFlag = false
}
}
}
case 9: // 数量不超过16位正整数
{
var (
typeErrFlag bool
lenErrFlag bool
)
//参数类型转换
orderNum, err := strconv.ParseInt(cell, 10, 64)
if err != nil {
typeErrFlag = true
}
// 长度校验
if orderNum > 1e16 {
lenErrFlag = true
}
if typeErrFlag || lenErrFlag {
var tmpRow []string
tmpRow = append(tmpRow, "数量长度超过最大限制十六位整数,请重新填写") // 错误信息
s := strconv.Itoa(i + 1)
tmpRow = append(tmpRow, s) // 行号
tmpRow = append(tmpRow, row...) // 错误行数据
myRow = tmpRow
typeErrFlag = false
lenErrFlag = false
}
}
case 10: // 单价,精确到小数点后两位,小数点左侧最多可输入16位数字
{
// 参数类型转换
univalent, err := strconv.ParseFloat(cell, 64)
if err != nil {
var tmpRow []string
tmpRow = append(tmpRow, "单价格式错误,请输入正确的单价,保留两位小数点,小数点前面不能超过十六位数字") // 错误信息
s := strconv.Itoa(i + 1)
tmpRow = append(tmpRow, s) // 行号
tmpRow = append(tmpRow, row...) // 错误行数据
myRow = tmpRow
}
// 长度校验
if univalent >= 1e16 {
var tmpRow []string
tmpRow = append(tmpRow, "单价格式错误,请输入正确的单价,保留两位小数点,小数点前面不能超过十六位数字") // 错误信息
s := strconv.Itoa(i + 1)
tmpRow = append(tmpRow, s) // 行号
tmpRow = append(tmpRow, row...) // 错误行数据
myRow = tmpRow
}
}
case 11: // 合伙人分红比例,精确到小数点后两位
{
var (
typeErrFlag bool
lenErrFlag bool
ratioErrFlag bool
)
//参数类型转换
partnerRatio, err := strconv.ParseFloat(cell, 64)
if err != nil {
typeErrFlag = true
}
// 合伙人分红比例超额
if partnerRatio > 100 {
ratioErrFlag = true
}
// 长度判断
regexpStr := `^(100|[1-9]\d|\d)(.\d{1,2})?$`
ok := regexp.MustCompile(regexpStr).MatchString(cell)
if !ok {
lenErrFlag = true
}
if typeErrFlag || lenErrFlag || ratioErrFlag {
var tmpRow []string
tmpRow = append(tmpRow, "合伙人分红比例格式错误,请输入正确的合伙人分红比例,保留两位小数") // 错误信息
s := strconv.Itoa(i + 1)
tmpRow = append(tmpRow, s) // 行号
tmpRow = append(tmpRow, row...) // 错误行数据
myRow = tmpRow
typeErrFlag = false
lenErrFlag = false
ratioErrFlag = false
}
}
}
}
if myRow != nil {
errorLine = append(errorLine, myRow)
}
}
}
// 内容错误行返回
if len(errorLine) > 0 {
ret = map[string]interface{}{
"successCount": 0,
"fail": map[string]interface{}{
"tableHeader": tableHeader,
"tableData": errorLine,
},
}
c.ResponseData(ret)
return
}
// 创建订单服务
orderSrv := orderService.NewOrderInfoService(nil)
// 聚合订单产品
var orderCommands = make(map[string]*orderCmd.CreateOrderCommand, 0)
for i, row := range rows {
if i > 2 && len(row) == constant.EXCEL_COLUMN {
hashValue := md5.Sum([]byte(row[0] + row[1] + row[4] + row[6])) // 根据:订单号+发货单号+合伙人编号+合伙类型计算哈希值
hashString := hex.EncodeToString(hashValue[:])
if _, ok := orderCommands[hashString]; !ok {
//订单相关,0: 订单号, 1: 发货单号, 2: 客户名称, 3: 订单区域, 4: 编号, 5: 合伙人, 6: 类型, 7: 业务抽成比例,
sbPercent, _ := strconv.ParseFloat(row[7], 64) //业务抽成比例
//产品相关,8: 产品名称, 9: 数量, 10: 单价, 11: 合伙人分红比例
amount, _ := strconv.ParseInt(row[9], 10, 64) // 数量
price, _ := strconv.ParseFloat(row[10], 64) // 单价
percent, _ := strconv.ParseFloat(row[11], 64) // 合伙人分红比例
// 初始化建订单命令集
orderCommands[hashString] = &orderCmd.CreateOrderCommand{
OrderType: 0,
OrderCode: row[0],
DeliveryCode: row[1],
BuyerName: row[2],
OrderRegion: row[3],
PartnerId: 0, // 根据合伙人类型+合伙人编号查找合伙人id
SalesmanBonusPercent: sbPercent,
Goods: []orderCmd.OrderGoodData{
{
GoodName: row[8],
PlanGoodNumber: int(amount),
Price: price,
PartnerBonusPercent: percent,
LineNumber: i,
},
},
CompanyId: companyId,
PartnerCategory: 1,
LineNumbers: []int{i}, // 记录行号
}
// 获取partnerId
var partnerInfo *domain.PartnerInfo
partnerInfo, err = orderSrv.GetPartnerIdByCodeAndCategory(orderQuery.GetPartnerIdQuery{
Code: row[4],
PartnerCategory: 0,
CompanyId: companyId,
})
if err != nil {
}
if partnerInfo != nil {
orderCommands[hashString].PartnerId = partnerInfo.Partner.Id
// 1: 事业合伙、2: 业务合伙、3: 研发合伙、4: 业务-产品应用合伙
switch row[6] {
case "事业合伙":
partnerInfo.PartnerCategory = 1
case "业务合伙":
partnerInfo.PartnerCategory = 2
case "研发合伙":
partnerInfo.PartnerCategory = 3
case "业务-产品应用合伙":
partnerInfo.PartnerCategory = 4
}
}
} else {
//产品相关,8: 产品名称, 9: 数量, 10: 单价, 11: 合伙人分红比例
amount, _ := strconv.ParseInt(row[9], 10, 64) // 数量
price, _ := strconv.ParseFloat(row[10], 64) // 单价
percent, _ := strconv.ParseFloat(row[11], 64) // 合伙人分红比例
// 记录同一笔订单产品
orderCommands[hashString].Goods = append(orderCommands[hashString].Goods, orderCmd.OrderGoodData{
GoodName: row[8],
PlanGoodNumber: int(amount),
Price: price,
PartnerBonusPercent: percent,
LineNumber: i, // 记录行号
})
// 记录聚合行号
orderCommands[hashString].LineNumbers = append(orderCommands[hashString].LineNumbers, i)
}
}
}
// 批量创建订单命令集
var createOrderCommands []*orderCmd.CreateOrderCommand
for _, orderCommand := range orderCommands {
createOrderCommands = append(createOrderCommands, orderCommand)
}
// 新增失败记录
failureDataList := make([]interface{}, 0)
// 新增成功记录计数
var successDataCount int64
// 批量新增订单
errorDataList, createError := orderSrv.CreateNewOrderByImport(createOrderCommands)
if createError != nil {
c.ResponseError(createError)
return
} else {
if len(errorDataList) > 0 { // 导入失败返回
successDataCount = 0
// 错误记录处理
for _, errorData := range errorDataList {
if len(errorData.GoodLine) == 0 { // 订单错误
for _, line := range errorData.LineNumbers {
var tmpRow []string
tmpRow = append(tmpRow, errorData.Error.Error()) // 错误信息
s := strconv.Itoa(line + 1)
tmpRow = append(tmpRow, s) // 行号
tmpRow = append(tmpRow, rows[line]...) // 错误行数据
failureDataList = append(failureDataList, tmpRow)
}
} else if len(errorData.GoodLine) > 0 { // 订单产品错误
for line := range errorData.GoodLine {
var tmpRow []string
tmpRow = append(tmpRow, errorData.Error.Error()) // 错误信息
s := strconv.Itoa(line + 1)
tmpRow = append(tmpRow, s) // 行号
tmpRow = append(tmpRow, rows[line]...) // 错误行数据
failureDataList = append(failureDataList, tmpRow)
}
}
}
ret = map[string]interface{}{
"successCount": successDataCount,
"fail": map[string]interface{}{
"tableHeader": tableHeader,
"tableData": failureDataList,
},
}
} else { // 导入成功返回
successDataCount = int64(rowCnt - len(failureDataList))
if successDataCount == int64(rowCnt) {
ret = map[string]interface{}{
"successCount": successDataCount,
"fail": nil,
}
}
}
}
// 返回错误详情
c.ResponseData(ret)
return
}
... ...
... ... @@ -122,6 +122,7 @@ func (c *UserController) EditUserPermission() {
Id int64 `json:"id"`
PermissionType []int64 `json:"permissionType"`
CheckedPartner []int64 `json:"checkedPartner"` //合伙人
IsSenior int8 `json:"isSenior"` //是否是高管【1:是】【2:否】
}
var (
param UserDetailParam
... ... @@ -139,6 +140,7 @@ func (c *UserController) EditUserPermission() {
CompanyId: companyId,
PermissionType: param.PermissionType,
CheckedPartner: param.CheckedPartner,
IsSenior: param.IsSenior,
})
if err != nil {
c.ResponseError(err)
... ...
... ... @@ -6,6 +6,10 @@ import (
)
func init() {
// 导入相关
beego.Router("/fileImportTemplate", &controllers.OrderInfoController{}, "POST:DownloadTemplate") // 下载导入模板
beego.Router("/fileImport", &controllers.OrderInfoController{}, "POST:ImportOrderFromExcel") // 导入订单数据
adminRouter := beego.NewNamespace("/v1",
beego.NSNamespace("/auth",
beego.NSRouter("/login", &controllers.AdminLoginController{}, "POST:Login"),
... ... @@ -35,20 +39,18 @@ func init() {
beego.NSRouter("/list/excel", &controllers.OrderDividendController{}, "POST:ListOrderBonusForExcel"),
),
beego.NSNamespace("/order",
beego.NSRouter("/actual/list", &controllers.OrderInfoController{}, "POST:PageListOrderReal"),
beego.NSRouter("/actual/list/excel", &controllers.OrderInfoController{}, "POST:ListOrderForExcel"),
beego.NSRouter("/actual/detail", &controllers.OrderInfoController{}, "POST:GetOrderReal"),
beego.NSRouter("/actual/del", &controllers.OrderInfoController{}, "POST:RemoveOrderReal"),
beego.NSRouter("/actual/update", &controllers.OrderInfoController{}, "POST:UpdateOrderReal"),
beego.NSRouter("/actual/list", &controllers.OrderInfoController{}, "POST:PageListOrderReal"), // 返归订单列表
beego.NSRouter("/actual/list/excel", &controllers.OrderInfoController{}, "POST:ListOrderForExcel"), // 导出订单记录
beego.NSRouter("/actual/detail", &controllers.OrderInfoController{}, "POST:GetOrderReal"), // 查看实际订单详情
beego.NSRouter("/actual/del", &controllers.OrderInfoController{}, "POST:RemoveOrderReal"), // 删除实际订单
beego.NSRouter("/actual/update", &controllers.OrderInfoController{}, "POST:UpdateOrderReal"), // 新增实际订单
beego.NSRouter("/actual/close", &controllers.OrderInfoController{}, "POST:OrderDisable"),
),
beego.NSNamespace("/common",
beego.NSRouter("/partner", &controllers.CommonController{}, "POST:GetPartnerList"),
beego.NSRouter("/partnerType", &controllers.CommonController{}, "POST:GetPartnerCategory"),
beego.NSRouter("/orderType", &controllers.CommonController{}, "POST:GetOrderType"),
),
beego.NSNamespace("/enterprises",
beego.NSRouter("/setPhone", &controllers.CompanyController{}, "POST:SetPhone"),
),
... ...
... ... @@ -20,4 +20,7 @@ func init() {
http.ServeFile(ctx.ResponseWriter, ctx.Request, constant.LOG_File)
return
})
// 静态文件路径映射
beego.SetStaticPath("/download", "download")
}
... ...
此 diff 太大无法显示。
Copyright 2014 astaxie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
... ...
# httplib
httplib is an libs help you to curl remote url.
# How to use?
## GET
you can use Get to crawl data.
import "github.com/beego/beego/v2/httplib"
str, err := httplib.Get("http://beego.me/").String()
if err != nil {
// error
}
fmt.Println(str)
## POST
POST data to remote url
req := httplib.Post("http://beego.me/")
req.Param("username","astaxie")
req.Param("password","123456")
str, err := req.String()
if err != nil {
// error
}
fmt.Println(str)
## Set timeout
The default timeout is `60` seconds, function prototype:
SetTimeout(connectTimeout, readWriteTimeout time.Duration)
Example:
// GET
httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)
// POST
httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)
## Debug
If you want to debug the request info, set the debug on
httplib.Get("http://beego.me/").Debug(true)
## Set HTTP Basic Auth
str, err := Get("http://beego.me/").SetBasicAuth("user", "passwd").String()
if err != nil {
// error
}
fmt.Println(str)
## Set HTTPS
If request url is https, You can set the client support TSL:
httplib.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config
## Set HTTP Version
some servers need to specify the protocol version of HTTP
httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1")
## Set Cookie
some http request need setcookie. So set it like this:
cookie := &http.Cookie{}
cookie.Name = "username"
cookie.Value = "astaxie"
httplib.Get("http://beego.me/").SetCookie(cookie)
## Upload file
httplib support mutil file upload, use `req.PostFile()`
req := httplib.Post("http://beego.me/")
req.Param("username","astaxie")
req.PostFile("uploadfile1", "httplib.pdf")
str, err := req.String()
if err != nil {
// error
}
fmt.Println(str)
See godoc for further documentation and examples.
* [godoc.org/github.com/beego/beego/v2/httplib](https://godoc.org/github.com/beego/beego/v2/httplib)
... ...
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httplib
import (
"context"
"net/http"
)
type FilterChain func(next Filter) Filter
type Filter func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error)
... ...
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package httplib is used as http.Client
// Usage:
//
// import "github.com/beego/beego/v2/httplib"
//
// b := httplib.Post("http://beego.me/")
// b.Param("username","astaxie")
// b.Param("password","123456")
// b.PostFile("uploadfile1", "httplib.pdf")
// b.PostFile("uploadfile2", "httplib.txt")
// str, err := b.String()
// if err != nil {
// t.Fatal(err)
// }
// fmt.Println(str)
//
// more docs http://beego.me/docs/module/httplib.md
package httplib
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"encoding/json"
"encoding/xml"
"io"
"io/ioutil"
"log"
"mime/multipart"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httputil"
"net/url"
"os"
"path"
"strings"
"sync"
"time"
"gopkg.in/yaml.v2"
)
var defaultSetting = BeegoHTTPSettings{
UserAgent: "beegoServer",
ConnectTimeout: 60 * time.Second,
ReadWriteTimeout: 60 * time.Second,
Gzip: true,
DumpBody: true,
}
var defaultCookieJar http.CookieJar
var settingMutex sync.Mutex
// it will be the last filter and execute request.Do
var doRequestFilter = func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
return req.doRequest(ctx)
}
// createDefaultCookie creates a global cookiejar to store cookies.
func createDefaultCookie() {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultCookieJar, _ = cookiejar.New(nil)
}
// SetDefaultSetting overwrites default settings
func SetDefaultSetting(setting BeegoHTTPSettings) {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultSetting = setting
}
// NewBeegoRequest returns *BeegoHttpRequest with specific method
func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest {
var resp http.Response
u, err := url.Parse(rawurl)
if err != nil {
log.Println("Httplib:", err)
}
req := http.Request{
URL: u,
Method: method,
Header: make(http.Header),
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
}
return &BeegoHTTPRequest{
url: rawurl,
req: &req,
params: map[string][]string{},
files: map[string]string{},
setting: defaultSetting,
resp: &resp,
}
}
// Get returns *BeegoHttpRequest with GET method.
func Get(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "GET")
}
// Post returns *BeegoHttpRequest with POST method.
func Post(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "POST")
}
// Put returns *BeegoHttpRequest with PUT method.
func Put(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "PUT")
}
// Delete returns *BeegoHttpRequest DELETE method.
func Delete(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "DELETE")
}
// Head returns *BeegoHttpRequest with HEAD method.
func Head(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "HEAD")
}
// BeegoHTTPSettings is the http.Client setting
type BeegoHTTPSettings struct {
ShowDebug bool
UserAgent string
ConnectTimeout time.Duration
ReadWriteTimeout time.Duration
TLSClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error)
Transport http.RoundTripper
CheckRedirect func(req *http.Request, via []*http.Request) error
EnableCookie bool
Gzip bool
DumpBody bool
Retries int // if set to -1 means will retry forever
RetryDelay time.Duration
FilterChains []FilterChain
}
// BeegoHTTPRequest provides more useful methods than http.Request for requesting a url.
type BeegoHTTPRequest struct {
url string
req *http.Request
params map[string][]string
files map[string]string
setting BeegoHTTPSettings
resp *http.Response
body []byte
dump []byte
}
// GetRequest returns the request object
func (b *BeegoHTTPRequest) GetRequest() *http.Request {
return b.req
}
// Setting changes request settings
func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest {
b.setting = setting
return b
}
// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password.
func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest {
b.req.SetBasicAuth(username, password)
return b
}
// SetEnableCookie sets enable/disable cookiejar
func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest {
b.setting.EnableCookie = enable
return b
}
// SetUserAgent sets User-Agent header field
func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest {
b.setting.UserAgent = useragent
return b
}
// Debug sets show debug or not when executing request.
func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
b.setting.ShowDebug = isdebug
return b
}
// Retries sets Retries times.
// default is 0 (never retry)
// -1 retry indefinitely (forever)
// Other numbers specify the exact retry amount
func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest {
b.setting.Retries = times
return b
}
// RetryDelay sets the time to sleep between reconnection attempts
func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest {
b.setting.RetryDelay = delay
return b
}
// DumpBody sets the DumbBody field
func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
b.setting.DumpBody = isdump
return b
}
// DumpRequest returns the DumpRequest
func (b *BeegoHTTPRequest) DumpRequest() []byte {
return b.dump
}
// SetTimeout sets connect time out and read-write time out for BeegoRequest.
func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest {
b.setting.ConnectTimeout = connectTimeout
b.setting.ReadWriteTimeout = readWriteTimeout
return b
}
// SetTLSClientConfig sets TLS connection configuration if visiting HTTPS url.
func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest {
b.setting.TLSClientConfig = config
return b
}
// Header adds header item string in request.
func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest {
b.req.Header.Set(key, value)
return b
}
// SetHost set the request host
func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest {
b.req.Host = host
return b
}
// SetProtocolVersion sets the protocol version for incoming requests.
// Client requests always use HTTP/1.1.
func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
if len(vers) == 0 {
vers = "HTTP/1.1"
}
major, minor, ok := http.ParseHTTPVersion(vers)
if ok {
b.req.Proto = vers
b.req.ProtoMajor = major
b.req.ProtoMinor = minor
}
return b
}
// SetCookie adds a cookie to the request.
func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest {
b.req.Header.Add("Cookie", cookie.String())
return b
}
// SetTransport sets the transport field
func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest {
b.setting.Transport = transport
return b
}
// SetProxy sets the HTTP proxy
// example:
//
// func(req *http.Request) (*url.URL, error) {
// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
// return u, nil
// }
func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest {
b.setting.Proxy = proxy
return b
}
// SetCheckRedirect specifies the policy for handling redirects.
//
// If CheckRedirect is nil, the Client uses its default policy,
// which is to stop after 10 consecutive requests.
func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest {
b.setting.CheckRedirect = redirect
return b
}
// SetFilters will use the filter as the invocation filters
func (b *BeegoHTTPRequest) SetFilters(fcs ...FilterChain) *BeegoHTTPRequest {
b.setting.FilterChains = fcs
return b
}
// AddFilters adds filter
func (b *BeegoHTTPRequest) AddFilters(fcs ...FilterChain) *BeegoHTTPRequest {
b.setting.FilterChains = append(b.setting.FilterChains, fcs...)
return b
}
// Param adds query param in to request.
// params build query string as ?key1=value1&key2=value2...
func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest {
if param, ok := b.params[key]; ok {
b.params[key] = append(param, value)
} else {
b.params[key] = []string{value}
}
return b
}
// PostFile adds a post file to the request
func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest {
b.files[formname] = filename
return b
}
// Body adds request raw body.
// Supports string and []byte.
func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
switch t := data.(type) {
case string:
bf := bytes.NewBufferString(t)
b.req.Body = ioutil.NopCloser(bf)
b.req.ContentLength = int64(len(t))
case []byte:
bf := bytes.NewBuffer(t)
b.req.Body = ioutil.NopCloser(bf)
b.req.ContentLength = int64(len(t))
}
return b
}
// XMLBody adds the request raw body encoded in XML.
func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil {
byts, err := xml.Marshal(obj)
if err != nil {
return b, err
}
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
b.req.ContentLength = int64(len(byts))
b.req.Header.Set("Content-Type", "application/xml")
}
return b, nil
}
// YAMLBody adds the request raw body encoded in YAML.
func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil {
byts, err := yaml.Marshal(obj)
if err != nil {
return b, err
}
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
b.req.ContentLength = int64(len(byts))
b.req.Header.Set("Content-Type", "application/x+yaml")
}
return b, nil
}
// JSONBody adds the request raw body encoded in JSON.
func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil {
byts, err := json.Marshal(obj)
if err != nil {
return b, err
}
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
b.req.ContentLength = int64(len(byts))
b.req.Header.Set("Content-Type", "application/json")
}
return b, nil
}
func (b *BeegoHTTPRequest) buildURL(paramBody string) {
// build GET url with query string
if b.req.Method == "GET" && len(paramBody) > 0 {
if strings.Contains(b.url, "?") {
b.url += "&" + paramBody
} else {
b.url = b.url + "?" + paramBody
}
return
}
// build POST/PUT/PATCH url and body
if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil {
// with files
if len(b.files) > 0 {
pr, pw := io.Pipe()
bodyWriter := multipart.NewWriter(pw)
go func() {
for formname, filename := range b.files {
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
if err != nil {
log.Println("Httplib:", err)
}
fh, err := os.Open(filename)
if err != nil {
log.Println("Httplib:", err)
}
// iocopy
_, err = io.Copy(fileWriter, fh)
fh.Close()
if err != nil {
log.Println("Httplib:", err)
}
}
for k, v := range b.params {
for _, vv := range v {
bodyWriter.WriteField(k, vv)
}
}
bodyWriter.Close()
pw.Close()
}()
b.Header("Content-Type", bodyWriter.FormDataContentType())
b.req.Body = ioutil.NopCloser(pr)
b.Header("Transfer-Encoding", "chunked")
return
}
// with params
if len(paramBody) > 0 {
b.Header("Content-Type", "application/x-www-form-urlencoded")
b.Body(paramBody)
}
}
}
func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
if b.resp.StatusCode != 0 {
return b.resp, nil
}
resp, err := b.DoRequest()
if err != nil {
return nil, err
}
b.resp = resp
return resp, nil
}
// DoRequest executes client.Do
func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) {
return b.DoRequestWithCtx(context.Background())
}
func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Response, err error) {
root := doRequestFilter
if len(b.setting.FilterChains) > 0 {
for i := len(b.setting.FilterChains) - 1; i >= 0; i-- {
root = b.setting.FilterChains[i](root)
}
}
return root(ctx, b)
}
func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, err error) {
var paramBody string
if len(b.params) > 0 {
var buf bytes.Buffer
for k, v := range b.params {
for _, vv := range v {
buf.WriteString(url.QueryEscape(k))
buf.WriteByte('=')
buf.WriteString(url.QueryEscape(vv))
buf.WriteByte('&')
}
}
paramBody = buf.String()
paramBody = paramBody[0 : len(paramBody)-1]
}
b.buildURL(paramBody)
urlParsed, err := url.Parse(b.url)
if err != nil {
return nil, err
}
b.req.URL = urlParsed
trans := b.setting.Transport
if trans == nil {
// create default transport
trans = &http.Transport{
TLSClientConfig: b.setting.TLSClientConfig,
Proxy: b.setting.Proxy,
Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
MaxIdleConnsPerHost: 100,
}
} else {
// if b.transport is *http.Transport then set the settings.
if t, ok := trans.(*http.Transport); ok {
if t.TLSClientConfig == nil {
t.TLSClientConfig = b.setting.TLSClientConfig
}
if t.Proxy == nil {
t.Proxy = b.setting.Proxy
}
if t.Dial == nil {
t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout)
}
}
}
var jar http.CookieJar
if b.setting.EnableCookie {
if defaultCookieJar == nil {
createDefaultCookie()
}
jar = defaultCookieJar
}
client := &http.Client{
Transport: trans,
Jar: jar,
}
if b.setting.UserAgent != "" && b.req.Header.Get("User-Agent") == "" {
b.req.Header.Set("User-Agent", b.setting.UserAgent)
}
if b.setting.CheckRedirect != nil {
client.CheckRedirect = b.setting.CheckRedirect
}
if b.setting.ShowDebug {
dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody)
if err != nil {
log.Println(err.Error())
}
b.dump = dump
}
// retries default value is 0, it will run once.
// retries equal to -1, it will run forever until success
// retries is setted, it will retries fixed times.
// Sleeps for a 400ms between calls to reduce spam
for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ {
resp, err = client.Do(b.req)
if err == nil {
break
}
time.Sleep(b.setting.RetryDelay)
}
return resp, err
}
// String returns the body string in response.
// Calls Response inner.
func (b *BeegoHTTPRequest) String() (string, error) {
data, err := b.Bytes()
if err != nil {
return "", err
}
return string(data), nil
}
// Bytes returns the body []byte in response.
// Calls Response inner.
func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
if b.body != nil {
return b.body, nil
}
resp, err := b.getResponse()
if err != nil {
return nil, err
}
if resp.Body == nil {
return nil, nil
}
defer resp.Body.Close()
if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" {
reader, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, err
}
b.body, err = ioutil.ReadAll(reader)
return b.body, err
}
b.body, err = ioutil.ReadAll(resp.Body)
return b.body, err
}
// ToFile saves the body data in response to one file.
// Calls Response inner.
func (b *BeegoHTTPRequest) ToFile(filename string) error {
resp, err := b.getResponse()
if err != nil {
return err
}
if resp.Body == nil {
return nil
}
defer resp.Body.Close()
err = pathExistAndMkdir(filename)
if err != nil {
return err
}
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
}
// Check if the file directory exists. If it doesn't then it's created
func pathExistAndMkdir(filename string) (err error) {
filename = path.Dir(filename)
_, err = os.Stat(filename)
if err == nil {
return nil
}
if os.IsNotExist(err) {
err = os.MkdirAll(filename, os.ModePerm)
if err == nil {
return nil
}
}
return err
}
// ToJSON returns the map that marshals from the body bytes as json in response.
// Calls Response inner.
func (b *BeegoHTTPRequest) ToJSON(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
// ToXML returns the map that marshals from the body bytes as xml in response .
// Calls Response inner.
func (b *BeegoHTTPRequest) ToXML(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
}
return xml.Unmarshal(data, v)
}
// ToYAML returns the map that marshals from the body bytes as yaml in response .
// Calls Response inner.
func (b *BeegoHTTPRequest) ToYAML(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
}
return yaml.Unmarshal(data, v)
}
// Response executes request client gets response manually.
func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
return b.getResponse()
}
// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field.
func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
return func(netw, addr string) (net.Conn, error) {
conn, err := net.DialTimeout(netw, addr, cTimeout)
if err != nil {
return nil, err
}
err = conn.SetDeadline(time.Now().Add(rwTimeout))
return conn, err
}
}
... ...
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
The MIT License (MIT)
Copyright (c) 2015 codemodus
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# kace
go get "github.com/codemodus/kace"
Package kace provides common case conversion functions which take into
consideration common initialisms.
## Usage
```go
func Camel(s string) string
func Kebab(s string) string
func KebabUpper(s string) string
func Pascal(s string) string
func Snake(s string) string
func SnakeUpper(s string) string
type Kace
func New(initialisms map[string]bool) (*Kace, error)
func (k *Kace) Camel(s string) string
func (k *Kace) Kebab(s string) string
func (k *Kace) KebabUpper(s string) string
func (k *Kace) Pascal(s string) string
func (k *Kace) Snake(s string) string
func (k *Kace) SnakeUpper(s string) string
```
### Setup
```go
import (
"fmt"
"github.com/codemodus/kace"
)
func main() {
s := "this is a test sql."
fmt.Println(kace.Camel(s))
fmt.Println(kace.Pascal(s))
fmt.Println(kace.Snake(s))
fmt.Println(kace.SnakeUpper(s))
fmt.Println(kace.Kebab(s))
fmt.Println(kace.KebabUpper(s))
customInitialisms := map[string]bool{
"THIS": true,
}
k, err := kace.New(customInitialisms)
if err != nil {
// handle error
}
fmt.Println(k.Camel(s))
fmt.Println(k.Pascal(s))
fmt.Println(k.Snake(s))
fmt.Println(k.SnakeUpper(s))
fmt.Println(k.Kebab(s))
fmt.Println(k.KebabUpper(s))
// Output:
// thisIsATestSQL
// ThisIsATestSQL
// this_is_a_test_sql
// THIS_IS_A_TEST_SQL
// this-is-a-test-sql
// THIS-IS-A-TEST-SQL
// thisIsATestSql
// THISIsATestSql
// this_is_a_test_sql
// THIS_IS_A_TEST_SQL
// this-is-a-test-sql
// THIS-IS-A-TEST-SQL
}
```
## More Info
### TODO
#### Test Trie
Test the current trie.
## Documentation
View the [GoDoc](http://godoc.org/github.com/codemodus/kace)
## Benchmarks
benchmark iter time/iter bytes alloc allocs
--------- ---- --------- ----------- ------
BenchmarkCamel4 2000000 947.00 ns/op 112 B/op 3 allocs/op
BenchmarkSnake4 2000000 696.00 ns/op 128 B/op 2 allocs/op
BenchmarkSnakeUpper4 2000000 679.00 ns/op 128 B/op 2 allocs/op
BenchmarkKebab4 2000000 691.00 ns/op 128 B/op 2 allocs/op
BenchmarkKebabUpper4 2000000 677.00 ns/op 128 B/op 2 allocs/op
// Package kace provides common case conversion functions which take into
// consideration common initialisms.
package kace
import (
"fmt"
"strings"
"unicode"
"github.com/codemodus/kace/ktrie"
)
const (
kebabDelim = '-'
snakeDelim = '_'
none = rune(-1)
)
var (
ciTrie *ktrie.KTrie
)
func init() {
var err error
if ciTrie, err = ktrie.NewKTrie(ciMap); err != nil {
panic(err)
}
}
// Camel returns a camelCased string.
func Camel(s string) string {
return camelCase(ciTrie, s, false)
}
// Pascal returns a PascalCased string.
func Pascal(s string) string {
return camelCase(ciTrie, s, true)
}
// Kebab returns a kebab-cased string with all lowercase letters.
func Kebab(s string) string {
return delimitedCase(s, kebabDelim, false)
}
// KebabUpper returns a KEBAB-CASED string with all upper case letters.
func KebabUpper(s string) string {
return delimitedCase(s, kebabDelim, true)
}
// Snake returns a snake_cased string with all lowercase letters.
func Snake(s string) string {
return delimitedCase(s, snakeDelim, false)
}
// SnakeUpper returns a SNAKE_CASED string with all upper case letters.
func SnakeUpper(s string) string {
return delimitedCase(s, snakeDelim, true)
}
// Kace provides common case conversion methods which take into
// consideration common initialisms set by the user.
type Kace struct {
t *ktrie.KTrie
}
// New returns a pointer to an instance of kace loaded with a common
// initialsms trie based on the provided map. Before conversion to a
// trie, the provided map keys are all upper cased.
func New(initialisms map[string]bool) (*Kace, error) {
ci := initialisms
if ci == nil {
ci = map[string]bool{}
}
ci = sanitizeCI(ci)
t, err := ktrie.NewKTrie(ci)
if err != nil {
return nil, fmt.Errorf("kace: cannot create trie: %s", err)
}
k := &Kace{
t: t,
}
return k, nil
}
// Camel returns a camelCased string.
func (k *Kace) Camel(s string) string {
return camelCase(k.t, s, false)
}
// Pascal returns a PascalCased string.
func (k *Kace) Pascal(s string) string {
return camelCase(k.t, s, true)
}
// Snake returns a snake_cased string with all lowercase letters.
func (k *Kace) Snake(s string) string {
return delimitedCase(s, snakeDelim, false)
}
// SnakeUpper returns a SNAKE_CASED string with all upper case letters.
func (k *Kace) SnakeUpper(s string) string {
return delimitedCase(s, snakeDelim, true)
}
// Kebab returns a kebab-cased string with all lowercase letters.
func (k *Kace) Kebab(s string) string {
return delimitedCase(s, kebabDelim, false)
}
// KebabUpper returns a KEBAB-CASED string with all upper case letters.
func (k *Kace) KebabUpper(s string) string {
return delimitedCase(s, kebabDelim, true)
}
func camelCase(t *ktrie.KTrie, s string, ucFirst bool) string {
rs := []rune(s)
offset := 0
prev := none
for i := 0; i < len(rs); i++ {
r := rs[i]
switch {
case unicode.IsLetter(r):
ucCurr := isToBeUpper(r, prev, ucFirst)
if ucCurr || isSegmentStart(r, prev) {
prv, skip := updateRunes(rs, i, offset, t, ucCurr)
if skip > 0 {
i += skip - 1
prev = prv
continue
}
}
prev = updateRune(rs, i, offset, ucCurr)
continue
case unicode.IsNumber(r):
prev = updateRune(rs, i, offset, false)
continue
default:
prev = r
offset--
}
}
return string(rs[:len(rs)+offset])
}
func isToBeUpper(curr, prev rune, ucFirst bool) bool {
if prev == none {
return ucFirst
}
return isSegmentStart(curr, prev)
}
func isSegmentStart(curr, prev rune) bool {
if !unicode.IsLetter(prev) || unicode.IsUpper(curr) && unicode.IsLower(prev) {
return true
}
return false
}
func updateRune(rs []rune, i, offset int, upper bool) rune {
r := rs[i]
dest := i + offset
if dest < 0 || i > len(rs)-1 {
panic("this function has been used or designed incorrectly")
}
fn := unicode.ToLower
if upper {
fn = unicode.ToUpper
}
rs[dest] = fn(r)
return r
}
func updateRunes(rs []rune, i, offset int, t *ktrie.KTrie, upper bool) (rune, int) {
r := rs[i]
ns := nextSegment(rs, i)
ct := len(ns)
if ct < t.MinDepth() || ct > t.MaxDepth() || !t.FindAsUpper(ns) {
return r, 0
}
for j := i; j < i+ct; j++ {
r = updateRune(rs, j, offset, upper)
}
return r, ct
}
func nextSegment(rs []rune, i int) []rune {
for j := i; j < len(rs); j++ {
if !unicode.IsLetter(rs[j]) && !unicode.IsNumber(rs[j]) {
return rs[i:j]
}
if j == len(rs)-1 {
return rs[i : j+1]
}
}
return nil
}
func delimitedCase(s string, delim rune, upper bool) string {
buf := make([]rune, 0, len(s)*2)
for i := len(s); i > 0; i-- {
switch {
case unicode.IsLetter(rune(s[i-1])):
if i < len(s) && unicode.IsUpper(rune(s[i])) {
if i > 1 && unicode.IsLower(rune(s[i-1])) || i < len(s)-2 && unicode.IsLower(rune(s[i+1])) {
buf = append(buf, delim)
}
}
buf = appendCased(buf, upper, rune(s[i-1]))
case unicode.IsNumber(rune(s[i-1])):
if i == len(s) || i == 1 || unicode.IsNumber(rune(s[i])) {
buf = append(buf, rune(s[i-1]))
continue
}
buf = append(buf, delim, rune(s[i-1]))
default:
if i == len(s) {
continue
}
buf = append(buf, delim)
}
}
reverse(buf)
return string(buf)
}
func appendCased(rs []rune, upper bool, r rune) []rune {
if upper {
rs = append(rs, unicode.ToUpper(r))
return rs
}
rs = append(rs, unicode.ToLower(r))
return rs
}
func reverse(s []rune) {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
s[i], s[j] = s[j], s[i]
}
}
var (
// github.com/golang/lint/blob/master/lint.go
ciMap = map[string]bool{
"ACL": true,
"API": true,
"ASCII": true,
"CPU": true,
"CSS": true,
"DNS": true,
"EOF": true,
"GUID": true,
"HTML": true,
"HTTP": true,
"HTTPS": true,
"ID": true,
"IP": true,
"JSON": true,
"LHS": true,
"QPS": true,
"RAM": true,
"RHS": true,
"RPC": true,
"SLA": true,
"SMTP": true,
"SQL": true,
"SSH": true,
"TCP": true,
"TLS": true,
"TTL": true,
"UDP": true,
"UI": true,
"UID": true,
"UUID": true,
"URI": true,
"URL": true,
"UTF8": true,
"VM": true,
"XML": true,
"XMPP": true,
"XSRF": true,
"XSS": true,
}
)
func sanitizeCI(m map[string]bool) map[string]bool {
r := map[string]bool{}
for k := range m {
fn := func(r rune) rune {
if !unicode.IsLetter(r) && !unicode.IsNumber(r) {
return -1
}
return r
}
k = strings.Map(fn, k)
k = strings.ToUpper(k)
if k == "" {
continue
}
r[k] = true
}
return r
}
package ktrie
import "unicode"
// KNode ...
type KNode struct {
val rune
end bool
links []*KNode
}
// NewKNode ...
func NewKNode(val rune) *KNode {
return &KNode{
val: val,
links: make([]*KNode, 0),
}
}
// Add ...
func (n *KNode) Add(rs []rune) {
cur := n
for k, v := range rs {
link := cur.linkByVal(v)
if link == nil {
link = NewKNode(v)
cur.links = append(cur.links, link)
}
if k == len(rs)-1 {
link.end = true
}
cur = link
}
}
// Find ...
func (n *KNode) Find(rs []rune) bool {
cur := n
for _, v := range rs {
cur = cur.linkByVal(v)
if cur == nil {
return false
}
}
return cur.end
}
// FindAsUpper ...
func (n *KNode) FindAsUpper(rs []rune) bool {
cur := n
for _, v := range rs {
cur = cur.linkByVal(unicode.ToUpper(v))
if cur == nil {
return false
}
}
return cur.end
}
func (n *KNode) linkByVal(val rune) *KNode {
for _, v := range n.links {
if v.val == val {
return v
}
}
return nil
}
// KTrie ...
type KTrie struct {
*KNode
maxDepth int
minDepth int
}
// NewKTrie ...
func NewKTrie(data map[string]bool) (*KTrie, error) {
n := NewKNode(0)
maxDepth := 0
minDepth := 9001
for k := range data {
rs := []rune(k)
l := len(rs)
n.Add(rs)
if l > maxDepth {
maxDepth = l
}
if l < minDepth {
minDepth = l
}
}
t := &KTrie{
maxDepth: maxDepth,
minDepth: minDepth,
KNode: n,
}
return t, nil
}
// MaxDepth ...
func (t *KTrie) MaxDepth() int {
return t.maxDepth
}
// MinDepth ...
func (t *KTrie) MinDepth() int {
return t.minDepth
}
... ... @@ -11,3 +11,8 @@ linters:
- wsl
- funlen
- godox
- goerr113
- exhaustive
- nestif
- gofumpt
- goconst
... ...
semi: false
singleQuote: true
proseWrap: always
printWidth: 80
printWidth: 100
... ...
dist: xenial
sudo: false
language: go
addons:
postgresql: "9.6"
postgresql: '9.6'
go:
- 1.13.x
- 1.14.x
- 1.15.x
- tip
matrix:
allow_failures:
- go: tip
env:
- GO111MODULE=on
go_import_path: github.com/go-pg/pg
before_install:
- psql -U postgres -c "CREATE EXTENSION hstore"
- curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.24.0
- curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s --
-b $(go env GOPATH)/bin v1.28.3
... ...
# Changelog
## v10 (unreleased)
> :heart: [**Uptrace.dev** - distributed traces, logs, and errors in one place](https://uptrace.dev)
- Added `pgext.OpenTemetryHook` that adds OpenTelemetry
[instrumentation](https://pg.uptrace.dev/tracing/).
- Added `pgext.DebugHook` that logs queries and errors.
- Added `db.Ping` to check if database is healthy.
- Changed `pg.QueryHook` to return temp byte slice to reduce memory usage.
- `,msgpack` struct tag marshals data in MessagePack format using
https://github.com/vmihailenco/msgpack
- Deprecated types and funcs are removed.
## v9
- `pg:",notnull"` is reworked. Now it means SQL `NOT NULL` constraint and
nothing more.
- Added `pg:",use_zero"` to prevent go-pg from converting Go zero values to SQL
`NULL`.
- UpdateNotNull is renamed to UpdateNotZero. As previously it omits zero Go
values, but it does not take in account if field is nullable or not.
- ORM supports DistinctOn.
- Hooks accept and return context.
- Client respects Context.Deadline when setting net.Conn deadline.
- Client listens on Context.Done while waiting for a connection from the pool
and returns an error when context is cancelled.
- `Query.Column` does not accept relation name any more. Use `Query.Relation`
instead which returns an error if relation does not exist.
- urlvalues package is removed in favor of https://github.com/go-pg/urlstruct.
You can also use struct based filters via `Query.WhereStruct`.
- `NewModel` and `AddModel` methods of `HooklessModel` interface were renamed to
`NextColumnScanner` and `AddColumnScanner` respectively.
- `types.F` and `pg.F` are deprecated in favor of `pg.Ident`.
- `types.Q` is deprecated in favor of `pg.Safe`.
- `pg.Q` is deprecated in favor of `pg.SafeQuery`.
- `TableName` field is deprecated in favor of `tableName`.
- Always use `pg:"..."` struct field tag instead of `sql:"..."`.
- `pg:",override"` is deprecated in favor of `pg:",inherit"`.
## v8
- Added `QueryContext`, `ExecContext`, and `ModelContext` which accept
`context.Context`. Queries are cancelled when context is cancelled.
- Model hooks are changed to accept `context.Context` as first argument.
- Fixed array and hstore parsers to handle multiple single quotes (#1235).
## v7
- DB.OnQueryProcessed is replaced with DB.AddQueryHook.
- Added WhereStruct.
- orm.Pager is moved to urlvalues.Pager. Pager.FromURLValues returns an error if
page or limit params can't be parsed.
## v6.16
- Read buffer is re-worked. Default read buffer is increased to 65kb.
## v6.15
- Added Options.MinIdleConns.
- Options.MaxAge renamed to Options.MaxConnAge.
- PoolStats.FreeConns is renamed to PoolStats.IdleConns.
- New hook BeforeSelectQuery.
- `,override` is renamed to `,inherit`.
- Dialer.KeepAlive is set to 5 minutes by default.
- Added support "scram-sha-256" authentication.
## v6.14
- Fields ignored with `sql:"-"` tag are no longer considered by ORM relation
detector.
## v6.12
- `Insert`, `Update`, and `Delete` can return `pg.ErrNoRows` and
`pg.ErrMultiRows` when `Returning` is used and model expects single row.
## v6.11
- `db.Model(&strct).Update()` and `db.Model(&strct).Delete()` no longer adds
WHERE condition based on primary key when there are no conditions. Instead you
should use `db.Update(&strct)` or `db.Model(&strct).WherePK().Update()`.
## v6.10
- `?Columns` is renamed to `?TableColumns`. `?Columns` is changed to produce
column names without table alias.
## v6.9
- `pg:"fk"` tag now accepts SQL names instead of Go names, e.g.
`pg:"fk:ParentId"` becomes `pg:"fk:parent_id"`. Old code should continue
working in most cases, but it is strongly advised to start using new
convention.
- uint and uint64 SQL type is changed from decimal to bigint according to the
lesser of two evils principle. Use `sql:"type:decimal"` to get old behavior.
## v6.8
- `CreateTable` no longer adds ON DELETE hook by default. To get old behavior
users should add `sql:"on_delete:CASCADE"` tag on foreign key field.
## v6
- `types.Result` is renamed to `orm.Result`.
- Added `OnQueryProcessed` event that can be used to log / report queries
timing. Query logger is removed.
- `orm.URLValues` is renamed to `orm.URLFilters`. It no longer adds ORDER
clause.
- `orm.Pager` is renamed to `orm.Pagination`.
- Support for net.IP and net.IPNet.
- Support for context.Context.
- Bulk/multi updates.
- Query.WhereGroup for enclosing conditions in parentheses.
## v5
- All fields are nullable by default. `,null` tag is replaced with `,notnull`.
- `Result.Affected` renamed to `Result.RowsAffected`.
- Added `Result.RowsReturned`.
- `Create` renamed to `Insert`, `BeforeCreate` to `BeforeInsert`, `AfterCreate`
to `AfterInsert`.
- Indexed placeholders support, e.g. `db.Exec("SELECT ?0 + ?0", 1)`.
- Named placeholders are evaluated when query is executed.
- Added Update and Delete hooks.
- Order reworked to quote column names. OrderExpr added to bypass Order quoting
restrictions.
- Group reworked to quote column names. GroupExpr added to bypass Group quoting
restrictions.
## v4
- `Options.Host` and `Options.Port` merged into `Options.Addr`.
- Added `Options.MaxRetries`. Now queries are not retried by default.
- `LoadInto` renamed to `Scan`, `ColumnLoader` renamed to `ColumnScanner`,
LoadColumn renamed to ScanColumn, `NewRecord() interface{}` changed to
`NewModel() ColumnScanner`, `AppendQuery(dst []byte) []byte` changed to
`AppendValue(dst []byte, quote bool) ([]byte, error)`.
- Structs, maps and slices are marshalled to JSON by default.
- Added support for scanning slices, .e.g. scanning `[]int`.
- Added object relational mapping.
See https://pg.uptrace.dev/changelog/
... ...
all:
go test ./...
go test ./... -short -race
go test ./... -run=NONE -bench=. -benchmem
TZ= go test ./...
TZ= go test ./... -short -race
TZ= go test ./... -run=NONE -bench=. -benchmem
env GOOS=linux GOARCH=386 go test ./...
go vet
golangci-lint run
.PHONY: cleanTest
cleanTest:
docker rm -fv pg || true
.PHONY: pre-test
pre-test: cleanTest
docker run -d --name pg -p 5432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust postgres:9.6
sleep 10
docker exec pg psql -U postgres -c "CREATE EXTENSION hstore"
.PHONY: test
test: pre-test
TZ= PGSSLMODE=disable go test ./... -v
... ...
# PostgreSQL client and ORM for Golang
[![Build Status](https://travis-ci.org/go-pg/pg.svg?branch=master)](https://travis-ci.org/go-pg/pg)
[![GoDoc](https://godoc.org/github.com/go-pg/pg?status.svg)](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc)
[![Build Status](https://travis-ci.org/go-pg/pg.svg?branch=v10)](https://travis-ci.org/go-pg/pg)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/go-pg/pg/v10)](https://pkg.go.dev/github.com/go-pg/pg/v10)
[![Documentation](https://img.shields.io/badge/pg-documentation-informational)](https://pg.uptrace.dev/)
[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj)
- [Docs](https://pg.uptrace.dev)
> :heart: [**Uptrace.dev** - distributed traces, logs, and errors in one place](https://uptrace.dev)
- Join [Discord](https://discord.gg/rWtp5Aj) to ask questions.
- [Documentation](https://pg.uptrace.dev)
- [Reference](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc)
- [Examples](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#pkg-examples)
- Example projects:
- [treemux](https://github.com/uptrace/go-treemux-realworld-example-app)
- [gin](https://github.com/gogjango/gjango)
- [go-kit](https://github.com/Tsovak/rest-api-demo)
- [aah framework](https://github.com/kieusonlam/golamapi)
- [GraphQL Tutorial on YouTube](https://www.youtube.com/playlist?list=PLzQWIQOqeUSNwXcneWYJHUREAIucJ5UZn).
## Ecosystem
- Migrations by [vmihailenco](https://github.com/go-pg/migrations) and
[robinjoseph08](https://github.com/robinjoseph08/go-pg-migrations).
- [Genna - cli tool for generating go-pg models](https://github.com/dizzyfool/genna).
- [urlstruct](https://github.com/go-pg/urlstruct) to decode `url.Values` into structs.
- [Sharding](https://github.com/go-pg/sharding).
- [Model generator from SQL tables](https://github.com/dizzyfool/genna).
- [urlstruct](https://github.com/go-pg/urlstruct) to decode `url.Values` into
structs.
## Sponsors
- [**Uptrace.dev** - distributed traces and metrics](https://uptrace.dev)
## Features
... ... @@ -26,71 +32,200 @@
- sql.NullBool, sql.NullString, sql.NullInt64, sql.NullFloat64 and
[pg.NullTime](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#NullTime).
- [sql.Scanner](http://golang.org/pkg/database/sql/#Scanner) and
[sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer)
interfaces.
[sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer) interfaces.
- Structs, maps and arrays are marshalled as JSON by default.
- PostgreSQL multidimensional Arrays using
[array tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-PostgresArrayStructTag)
and
[Array wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Array).
and [Array wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Array).
- Hstore using
[hstore tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HstoreStructTag)
and
[Hstore wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Hstore).
and [Hstore wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Hstore).
- [Composite types](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CompositeType).
- All struct fields are nullable by default and zero values (empty string, 0,
zero time, empty map or slice, nil ptr) are marshalled as SQL `NULL`.
`pg:",notnull"` is used to add SQL `NOT NULL` constraint and `pg:",use_zero"`
to allow Go zero values.
- All struct fields are nullable by default and zero values (empty string, 0, zero time, empty map
or slice, nil ptr) are marshalled as SQL `NULL`. `pg:",notnull"` is used to add SQL `NOT NULL`
constraint and `pg:",use_zero"` to allow Go zero values.
- [Transactions](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Begin).
- [Prepared statements](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Prepare).
- [Notifications](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Listener)
using `LISTEN` and `NOTIFY`.
- [Copying data](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CopyFrom)
using `COPY FROM` and `COPY TO`.
- [Timeouts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#Options) and
canceling queries using context.Context.
- [Notifications](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Listener) using
`LISTEN` and `NOTIFY`.
- [Copying data](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CopyFrom) using
`COPY FROM` and `COPY TO`.
- [Timeouts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#Options) and canceling queries using
context.Context.
- Automatic connection pooling with
[circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern)
support.
[circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support.
- Queries retry on network errors.
- Working with models using
[ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model) and
[SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Query).
[ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model) and
[SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Query).
- Scanning variables using
[ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-SomeColumnsIntoVars)
[ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectSomeColumnsIntoVars)
and [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Scan).
- [SelectOrInsert](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-SelectOrInsert)
- [SelectOrInsert](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertSelectOrInsert)
using on-conflict.
- [INSERT ... ON CONFLICT DO UPDATE](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-OnConflictDoUpdate)
- [INSERT ... ON CONFLICT DO UPDATE](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertOnConflictDoUpdate)
using ORM.
- Bulk/batch
[inserts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-BulkInsert),
[updates](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Update-BulkUpdate),
and
[deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Delete-BulkDelete).
[inserts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkInsert),
[updates](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkUpdate), and
[deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkDelete).
- Common table expressions using
[WITH](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-With)
and
[WrapWith](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-WrapWith).
- [CountEstimate](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CountEstimate)
[WITH](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWith) and
[WrapWith](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWrapWith).
- [CountEstimate](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CountEstimate)
using `EXPLAIN` to get
[estimated number of matching rows](https://wiki.postgresql.org/wiki/Count_estimate).
- ORM supports
[has one](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HasOne),
[belongs to](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-BelongsTo),
[has many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HasMany),
and
[many to many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-ManyToMany)
[has one](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasOne),
[belongs to](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BelongsTo),
[has many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasMany), and
[many to many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ManyToMany)
with composite/multi-column primary keys.
- [Soft deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-SoftDelete).
- [Creating tables from structs](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CreateTable).
- [ForEach](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-ForEach)
that calls a function for each row returned by the query without loading all
rows into the memory.
- [Soft deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SoftDelete).
- [Creating tables from structs](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CreateTable).
- [ForEach](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ForEach) that calls
a function for each row returned by the query without loading all rows into the memory.
- Works with PgBouncer in transaction pooling mode.
## Installation
go-pg supports 2 last Go versions and requires a Go version with
[modules](https://github.com/golang/go/wiki/Modules) support. So make sure to initialize a Go
module:
```shell
go mod init github.com/my/repo
```
And then install go-pg (note _v10_ in the import; omitting it is a popular mistake):
```shell
go get github.com/go-pg/pg/v10
```
## Quickstart
```go
package pg_test
import (
"fmt"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
)
type User struct {
Id int64
Name string
Emails []string
}
func (u User) String() string {
return fmt.Sprintf("User<%d %s %v>", u.Id, u.Name, u.Emails)
}
type Story struct {
Id int64
Title string
AuthorId int64
Author *User `pg:"rel:has-one"`
}
func (s Story) String() string {
return fmt.Sprintf("Story<%d %s %s>", s.Id, s.Title, s.Author)
}
func ExampleDB_Model() {
db := pg.Connect(&pg.Options{
User: "postgres",
})
defer db.Close()
err := createSchema(db)
if err != nil {
panic(err)
}
user1 := &User{
Name: "admin",
Emails: []string{"admin1@admin", "admin2@admin"},
}
_, err = db.Model(user1).Insert()
if err != nil {
panic(err)
}
_, err = db.Model(&User{
Name: "root",
Emails: []string{"root1@root", "root2@root"},
}).Insert()
if err != nil {
panic(err)
}
story1 := &Story{
Title: "Cool story",
AuthorId: user1.Id,
}
_, err = db.Model(story1).Insert()
if err != nil {
panic(err)
}
// Select user by primary key.
user := &User{Id: user1.Id}
err = db.Model(user).WherePK().Select()
if err != nil {
panic(err)
}
// Select all users.
var users []User
err = db.Model(&users).Select()
if err != nil {
panic(err)
}
// Select story and associated author in one query.
story := new(Story)
err = db.Model(story).
Relation("Author").
Where("story.id = ?", story1.Id).
Select()
if err != nil {
panic(err)
}
fmt.Println(user)
fmt.Println(users)
fmt.Println(story)
// Output: User<1 admin [admin1@admin admin2@admin]>
// [User<1 admin [admin1@admin admin2@admin]> User<2 root [root1@root root2@root]>]
// Story<1 Cool story User<1 admin [admin1@admin admin2@admin]>>
}
// createSchema creates database schema for User and Story models.
func createSchema(db *pg.DB) error {
models := []interface{}{
(*User)(nil),
(*Story)(nil),
}
for _, model := range models {
err := db.Model(model).CreateTable(&orm.CreateTableOptions{
Temp: true,
})
if err != nil {
return err
}
}
return nil
}
```
## See also
- [Fast and flexible HTTP router](https://github.com/vmihailenco/treemux)
- [Golang msgpack](https://github.com/vmihailenco/msgpack)
- [Golang message task queue](https://github.com/vmihailenco/taskq)
... ...
... ... @@ -5,12 +5,13 @@ import (
"io"
"time"
"go.opentelemetry.io/otel/api/kv"
"go.opentelemetry.io/otel/api/trace"
"go.opentelemetry.io/otel/label"
"go.opentelemetry.io/otel/trace"
"github.com/go-pg/pg/v10/internal"
"github.com/go-pg/pg/v10/internal/pool"
"github.com/go-pg/pg/v10/orm"
"github.com/go-pg/pg/v10/types"
)
type baseDB struct {
... ... @@ -83,14 +84,14 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}
err = internal.WithSpan(ctx, "init_conn", func(ctx context.Context, span trace.Span) error {
err = internal.WithSpan(ctx, "pg.init_conn", func(ctx context.Context, span trace.Span) error {
return db.initConn(ctx, cn)
})
if err != nil {
db.pool.Remove(cn, err)
// It is safe to reset SingleConnPool if conn can't be initialized.
if p, ok := db.pool.(*pool.SingleConnPool); ok {
_ = p.Reset()
db.pool.Remove(ctx, cn, err)
// It is safe to reset StickyConnPool if conn can't be initialized.
if p, ok := db.pool.(*pool.StickyConnPool); ok {
_ = p.Reset(ctx)
}
if err := internal.Unwrap(err); err != nil {
return nil, err
... ... @@ -101,45 +102,44 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}
func (db *baseDB) initConn(c context.Context, cn *pool.Conn) error {
func (db *baseDB) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
return nil
}
cn.Inited = true
if db.opt.TLSConfig != nil {
err := db.enableSSL(c, cn, db.opt.TLSConfig)
err := db.enableSSL(ctx, cn, db.opt.TLSConfig)
if err != nil {
return err
}
}
err := db.startup(c, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName)
err := db.startup(ctx, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName)
if err != nil {
return err
}
if db.opt.OnConnect != nil {
p := pool.NewSingleConnPool(nil)
p.SetConn(cn)
return db.opt.OnConnect(newConn(c, db.withPool(p)))
p := pool.NewSingleConnPool(db.pool, cn)
return db.opt.OnConnect(ctx, newConn(ctx, db.withPool(p)))
}
return nil
}
func (db *baseDB) releaseConn(cn *pool.Conn, err error) {
func (db *baseDB) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
if isBadConn(err, false) {
db.pool.Remove(cn, err)
db.pool.Remove(ctx, cn, err)
} else {
db.pool.Put(cn)
db.pool.Put(ctx, cn)
}
}
func (db *baseDB) withConn(
ctx context.Context, fn func(context.Context, *pool.Conn) error,
) error {
return internal.WithSpan(ctx, "with_conn", func(ctx context.Context, span trace.Span) error {
return internal.WithSpan(ctx, "pg.with_conn", func(ctx context.Context, span trace.Span) error {
cn, err := db.getConn(ctx)
if err != nil {
return err
... ... @@ -154,7 +154,7 @@ func (db *baseDB) withConn(
case <-ctx.Done():
err := db.cancelRequest(cn.ProcessID, cn.SecretKey)
if err != nil {
internal.Logger.Printf("cancelRequest failed: %s", err)
internal.Logger.Printf(ctx, "cancelRequest failed: %s", err)
}
// Signal end of conn use.
fnDone <- struct{}{}
... ... @@ -169,7 +169,7 @@ func (db *baseDB) withConn(
case fnDone <- struct{}{}: // signal fn finish, skip cancel goroutine
}
}
db.releaseConn(cn, err)
db.releaseConn(ctx, cn, err)
}()
err = fn(ctx, cn)
... ... @@ -179,9 +179,12 @@ func (db *baseDB) withConn(
func (db *baseDB) shouldRetry(err error) bool {
switch err {
case io.EOF, io.ErrUnexpectedEOF:
return true
case nil, context.Canceled, context.DeadlineExceeded:
return false
}
if pgerr, ok := err.(Error); ok {
switch pgerr.Field('C') {
case "40001", // serialization_failure
... ... @@ -194,7 +197,12 @@ func (db *baseDB) shouldRetry(err error) bool {
return false
}
}
return isNetworkError(err)
if _, ok := err.(timeoutError); ok {
return true
}
return false
}
// Close closes the database client, releasing any open resources.
... ... @@ -233,9 +241,9 @@ func (db *baseDB) exec(ctx context.Context, query interface{}, params ...interfa
for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ {
attempt := attempt
lastErr = internal.WithSpan(ctx, "exec", func(ctx context.Context, span trace.Span) error {
lastErr = internal.WithSpan(ctx, "pg.exec", func(ctx context.Context, span trace.Span) error {
if attempt > 0 {
span.SetAttributes(kv.Int("retry", attempt))
span.SetAttributes(label.Int("retry", attempt))
if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil {
return err
... ... @@ -311,9 +319,9 @@ func (db *baseDB) query(ctx context.Context, model, query interface{}, params ..
for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ {
attempt := attempt
lastErr = internal.WithSpan(ctx, "query", func(ctx context.Context, span trace.Span) error {
lastErr = internal.WithSpan(ctx, "pg.query", func(ctx context.Context, span trace.Span) error {
if attempt > 0 {
span.SetAttributes(kv.Int("retry", attempt))
span.SetAttributes(label.Int("retry", attempt))
if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil {
return err
... ... @@ -373,7 +381,7 @@ func (db *baseDB) CopyFrom(r io.Reader, query interface{}, params ...interface{}
return res, err
}
// TODO: don't get/put conn in the pool
// TODO: don't get/put conn in the pool.
func (db *baseDB) copyFrom(
ctx context.Context, cn *pool.Conn, r io.Reader, query interface{}, params ...interface{},
) (res Result, err error) {
... ... @@ -396,6 +404,7 @@ func (db *baseDB) copyFrom(
return nil, err
}
// Note that afterQuery uses the err.
defer func() {
if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil {
err = afterQueryErr
... ... @@ -434,7 +443,7 @@ func (db *baseDB) copyFrom(
return nil, err
}
err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
res, err = readReadyForQuery(rd)
return err
})
... ... @@ -456,7 +465,7 @@ func (db *baseDB) CopyTo(w io.Writer, query interface{}, params ...interface{})
}
func (db *baseDB) copyTo(
c context.Context, cn *pool.Conn, w io.Writer, query interface{}, params ...interface{},
ctx context.Context, cn *pool.Conn, w io.Writer, query interface{}, params ...interface{},
) (res Result, err error) {
var evt *QueryEvent
... ... @@ -472,25 +481,26 @@ func (db *baseDB) copyTo(
model, _ = params[len(params)-1].(orm.TableModel)
}
c, evt, err = db.beforeQuery(c, db.db, model, query, params, wb.Query())
ctx, evt, err = db.beforeQuery(ctx, db.db, model, query, params, wb.Query())
if err != nil {
return nil, err
}
// Note that afterQuery uses the err.
defer func() {
if afterQueryErr := db.afterQuery(c, evt, res, err); afterQueryErr != nil {
if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil {
err = afterQueryErr
}
}()
err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
return writeQueryMsg(wb, db.fmter, query, params...)
})
if err != nil {
return nil, err
}
err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
err := readCopyOutResponse(rd)
if err != nil {
return err
... ... @@ -522,52 +532,6 @@ func (db *baseDB) ModelContext(c context.Context, model ...interface{}) *orm.Que
return orm.NewQueryContext(c, db.db, model...)
}
// Select selects the model by primary key.
func (db *baseDB) Select(model interface{}) error {
return orm.Select(db.db, model)
}
// Insert inserts the model updating primary keys if they are empty.
func (db *baseDB) Insert(model ...interface{}) error {
return orm.Insert(db.db, model...)
}
// Update updates the model by primary key.
func (db *baseDB) Update(model interface{}) error {
return orm.Update(db.db, model)
}
// Delete deletes the model by primary key.
func (db *baseDB) Delete(model interface{}) error {
return orm.Delete(db.db, model)
}
// Delete forces delete of the model with deleted_at column.
func (db *baseDB) ForceDelete(model interface{}) error {
return orm.ForceDelete(db.db, model)
}
// CreateTable creates table for the model. It recognizes following field tags:
// - notnull - sets NOT NULL constraint.
// - unique - sets UNIQUE constraint.
// - default:value - sets default value.
func (db *baseDB) CreateTable(model interface{}, opt *orm.CreateTableOptions) error {
return orm.CreateTable(db.db, model, opt)
}
// DropTable drops table for the model.
func (db *baseDB) DropTable(model interface{}, opt *orm.DropTableOptions) error {
return orm.DropTable(db.db, model, opt)
}
func (db *baseDB) CreateComposite(model interface{}, opt *orm.CreateCompositeOptions) error {
return orm.CreateComposite(db.db, model, opt)
}
func (db *baseDB) DropComposite(model interface{}, opt *orm.DropCompositeOptions) error {
return orm.DropComposite(db.db, model, opt)
}
func (db *baseDB) Formatter() orm.QueryFormatter {
return db.fmter
}
... ... @@ -597,7 +561,7 @@ func (db *baseDB) simpleQuery(
}
var res *result
if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
var err error
res, err = readSimpleQuery(rd)
return err
... ... @@ -616,7 +580,7 @@ func (db *baseDB) simpleQueryData(
}
var res *result
if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
var err error
res, err = readSimpleQueryData(c, rd, model)
return err
... ... @@ -631,12 +595,12 @@ func (db *baseDB) simpleQueryData(
// executions. Multiple queries or executions may be run concurrently
// from the returned statement.
func (db *baseDB) Prepare(q string) (*Stmt, error) {
return prepareStmt(db.withPool(pool.NewSingleConnPool(db.pool)), q)
return prepareStmt(db.withPool(pool.NewStickyConnPool(db.pool)), q)
}
func (db *baseDB) prepare(
c context.Context, cn *pool.Conn, q string,
) (string, [][]byte, error) {
) (string, []types.ColumnInfo, error) {
name := cn.NextID()
err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
writeParseDescribeSyncMsg(wb, name, q)
... ... @@ -646,8 +610,8 @@ func (db *baseDB) prepare(
return "", nil, err
}
var columns [][]byte
err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
var columns []types.ColumnInfo
err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
columns, err = readParseDescribeSync(rd)
return err
})
... ...
... ... @@ -75,12 +75,12 @@ func (db *DB) WithParam(param string, value interface{}) *DB {
}
// Listen listens for notifications sent with NOTIFY command.
func (db *DB) Listen(channels ...string) *Listener {
func (db *DB) Listen(ctx context.Context, channels ...string) *Listener {
ln := &Listener{
db: db,
}
ln.init()
_ = ln.Listen(channels...)
_ = ln.Listen(ctx, channels...)
return ln
}
... ... @@ -105,7 +105,7 @@ var _ orm.DB = (*Conn)(nil)
// Every Conn must be returned to the database pool after use by
// calling Conn.Close.
func (db *DB) Conn() *Conn {
return newConn(db.ctx, db.baseDB.withPool(pool.NewSingleConnPool(db.pool)))
return newConn(db.ctx, db.baseDB.withPool(pool.NewStickyConnPool(db.pool)))
}
func newConn(ctx context.Context, baseDB *baseDB) *Conn {
... ...
package pg
import (
"io"
"net"
"github.com/go-pg/pg/v10/internal"
... ... @@ -22,10 +21,10 @@ var ErrMultiRows = internal.ErrMultiRows
type Error interface {
error
// Field returns a string value associated with an error code.
// Field returns a string value associated with an error field.
//
// https://www.postgresql.org/docs/10/static/protocol-error-fields.html
Field(byte) string
Field(field byte) string
// IntegrityViolation reports whether an error is a part of
// Integrity Constraint Violation class of errors.
... ... @@ -43,21 +42,19 @@ func isBadConn(err error, allowTimeout bool) bool {
if _, ok := err.(internal.Error); ok {
return false
}
if pgErr, ok := err.(Error); ok && pgErr.Field('S') != "FATAL" {
return false
if pgErr, ok := err.(Error); ok {
return pgErr.Field('S') == "FATAL"
}
if allowTimeout {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return false
return !netErr.Temporary()
}
}
return true
}
func isNetworkError(err error) bool {
if err == io.EOF {
return true
}
_, ok := err.(net.Error)
return ok
//------------------------------------------------------------------------------
type timeoutError interface {
Timeout() bool
}
... ...
... ... @@ -3,25 +3,24 @@ module github.com/go-pg/pg/v10
go 1.11
require (
github.com/go-pg/pg/v9 v9.1.6 // indirect
github.com/go-pg/urlstruct v0.4.0
github.com/go-pg/zerochecker v0.1.1
github.com/golang/protobuf v1.4.2 // indirect
github.com/go-pg/zerochecker v0.2.0
github.com/golang/protobuf v1.4.3 // indirect
github.com/jinzhu/inflection v1.0.0
github.com/onsi/ginkgo v1.10.1
github.com/onsi/gomega v1.7.0
github.com/segmentio/encoding v0.1.13
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/onsi/ginkgo v1.14.2
github.com/onsi/gomega v1.10.3
github.com/stretchr/testify v1.6.1
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
github.com/vmihailenco/bufpool v0.1.11
github.com/vmihailenco/msgpack/v5 v5.0.0-beta.1
github.com/vmihailenco/tagparser v0.1.1
go.opentelemetry.io/otel v0.6.0
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 // indirect
golang.org/x/net v0.0.0-20200602114024-627f9648deb9 // indirect
golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 // indirect
google.golang.org/appengine v1.6.6 // indirect
google.golang.org/grpc v1.29.1
google.golang.org/protobuf v1.24.0 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15
github.com/vmihailenco/msgpack/v4 v4.3.11 // indirect
github.com/vmihailenco/msgpack/v5 v5.0.0
github.com/vmihailenco/tagparser v0.1.2
go.opentelemetry.io/otel v0.14.0
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 // indirect
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b // indirect
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.25.0 // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f
mellium.im/sasl v0.2.1
)
... ...
... ... @@ -8,15 +8,17 @@ import (
"github.com/go-pg/pg/v10/orm"
)
type BeforeScanHook = orm.BeforeScanHook
type AfterScanHook = orm.AfterScanHook
type AfterSelectHook = orm.AfterSelectHook
type BeforeInsertHook = orm.BeforeInsertHook
type AfterInsertHook = orm.AfterInsertHook
type BeforeUpdateHook = orm.BeforeUpdateHook
type AfterUpdateHook = orm.AfterUpdateHook
type BeforeDeleteHook = orm.BeforeDeleteHook
type AfterDeleteHook = orm.AfterDeleteHook
type (
BeforeScanHook = orm.BeforeScanHook
AfterScanHook = orm.AfterScanHook
AfterSelectHook = orm.AfterSelectHook
BeforeInsertHook = orm.BeforeInsertHook
AfterInsertHook = orm.AfterInsertHook
BeforeUpdateHook = orm.BeforeUpdateHook
AfterUpdateHook = orm.AfterUpdateHook
BeforeDeleteHook = orm.BeforeDeleteHook
AfterDeleteHook = orm.AfterDeleteHook
)
//------------------------------------------------------------------------------
... ... @@ -94,11 +96,14 @@ func (db *baseDB) beforeQuery(
fmtedQuery: fmtedQuery,
}
for _, hook := range db.queryHooks {
for i, hook := range db.queryHooks {
var err error
ctx, err = hook.BeforeQuery(ctx, event)
if err != nil {
return nil, nil, err
if err := db.afterQueryFromIndex(ctx, event, i); err != nil {
return ctx, nil, err
}
return ctx, nil, err
}
}
... ... @@ -117,14 +122,15 @@ func (db *baseDB) afterQuery(
event.Err = err
event.Result = res
return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1)
}
for _, hook := range db.queryHooks {
err := hook.AfterQuery(ctx, event)
if err != nil {
func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error {
for ; hookIndex >= 0; hookIndex-- {
if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil {
return err
}
}
return nil
}
... ...
... ... @@ -4,8 +4,10 @@ import (
"fmt"
)
var ErrNoRows = Errorf("pg: no rows in result set")
var ErrMultiRows = Errorf("pg: multiple rows in result set")
var (
ErrNoRows = Errorf("pg: no rows in result set")
ErrMultiRows = Errorf("pg: multiple rows in result set")
)
type Error struct {
s string
... ...
... ... @@ -8,20 +8,20 @@ import (
"time"
)
// Retry backoff with jitter sleep to prevent overloaded conditions during intervals
// https://www.awsarchitectureblog.com/2015/03/backoff.html
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
if retry < 0 {
retry = 0
panic("not reached")
}
backoff := minBackoff << uint(retry)
if backoff > maxBackoff || backoff < minBackoff {
backoff = maxBackoff
if minBackoff == 0 {
return 0
}
if backoff == 0 {
return 0
d := minBackoff << uint(retry)
d = minBackoff + time.Duration(rand.Int63n(int64(d)))
if d > maxBackoff || d < minBackoff {
d = maxBackoff
}
return time.Duration(rand.Int63n(int64(backoff)))
return d
}
... ...
package internal
import (
"context"
"fmt"
"log"
"os"
)
var Logger = log.New(os.Stderr, "pg: ", log.LstdFlags|log.Lshortfile)
var Warn = log.New(os.Stderr, "WARN: pg: ", log.LstdFlags)
var Deprecated = log.New(os.Stderr, "DEPRECATED: pg: ", log.LstdFlags)
type Logging interface {
Printf(ctx context.Context, format string, v ...interface{})
}
type logger struct {
log *log.Logger
}
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
_ = l.log.Output(2, fmt.Sprintf(format, v...))
}
var Logger Logging = &logger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
}
... ...
... ... @@ -8,16 +8,15 @@ import (
"time"
"github.com/go-pg/pg/v10/internal"
"go.opentelemetry.io/otel/api/kv"
"go.opentelemetry.io/otel/api/trace"
"go.opentelemetry.io/otel/label"
"go.opentelemetry.io/otel/trace"
)
var noDeadline = time.Time{}
type Conn struct {
netConn net.Conn
rd *BufReader
rd *ReaderContext
ProcessID int32
SecretKey int32
... ... @@ -31,8 +30,6 @@ type Conn struct {
func NewConn(netConn net.Conn) *Conn {
cn := &Conn{
rd: NewBufReader(netConn),
createdAt: time.Now(),
}
cn.SetNetConn(netConn)
... ... @@ -55,7 +52,17 @@ func (cn *Conn) RemoteAddr() net.Addr {
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
if cn.rd != nil {
cn.rd.Reset(netConn)
}
}
func (cn *Conn) LockReader() {
if cn.rd != nil {
panic("not reached")
}
cn.rd = NewReaderContext()
cn.rd.Reset(cn.netConn)
}
func (cn *Conn) NetConn() net.Conn {
... ... @@ -68,30 +75,44 @@ func (cn *Conn) NextID() string {
}
func (cn *Conn) WithReader(
ctx context.Context, timeout time.Duration, fn func(rd *BufReader) error,
ctx context.Context, timeout time.Duration, fn func(rd *ReaderContext) error,
) error {
return internal.WithSpan(ctx, "with_reader", func(ctx context.Context, span trace.Span) error {
err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout))
if err != nil {
return internal.WithSpan(ctx, "pg.with_reader", func(ctx context.Context, span trace.Span) error {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
span.RecordError(err)
return err
}
cn.rd.bytesRead = 0
err = fn(cn.rd)
span.SetAttributes(kv.Int64("net.read_bytes", cn.rd.bytesRead))
rd := cn.rd
if rd == nil {
rd = GetReaderContext()
defer PutReaderContext(rd)
rd.Reset(cn.netConn)
}
rd.bytesRead = 0
if err := fn(rd); err != nil {
span.RecordError(err)
return err
}
span.SetAttributes(label.Int64("net.read_bytes", rd.bytesRead))
return nil
})
}
func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wb *WriteBuffer) error,
) error {
return internal.WithSpan(ctx, "with_writer", func(ctx context.Context, span trace.Span) error {
return internal.WithSpan(ctx, "pg.with_writer", func(ctx context.Context, span trace.Span) error {
wb := GetWriteBuffer()
defer PutWriteBuffer(wb)
if err := fn(wb); err != nil {
span.RecordError(err)
return err
}
... ... @@ -100,7 +121,7 @@ func (cn *Conn) WithWriter(
}
func (cn *Conn) WriteBuffer(ctx context.Context, timeout time.Duration, wb *WriteBuffer) error {
return internal.WithSpan(ctx, "with_writer", func(ctx context.Context, span trace.Span) error {
return internal.WithSpan(ctx, "pg.with_writer", func(ctx context.Context, span trace.Span) error {
return cn.writeBuffer(ctx, span, timeout, wb)
})
}
... ... @@ -111,14 +132,19 @@ func (cn *Conn) writeBuffer(
timeout time.Duration,
wb *WriteBuffer,
) error {
err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout))
if err != nil {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
span.RecordError(err)
return err
}
span.SetAttributes(kv.Int("net.wrote_bytes", len(wb.Bytes)))
_, err = cn.netConn.Write(wb.Bytes)
span.SetAttributes(label.Int("net.wrote_bytes", len(wb.Bytes)))
if _, err := cn.netConn.Write(wb.Bytes); err != nil {
span.RecordError(err)
return err
}
return nil
}
func (cn *Conn) Close() error {
... ...
... ... @@ -11,8 +11,10 @@ import (
"github.com/go-pg/pg/v10/internal"
)
var ErrClosed = errors.New("pg: database is closed")
var ErrPoolTimeout = errors.New("pg: connection pool timeout")
var (
ErrClosed = errors.New("pg: database is closed")
ErrPoolTimeout = errors.New("pg: connection pool timeout")
)
var timers = sync.Pool{
New: func() interface{} {
... ... @@ -38,8 +40,8 @@ type Pooler interface {
CloseConn(*Conn) error
Get(context.Context) (*Conn, error)
Put(*Conn)
Remove(*Conn, error)
Put(context.Context, *Conn)
Remove(context.Context, *Conn, error)
Len() int
IdleLen() int
... ... @@ -216,12 +218,12 @@ func (p *ConnPool) getLastDialError() error {
}
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(c context.Context) (*Conn, error) {
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
err := p.waitTurn(c)
err := p.waitTurn(ctx)
if err != nil {
return nil, err
}
... ... @@ -246,7 +248,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) {
atomic.AddUint32(&p.stats.Misses, 1)
newcn, err := p.newConn(c, true)
newcn, err := p.newConn(ctx, true)
if err != nil {
p.freeTurn()
return nil, err
... ... @@ -312,15 +314,9 @@ func (p *ConnPool) popIdle() *Conn {
return cn
}
func (p *ConnPool) Put(cn *Conn) {
if cn.rd.Buffered() > 0 {
internal.Logger.Printf("Conn has unread data")
p.Remove(cn, BadConnError{})
return
}
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
if !cn.pooled {
p.Remove(cn, nil)
p.Remove(ctx, cn, nil)
return
}
... ... @@ -331,7 +327,7 @@ func (p *ConnPool) Put(cn *Conn) {
p.freeTurn()
}
func (p *ConnPool) Remove(cn *Conn, reason error) {
func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.removeConnWithLock(cn)
p.freeTurn()
_ = p.closeConn(cn)
... ... @@ -446,7 +442,7 @@ func (p *ConnPool) reaper(frequency time.Duration) {
}
n, err := p.ReapStaleConns()
if err != nil {
internal.Logger.Printf("ReapStaleConns failed: %s", err)
internal.Logger.Printf(context.TODO(), "ReapStaleConns failed: %s", err)
continue
}
atomic.AddUint32(&p.stats.StaleConns, uint32(n))
... ...