作者 唐旭辉

更新依赖

正在显示 50 个修改的文件 包含 1459 行增加1637 行删除

要显示太多修改。

为保证性能只显示 50 of 50+ 个文件。

... ... @@ -12,7 +12,7 @@ require (
github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072 // indirect
github.com/fatih/structs v1.1.0 // indirect
github.com/gavv/httpexpect v2.0.0+incompatible
github.com/go-pg/pg/v10 v10.0.0-beta.2
github.com/go-pg/pg/v10 v10.7.3
github.com/google/go-querystring v1.0.0 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/imkira/go-interpol v1.1.0 // indirect
... ... @@ -20,8 +20,9 @@ require (
github.com/linmadan/egglib-go v0.0.0-20191217144343-ca4539f95bf9
github.com/mattn/go-colorable v0.1.6 // indirect
github.com/moul/http2curl v1.0.0 // indirect
github.com/onsi/ginkgo v1.13.0
github.com/onsi/gomega v1.10.1
github.com/onsi/ginkgo v1.14.2
github.com/onsi/gomega v1.10.3
github.com/sclevine/agouti v3.0.0+incompatible // indirect
github.com/sergi/go-diff v1.1.0 // indirect
github.com/shopspring/decimal v1.2.0
github.com/smartystreets/goconvey v1.6.4 // indirect
... ...
... ... @@ -47,7 +47,7 @@ func (repository *PartnerInfoRepository) Save(dm *domain.PartnerInfo) error {
Remark: dm.Remark,
}
if m.Id == 0 {
err = tx.Insert(m)
_, err = tx.Model(m).Insert()
dm.Partner.Id = m.Id
if err != nil {
return err
... ...
此 diff 太大无法显示。
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
The MIT License (MIT)
Copyright (c) 2015 codemodus
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# kace
go get "github.com/codemodus/kace"
Package kace provides common case conversion functions which take into
consideration common initialisms.
## Usage
```go
func Camel(s string) string
func Kebab(s string) string
func KebabUpper(s string) string
func Pascal(s string) string
func Snake(s string) string
func SnakeUpper(s string) string
type Kace
func New(initialisms map[string]bool) (*Kace, error)
func (k *Kace) Camel(s string) string
func (k *Kace) Kebab(s string) string
func (k *Kace) KebabUpper(s string) string
func (k *Kace) Pascal(s string) string
func (k *Kace) Snake(s string) string
func (k *Kace) SnakeUpper(s string) string
```
### Setup
```go
import (
"fmt"
"github.com/codemodus/kace"
)
func main() {
s := "this is a test sql."
fmt.Println(kace.Camel(s))
fmt.Println(kace.Pascal(s))
fmt.Println(kace.Snake(s))
fmt.Println(kace.SnakeUpper(s))
fmt.Println(kace.Kebab(s))
fmt.Println(kace.KebabUpper(s))
customInitialisms := map[string]bool{
"THIS": true,
}
k, err := kace.New(customInitialisms)
if err != nil {
// handle error
}
fmt.Println(k.Camel(s))
fmt.Println(k.Pascal(s))
fmt.Println(k.Snake(s))
fmt.Println(k.SnakeUpper(s))
fmt.Println(k.Kebab(s))
fmt.Println(k.KebabUpper(s))
// Output:
// thisIsATestSQL
// ThisIsATestSQL
// this_is_a_test_sql
// THIS_IS_A_TEST_SQL
// this-is-a-test-sql
// THIS-IS-A-TEST-SQL
// thisIsATestSql
// THISIsATestSql
// this_is_a_test_sql
// THIS_IS_A_TEST_SQL
// this-is-a-test-sql
// THIS-IS-A-TEST-SQL
}
```
## More Info
### TODO
#### Test Trie
Test the current trie.
## Documentation
View the [GoDoc](http://godoc.org/github.com/codemodus/kace)
## Benchmarks
benchmark iter time/iter bytes alloc allocs
--------- ---- --------- ----------- ------
BenchmarkCamel4 2000000 947.00 ns/op 112 B/op 3 allocs/op
BenchmarkSnake4 2000000 696.00 ns/op 128 B/op 2 allocs/op
BenchmarkSnakeUpper4 2000000 679.00 ns/op 128 B/op 2 allocs/op
BenchmarkKebab4 2000000 691.00 ns/op 128 B/op 2 allocs/op
BenchmarkKebabUpper4 2000000 677.00 ns/op 128 B/op 2 allocs/op
// Package kace provides common case conversion functions which take into
// consideration common initialisms.
package kace
import (
"fmt"
"strings"
"unicode"
"github.com/codemodus/kace/ktrie"
)
const (
kebabDelim = '-'
snakeDelim = '_'
none = rune(-1)
)
var (
ciTrie *ktrie.KTrie
)
func init() {
var err error
if ciTrie, err = ktrie.NewKTrie(ciMap); err != nil {
panic(err)
}
}
// Camel returns a camelCased string.
func Camel(s string) string {
return camelCase(ciTrie, s, false)
}
// Pascal returns a PascalCased string.
func Pascal(s string) string {
return camelCase(ciTrie, s, true)
}
// Kebab returns a kebab-cased string with all lowercase letters.
func Kebab(s string) string {
return delimitedCase(s, kebabDelim, false)
}
// KebabUpper returns a KEBAB-CASED string with all upper case letters.
func KebabUpper(s string) string {
return delimitedCase(s, kebabDelim, true)
}
// Snake returns a snake_cased string with all lowercase letters.
func Snake(s string) string {
return delimitedCase(s, snakeDelim, false)
}
// SnakeUpper returns a SNAKE_CASED string with all upper case letters.
func SnakeUpper(s string) string {
return delimitedCase(s, snakeDelim, true)
}
// Kace provides common case conversion methods which take into
// consideration common initialisms set by the user.
type Kace struct {
t *ktrie.KTrie
}
// New returns a pointer to an instance of kace loaded with a common
// initialsms trie based on the provided map. Before conversion to a
// trie, the provided map keys are all upper cased.
func New(initialisms map[string]bool) (*Kace, error) {
ci := initialisms
if ci == nil {
ci = map[string]bool{}
}
ci = sanitizeCI(ci)
t, err := ktrie.NewKTrie(ci)
if err != nil {
return nil, fmt.Errorf("kace: cannot create trie: %s", err)
}
k := &Kace{
t: t,
}
return k, nil
}
// Camel returns a camelCased string.
func (k *Kace) Camel(s string) string {
return camelCase(k.t, s, false)
}
// Pascal returns a PascalCased string.
func (k *Kace) Pascal(s string) string {
return camelCase(k.t, s, true)
}
// Snake returns a snake_cased string with all lowercase letters.
func (k *Kace) Snake(s string) string {
return delimitedCase(s, snakeDelim, false)
}
// SnakeUpper returns a SNAKE_CASED string with all upper case letters.
func (k *Kace) SnakeUpper(s string) string {
return delimitedCase(s, snakeDelim, true)
}
// Kebab returns a kebab-cased string with all lowercase letters.
func (k *Kace) Kebab(s string) string {
return delimitedCase(s, kebabDelim, false)
}
// KebabUpper returns a KEBAB-CASED string with all upper case letters.
func (k *Kace) KebabUpper(s string) string {
return delimitedCase(s, kebabDelim, true)
}
func camelCase(t *ktrie.KTrie, s string, ucFirst bool) string {
rs := []rune(s)
offset := 0
prev := none
for i := 0; i < len(rs); i++ {
r := rs[i]
switch {
case unicode.IsLetter(r):
ucCurr := isToBeUpper(r, prev, ucFirst)
if ucCurr || isSegmentStart(r, prev) {
prv, skip := updateRunes(rs, i, offset, t, ucCurr)
if skip > 0 {
i += skip - 1
prev = prv
continue
}
}
prev = updateRune(rs, i, offset, ucCurr)
continue
case unicode.IsNumber(r):
prev = updateRune(rs, i, offset, false)
continue
default:
prev = r
offset--
}
}
return string(rs[:len(rs)+offset])
}
func isToBeUpper(curr, prev rune, ucFirst bool) bool {
if prev == none {
return ucFirst
}
return isSegmentStart(curr, prev)
}
func isSegmentStart(curr, prev rune) bool {
if !unicode.IsLetter(prev) || unicode.IsUpper(curr) && unicode.IsLower(prev) {
return true
}
return false
}
func updateRune(rs []rune, i, offset int, upper bool) rune {
r := rs[i]
dest := i + offset
if dest < 0 || i > len(rs)-1 {
panic("this function has been used or designed incorrectly")
}
fn := unicode.ToLower
if upper {
fn = unicode.ToUpper
}
rs[dest] = fn(r)
return r
}
func updateRunes(rs []rune, i, offset int, t *ktrie.KTrie, upper bool) (rune, int) {
r := rs[i]
ns := nextSegment(rs, i)
ct := len(ns)
if ct < t.MinDepth() || ct > t.MaxDepth() || !t.FindAsUpper(ns) {
return r, 0
}
for j := i; j < i+ct; j++ {
r = updateRune(rs, j, offset, upper)
}
return r, ct
}
func nextSegment(rs []rune, i int) []rune {
for j := i; j < len(rs); j++ {
if !unicode.IsLetter(rs[j]) && !unicode.IsNumber(rs[j]) {
return rs[i:j]
}
if j == len(rs)-1 {
return rs[i : j+1]
}
}
return nil
}
func delimitedCase(s string, delim rune, upper bool) string {
buf := make([]rune, 0, len(s)*2)
for i := len(s); i > 0; i-- {
switch {
case unicode.IsLetter(rune(s[i-1])):
if i < len(s) && unicode.IsUpper(rune(s[i])) {
if i > 1 && unicode.IsLower(rune(s[i-1])) || i < len(s)-2 && unicode.IsLower(rune(s[i+1])) {
buf = append(buf, delim)
}
}
buf = appendCased(buf, upper, rune(s[i-1]))
case unicode.IsNumber(rune(s[i-1])):
if i == len(s) || i == 1 || unicode.IsNumber(rune(s[i])) {
buf = append(buf, rune(s[i-1]))
continue
}
buf = append(buf, delim, rune(s[i-1]))
default:
if i == len(s) {
continue
}
buf = append(buf, delim)
}
}
reverse(buf)
return string(buf)
}
func appendCased(rs []rune, upper bool, r rune) []rune {
if upper {
rs = append(rs, unicode.ToUpper(r))
return rs
}
rs = append(rs, unicode.ToLower(r))
return rs
}
func reverse(s []rune) {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
s[i], s[j] = s[j], s[i]
}
}
var (
// github.com/golang/lint/blob/master/lint.go
ciMap = map[string]bool{
"ACL": true,
"API": true,
"ASCII": true,
"CPU": true,
"CSS": true,
"DNS": true,
"EOF": true,
"GUID": true,
"HTML": true,
"HTTP": true,
"HTTPS": true,
"ID": true,
"IP": true,
"JSON": true,
"LHS": true,
"QPS": true,
"RAM": true,
"RHS": true,
"RPC": true,
"SLA": true,
"SMTP": true,
"SQL": true,
"SSH": true,
"TCP": true,
"TLS": true,
"TTL": true,
"UDP": true,
"UI": true,
"UID": true,
"UUID": true,
"URI": true,
"URL": true,
"UTF8": true,
"VM": true,
"XML": true,
"XMPP": true,
"XSRF": true,
"XSS": true,
}
)
func sanitizeCI(m map[string]bool) map[string]bool {
r := map[string]bool{}
for k := range m {
fn := func(r rune) rune {
if !unicode.IsLetter(r) && !unicode.IsNumber(r) {
return -1
}
return r
}
k = strings.Map(fn, k)
k = strings.ToUpper(k)
if k == "" {
continue
}
r[k] = true
}
return r
}
package ktrie
import "unicode"
// KNode ...
type KNode struct {
val rune
end bool
links []*KNode
}
// NewKNode ...
func NewKNode(val rune) *KNode {
return &KNode{
val: val,
links: make([]*KNode, 0),
}
}
// Add ...
func (n *KNode) Add(rs []rune) {
cur := n
for k, v := range rs {
link := cur.linkByVal(v)
if link == nil {
link = NewKNode(v)
cur.links = append(cur.links, link)
}
if k == len(rs)-1 {
link.end = true
}
cur = link
}
}
// Find ...
func (n *KNode) Find(rs []rune) bool {
cur := n
for _, v := range rs {
cur = cur.linkByVal(v)
if cur == nil {
return false
}
}
return cur.end
}
// FindAsUpper ...
func (n *KNode) FindAsUpper(rs []rune) bool {
cur := n
for _, v := range rs {
cur = cur.linkByVal(unicode.ToUpper(v))
if cur == nil {
return false
}
}
return cur.end
}
func (n *KNode) linkByVal(val rune) *KNode {
for _, v := range n.links {
if v.val == val {
return v
}
}
return nil
}
// KTrie ...
type KTrie struct {
*KNode
maxDepth int
minDepth int
}
// NewKTrie ...
func NewKTrie(data map[string]bool) (*KTrie, error) {
n := NewKNode(0)
maxDepth := 0
minDepth := 9001
for k := range data {
rs := []rune(k)
l := len(rs)
n.Add(rs)
if l > maxDepth {
maxDepth = l
}
if l < minDepth {
minDepth = l
}
}
t := &KTrie{
maxDepth: maxDepth,
minDepth: minDepth,
KNode: n,
}
return t, nil
}
// MaxDepth ...
func (t *KTrie) MaxDepth() int {
return t.maxDepth
}
// MinDepth ...
func (t *KTrie) MinDepth() int {
return t.minDepth
}
... ... @@ -11,3 +11,8 @@ linters:
- wsl
- funlen
- godox
- goerr113
- exhaustive
- nestif
- gofumpt
- goconst
... ...
semi: false
singleQuote: true
proseWrap: always
printWidth: 80
printWidth: 100
... ...
dist: xenial
sudo: false
language: go
addons:
postgresql: "9.6"
postgresql: '9.6'
go:
- 1.13.x
- 1.14.x
- 1.15.x
- tip
matrix:
allow_failures:
- go: tip
env:
- GO111MODULE=on
go_import_path: github.com/go-pg/pg
before_install:
- psql -U postgres -c "CREATE EXTENSION hstore"
- curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.24.0
- curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s --
-b $(go env GOPATH)/bin v1.28.3
... ...
# Changelog
## v10 (unreleased)
> :heart: [**Uptrace.dev** - distributed traces, logs, and errors in one place](https://uptrace.dev)
- Added `pgext.OpenTemetryHook` that adds OpenTelemetry
[instrumentation](https://pg.uptrace.dev/tracing/).
- Added `pgext.DebugHook` that logs queries and errors.
- Added `db.Ping` to check if database is healthy.
- Changed `pg.QueryHook` to return temp byte slice to reduce memory usage.
- `,msgpack` struct tag marshals data in MessagePack format using
https://github.com/vmihailenco/msgpack
- Deprecated types and funcs are removed.
## v9
- `pg:",notnull"` is reworked. Now it means SQL `NOT NULL` constraint and
nothing more.
- Added `pg:",use_zero"` to prevent go-pg from converting Go zero values to SQL
`NULL`.
- UpdateNotNull is renamed to UpdateNotZero. As previously it omits zero Go
values, but it does not take in account if field is nullable or not.
- ORM supports DistinctOn.
- Hooks accept and return context.
- Client respects Context.Deadline when setting net.Conn deadline.
- Client listens on Context.Done while waiting for a connection from the pool
and returns an error when context is cancelled.
- `Query.Column` does not accept relation name any more. Use `Query.Relation`
instead which returns an error if relation does not exist.
- urlvalues package is removed in favor of https://github.com/go-pg/urlstruct.
You can also use struct based filters via `Query.WhereStruct`.
- `NewModel` and `AddModel` methods of `HooklessModel` interface were renamed to
`NextColumnScanner` and `AddColumnScanner` respectively.
- `types.F` and `pg.F` are deprecated in favor of `pg.Ident`.
- `types.Q` is deprecated in favor of `pg.Safe`.
- `pg.Q` is deprecated in favor of `pg.SafeQuery`.
- `TableName` field is deprecated in favor of `tableName`.
- Always use `pg:"..."` struct field tag instead of `sql:"..."`.
- `pg:",override"` is deprecated in favor of `pg:",inherit"`.
## v8
- Added `QueryContext`, `ExecContext`, and `ModelContext` which accept
`context.Context`. Queries are cancelled when context is cancelled.
- Model hooks are changed to accept `context.Context` as first argument.
- Fixed array and hstore parsers to handle multiple single quotes (#1235).
## v7
- DB.OnQueryProcessed is replaced with DB.AddQueryHook.
- Added WhereStruct.
- orm.Pager is moved to urlvalues.Pager. Pager.FromURLValues returns an error if
page or limit params can't be parsed.
## v6.16
- Read buffer is re-worked. Default read buffer is increased to 65kb.
## v6.15
- Added Options.MinIdleConns.
- Options.MaxAge renamed to Options.MaxConnAge.
- PoolStats.FreeConns is renamed to PoolStats.IdleConns.
- New hook BeforeSelectQuery.
- `,override` is renamed to `,inherit`.
- Dialer.KeepAlive is set to 5 minutes by default.
- Added support "scram-sha-256" authentication.
## v6.14
- Fields ignored with `sql:"-"` tag are no longer considered by ORM relation
detector.
## v6.12
- `Insert`, `Update`, and `Delete` can return `pg.ErrNoRows` and
`pg.ErrMultiRows` when `Returning` is used and model expects single row.
## v6.11
- `db.Model(&strct).Update()` and `db.Model(&strct).Delete()` no longer adds
WHERE condition based on primary key when there are no conditions. Instead you
should use `db.Update(&strct)` or `db.Model(&strct).WherePK().Update()`.
## v6.10
- `?Columns` is renamed to `?TableColumns`. `?Columns` is changed to produce
column names without table alias.
## v6.9
- `pg:"fk"` tag now accepts SQL names instead of Go names, e.g.
`pg:"fk:ParentId"` becomes `pg:"fk:parent_id"`. Old code should continue
working in most cases, but it is strongly advised to start using new
convention.
- uint and uint64 SQL type is changed from decimal to bigint according to the
lesser of two evils principle. Use `sql:"type:decimal"` to get old behavior.
## v6.8
- `CreateTable` no longer adds ON DELETE hook by default. To get old behavior
users should add `sql:"on_delete:CASCADE"` tag on foreign key field.
## v6
- `types.Result` is renamed to `orm.Result`.
- Added `OnQueryProcessed` event that can be used to log / report queries
timing. Query logger is removed.
- `orm.URLValues` is renamed to `orm.URLFilters`. It no longer adds ORDER
clause.
- `orm.Pager` is renamed to `orm.Pagination`.
- Support for net.IP and net.IPNet.
- Support for context.Context.
- Bulk/multi updates.
- Query.WhereGroup for enclosing conditions in parentheses.
## v5
- All fields are nullable by default. `,null` tag is replaced with `,notnull`.
- `Result.Affected` renamed to `Result.RowsAffected`.
- Added `Result.RowsReturned`.
- `Create` renamed to `Insert`, `BeforeCreate` to `BeforeInsert`, `AfterCreate`
to `AfterInsert`.
- Indexed placeholders support, e.g. `db.Exec("SELECT ?0 + ?0", 1)`.
- Named placeholders are evaluated when query is executed.
- Added Update and Delete hooks.
- Order reworked to quote column names. OrderExpr added to bypass Order quoting
restrictions.
- Group reworked to quote column names. GroupExpr added to bypass Group quoting
restrictions.
## v4
- `Options.Host` and `Options.Port` merged into `Options.Addr`.
- Added `Options.MaxRetries`. Now queries are not retried by default.
- `LoadInto` renamed to `Scan`, `ColumnLoader` renamed to `ColumnScanner`,
LoadColumn renamed to ScanColumn, `NewRecord() interface{}` changed to
`NewModel() ColumnScanner`, `AppendQuery(dst []byte) []byte` changed to
`AppendValue(dst []byte, quote bool) ([]byte, error)`.
- Structs, maps and slices are marshalled to JSON by default.
- Added support for scanning slices, .e.g. scanning `[]int`.
- Added object relational mapping.
See https://pg.uptrace.dev/changelog/
... ...
all:
go test ./...
go test ./... -short -race
go test ./... -run=NONE -bench=. -benchmem
TZ= go test ./...
TZ= go test ./... -short -race
TZ= go test ./... -run=NONE -bench=. -benchmem
env GOOS=linux GOARCH=386 go test ./...
go vet
golangci-lint run
.PHONY: cleanTest
cleanTest:
docker rm -fv pg || true
.PHONY: pre-test
pre-test: cleanTest
docker run -d --name pg -p 5432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust postgres:9.6
sleep 10
docker exec pg psql -U postgres -c "CREATE EXTENSION hstore"
.PHONY: test
test: pre-test
TZ= PGSSLMODE=disable go test ./... -v
... ...
# PostgreSQL client and ORM for Golang
[![Build Status](https://travis-ci.org/go-pg/pg.svg?branch=master)](https://travis-ci.org/go-pg/pg)
[![GoDoc](https://godoc.org/github.com/go-pg/pg?status.svg)](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc)
[![Build Status](https://travis-ci.org/go-pg/pg.svg?branch=v10)](https://travis-ci.org/go-pg/pg)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/go-pg/pg/v10)](https://pkg.go.dev/github.com/go-pg/pg/v10)
[![Documentation](https://img.shields.io/badge/pg-documentation-informational)](https://pg.uptrace.dev/)
[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj)
- [Docs](https://pg.uptrace.dev)
> :heart: [**Uptrace.dev** - distributed traces, logs, and errors in one place](https://uptrace.dev)
- Join [Discord](https://discord.gg/rWtp5Aj) to ask questions.
- [Documentation](https://pg.uptrace.dev)
- [Reference](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc)
- [Examples](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#pkg-examples)
- Example projects:
- [treemux](https://github.com/uptrace/go-treemux-realworld-example-app)
- [gin](https://github.com/gogjango/gjango)
- [go-kit](https://github.com/Tsovak/rest-api-demo)
- [aah framework](https://github.com/kieusonlam/golamapi)
- [GraphQL Tutorial on YouTube](https://www.youtube.com/playlist?list=PLzQWIQOqeUSNwXcneWYJHUREAIucJ5UZn).
## Ecosystem
- Migrations by [vmihailenco](https://github.com/go-pg/migrations) and
[robinjoseph08](https://github.com/robinjoseph08/go-pg-migrations).
- [Genna - cli tool for generating go-pg models](https://github.com/dizzyfool/genna).
- [urlstruct](https://github.com/go-pg/urlstruct) to decode `url.Values` into structs.
- [Sharding](https://github.com/go-pg/sharding).
- [Model generator from SQL tables](https://github.com/dizzyfool/genna).
- [urlstruct](https://github.com/go-pg/urlstruct) to decode `url.Values` into
structs.
## Sponsors
- [**Uptrace.dev** - distributed traces and metrics](https://uptrace.dev)
## Features
... ... @@ -26,71 +32,200 @@
- sql.NullBool, sql.NullString, sql.NullInt64, sql.NullFloat64 and
[pg.NullTime](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#NullTime).
- [sql.Scanner](http://golang.org/pkg/database/sql/#Scanner) and
[sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer)
interfaces.
[sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer) interfaces.
- Structs, maps and arrays are marshalled as JSON by default.
- PostgreSQL multidimensional Arrays using
[array tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-PostgresArrayStructTag)
and
[Array wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Array).
and [Array wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Array).
- Hstore using
[hstore tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HstoreStructTag)
and
[Hstore wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Hstore).
and [Hstore wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Hstore).
- [Composite types](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CompositeType).
- All struct fields are nullable by default and zero values (empty string, 0,
zero time, empty map or slice, nil ptr) are marshalled as SQL `NULL`.
`pg:",notnull"` is used to add SQL `NOT NULL` constraint and `pg:",use_zero"`
to allow Go zero values.
- All struct fields are nullable by default and zero values (empty string, 0, zero time, empty map
or slice, nil ptr) are marshalled as SQL `NULL`. `pg:",notnull"` is used to add SQL `NOT NULL`
constraint and `pg:",use_zero"` to allow Go zero values.
- [Transactions](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Begin).
- [Prepared statements](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Prepare).
- [Notifications](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Listener)
using `LISTEN` and `NOTIFY`.
- [Copying data](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CopyFrom)
using `COPY FROM` and `COPY TO`.
- [Timeouts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#Options) and
canceling queries using context.Context.
- [Notifications](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Listener) using
`LISTEN` and `NOTIFY`.
- [Copying data](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CopyFrom) using
`COPY FROM` and `COPY TO`.
- [Timeouts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#Options) and canceling queries using
context.Context.
- Automatic connection pooling with
[circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern)
support.
[circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support.
- Queries retry on network errors.
- Working with models using
[ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model) and
[SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Query).
[ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model) and
[SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Query).
- Scanning variables using
[ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-SomeColumnsIntoVars)
[ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectSomeColumnsIntoVars)
and [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Scan).
- [SelectOrInsert](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-SelectOrInsert)
- [SelectOrInsert](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertSelectOrInsert)
using on-conflict.
- [INSERT ... ON CONFLICT DO UPDATE](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-OnConflictDoUpdate)
- [INSERT ... ON CONFLICT DO UPDATE](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertOnConflictDoUpdate)
using ORM.
- Bulk/batch
[inserts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-BulkInsert),
[updates](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Update-BulkUpdate),
and
[deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Delete-BulkDelete).
[inserts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkInsert),
[updates](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkUpdate), and
[deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkDelete).
- Common table expressions using
[WITH](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-With)
and
[WrapWith](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-WrapWith).
- [CountEstimate](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CountEstimate)
[WITH](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWith) and
[WrapWith](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWrapWith).
- [CountEstimate](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CountEstimate)
using `EXPLAIN` to get
[estimated number of matching rows](https://wiki.postgresql.org/wiki/Count_estimate).
- ORM supports
[has one](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HasOne),
[belongs to](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-BelongsTo),
[has many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HasMany),
and
[many to many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-ManyToMany)
[has one](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasOne),
[belongs to](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BelongsTo),
[has many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasMany), and
[many to many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ManyToMany)
with composite/multi-column primary keys.
- [Soft deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-SoftDelete).
- [Creating tables from structs](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CreateTable).
- [ForEach](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-ForEach)
that calls a function for each row returned by the query without loading all
rows into the memory.
- [Soft deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SoftDelete).
- [Creating tables from structs](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CreateTable).
- [ForEach](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ForEach) that calls
a function for each row returned by the query without loading all rows into the memory.
- Works with PgBouncer in transaction pooling mode.
## Installation
go-pg supports 2 last Go versions and requires a Go version with
[modules](https://github.com/golang/go/wiki/Modules) support. So make sure to initialize a Go
module:
```shell
go mod init github.com/my/repo
```
And then install go-pg (note _v10_ in the import; omitting it is a popular mistake):
```shell
go get github.com/go-pg/pg/v10
```
## Quickstart
```go
package pg_test
import (
"fmt"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
)
type User struct {
Id int64
Name string
Emails []string
}
func (u User) String() string {
return fmt.Sprintf("User<%d %s %v>", u.Id, u.Name, u.Emails)
}
type Story struct {
Id int64
Title string
AuthorId int64
Author *User `pg:"rel:has-one"`
}
func (s Story) String() string {
return fmt.Sprintf("Story<%d %s %s>", s.Id, s.Title, s.Author)
}
func ExampleDB_Model() {
db := pg.Connect(&pg.Options{
User: "postgres",
})
defer db.Close()
err := createSchema(db)
if err != nil {
panic(err)
}
user1 := &User{
Name: "admin",
Emails: []string{"admin1@admin", "admin2@admin"},
}
_, err = db.Model(user1).Insert()
if err != nil {
panic(err)
}
_, err = db.Model(&User{
Name: "root",
Emails: []string{"root1@root", "root2@root"},
}).Insert()
if err != nil {
panic(err)
}
story1 := &Story{
Title: "Cool story",
AuthorId: user1.Id,
}
_, err = db.Model(story1).Insert()
if err != nil {
panic(err)
}
// Select user by primary key.
user := &User{Id: user1.Id}
err = db.Model(user).WherePK().Select()
if err != nil {
panic(err)
}
// Select all users.
var users []User
err = db.Model(&users).Select()
if err != nil {
panic(err)
}
// Select story and associated author in one query.
story := new(Story)
err = db.Model(story).
Relation("Author").
Where("story.id = ?", story1.Id).
Select()
if err != nil {
panic(err)
}
fmt.Println(user)
fmt.Println(users)
fmt.Println(story)
// Output: User<1 admin [admin1@admin admin2@admin]>
// [User<1 admin [admin1@admin admin2@admin]> User<2 root [root1@root root2@root]>]
// Story<1 Cool story User<1 admin [admin1@admin admin2@admin]>>
}
// createSchema creates database schema for User and Story models.
func createSchema(db *pg.DB) error {
models := []interface{}{
(*User)(nil),
(*Story)(nil),
}
for _, model := range models {
err := db.Model(model).CreateTable(&orm.CreateTableOptions{
Temp: true,
})
if err != nil {
return err
}
}
return nil
}
```
## See also
- [Fast and flexible HTTP router](https://github.com/vmihailenco/treemux)
- [Golang msgpack](https://github.com/vmihailenco/msgpack)
- [Golang message task queue](https://github.com/vmihailenco/taskq)
... ...
... ... @@ -5,12 +5,13 @@ import (
"io"
"time"
"go.opentelemetry.io/otel/api/kv"
"go.opentelemetry.io/otel/api/trace"
"go.opentelemetry.io/otel/label"
"go.opentelemetry.io/otel/trace"
"github.com/go-pg/pg/v10/internal"
"github.com/go-pg/pg/v10/internal/pool"
"github.com/go-pg/pg/v10/orm"
"github.com/go-pg/pg/v10/types"
)
type baseDB struct {
... ... @@ -83,14 +84,14 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}
err = internal.WithSpan(ctx, "init_conn", func(ctx context.Context, span trace.Span) error {
err = internal.WithSpan(ctx, "pg.init_conn", func(ctx context.Context, span trace.Span) error {
return db.initConn(ctx, cn)
})
if err != nil {
db.pool.Remove(cn, err)
// It is safe to reset SingleConnPool if conn can't be initialized.
if p, ok := db.pool.(*pool.SingleConnPool); ok {
_ = p.Reset()
db.pool.Remove(ctx, cn, err)
// It is safe to reset StickyConnPool if conn can't be initialized.
if p, ok := db.pool.(*pool.StickyConnPool); ok {
_ = p.Reset(ctx)
}
if err := internal.Unwrap(err); err != nil {
return nil, err
... ... @@ -101,45 +102,44 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}
func (db *baseDB) initConn(c context.Context, cn *pool.Conn) error {
func (db *baseDB) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
return nil
}
cn.Inited = true
if db.opt.TLSConfig != nil {
err := db.enableSSL(c, cn, db.opt.TLSConfig)
err := db.enableSSL(ctx, cn, db.opt.TLSConfig)
if err != nil {
return err
}
}
err := db.startup(c, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName)
err := db.startup(ctx, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName)
if err != nil {
return err
}
if db.opt.OnConnect != nil {
p := pool.NewSingleConnPool(nil)
p.SetConn(cn)
return db.opt.OnConnect(newConn(c, db.withPool(p)))
p := pool.NewSingleConnPool(db.pool, cn)
return db.opt.OnConnect(ctx, newConn(ctx, db.withPool(p)))
}
return nil
}
func (db *baseDB) releaseConn(cn *pool.Conn, err error) {
func (db *baseDB) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
if isBadConn(err, false) {
db.pool.Remove(cn, err)
db.pool.Remove(ctx, cn, err)
} else {
db.pool.Put(cn)
db.pool.Put(ctx, cn)
}
}
func (db *baseDB) withConn(
ctx context.Context, fn func(context.Context, *pool.Conn) error,
) error {
return internal.WithSpan(ctx, "with_conn", func(ctx context.Context, span trace.Span) error {
return internal.WithSpan(ctx, "pg.with_conn", func(ctx context.Context, span trace.Span) error {
cn, err := db.getConn(ctx)
if err != nil {
return err
... ... @@ -154,7 +154,7 @@ func (db *baseDB) withConn(
case <-ctx.Done():
err := db.cancelRequest(cn.ProcessID, cn.SecretKey)
if err != nil {
internal.Logger.Printf("cancelRequest failed: %s", err)
internal.Logger.Printf(ctx, "cancelRequest failed: %s", err)
}
// Signal end of conn use.
fnDone <- struct{}{}
... ... @@ -169,7 +169,7 @@ func (db *baseDB) withConn(
case fnDone <- struct{}{}: // signal fn finish, skip cancel goroutine
}
}
db.releaseConn(cn, err)
db.releaseConn(ctx, cn, err)
}()
err = fn(ctx, cn)
... ... @@ -179,9 +179,12 @@ func (db *baseDB) withConn(
func (db *baseDB) shouldRetry(err error) bool {
switch err {
case io.EOF, io.ErrUnexpectedEOF:
return true
case nil, context.Canceled, context.DeadlineExceeded:
return false
}
if pgerr, ok := err.(Error); ok {
switch pgerr.Field('C') {
case "40001", // serialization_failure
... ... @@ -194,7 +197,12 @@ func (db *baseDB) shouldRetry(err error) bool {
return false
}
}
return isNetworkError(err)
if _, ok := err.(timeoutError); ok {
return true
}
return false
}
// Close closes the database client, releasing any open resources.
... ... @@ -233,9 +241,9 @@ func (db *baseDB) exec(ctx context.Context, query interface{}, params ...interfa
for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ {
attempt := attempt
lastErr = internal.WithSpan(ctx, "exec", func(ctx context.Context, span trace.Span) error {
lastErr = internal.WithSpan(ctx, "pg.exec", func(ctx context.Context, span trace.Span) error {
if attempt > 0 {
span.SetAttributes(kv.Int("retry", attempt))
span.SetAttributes(label.Int("retry", attempt))
if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil {
return err
... ... @@ -311,9 +319,9 @@ func (db *baseDB) query(ctx context.Context, model, query interface{}, params ..
for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ {
attempt := attempt
lastErr = internal.WithSpan(ctx, "query", func(ctx context.Context, span trace.Span) error {
lastErr = internal.WithSpan(ctx, "pg.query", func(ctx context.Context, span trace.Span) error {
if attempt > 0 {
span.SetAttributes(kv.Int("retry", attempt))
span.SetAttributes(label.Int("retry", attempt))
if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil {
return err
... ... @@ -373,7 +381,7 @@ func (db *baseDB) CopyFrom(r io.Reader, query interface{}, params ...interface{}
return res, err
}
// TODO: don't get/put conn in the pool
// TODO: don't get/put conn in the pool.
func (db *baseDB) copyFrom(
ctx context.Context, cn *pool.Conn, r io.Reader, query interface{}, params ...interface{},
) (res Result, err error) {
... ... @@ -396,6 +404,7 @@ func (db *baseDB) copyFrom(
return nil, err
}
// Note that afterQuery uses the err.
defer func() {
if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil {
err = afterQueryErr
... ... @@ -434,7 +443,7 @@ func (db *baseDB) copyFrom(
return nil, err
}
err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
res, err = readReadyForQuery(rd)
return err
})
... ... @@ -456,7 +465,7 @@ func (db *baseDB) CopyTo(w io.Writer, query interface{}, params ...interface{})
}
func (db *baseDB) copyTo(
c context.Context, cn *pool.Conn, w io.Writer, query interface{}, params ...interface{},
ctx context.Context, cn *pool.Conn, w io.Writer, query interface{}, params ...interface{},
) (res Result, err error) {
var evt *QueryEvent
... ... @@ -472,25 +481,26 @@ func (db *baseDB) copyTo(
model, _ = params[len(params)-1].(orm.TableModel)
}
c, evt, err = db.beforeQuery(c, db.db, model, query, params, wb.Query())
ctx, evt, err = db.beforeQuery(ctx, db.db, model, query, params, wb.Query())
if err != nil {
return nil, err
}
// Note that afterQuery uses the err.
defer func() {
if afterQueryErr := db.afterQuery(c, evt, res, err); afterQueryErr != nil {
if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil {
err = afterQueryErr
}
}()
err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
return writeQueryMsg(wb, db.fmter, query, params...)
})
if err != nil {
return nil, err
}
err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
err := readCopyOutResponse(rd)
if err != nil {
return err
... ... @@ -522,52 +532,6 @@ func (db *baseDB) ModelContext(c context.Context, model ...interface{}) *orm.Que
return orm.NewQueryContext(c, db.db, model...)
}
// Select selects the model by primary key.
func (db *baseDB) Select(model interface{}) error {
return orm.Select(db.db, model)
}
// Insert inserts the model updating primary keys if they are empty.
func (db *baseDB) Insert(model ...interface{}) error {
return orm.Insert(db.db, model...)
}
// Update updates the model by primary key.
func (db *baseDB) Update(model interface{}) error {
return orm.Update(db.db, model)
}
// Delete deletes the model by primary key.
func (db *baseDB) Delete(model interface{}) error {
return orm.Delete(db.db, model)
}
// Delete forces delete of the model with deleted_at column.
func (db *baseDB) ForceDelete(model interface{}) error {
return orm.ForceDelete(db.db, model)
}
// CreateTable creates table for the model. It recognizes following field tags:
// - notnull - sets NOT NULL constraint.
// - unique - sets UNIQUE constraint.
// - default:value - sets default value.
func (db *baseDB) CreateTable(model interface{}, opt *orm.CreateTableOptions) error {
return orm.CreateTable(db.db, model, opt)
}
// DropTable drops table for the model.
func (db *baseDB) DropTable(model interface{}, opt *orm.DropTableOptions) error {
return orm.DropTable(db.db, model, opt)
}
func (db *baseDB) CreateComposite(model interface{}, opt *orm.CreateCompositeOptions) error {
return orm.CreateComposite(db.db, model, opt)
}
func (db *baseDB) DropComposite(model interface{}, opt *orm.DropCompositeOptions) error {
return orm.DropComposite(db.db, model, opt)
}
func (db *baseDB) Formatter() orm.QueryFormatter {
return db.fmter
}
... ... @@ -597,7 +561,7 @@ func (db *baseDB) simpleQuery(
}
var res *result
if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
var err error
res, err = readSimpleQuery(rd)
return err
... ... @@ -616,7 +580,7 @@ func (db *baseDB) simpleQueryData(
}
var res *result
if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
var err error
res, err = readSimpleQueryData(c, rd, model)
return err
... ... @@ -631,12 +595,12 @@ func (db *baseDB) simpleQueryData(
// executions. Multiple queries or executions may be run concurrently
// from the returned statement.
func (db *baseDB) Prepare(q string) (*Stmt, error) {
return prepareStmt(db.withPool(pool.NewSingleConnPool(db.pool)), q)
return prepareStmt(db.withPool(pool.NewStickyConnPool(db.pool)), q)
}
func (db *baseDB) prepare(
c context.Context, cn *pool.Conn, q string,
) (string, [][]byte, error) {
) (string, []types.ColumnInfo, error) {
name := cn.NextID()
err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
writeParseDescribeSyncMsg(wb, name, q)
... ... @@ -646,8 +610,8 @@ func (db *baseDB) prepare(
return "", nil, err
}
var columns [][]byte
err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
var columns []types.ColumnInfo
err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
columns, err = readParseDescribeSync(rd)
return err
})
... ...
... ... @@ -75,12 +75,12 @@ func (db *DB) WithParam(param string, value interface{}) *DB {
}
// Listen listens for notifications sent with NOTIFY command.
func (db *DB) Listen(channels ...string) *Listener {
func (db *DB) Listen(ctx context.Context, channels ...string) *Listener {
ln := &Listener{
db: db,
}
ln.init()
_ = ln.Listen(channels...)
_ = ln.Listen(ctx, channels...)
return ln
}
... ... @@ -105,7 +105,7 @@ var _ orm.DB = (*Conn)(nil)
// Every Conn must be returned to the database pool after use by
// calling Conn.Close.
func (db *DB) Conn() *Conn {
return newConn(db.ctx, db.baseDB.withPool(pool.NewSingleConnPool(db.pool)))
return newConn(db.ctx, db.baseDB.withPool(pool.NewStickyConnPool(db.pool)))
}
func newConn(ctx context.Context, baseDB *baseDB) *Conn {
... ...
package pg
import (
"io"
"net"
"github.com/go-pg/pg/v10/internal"
... ... @@ -22,10 +21,10 @@ var ErrMultiRows = internal.ErrMultiRows
type Error interface {
error
// Field returns a string value associated with an error code.
// Field returns a string value associated with an error field.
//
// https://www.postgresql.org/docs/10/static/protocol-error-fields.html
Field(byte) string
Field(field byte) string
// IntegrityViolation reports whether an error is a part of
// Integrity Constraint Violation class of errors.
... ... @@ -43,21 +42,19 @@ func isBadConn(err error, allowTimeout bool) bool {
if _, ok := err.(internal.Error); ok {
return false
}
if pgErr, ok := err.(Error); ok && pgErr.Field('S') != "FATAL" {
return false
if pgErr, ok := err.(Error); ok {
return pgErr.Field('S') == "FATAL"
}
if allowTimeout {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return false
return !netErr.Temporary()
}
}
return true
}
func isNetworkError(err error) bool {
if err == io.EOF {
return true
}
_, ok := err.(net.Error)
return ok
//------------------------------------------------------------------------------
type timeoutError interface {
Timeout() bool
}
... ...
... ... @@ -3,25 +3,24 @@ module github.com/go-pg/pg/v10
go 1.11
require (
github.com/go-pg/pg/v9 v9.1.6 // indirect
github.com/go-pg/urlstruct v0.4.0
github.com/go-pg/zerochecker v0.1.1
github.com/golang/protobuf v1.4.2 // indirect
github.com/go-pg/zerochecker v0.2.0
github.com/golang/protobuf v1.4.3 // indirect
github.com/jinzhu/inflection v1.0.0
github.com/onsi/ginkgo v1.10.1
github.com/onsi/gomega v1.7.0
github.com/segmentio/encoding v0.1.13
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/onsi/ginkgo v1.14.2
github.com/onsi/gomega v1.10.3
github.com/stretchr/testify v1.6.1
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
github.com/vmihailenco/bufpool v0.1.11
github.com/vmihailenco/msgpack/v5 v5.0.0-beta.1
github.com/vmihailenco/tagparser v0.1.1
go.opentelemetry.io/otel v0.6.0
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 // indirect
golang.org/x/net v0.0.0-20200602114024-627f9648deb9 // indirect
golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 // indirect
google.golang.org/appengine v1.6.6 // indirect
google.golang.org/grpc v1.29.1
google.golang.org/protobuf v1.24.0 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15
github.com/vmihailenco/msgpack/v4 v4.3.11 // indirect
github.com/vmihailenco/msgpack/v5 v5.0.0
github.com/vmihailenco/tagparser v0.1.2
go.opentelemetry.io/otel v0.14.0
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 // indirect
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b // indirect
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.25.0 // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f
mellium.im/sasl v0.2.1
)
... ...
... ... @@ -8,15 +8,17 @@ import (
"github.com/go-pg/pg/v10/orm"
)
type BeforeScanHook = orm.BeforeScanHook
type AfterScanHook = orm.AfterScanHook
type AfterSelectHook = orm.AfterSelectHook
type BeforeInsertHook = orm.BeforeInsertHook
type AfterInsertHook = orm.AfterInsertHook
type BeforeUpdateHook = orm.BeforeUpdateHook
type AfterUpdateHook = orm.AfterUpdateHook
type BeforeDeleteHook = orm.BeforeDeleteHook
type AfterDeleteHook = orm.AfterDeleteHook
type (
BeforeScanHook = orm.BeforeScanHook
AfterScanHook = orm.AfterScanHook
AfterSelectHook = orm.AfterSelectHook
BeforeInsertHook = orm.BeforeInsertHook
AfterInsertHook = orm.AfterInsertHook
BeforeUpdateHook = orm.BeforeUpdateHook
AfterUpdateHook = orm.AfterUpdateHook
BeforeDeleteHook = orm.BeforeDeleteHook
AfterDeleteHook = orm.AfterDeleteHook
)
//------------------------------------------------------------------------------
... ... @@ -94,11 +96,14 @@ func (db *baseDB) beforeQuery(
fmtedQuery: fmtedQuery,
}
for _, hook := range db.queryHooks {
for i, hook := range db.queryHooks {
var err error
ctx, err = hook.BeforeQuery(ctx, event)
if err != nil {
return nil, nil, err
if err := db.afterQueryFromIndex(ctx, event, i); err != nil {
return ctx, nil, err
}
return ctx, nil, err
}
}
... ... @@ -117,14 +122,15 @@ func (db *baseDB) afterQuery(
event.Err = err
event.Result = res
return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1)
}
for _, hook := range db.queryHooks {
err := hook.AfterQuery(ctx, event)
if err != nil {
func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error {
for ; hookIndex >= 0; hookIndex-- {
if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil {
return err
}
}
return nil
}
... ...
... ... @@ -4,8 +4,10 @@ import (
"fmt"
)
var ErrNoRows = Errorf("pg: no rows in result set")
var ErrMultiRows = Errorf("pg: multiple rows in result set")
var (
ErrNoRows = Errorf("pg: no rows in result set")
ErrMultiRows = Errorf("pg: multiple rows in result set")
)
type Error struct {
s string
... ...
... ... @@ -8,20 +8,20 @@ import (
"time"
)
// Retry backoff with jitter sleep to prevent overloaded conditions during intervals
// https://www.awsarchitectureblog.com/2015/03/backoff.html
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
if retry < 0 {
retry = 0
panic("not reached")
}
backoff := minBackoff << uint(retry)
if backoff > maxBackoff || backoff < minBackoff {
backoff = maxBackoff
if minBackoff == 0 {
return 0
}
if backoff == 0 {
return 0
d := minBackoff << uint(retry)
d = minBackoff + time.Duration(rand.Int63n(int64(d)))
if d > maxBackoff || d < minBackoff {
d = maxBackoff
}
return time.Duration(rand.Int63n(int64(backoff)))
return d
}
... ...
package internal
import (
"context"
"fmt"
"log"
"os"
)
var Logger = log.New(os.Stderr, "pg: ", log.LstdFlags|log.Lshortfile)
var Warn = log.New(os.Stderr, "WARN: pg: ", log.LstdFlags)
var Deprecated = log.New(os.Stderr, "DEPRECATED: pg: ", log.LstdFlags)
type Logging interface {
Printf(ctx context.Context, format string, v ...interface{})
}
type logger struct {
log *log.Logger
}
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
_ = l.log.Output(2, fmt.Sprintf(format, v...))
}
var Logger Logging = &logger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
}
... ...
... ... @@ -8,16 +8,15 @@ import (
"time"
"github.com/go-pg/pg/v10/internal"
"go.opentelemetry.io/otel/api/kv"
"go.opentelemetry.io/otel/api/trace"
"go.opentelemetry.io/otel/label"
"go.opentelemetry.io/otel/trace"
)
var noDeadline = time.Time{}
type Conn struct {
netConn net.Conn
rd *BufReader
rd *ReaderContext
ProcessID int32
SecretKey int32
... ... @@ -31,8 +30,6 @@ type Conn struct {
func NewConn(netConn net.Conn) *Conn {
cn := &Conn{
rd: NewBufReader(netConn),
createdAt: time.Now(),
}
cn.SetNetConn(netConn)
... ... @@ -55,7 +52,17 @@ func (cn *Conn) RemoteAddr() net.Addr {
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
if cn.rd != nil {
cn.rd.Reset(netConn)
}
}
func (cn *Conn) LockReader() {
if cn.rd != nil {
panic("not reached")
}
cn.rd = NewReaderContext()
cn.rd.Reset(cn.netConn)
}
func (cn *Conn) NetConn() net.Conn {
... ... @@ -68,30 +75,44 @@ func (cn *Conn) NextID() string {
}
func (cn *Conn) WithReader(
ctx context.Context, timeout time.Duration, fn func(rd *BufReader) error,
ctx context.Context, timeout time.Duration, fn func(rd *ReaderContext) error,
) error {
return internal.WithSpan(ctx, "with_reader", func(ctx context.Context, span trace.Span) error {
err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout))
if err != nil {
return internal.WithSpan(ctx, "pg.with_reader", func(ctx context.Context, span trace.Span) error {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
span.RecordError(err)
return err
}
cn.rd.bytesRead = 0
err = fn(cn.rd)
span.SetAttributes(kv.Int64("net.read_bytes", cn.rd.bytesRead))
rd := cn.rd
if rd == nil {
rd = GetReaderContext()
defer PutReaderContext(rd)
rd.Reset(cn.netConn)
}
rd.bytesRead = 0
if err := fn(rd); err != nil {
span.RecordError(err)
return err
}
span.SetAttributes(label.Int64("net.read_bytes", rd.bytesRead))
return nil
})
}
func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wb *WriteBuffer) error,
) error {
return internal.WithSpan(ctx, "with_writer", func(ctx context.Context, span trace.Span) error {
return internal.WithSpan(ctx, "pg.with_writer", func(ctx context.Context, span trace.Span) error {
wb := GetWriteBuffer()
defer PutWriteBuffer(wb)
if err := fn(wb); err != nil {
span.RecordError(err)
return err
}
... ... @@ -100,7 +121,7 @@ func (cn *Conn) WithWriter(
}
func (cn *Conn) WriteBuffer(ctx context.Context, timeout time.Duration, wb *WriteBuffer) error {
return internal.WithSpan(ctx, "with_writer", func(ctx context.Context, span trace.Span) error {
return internal.WithSpan(ctx, "pg.with_writer", func(ctx context.Context, span trace.Span) error {
return cn.writeBuffer(ctx, span, timeout, wb)
})
}
... ... @@ -111,14 +132,19 @@ func (cn *Conn) writeBuffer(
timeout time.Duration,
wb *WriteBuffer,
) error {
err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout))
if err != nil {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
span.RecordError(err)
return err
}
span.SetAttributes(kv.Int("net.wrote_bytes", len(wb.Bytes)))
_, err = cn.netConn.Write(wb.Bytes)
span.SetAttributes(label.Int("net.wrote_bytes", len(wb.Bytes)))
if _, err := cn.netConn.Write(wb.Bytes); err != nil {
span.RecordError(err)
return err
}
return nil
}
func (cn *Conn) Close() error {
... ...
... ... @@ -11,8 +11,10 @@ import (
"github.com/go-pg/pg/v10/internal"
)
var ErrClosed = errors.New("pg: database is closed")
var ErrPoolTimeout = errors.New("pg: connection pool timeout")
var (
ErrClosed = errors.New("pg: database is closed")
ErrPoolTimeout = errors.New("pg: connection pool timeout")
)
var timers = sync.Pool{
New: func() interface{} {
... ... @@ -38,8 +40,8 @@ type Pooler interface {
CloseConn(*Conn) error
Get(context.Context) (*Conn, error)
Put(*Conn)
Remove(*Conn, error)
Put(context.Context, *Conn)
Remove(context.Context, *Conn, error)
Len() int
IdleLen() int
... ... @@ -216,12 +218,12 @@ func (p *ConnPool) getLastDialError() error {
}
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(c context.Context) (*Conn, error) {
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
err := p.waitTurn(c)
err := p.waitTurn(ctx)
if err != nil {
return nil, err
}
... ... @@ -246,7 +248,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) {
atomic.AddUint32(&p.stats.Misses, 1)
newcn, err := p.newConn(c, true)
newcn, err := p.newConn(ctx, true)
if err != nil {
p.freeTurn()
return nil, err
... ... @@ -312,15 +314,9 @@ func (p *ConnPool) popIdle() *Conn {
return cn
}
func (p *ConnPool) Put(cn *Conn) {
if cn.rd.Buffered() > 0 {
internal.Logger.Printf("Conn has unread data")
p.Remove(cn, BadConnError{})
return
}
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
if !cn.pooled {
p.Remove(cn, nil)
p.Remove(ctx, cn, nil)
return
}
... ... @@ -331,7 +327,7 @@ func (p *ConnPool) Put(cn *Conn) {
p.freeTurn()
}
func (p *ConnPool) Remove(cn *Conn, reason error) {
func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.removeConnWithLock(cn)
p.freeTurn()
_ = p.closeConn(cn)
... ... @@ -446,7 +442,7 @@ func (p *ConnPool) reaper(frequency time.Duration) {
}
n, err := p.ReapStaleConns()
if err != nil {
internal.Logger.Printf("ReapStaleConns failed: %s", err)
internal.Logger.Printf(context.TODO(), "ReapStaleConns failed: %s", err)
continue
}
atomic.AddUint32(&p.stats.StaleConns, uint32(n))
... ...
package pool
import (
"context"
"errors"
"fmt"
"sync/atomic"
)
const (
stateDefault = 0
stateInited = 1
stateClosed = 2
)
type BadConnError struct {
wrapped error
}
var _ error = (*BadConnError)(nil)
func (e BadConnError) Error() string {
s := "pg: Conn is in a bad state"
if e.wrapped != nil {
s += ": " + e.wrapped.Error()
}
return s
}
func (e BadConnError) Unwrap() error {
return e.wrapped
}
import "context"
type SingleConnPool struct {
pool Pooler
level int32 // atomic
state uint32 // atomic
ch chan *Conn
_badConnError atomic.Value
cn *Conn
stickyErr error
}
var _ Pooler = (*SingleConnPool)(nil)
func NewSingleConnPool(pool Pooler) *SingleConnPool {
p, ok := pool.(*SingleConnPool)
if !ok {
p = &SingleConnPool{
func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool {
return &SingleConnPool{
pool: pool,
ch: make(chan *Conn, 1),
}
}
atomic.AddInt32(&p.level, 1)
return p
}
func (p *SingleConnPool) SetConn(cn *Conn) {
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
p.ch <- cn
} else {
panic("not reached")
cn: cn,
}
}
func (p *SingleConnPool) NewConn(c context.Context) (*Conn, error) {
return p.pool.NewConn(c)
func (p *SingleConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.pool.NewConn(ctx)
}
func (p *SingleConnPool) CloseConn(cn *Conn) error {
return p.pool.CloseConn(cn)
}
func (p *SingleConnPool) Get(c context.Context) (*Conn, error) {
// In worst case this races with Close which is not a very common operation.
for i := 0; i < 1000; i++ {
switch atomic.LoadUint32(&p.state) {
case stateDefault:
cn, err := p.pool.Get(c)
if err != nil {
return nil, err
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
if p.stickyErr != nil {
return nil, p.stickyErr
}
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
return cn, nil
}
p.pool.Remove(cn, ErrClosed)
case stateInited:
if err := p.badConnError(); err != nil {
return nil, err
}
cn, ok := <-p.ch
if !ok {
return nil, ErrClosed
}
return cn, nil
case stateClosed:
return nil, ErrClosed
default:
panic("not reached")
}
}
return nil, fmt.Errorf("pg: SingleConnPool.Get: infinite loop")
return p.cn, nil
}
func (p *SingleConnPool) Put(cn *Conn) {
defer func() {
if recover() != nil {
p.freeConn(cn)
}
}()
p.ch <- cn
}
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {}
func (p *SingleConnPool) freeConn(cn *Conn) {
if err := p.badConnError(); err != nil {
p.pool.Remove(cn, err)
} else {
p.pool.Put(cn)
}
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.cn = nil
p.stickyErr = reason
}
func (p *SingleConnPool) Remove(cn *Conn, reason error) {
defer func() {
if recover() != nil {
p.pool.Remove(cn, ErrClosed)
}
}()
p._badConnError.Store(BadConnError{wrapped: reason})
p.ch <- cn
func (p *SingleConnPool) Close() error {
p.cn = nil
p.stickyErr = ErrClosed
return nil
}
func (p *SingleConnPool) Len() int {
switch atomic.LoadUint32(&p.state) {
case stateDefault:
return 0
case stateInited:
return 1
case stateClosed:
return 0
default:
panic("not reached")
}
}
func (p *SingleConnPool) IdleLen() int {
return len(p.ch)
return 0
}
func (p *SingleConnPool) Stats() *Stats {
return &Stats{}
}
func (p *SingleConnPool) Close() error {
level := atomic.AddInt32(&p.level, -1)
if level > 0 {
return nil
}
for i := 0; i < 1000; i++ {
state := atomic.LoadUint32(&p.state)
if state == stateClosed {
return ErrClosed
}
if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
close(p.ch)
cn, ok := <-p.ch
if ok {
p.freeConn(cn)
}
return nil
}
}
return errors.New("pg: SingleConnPool.Close: infinite loop")
}
func (p *SingleConnPool) Reset() error {
if p.badConnError() == nil {
return nil
}
select {
case cn, ok := <-p.ch:
if !ok {
return ErrClosed
}
p.pool.Remove(cn, ErrClosed)
p._badConnError.Store(BadConnError{wrapped: nil})
default:
return errors.New("pg: SingleConnPool does not have a Conn")
}
if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
state := atomic.LoadUint32(&p.state)
return fmt.Errorf("pg: invalid SingleConnPool state: %d", state)
}
return nil
}
func (p *SingleConnPool) badConnError() error {
if v := p._badConnError.Load(); v != nil {
err := v.(BadConnError)
if err.wrapped != nil {
return err
}
}
return nil
}
... ...
package pool
import (
"context"
"errors"
"fmt"
"sync/atomic"
)
const (
stateDefault = 0
stateInited = 1
stateClosed = 2
)
type BadConnError struct {
wrapped error
}
var _ error = (*BadConnError)(nil)
func (e BadConnError) Error() string {
s := "pg: Conn is in a bad state"
if e.wrapped != nil {
s += ": " + e.wrapped.Error()
}
return s
}
func (e BadConnError) Unwrap() error {
return e.wrapped
}
//------------------------------------------------------------------------------
type StickyConnPool struct {
pool Pooler
shared int32 // atomic
state uint32 // atomic
ch chan *Conn
_badConnError atomic.Value
}
var _ Pooler = (*StickyConnPool)(nil)
func NewStickyConnPool(pool Pooler) *StickyConnPool {
p, ok := pool.(*StickyConnPool)
if !ok {
p = &StickyConnPool{
pool: pool,
ch: make(chan *Conn, 1),
}
}
atomic.AddInt32(&p.shared, 1)
return p
}
func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.pool.NewConn(ctx)
}
func (p *StickyConnPool) CloseConn(cn *Conn) error {
return p.pool.CloseConn(cn)
}
func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
// In worst case this races with Close which is not a very common operation.
for i := 0; i < 1000; i++ {
switch atomic.LoadUint32(&p.state) {
case stateDefault:
cn, err := p.pool.Get(ctx)
if err != nil {
return nil, err
}
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
return cn, nil
}
p.pool.Remove(ctx, cn, ErrClosed)
case stateInited:
if err := p.badConnError(); err != nil {
return nil, err
}
cn, ok := <-p.ch
if !ok {
return nil, ErrClosed
}
return cn, nil
case stateClosed:
return nil, ErrClosed
default:
panic("not reached")
}
}
return nil, fmt.Errorf("pg: StickyConnPool.Get: infinite loop")
}
func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) {
defer func() {
if recover() != nil {
p.freeConn(ctx, cn)
}
}()
p.ch <- cn
}
func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) {
if err := p.badConnError(); err != nil {
p.pool.Remove(ctx, cn, err)
} else {
p.pool.Put(ctx, cn)
}
}
func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
defer func() {
if recover() != nil {
p.pool.Remove(ctx, cn, ErrClosed)
}
}()
p._badConnError.Store(BadConnError{wrapped: reason})
p.ch <- cn
}
func (p *StickyConnPool) Close() error {
if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
return nil
}
for i := 0; i < 1000; i++ {
state := atomic.LoadUint32(&p.state)
if state == stateClosed {
return ErrClosed
}
if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
close(p.ch)
cn, ok := <-p.ch
if ok {
p.freeConn(context.TODO(), cn)
}
return nil
}
}
return errors.New("pg: StickyConnPool.Close: infinite loop")
}
func (p *StickyConnPool) Reset(ctx context.Context) error {
if p.badConnError() == nil {
return nil
}
select {
case cn, ok := <-p.ch:
if !ok {
return ErrClosed
}
p.pool.Remove(ctx, cn, ErrClosed)
p._badConnError.Store(BadConnError{wrapped: nil})
default:
return errors.New("pg: StickyConnPool does not have a Conn")
}
if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
state := atomic.LoadUint32(&p.state)
return fmt.Errorf("pg: invalid StickyConnPool state: %d", state)
}
return nil
}
func (p *StickyConnPool) badConnError() error {
if v := p._badConnError.Load(); v != nil {
err := v.(BadConnError)
if err.wrapped != nil {
return err
}
}
return nil
}
func (p *StickyConnPool) Len() int {
switch atomic.LoadUint32(&p.state) {
case stateDefault:
return 0
case stateInited:
return 1
case stateClosed:
return 0
default:
panic("not reached")
}
}
func (p *StickyConnPool) IdleLen() int {
return len(p.ch)
}
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}
... ...
package pool
import (
"sync"
)
type Reader interface {
Buffered() int
... ... @@ -10,8 +14,67 @@ type Reader interface {
ReadSlice(byte) ([]byte, error)
Discard(int) (int, error)
//ReadBytes(fn func(byte) bool) ([]byte, error)
//ReadN(int) ([]byte, error)
// ReadBytes(fn func(byte) bool) ([]byte, error)
// ReadN(int) ([]byte, error)
ReadFull() ([]byte, error)
ReadFullTemp() ([]byte, error)
}
type ColumnInfo struct {
Index int16
DataType int32
Name string
}
type ColumnAlloc struct {
columns []ColumnInfo
}
func NewColumnAlloc() *ColumnAlloc {
return new(ColumnAlloc)
}
func (c *ColumnAlloc) Reset() {
c.columns = c.columns[:0]
}
func (c *ColumnAlloc) New(index int16, name []byte) *ColumnInfo {
c.columns = append(c.columns, ColumnInfo{
Index: index,
Name: string(name),
})
return &c.columns[len(c.columns)-1]
}
func (c *ColumnAlloc) Columns() []ColumnInfo {
return c.columns
}
type ReaderContext struct {
*BufReader
ColumnAlloc *ColumnAlloc
}
func NewReaderContext() *ReaderContext {
const bufSize = 1 << 20 // 1mb
return &ReaderContext{
BufReader: NewBufReader(bufSize),
ColumnAlloc: NewColumnAlloc(),
}
}
var readerPool = sync.Pool{
New: func() interface{} {
return NewReaderContext()
},
}
func GetReaderContext() *ReaderContext {
rd := readerPool.Get().(*ReaderContext)
return rd
}
func PutReaderContext(rd *ReaderContext) {
rd.ColumnAlloc.Reset()
readerPool.Put(rd)
}
... ...
... ... @@ -10,11 +10,7 @@ import (
"io"
)
const defaultBufSize = 65536
type BufReader struct {
Columns [][]byte
rd io.Reader // reader provided by the client
buf []byte
... ... @@ -24,25 +20,24 @@ type BufReader struct {
err error
available int // bytes available for reading
bytesRd BytesReader // reusable bytes reader
brd BytesReader // reusable bytes reader
}
func NewBufReader(rd io.Reader) *BufReader {
func NewBufReader(bufSize int) *BufReader {
return &BufReader{
rd: rd,
buf: make([]byte, defaultBufSize),
buf: make([]byte, bufSize),
available: -1,
}
}
func (b *BufReader) BytesReader(n int) *BytesReader {
if b.Buffered() < n {
return nil
if n == -1 {
n = 0
}
buf := b.buf[b.r : b.r+n]
b.r += n
b.bytesRd.Reset(buf)
return &b.bytesRd
b.brd.Reset(buf)
return &b.brd
}
func (b *BufReader) SetAvailable(n int) {
... ... @@ -67,11 +62,11 @@ func (b *BufReader) Reset(rd io.Reader) {
// Buffered returns the number of bytes that can be read from the current buffer.
func (b *BufReader) Buffered() int {
d := b.w - b.r
if b.available != -1 && d > b.available {
return b.available
buffered := b.w - b.r
if b.available == -1 || buffered <= b.available {
return buffered
}
return d
return b.available
}
func (b *BufReader) Bytes() []byte {
... ... @@ -122,7 +117,7 @@ func (b *BufReader) fill() {
// Read new data: try a limited number of times.
const maxConsecutiveEmptyReads = 100
for i := maxConsecutiveEmptyReads; i > 0; i-- {
n, err := b.readDirectly(b.buf[b.w:])
n, err := b.read(b.buf[b.w:])
b.w += n
if err != nil {
b.err = err
... ... @@ -163,7 +158,7 @@ func (b *BufReader) Read(p []byte) (n int, err error) {
if len(p) >= len(b.buf) {
// Large read, empty buffer.
// Read directly into p to avoid copy.
n, err = b.readDirectly(p)
n, err = b.read(p)
if n > 0 {
b.changeAvailable(-n)
b.lastByte = int(p[n-1])
... ... @@ -175,7 +170,7 @@ func (b *BufReader) Read(p []byte) (n int, err error) {
// Do not use b.fill, which will loop.
b.r = 0
b.w = 0
n, b.err = b.readDirectly(b.buf)
n, b.err = b.read(b.buf)
if n == 0 {
return 0, b.readErr()
}
... ... @@ -259,7 +254,7 @@ func (b *BufReader) ReadBytes(fn func(byte) bool) (line []byte, err error) {
// Pending error?
if b.err != nil {
line = b.flush() //nolint
line = b.flush()
err = b.readErr()
break
}
... ... @@ -429,7 +424,7 @@ func (b *BufReader) ReadFullTemp() ([]byte, error) {
return b.ReadFull()
}
func (b *BufReader) readDirectly(buf []byte) (int, error) {
func (b *BufReader) read(buf []byte) (int, error) {
n, err := b.rd.Read(buf)
b.bytesRead += int64(n)
return n, err
... ...
... ... @@ -6,20 +6,22 @@ import (
"sync"
)
var pool = sync.Pool{
const defaultBufSize = 65 << 10 // 65kb
var wbPool = sync.Pool{
New: func() interface{} {
return NewWriteBuffer()
},
}
func GetWriteBuffer() *WriteBuffer {
wb := pool.Get().(*WriteBuffer)
wb.Reset()
wb := wbPool.Get().(*WriteBuffer)
return wb
}
func PutWriteBuffer(wb *WriteBuffer) {
pool.Put(wb)
wb.Reset()
wbPool.Put(wb)
}
type WriteBuffer struct {
... ... @@ -39,10 +41,6 @@ func (buf *WriteBuffer) Reset() {
buf.Bytes = buf.Bytes[:0]
}
func (buf *WriteBuffer) ResetBuffer(b []byte) {
buf.Bytes = b[:0]
}
func (buf *WriteBuffer) StartMessage(c byte) {
if c == 0 {
buf.msgStart = len(buf.Bytes)
... ...
... ... @@ -5,12 +5,14 @@ import (
"reflect"
"time"
"go.opentelemetry.io/otel/api/global"
"go.opentelemetry.io/otel/api/trace"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
)
var tracer = otel.Tracer("github.com/go-pg/pg")
func Sleep(ctx context.Context, dur time.Duration) error {
return WithSpan(ctx, "sleep", func(ctx context.Context, span trace.Span) error {
return WithSpan(ctx, "time.Sleep", func(ctx context.Context, span trace.Span) error {
t := time.NewTimer(dur)
defer t.Stop()
... ... @@ -80,11 +82,11 @@ func WithSpan(
name string,
fn func(context.Context, trace.Span) error,
) error {
if !trace.SpanFromContext(ctx).IsRecording() {
return fn(ctx, trace.NoopSpan{})
if span := trace.SpanFromContext(ctx); !span.IsRecording() {
return fn(ctx, span)
}
ctx, span := global.Tracer("go-pg").Start(ctx, name)
ctx, span := tracer.Start(ctx, name)
defer span.End()
return fn(ctx, span)
... ...
... ... @@ -15,8 +15,10 @@ import (
const gopgChannel = "gopg:ping"
var errListenerClosed = errors.New("pg: listener is closed")
var errPingTimeout = errors.New("pg: ping timeout")
var (
errListenerClosed = errors.New("pg: listener is closed")
errPingTimeout = errors.New("pg: ping timeout")
)
// Notification which is received with LISTEN command.
type Notification struct {
... ... @@ -38,11 +40,14 @@ type Listener struct {
closed bool
chOnce sync.Once
ch chan *Notification
ch chan Notification
pingCh chan struct{}
}
func (ln *Listener) String() string {
ln.mu.Lock()
defer ln.mu.Unlock()
return fmt.Sprintf("Listener(%s)", strings.Join(ln.channels, ", "))
}
... ... @@ -50,9 +55,9 @@ func (ln *Listener) init() {
ln.exit = make(chan struct{})
}
func (ln *Listener) connWithLock() (*pool.Conn, error) {
func (ln *Listener) connWithLock(ctx context.Context) (*pool.Conn, error) {
ln.mu.Lock()
cn, err := ln.conn()
cn, err := ln.conn(ctx)
ln.mu.Unlock()
switch err {
... ... @@ -64,12 +69,12 @@ func (ln *Listener) connWithLock() (*pool.Conn, error) {
_ = ln.Close()
return nil, errListenerClosed
default:
internal.Logger.Printf("pg: Listen failed: %s", err)
internal.Logger.Printf(ctx, "pg: Listen failed: %s", err)
return nil, err
}
}
func (ln *Listener) conn() (*pool.Conn, error) {
func (ln *Listener) conn(ctx context.Context) (*pool.Conn, error) {
if ln.closed {
return nil, errListenerClosed
}
... ... @@ -78,21 +83,20 @@ func (ln *Listener) conn() (*pool.Conn, error) {
return ln.cn, nil
}
c := context.TODO()
cn, err := ln.db.pool.NewConn(c)
cn, err := ln.db.pool.NewConn(ctx)
if err != nil {
return nil, err
}
err = ln.db.initConn(c, cn)
if err != nil {
if err := ln.db.initConn(ctx, cn); err != nil {
_ = ln.db.pool.CloseConn(cn)
return nil, err
}
cn.LockReader()
if len(ln.channels) > 0 {
err := ln.listen(c, cn, ln.channels...)
err := ln.listen(ctx, cn, ln.channels...)
if err != nil {
_ = ln.db.pool.CloseConn(cn)
return nil, err
... ... @@ -103,19 +107,19 @@ func (ln *Listener) conn() (*pool.Conn, error) {
return cn, nil
}
func (ln *Listener) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
func (ln *Listener) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
ln.mu.Lock()
if ln.cn == cn {
if isBadConn(err, allowTimeout) {
ln.reconnect(err)
ln.reconnect(ctx, err)
}
}
ln.mu.Unlock()
}
func (ln *Listener) reconnect(reason error) {
func (ln *Listener) reconnect(ctx context.Context, reason error) {
_ = ln.closeTheCn(reason)
_, _ = ln.conn()
_, _ = ln.conn(ctx)
}
func (ln *Listener) closeTheCn(reason error) error {
... ... @@ -123,7 +127,7 @@ func (ln *Listener) closeTheCn(reason error) error {
return nil
}
if !ln.closed {
internal.Logger.Printf("pg: discarding bad listener connection: %s", reason)
internal.Logger.Printf(ln.db.ctx, "pg: discarding bad listener connection: %s", reason)
}
err := ln.db.pool.CloseConn(ln.cn)
... ... @@ -146,31 +150,62 @@ func (ln *Listener) Close() error {
}
// Listen starts listening for notifications on channels.
func (ln *Listener) Listen(channels ...string) error {
func (ln *Listener) Listen(ctx context.Context, channels ...string) error {
// Always append channels so DB.Listen works correctly.
ln.mu.Lock()
ln.channels = appendIfNotExists(ln.channels, channels...)
ln.mu.Unlock()
cn, err := ln.connWithLock()
cn, err := ln.connWithLock(ctx)
if err != nil {
return err
}
err = ln.listen(context.TODO(), cn, channels...)
if err != nil {
ln.releaseConn(cn, err, false)
if err := ln.listen(ctx, cn, channels...); err != nil {
ln.releaseConn(ctx, cn, err, false)
return err
}
return nil
}
func (ln *Listener) listen(c context.Context, cn *pool.Conn, channels ...string) error {
err := cn.WithWriter(c, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
func (ln *Listener) listen(ctx context.Context, cn *pool.Conn, channels ...string) error {
err := cn.WithWriter(ctx, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
for _, channel := range channels {
err := writeQueryMsg(wb, ln.db.fmter, "LISTEN ?", pgChan(channel))
if err := writeQueryMsg(wb, ln.db.fmter, "LISTEN ?", pgChan(channel)); err != nil {
return err
}
}
return nil
})
return err
}
// Unlisten stops listening for notifications on channels.
func (ln *Listener) Unlisten(ctx context.Context, channels ...string) error {
ln.mu.Lock()
ln.channels = removeIfExists(ln.channels, channels...)
ln.mu.Unlock()
cn, err := ln.conn(ctx)
if err != nil {
return err
}
if err := ln.unlisten(ctx, cn, channels...); err != nil {
ln.releaseConn(ctx, cn, err, false)
return err
}
return nil
}
func (ln *Listener) unlisten(ctx context.Context, cn *pool.Conn, channels ...string) error {
err := cn.WithWriter(ctx, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
for _, channel := range channels {
if err := writeQueryMsg(wb, ln.db.fmter, "UNLISTEN ?", pgChan(channel)); err != nil {
return err
}
}
return nil
})
... ... @@ -179,24 +214,26 @@ func (ln *Listener) listen(c context.Context, cn *pool.Conn, channels ...string)
// Receive indefinitely waits for a notification. This is low-level API
// and in most cases Channel should be used instead.
func (ln *Listener) Receive() (channel string, payload string, err error) {
return ln.ReceiveTimeout(0)
func (ln *Listener) Receive(ctx context.Context) (channel string, payload string, err error) {
return ln.ReceiveTimeout(ctx, 0)
}
// ReceiveTimeout waits for a notification until timeout is reached.
// This is low-level API and in most cases Channel should be used instead.
func (ln *Listener) ReceiveTimeout(timeout time.Duration) (channel, payload string, err error) {
cn, err := ln.connWithLock()
func (ln *Listener) ReceiveTimeout(
ctx context.Context, timeout time.Duration,
) (channel, payload string, err error) {
cn, err := ln.connWithLock(ctx)
if err != nil {
return "", "", err
}
err = cn.WithReader(context.TODO(), timeout, func(rd *pool.BufReader) error {
err = cn.WithReader(ctx, timeout, func(rd *pool.ReaderContext) error {
channel, payload, err = readNotification(rd)
return err
})
if err != nil {
ln.releaseConn(cn, err, timeout > 0)
ln.releaseConn(ctx, cn, err, timeout > 0)
return "", "", err
}
... ... @@ -208,17 +245,17 @@ func (ln *Listener) ReceiveTimeout(timeout time.Duration) (channel, payload stri
//
// The channel is closed with Listener. Receive* APIs can not be used
// after channel is created.
func (ln *Listener) Channel() <-chan *Notification {
func (ln *Listener) Channel() <-chan Notification {
return ln.channel(100)
}
// ChannelSize is like Channel, but creates a Go channel
// with specified buffer size.
func (ln *Listener) ChannelSize(size int) <-chan *Notification {
func (ln *Listener) ChannelSize(size int) <-chan Notification {
return ln.channel(size)
}
func (ln *Listener) channel(size int) <-chan *Notification {
func (ln *Listener) channel(size int) <-chan Notification {
ln.chOnce.Do(func() {
ln.initChannel(size)
})
... ... @@ -230,29 +267,33 @@ func (ln *Listener) channel(size int) <-chan *Notification {
}
func (ln *Listener) initChannel(size int) {
const timeout = 30 * time.Second
const pingTimeout = time.Second
const chanSendTimeout = time.Minute
_ = ln.Listen(gopgChannel)
ctx := ln.db.ctx
_ = ln.Listen(ctx, gopgChannel)
ln.ch = make(chan *Notification, size)
ln.ch = make(chan Notification, size)
ln.pingCh = make(chan struct{}, 1)
go func() {
timer := time.NewTimer(timeout)
timer := time.NewTimer(time.Minute)
timer.Stop()
var errCount int
for {
channel, payload, err := ln.Receive()
channel, payload, err := ln.Receive(ctx)
if err != nil {
if err == errListenerClosed {
close(ln.ch)
return
}
if errCount > 0 {
time.Sleep(ln.db.retryBackoff(errCount))
time.Sleep(500 * time.Millisecond)
}
errCount++
continue
}
... ... @@ -268,28 +309,31 @@ func (ln *Listener) initChannel(size int) {
case gopgChannel:
// ignore
default:
timer.Reset(timeout)
timer.Reset(chanSendTimeout)
select {
case ln.ch <- &Notification{channel, payload}:
case ln.ch <- Notification{channel, payload}:
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
internal.Logger.Printf(
ctx,
"pg: %s channel is full for %s (notification is dropped)",
ln, timeout)
ln,
chanSendTimeout,
)
}
}
}
}()
go func() {
timer := time.NewTimer(timeout)
timer := time.NewTimer(time.Minute)
timer.Stop()
healthy := true
for {
timer.Reset(timeout)
timer.Reset(pingTimeout)
select {
case <-ln.pingCh:
healthy = true
... ... @@ -305,7 +349,7 @@ func (ln *Listener) initChannel(size int) {
pingErr = errPingTimeout
}
ln.mu.Lock()
ln.reconnect(pingErr)
ln.reconnect(ctx, pingErr)
ln.mu.Unlock()
}
case <-ln.exit:
... ... @@ -333,6 +377,20 @@ loop:
return ss
}
func removeIfExists(ss []string, es ...string) []string {
for _, e := range es {
for i, s := range ss {
if s == e {
last := len(ss) - 1
ss[i] = ss[last]
ss = ss[:last]
break
}
}
}
return ss
}
type pgChan string
var _ types.ValueAppender = pgChan("")
... ...
... ... @@ -19,6 +19,7 @@ import (
"github.com/go-pg/pg/v10/types"
)
// https://www.postgresql.org/docs/current/protocol-message-formats.html
const (
commandCompleteMsg = 'C'
errorResponseMsg = 'E'
... ... @@ -84,7 +85,7 @@ func (db *baseDB) startup(
return err
}
return cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
return cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
for {
typ, msgLen, err := readMessageType(rd)
if err != nil {
... ... @@ -137,7 +138,7 @@ func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Confi
return err
}
err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error {
err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
c, err := rd.ReadByte()
if err != nil {
return err
... ... @@ -156,7 +157,7 @@ func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Confi
}
func (db *baseDB) auth(
c context.Context, cn *pool.Conn, rd *pool.BufReader, user, password string,
c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string,
) error {
num, err := readInt32(rd)
if err != nil {
... ... @@ -178,7 +179,7 @@ func (db *baseDB) auth(
}
func (db *baseDB) authCleartext(
c context.Context, cn *pool.Conn, rd *pool.BufReader, password string,
c context.Context, cn *pool.Conn, rd *pool.ReaderContext, password string,
) error {
err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
writePasswordMsg(wb, password)
... ... @@ -191,7 +192,7 @@ func (db *baseDB) authCleartext(
}
func (db *baseDB) authMD5(
c context.Context, cn *pool.Conn, rd *pool.BufReader, user, password string,
c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string,
) error {
b, err := rd.ReadN(4)
if err != nil {
... ... @@ -210,7 +211,7 @@ func (db *baseDB) authMD5(
return readAuthOK(rd)
}
func readAuthOK(rd *pool.BufReader) error {
func readAuthOK(rd *pool.ReaderContext) error {
c, _, err := readMessageType(rd)
if err != nil {
return err
... ... @@ -238,7 +239,7 @@ func readAuthOK(rd *pool.BufReader) error {
}
func (db *baseDB) authSASL(
c context.Context, cn *pool.Conn, rd *pool.BufReader, user, password string,
c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string,
) error {
s, err := readString(rd)
if err != nil {
... ... @@ -332,7 +333,7 @@ func (db *baseDB) authSASL(
}
}
func readAuthSASLFinal(rd *pool.BufReader, client *sasl.Negotiator) error {
func readAuthSASLFinal(rd *pool.ReaderContext, client *sasl.Negotiator) error {
c, n, err := readMessageType(rd)
if err != nil {
return err
... ... @@ -485,8 +486,8 @@ func writeParseDescribeSyncMsg(buf *pool.WriteBuffer, name, q string) {
writeSyncMsg(buf)
}
func readParseDescribeSync(rd *pool.BufReader) ([][]byte, error) {
var columns [][]byte
func readParseDescribeSync(rd *pool.ReaderContext) ([]types.ColumnInfo, error) {
var columns []types.ColumnInfo
var firstErr error
for {
c, msgLen, err := readMessageType(rd)
... ... @@ -500,7 +501,7 @@ func readParseDescribeSync(rd *pool.BufReader) ([][]byte, error) {
return nil, err
}
case rowDescriptionMsg: // Response to the DESCRIBE message.
columns, err = readRowDescription(rd, nil)
columns, err = readRowDescription(rd, pool.NewColumnAlloc())
if err != nil {
return nil, err
}
... ... @@ -582,7 +583,7 @@ func writeCloseMsg(buf *pool.WriteBuffer, name string) {
buf.FinishMessage()
}
func readCloseCompleteMsg(rd *pool.BufReader) error {
func readCloseCompleteMsg(rd *pool.ReaderContext) error {
for {
c, msgLen, err := readMessageType(rd)
if err != nil {
... ... @@ -612,7 +613,7 @@ func readCloseCompleteMsg(rd *pool.BufReader) error {
}
}
func readSimpleQuery(rd *pool.BufReader) (*result, error) {
func readSimpleQuery(rd *pool.ReaderContext) (*result, error) {
var res result
var firstErr error
for {
... ... @@ -675,7 +676,7 @@ func readSimpleQuery(rd *pool.BufReader) (*result, error) {
}
}
func readExtQuery(rd *pool.BufReader) (*result, error) {
func readExtQuery(rd *pool.ReaderContext) (*result, error) {
var res result
var firstErr error
for {
... ... @@ -739,42 +740,47 @@ func readExtQuery(rd *pool.BufReader) (*result, error) {
}
}
func readRowDescription(rd *pool.BufReader, columns [][]byte) ([][]byte, error) {
colNum, err := readInt16(rd)
func readRowDescription(
rd *pool.ReaderContext, columnAlloc *pool.ColumnAlloc,
) ([]types.ColumnInfo, error) {
numCol, err := readInt16(rd)
if err != nil {
return nil, err
}
columns = setByteSliceLen(columns, int(colNum))
for i := 0; i < int(colNum); i++ {
for i := 0; i < int(numCol); i++ {
b, err := rd.ReadSlice(0)
if err != nil {
return nil, err
}
columns[i] = append(columns[i][:0], b[:len(b)-1]...)
_, err = rd.ReadN(18)
if err != nil {
col := columnAlloc.New(int16(i), b[:len(b)-1])
if _, err := rd.ReadN(6); err != nil {
return nil, err
}
}
return columns, nil
}
dataType, err := readInt32(rd)
if err != nil {
return nil, err
}
col.DataType = dataType
func setByteSliceLen(b [][]byte, n int) [][]byte {
if n <= cap(b) {
return b[:n]
if _, err := rd.ReadN(8); err != nil {
return nil, err
}
}
b = b[:cap(b)]
b = append(b, make([][]byte, n-cap(b))...)
return b
return columnAlloc.Columns(), nil
}
func readDataRow(
ctx context.Context, rd *pool.BufReader, scanner orm.ColumnScanner, columns [][]byte,
ctx context.Context,
rd *pool.ReaderContext,
columns []types.ColumnInfo,
scanner orm.ColumnScanner,
) error {
colNum, err := readInt16(rd)
numCol, err := readInt16(rd)
if err != nil {
return err
}
... ... @@ -787,35 +793,28 @@ func readDataRow(
var firstErr error
for colIdx := int16(0); colIdx < colNum; colIdx++ {
for colIdx := int16(0); colIdx < numCol; colIdx++ {
n, err := readInt32(rd)
if err != nil {
return err
}
column := internal.BytesToString(columns[colIdx])
var colRd types.Reader
if n >= 0 {
bytesRd := rd.BytesReader(int(n))
if bytesRd != nil {
colRd = bytesRd
if int(n) <= rd.Buffered() {
colRd = rd.BytesReader(int(n))
} else {
rd.SetAvailable(int(n))
colRd = rd
}
} else {
colRd = rd.BytesReader(0)
}
err = scanner.ScanColumn(int(colIdx), column, colRd, int(n))
if err != nil && firstErr == nil {
column := columns[colIdx]
if err := scanner.ScanColumn(column, colRd, int(n)); err != nil && firstErr == nil {
firstErr = internal.Errorf(err.Error())
}
if rd == colRd {
if rd.Available() > 0 {
_, err = rd.Discard(rd.Available())
if err != nil && firstErr == nil {
if _, err := rd.Discard(rd.Available()); err != nil && firstErr == nil {
firstErr = err
}
}
... ... @@ -841,8 +840,9 @@ func newModel(mod interface{}) (orm.Model, error) {
}
func readSimpleQueryData(
ctx context.Context, rd *pool.BufReader, mod interface{},
ctx context.Context, rd *pool.ReaderContext, mod interface{},
) (*result, error) {
var columns []types.ColumnInfo
var res result
var firstErr error
for {
... ... @@ -853,7 +853,7 @@ func readSimpleQueryData(
switch c {
case rowDescriptionMsg:
rd.Columns, err = readRowDescription(rd, rd.Columns[:0])
columns, err = readRowDescription(rd, rd.ColumnAlloc)
if err != nil {
return nil, err
}
... ... @@ -870,7 +870,7 @@ func readSimpleQueryData(
}
case dataRowMsg:
scanner := res.model.NextColumnScanner()
if err := readDataRow(ctx, rd, scanner, rd.Columns); err != nil {
if err := readDataRow(ctx, rd, columns, scanner); err != nil {
if firstErr == nil {
firstErr = err
}
... ... @@ -925,7 +925,7 @@ func readSimpleQueryData(
}
func readExtQueryData(
ctx context.Context, rd *pool.BufReader, mod interface{}, columns [][]byte,
ctx context.Context, rd *pool.ReaderContext, mod interface{}, columns []types.ColumnInfo,
) (*result, error) {
var res result
var firstErr error
... ... @@ -954,7 +954,7 @@ func readExtQueryData(
}
scanner := res.model.NextColumnScanner()
if err := readDataRow(ctx, rd, scanner, columns); err != nil {
if err := readDataRow(ctx, rd, columns, scanner); err != nil {
if firstErr == nil {
firstErr = err
}
... ... @@ -1004,7 +1004,7 @@ func readExtQueryData(
}
}
func readCopyInResponse(rd *pool.BufReader) error {
func readCopyInResponse(rd *pool.ReaderContext) error {
var firstErr error
for {
c, msgLen, err := readMessageType(rd)
... ... @@ -1044,7 +1044,7 @@ func readCopyInResponse(rd *pool.BufReader) error {
}
}
func readCopyOutResponse(rd *pool.BufReader) error {
func readCopyOutResponse(rd *pool.ReaderContext) error {
var firstErr error
for {
c, msgLen, err := readMessageType(rd)
... ... @@ -1084,7 +1084,7 @@ func readCopyOutResponse(rd *pool.BufReader) error {
}
}
func readCopyData(rd *pool.BufReader, w io.Writer) (*result, error) {
func readCopyData(rd *pool.ReaderContext, w io.Writer) (*result, error) {
var res result
var firstErr error
for {
... ... @@ -1162,7 +1162,7 @@ func writeCopyDone(buf *pool.WriteBuffer) {
buf.FinishMessage()
}
func readReadyForQuery(rd *pool.BufReader) (*result, error) {
func readReadyForQuery(rd *pool.ReaderContext) (*result, error) {
var res result
var firstErr error
for {
... ... @@ -1211,7 +1211,7 @@ func readReadyForQuery(rd *pool.BufReader) (*result, error) {
}
}
func readNotification(rd *pool.BufReader) (channel, payload string, err error) {
func readNotification(rd *pool.ReaderContext) (channel, payload string, err error) {
for {
c, msgLen, err := readMessageType(rd)
if err != nil {
... ... @@ -1269,17 +1269,17 @@ func terminateConn(cn *pool.Conn) error {
//------------------------------------------------------------------------------
func logNotice(rd *pool.BufReader, msgLen int) error {
func logNotice(rd *pool.ReaderContext, msgLen int) error {
_, err := rd.ReadN(msgLen)
return err
}
func logParameterStatus(rd *pool.BufReader, msgLen int) error {
func logParameterStatus(rd *pool.ReaderContext, msgLen int) error {
_, err := rd.ReadN(msgLen)
return err
}
func readInt16(rd *pool.BufReader) (int16, error) {
func readInt16(rd *pool.ReaderContext) (int16, error) {
b, err := rd.ReadN(2)
if err != nil {
return 0, err
... ... @@ -1287,7 +1287,7 @@ func readInt16(rd *pool.BufReader) (int16, error) {
return int16(binary.BigEndian.Uint16(b)), nil
}
func readInt32(rd *pool.BufReader) (int32, error) {
func readInt32(rd *pool.ReaderContext) (int32, error) {
b, err := rd.ReadN(4)
if err != nil {
return 0, err
... ... @@ -1295,7 +1295,7 @@ func readInt32(rd *pool.BufReader) (int32, error) {
return int32(binary.BigEndian.Uint32(b)), nil
}
func readString(rd *pool.BufReader) (string, error) {
func readString(rd *pool.ReaderContext) (string, error) {
b, err := rd.ReadSlice(0)
if err != nil {
return "", err
... ... @@ -1303,7 +1303,7 @@ func readString(rd *pool.BufReader) (string, error) {
return string(b[:len(b)-1]), nil
}
func readError(rd *pool.BufReader) (error, error) {
func readError(rd *pool.ReaderContext) (error, error) {
m := make(map[byte]string)
for {
c, err := rd.ReadByte()
... ... @@ -1322,7 +1322,7 @@ func readError(rd *pool.BufReader) (error, error) {
return internal.NewPGError(m), nil
}
func readMessageType(rd *pool.BufReader) (byte, int, error) {
func readMessageType(rd *pool.ReaderContext) (byte, int, error) {
c, err := rd.ReadByte()
if err != nil {
return 0, 0, err
... ...
... ... @@ -13,7 +13,8 @@ import (
"strings"
"time"
"go.opentelemetry.io/otel/api/trace"
"go.opentelemetry.io/otel/label"
"go.opentelemetry.io/otel/trace"
"github.com/go-pg/pg/v10/internal"
"github.com/go-pg/pg/v10/internal/pool"
... ... @@ -31,6 +32,10 @@ type Options struct {
// Network and Addr options.
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Hook that is called after new connection is established
// and user is authenticated.
OnConnect func(ctx context.Context, cn *Conn) error
User string
Password string
Database string
... ... @@ -53,10 +58,6 @@ type Options struct {
// with a timeout instead of blocking.
WriteTimeout time.Duration
// Hook that is called after new connection is established
// and user is authenticated.
OnConnect func(*Conn) error
// Maximum number of retries before giving up.
// Default is to not retry failed queries.
MaxRetries int
... ... @@ -110,6 +111,10 @@ func (opt *Options) init() {
opt.Addr = "/var/run/postgresql/.s.PGSQL.5432"
}
}
if opt.DialTimeout == 0 {
opt.DialTimeout = 5 * time.Second
}
if opt.Dialer == nil {
opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{
... ... @@ -140,10 +145,6 @@ func (opt *Options) init() {
}
}
if opt.DialTimeout == 0 {
opt.DialTimeout = 5 * time.Second
}
if opt.IdleTimeout == 0 {
opt.IdleTimeout = 5 * time.Minute
}
... ... @@ -262,9 +263,16 @@ func ParseURL(sURL string) (*Options, error) {
func (opt *Options) getDialer() func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
var conn net.Conn
err := internal.WithSpan(ctx, "dialer", func(ctx context.Context, span trace.Span) error {
err := internal.WithSpan(ctx, "pg.dial", func(ctx context.Context, span trace.Span) error {
span.SetAttributes(
label.String("db.connection_string", opt.Addr),
)
var err error
conn, err = opt.Dialer(ctx, opt.Network, opt.Addr)
if err != nil {
span.RecordError(err)
}
return err
})
return conn, err
... ...
... ... @@ -8,50 +8,62 @@ type CreateCompositeOptions struct {
Varchar int // replaces PostgreSQL data type `text` with `varchar(n)`
}
func CreateComposite(db DB, model interface{}, opt *CreateCompositeOptions) error {
q := NewQuery(db, model)
_, err := q.db.Exec(&createCompositeQuery{
type CreateCompositeQuery struct {
q *Query
opt *CreateCompositeOptions
}
var (
_ QueryAppender = (*CreateCompositeQuery)(nil)
_ QueryCommand = (*CreateCompositeQuery)(nil)
)
func NewCreateCompositeQuery(q *Query, opt *CreateCompositeOptions) *CreateCompositeQuery {
return &CreateCompositeQuery{
q: q,
opt: opt,
})
return err
}
}
type createCompositeQuery struct {
q *Query
opt *CreateCompositeOptions
func (q *CreateCompositeQuery) String() string {
b, err := q.AppendQuery(defaultFmter, nil)
if err != nil {
panic(err)
}
return string(b)
}
var _ QueryAppender = (*createCompositeQuery)(nil)
var _ queryCommand = (*createCompositeQuery)(nil)
func (q *CreateCompositeQuery) Operation() QueryOp {
return CreateCompositeOp
}
func (q *createCompositeQuery) Clone() queryCommand {
return &createCompositeQuery{
func (q *CreateCompositeQuery) Clone() QueryCommand {
return &CreateCompositeQuery{
q: q.q.Clone(),
opt: q.opt,
}
}
func (q *createCompositeQuery) Query() *Query {
func (q *CreateCompositeQuery) Query() *Query {
return q.q
}
func (q *createCompositeQuery) AppendTemplate(b []byte) ([]byte, error) {
func (q *CreateCompositeQuery) AppendTemplate(b []byte) ([]byte, error) {
return q.AppendQuery(dummyFormatter{}, b)
}
func (q *createCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) {
func (q *CreateCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) {
if q.q.stickyErr != nil {
return nil, q.q.stickyErr
}
if q.q.model == nil {
if q.q.tableModel == nil {
return nil, errModelNil
}
table := q.q.model.Table()
table := q.q.tableModel.Table()
b = append(b, "CREATE TYPE "...)
b = append(b, q.q.model.Table().Alias...)
b = append(b, table.Alias...)
b = append(b, " AS ("...)
for i, field := range table.Fields {
... ...
... ... @@ -5,43 +5,55 @@ type DropCompositeOptions struct {
Cascade bool
}
func DropComposite(db DB, model interface{}, opt *DropCompositeOptions) error {
q := NewQuery(db, model)
_, err := q.db.Exec(&dropCompositeQuery{
type DropCompositeQuery struct {
q *Query
opt *DropCompositeOptions
}
var (
_ QueryAppender = (*DropCompositeQuery)(nil)
_ QueryCommand = (*DropCompositeQuery)(nil)
)
func NewDropCompositeQuery(q *Query, opt *DropCompositeOptions) *DropCompositeQuery {
return &DropCompositeQuery{
q: q,
opt: opt,
})
return err
}
}
type dropCompositeQuery struct {
q *Query
opt *DropCompositeOptions
func (q *DropCompositeQuery) String() string {
b, err := q.AppendQuery(defaultFmter, nil)
if err != nil {
panic(err)
}
return string(b)
}
var _ QueryAppender = (*dropCompositeQuery)(nil)
var _ queryCommand = (*dropCompositeQuery)(nil)
func (q *DropCompositeQuery) Operation() QueryOp {
return DropCompositeOp
}
func (q *dropCompositeQuery) Clone() queryCommand {
return &dropCompositeQuery{
func (q *DropCompositeQuery) Clone() QueryCommand {
return &DropCompositeQuery{
q: q.q.Clone(),
opt: q.opt,
}
}
func (q *dropCompositeQuery) Query() *Query {
func (q *DropCompositeQuery) Query() *Query {
return q.q
}
func (q *dropCompositeQuery) AppendTemplate(b []byte) ([]byte, error) {
func (q *DropCompositeQuery) AppendTemplate(b []byte) ([]byte, error) {
return q.AppendQuery(dummyFormatter{}, b)
}
func (q *dropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) {
func (q *DropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) {
if q.q.stickyErr != nil {
return nil, q.q.stickyErr
}
if q.q.model == nil {
if q.q.tableModel == nil {
return nil, errModelNil
}
... ... @@ -49,7 +61,7 @@ func (q *dropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte
if q.opt != nil && q.opt.IfExists {
b = append(b, "IF EXISTS "...)
}
b = append(b, q.q.model.Table().Alias...)
b = append(b, q.q.tableModel.Table().Alias...)
if q.opt != nil && q.opt.Cascade {
b = append(b, " CASCADE"...)
}
... ...
package orm
import (
"github.com/go-pg/pg/v10/internal"
)
// Delete deletes a given model from the db
func Delete(db DB, model interface{}) error {
res, err := NewQuery(db, model).WherePK().Delete()
if err != nil {
return err
}
return internal.AssertOneRow(res.RowsAffected())
}
"reflect"
// ForceDelete force deletes a given model from the db
func ForceDelete(db DB, model interface{}) error {
res, err := NewQuery(db, model).WherePK().ForceDelete()
if err != nil {
return err
}
return internal.AssertOneRow(res.RowsAffected())
}
"github.com/go-pg/pg/v10/types"
)
type deleteQuery struct {
type DeleteQuery struct {
q *Query
placeholder bool
}
var _ QueryAppender = (*deleteQuery)(nil)
var _ queryCommand = (*deleteQuery)(nil)
var (
_ QueryAppender = (*DeleteQuery)(nil)
_ QueryCommand = (*DeleteQuery)(nil)
)
func newDeleteQuery(q *Query) *deleteQuery {
return &deleteQuery{
func NewDeleteQuery(q *Query) *DeleteQuery {
return &DeleteQuery{
q: q,
}
}
func (q *deleteQuery) Operation() string {
func (q *DeleteQuery) String() string {
b, err := q.AppendQuery(defaultFmter, nil)
if err != nil {
panic(err)
}
return string(b)
}
func (q *DeleteQuery) Operation() QueryOp {
return DeleteOp
}
func (q *deleteQuery) Clone() queryCommand {
return &deleteQuery{
func (q *DeleteQuery) Clone() QueryCommand {
return &DeleteQuery{
q: q.q.Clone(),
placeholder: q.placeholder,
}
}
func (q *deleteQuery) Query() *Query {
func (q *DeleteQuery) Query() *Query {
return q.q
}
func (q *deleteQuery) AppendTemplate(b []byte) ([]byte, error) {
cp := q.Clone().(*deleteQuery)
func (q *DeleteQuery) AppendTemplate(b []byte) ([]byte, error) {
cp := q.Clone().(*DeleteQuery)
cp.placeholder = true
return cp.AppendQuery(dummyFormatter{}, b)
}
func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) {
func (q *DeleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) {
if q.q.stickyErr != nil {
return nil, q.q.stickyErr
}
... ... @@ -84,7 +78,8 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
}
b = append(b, " WHERE "...)
value := q.q.model.Value()
value := q.q.tableModel.Value()
if q.q.isSliceModelWithData() {
if len(q.q.where) > 0 {
b, err = q.q.appendWhere(fmter, b)
... ... @@ -92,7 +87,7 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
return nil, err
}
} else {
table := q.q.model.Table()
table := q.q.tableModel.Table()
err = table.checkPKs()
if err != nil {
return nil, err
... ... @@ -116,3 +111,48 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
return b, q.q.stickyErr
}
func appendColumnAndSliceValue(
fmter QueryFormatter, b []byte, slice reflect.Value, alias types.Safe, fields []*Field,
) []byte {
if len(fields) > 1 {
b = append(b, '(')
}
b = appendColumns(b, alias, fields)
if len(fields) > 1 {
b = append(b, ')')
}
b = append(b, " IN ("...)
isPlaceholder := isPlaceholderFormatter(fmter)
sliceLen := slice.Len()
for i := 0; i < sliceLen; i++ {
if i > 0 {
b = append(b, ", "...)
}
el := indirect(slice.Index(i))
if len(fields) > 1 {
b = append(b, '(')
}
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
if isPlaceholder {
b = append(b, '?')
} else {
b = f.AppendValue(b, el, 1)
}
}
if len(fields) > 1 {
b = append(b, ')')
}
}
b = append(b, ')')
return b
}
... ...
... ... @@ -70,10 +70,10 @@ func (f *Field) Value(strct reflect.Value) reflect.Value {
}
func (f *Field) HasZeroValue(strct reflect.Value) bool {
return f.hasZeroField(strct, f.Index)
return f.hasZeroValue(strct, f.Index)
}
func (f *Field) hasZeroField(v reflect.Value, index []int) bool {
func (f *Field) hasZeroValue(v reflect.Value, index []int) bool {
for _, idx := range index {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
... ... @@ -106,10 +106,21 @@ func (f *Field) AppendValue(b []byte, strct reflect.Value, quote int) []byte {
}
func (f *Field) ScanValue(strct reflect.Value, rd types.Reader, n int) error {
fv := f.Value(strct)
if f.scan == nil {
return fmt.Errorf("pg: ScanValue(unsupported %s)", fv.Type())
return fmt.Errorf("pg: ScanValue(unsupported %s)", f.Type)
}
var fv reflect.Value
if n == -1 {
var ok bool
fv, ok = fieldByIndex(strct, f.Index)
if !ok {
return nil
}
} else {
fv = fieldByIndexAlloc(strct, f.Index)
}
return f.scan(fv, rd, n)
}
... ...
... ... @@ -26,8 +26,10 @@ type SafeQueryAppender struct {
params []interface{}
}
var _ QueryAppender = (*SafeQueryAppender)(nil)
var _ types.ValueAppender = (*SafeQueryAppender)(nil)
var (
_ QueryAppender = (*SafeQueryAppender)(nil)
_ types.ValueAppender = (*SafeQueryAppender)(nil)
)
//nolint
func SafeQuery(query string, params ...interface{}) *SafeQueryAppender {
... ... @@ -57,8 +59,10 @@ type condGroupAppender struct {
cond []queryWithSepAppender
}
var _ QueryAppender = (*condAppender)(nil)
var _ queryWithSepAppender = (*condAppender)(nil)
var (
_ QueryAppender = (*condAppender)(nil)
_ queryWithSepAppender = (*condAppender)(nil)
)
func (q *condGroupAppender) AppendSep(b []byte) []byte {
return append(b, q.sep...)
... ... @@ -87,8 +91,10 @@ type condAppender struct {
params []interface{}
}
var _ QueryAppender = (*condAppender)(nil)
var _ queryWithSepAppender = (*condAppender)(nil)
var (
_ QueryAppender = (*condAppender)(nil)
_ queryWithSepAppender = (*condAppender)(nil)
)
func (q *condAppender) AppendSep(b []byte) []byte {
return append(b, q.sep...)
... ... @@ -192,9 +198,9 @@ func (f *Formatter) WithModel(model interface{}) *Formatter {
case TableModel:
return f.WithTableModel(model)
case *Query:
return f.WithTableModel(model.model)
case queryCommand:
return f.WithTableModel(model.Query().model)
return f.WithTableModel(model.tableModel)
case QueryCommand:
return f.WithTableModel(model.Query().tableModel)
default:
panic(fmt.Errorf("pg: unsupported model %T", model))
}
... ...
... ... @@ -7,44 +7,51 @@ import (
type hookStubs struct{}
var _ AfterSelectHook = (*hookStubs)(nil)
var _ BeforeInsertHook = (*hookStubs)(nil)
var _ AfterInsertHook = (*hookStubs)(nil)
var _ BeforeUpdateHook = (*hookStubs)(nil)
var _ AfterUpdateHook = (*hookStubs)(nil)
var _ BeforeDeleteHook = (*hookStubs)(nil)
var _ AfterDeleteHook = (*hookStubs)(nil)
func (hookStubs) AfterSelect(c context.Context) error {
var (
_ AfterScanHook = (*hookStubs)(nil)
_ AfterSelectHook = (*hookStubs)(nil)
_ BeforeInsertHook = (*hookStubs)(nil)
_ AfterInsertHook = (*hookStubs)(nil)
_ BeforeUpdateHook = (*hookStubs)(nil)
_ AfterUpdateHook = (*hookStubs)(nil)
_ BeforeDeleteHook = (*hookStubs)(nil)
_ AfterDeleteHook = (*hookStubs)(nil)
)
func (hookStubs) AfterScan(ctx context.Context) error {
return nil
}
func (hookStubs) AfterSelect(ctx context.Context) error {
return nil
}
func (hookStubs) BeforeInsert(c context.Context) (context.Context, error) {
return c, nil
func (hookStubs) BeforeInsert(ctx context.Context) (context.Context, error) {
return ctx, nil
}
func (hookStubs) AfterInsert(c context.Context) error {
func (hookStubs) AfterInsert(ctx context.Context) error {
return nil
}
func (hookStubs) BeforeUpdate(c context.Context) (context.Context, error) {
return c, nil
func (hookStubs) BeforeUpdate(ctx context.Context) (context.Context, error) {
return ctx, nil
}
func (hookStubs) AfterUpdate(c context.Context) error {
func (hookStubs) AfterUpdate(ctx context.Context) error {
return nil
}
func (hookStubs) BeforeDelete(c context.Context) (context.Context, error) {
return c, nil
func (hookStubs) BeforeDelete(ctx context.Context) (context.Context, error) {
return ctx, nil
}
func (hookStubs) AfterDelete(c context.Context) error {
func (hookStubs) AfterDelete(ctx context.Context) error {
return nil
}
func callHookSlice(
c context.Context,
ctx context.Context,
slice reflect.Value,
ptr bool,
hook func(context.Context, reflect.Value) (context.Context, error),
... ... @@ -58,16 +65,16 @@ func callHookSlice(
}
var err error
c, err = hook(c, v)
ctx, err = hook(ctx, v)
if err != nil && firstErr == nil {
firstErr = err
}
}
return c, firstErr
return ctx, firstErr
}
func callHookSlice2(
c context.Context,
ctx context.Context,
slice reflect.Value,
ptr bool,
hook func(context.Context, reflect.Value) error,
... ... @@ -81,7 +88,7 @@ func callHookSlice2(
v = v.Addr()
}
err := hook(c, v)
err := hook(ctx, v)
if err != nil && firstErr == nil {
firstErr = err
}
... ... @@ -98,8 +105,8 @@ type BeforeScanHook interface {
var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem()
func callBeforeScanHook(c context.Context, v reflect.Value) error {
return v.Interface().(BeforeScanHook).BeforeScan(c)
func callBeforeScanHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(BeforeScanHook).BeforeScan(ctx)
}
//------------------------------------------------------------------------------
... ... @@ -110,8 +117,8 @@ type AfterScanHook interface {
var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem()
func callAfterScanHook(c context.Context, v reflect.Value) error {
return v.Interface().(AfterScanHook).AfterScan(c)
func callAfterScanHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(AfterScanHook).AfterScan(ctx)
}
//------------------------------------------------------------------------------
... ... @@ -122,14 +129,14 @@ type AfterSelectHook interface {
var afterSelectHookType = reflect.TypeOf((*AfterSelectHook)(nil)).Elem()
func callAfterSelectHook(c context.Context, v reflect.Value) error {
return v.Interface().(AfterSelectHook).AfterSelect(c)
func callAfterSelectHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(AfterSelectHook).AfterSelect(ctx)
}
func callAfterSelectHookSlice(
c context.Context, slice reflect.Value, ptr bool,
ctx context.Context, slice reflect.Value, ptr bool,
) error {
return callHookSlice2(c, slice, ptr, callAfterSelectHook)
return callHookSlice2(ctx, slice, ptr, callAfterSelectHook)
}
//------------------------------------------------------------------------------
... ... @@ -140,14 +147,14 @@ type BeforeInsertHook interface {
var beforeInsertHookType = reflect.TypeOf((*BeforeInsertHook)(nil)).Elem()
func callBeforeInsertHook(c context.Context, v reflect.Value) (context.Context, error) {
return v.Interface().(BeforeInsertHook).BeforeInsert(c)
func callBeforeInsertHook(ctx context.Context, v reflect.Value) (context.Context, error) {
return v.Interface().(BeforeInsertHook).BeforeInsert(ctx)
}
func callBeforeInsertHookSlice(
c context.Context, slice reflect.Value, ptr bool,
ctx context.Context, slice reflect.Value, ptr bool,
) (context.Context, error) {
return callHookSlice(c, slice, ptr, callBeforeInsertHook)
return callHookSlice(ctx, slice, ptr, callBeforeInsertHook)
}
//------------------------------------------------------------------------------
... ... @@ -158,14 +165,14 @@ type AfterInsertHook interface {
var afterInsertHookType = reflect.TypeOf((*AfterInsertHook)(nil)).Elem()
func callAfterInsertHook(c context.Context, v reflect.Value) error {
return v.Interface().(AfterInsertHook).AfterInsert(c)
func callAfterInsertHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(AfterInsertHook).AfterInsert(ctx)
}
func callAfterInsertHookSlice(
c context.Context, slice reflect.Value, ptr bool,
ctx context.Context, slice reflect.Value, ptr bool,
) error {
return callHookSlice2(c, slice, ptr, callAfterInsertHook)
return callHookSlice2(ctx, slice, ptr, callAfterInsertHook)
}
//------------------------------------------------------------------------------
... ... @@ -176,14 +183,14 @@ type BeforeUpdateHook interface {
var beforeUpdateHookType = reflect.TypeOf((*BeforeUpdateHook)(nil)).Elem()
func callBeforeUpdateHook(c context.Context, v reflect.Value) (context.Context, error) {
return v.Interface().(BeforeUpdateHook).BeforeUpdate(c)
func callBeforeUpdateHook(ctx context.Context, v reflect.Value) (context.Context, error) {
return v.Interface().(BeforeUpdateHook).BeforeUpdate(ctx)
}
func callBeforeUpdateHookSlice(
c context.Context, slice reflect.Value, ptr bool,
ctx context.Context, slice reflect.Value, ptr bool,
) (context.Context, error) {
return callHookSlice(c, slice, ptr, callBeforeUpdateHook)
return callHookSlice(ctx, slice, ptr, callBeforeUpdateHook)
}
//------------------------------------------------------------------------------
... ... @@ -194,14 +201,14 @@ type AfterUpdateHook interface {
var afterUpdateHookType = reflect.TypeOf((*AfterUpdateHook)(nil)).Elem()
func callAfterUpdateHook(c context.Context, v reflect.Value) error {
return v.Interface().(AfterUpdateHook).AfterUpdate(c)
func callAfterUpdateHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(AfterUpdateHook).AfterUpdate(ctx)
}
func callAfterUpdateHookSlice(
c context.Context, slice reflect.Value, ptr bool,
ctx context.Context, slice reflect.Value, ptr bool,
) error {
return callHookSlice2(c, slice, ptr, callAfterUpdateHook)
return callHookSlice2(ctx, slice, ptr, callAfterUpdateHook)
}
//------------------------------------------------------------------------------
... ... @@ -212,14 +219,14 @@ type BeforeDeleteHook interface {
var beforeDeleteHookType = reflect.TypeOf((*BeforeDeleteHook)(nil)).Elem()
func callBeforeDeleteHook(c context.Context, v reflect.Value) (context.Context, error) {
return v.Interface().(BeforeDeleteHook).BeforeDelete(c)
func callBeforeDeleteHook(ctx context.Context, v reflect.Value) (context.Context, error) {
return v.Interface().(BeforeDeleteHook).BeforeDelete(ctx)
}
func callBeforeDeleteHookSlice(
c context.Context, slice reflect.Value, ptr bool,
ctx context.Context, slice reflect.Value, ptr bool,
) (context.Context, error) {
return callHookSlice(c, slice, ptr, callBeforeDeleteHook)
return callHookSlice(ctx, slice, ptr, callBeforeDeleteHook)
}
//------------------------------------------------------------------------------
... ... @@ -230,12 +237,12 @@ type AfterDeleteHook interface {
var afterDeleteHookType = reflect.TypeOf((*AfterDeleteHook)(nil)).Elem()
func callAfterDeleteHook(c context.Context, v reflect.Value) error {
return v.Interface().(AfterDeleteHook).AfterDelete(c)
func callAfterDeleteHook(ctx context.Context, v reflect.Value) error {
return v.Interface().(AfterDeleteHook).AfterDelete(ctx)
}
func callAfterDeleteHookSlice(
c context.Context, slice reflect.Value, ptr bool,
ctx context.Context, slice reflect.Value, ptr bool,
) error {
return callHookSlice2(c, slice, ptr, callAfterDeleteHook)
return callHookSlice2(ctx, slice, ptr, callAfterDeleteHook)
}
... ...
... ... @@ -3,55 +3,59 @@ package orm
import (
"fmt"
"reflect"
"sort"
"github.com/go-pg/pg/v10/types"
)
func Insert(db DB, model ...interface{}) error {
_, err := NewQuery(db, model...).Insert()
return err
}
type insertQuery struct {
type InsertQuery struct {
q *Query
returningFields []*Field
placeholder bool
}
var _ queryCommand = (*insertQuery)(nil)
var _ QueryCommand = (*InsertQuery)(nil)
func newInsertQuery(q *Query) *insertQuery {
return &insertQuery{
func NewInsertQuery(q *Query) *InsertQuery {
return &InsertQuery{
q: q,
}
}
func (q *insertQuery) Operation() string {
func (q *InsertQuery) String() string {
b, err := q.AppendQuery(defaultFmter, nil)
if err != nil {
panic(err)
}
return string(b)
}
func (q *InsertQuery) Operation() QueryOp {
return InsertOp
}
func (q *insertQuery) Clone() queryCommand {
return &insertQuery{
func (q *InsertQuery) Clone() QueryCommand {
return &InsertQuery{
q: q.q.Clone(),
placeholder: q.placeholder,
}
}
func (q *insertQuery) Query() *Query {
func (q *InsertQuery) Query() *Query {
return q.q
}
var _ TemplateAppender = (*insertQuery)(nil)
var _ TemplateAppender = (*InsertQuery)(nil)
func (q *insertQuery) AppendTemplate(b []byte) ([]byte, error) {
cp := q.Clone().(*insertQuery)
func (q *InsertQuery) AppendTemplate(b []byte) ([]byte, error) {
cp := q.Clone().(*InsertQuery)
cp.placeholder = true
return cp.AppendQuery(dummyFormatter{}, b)
}
var _ QueryAppender = (*insertQuery)(nil)
var _ QueryAppender = (*InsertQuery)(nil)
func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) {
func (q *InsertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) {
if q.q.stickyErr != nil {
return nil, q.q.stickyErr
}
... ... @@ -73,6 +77,60 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
return nil, err
}
b, err = q.appendColumnsValues(fmter, b)
if err != nil {
return nil, err
}
if q.q.onConflict != nil {
b = append(b, " ON CONFLICT "...)
b, err = q.q.onConflict.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
if q.q.onConflictDoUpdate() {
if len(q.q.set) > 0 {
b, err = q.q.appendSet(fmter, b)
if err != nil {
return nil, err
}
} else {
fields, err := q.q.getDataFields()
if err != nil {
return nil, err
}
if len(fields) == 0 {
fields = q.q.tableModel.Table().DataFields
}
b = q.appendSetExcluded(b, fields)
}
if len(q.q.updWhere) > 0 {
b = append(b, " WHERE "...)
b, err = q.q.appendUpdWhere(fmter, b)
if err != nil {
return nil, err
}
}
}
}
if len(q.q.returning) > 0 {
b, err = q.q.appendReturning(fmter, b)
if err != nil {
return nil, err
}
} else if len(q.returningFields) > 0 {
b = appendReturningFields(b, q.returningFields)
}
return b, q.q.stickyErr
}
func (q *InsertQuery) appendColumnsValues(fmter QueryFormatter, b []byte) (_ []byte, err error) {
if q.q.hasMultiTables() {
if q.q.columns != nil {
b = append(b, " ("...)
... ... @@ -82,13 +140,21 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
}
b = append(b, ")"...)
}
b = append(b, " SELECT * FROM "...)
b, err = q.q.appendOtherTables(fmter, b)
if err != nil {
return nil, err
}
} else {
if !q.q.hasModel() {
return b, nil
}
if m, ok := q.q.model.(*mapModel); ok {
return q.appendMapColumnsValues(b, m.m), nil
}
if !q.q.hasTableModel() {
return nil, errModelNil
}
... ... @@ -98,14 +164,14 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
}
if len(fields) == 0 {
fields = q.q.model.Table().Fields
fields = q.q.tableModel.Table().Fields
}
value := q.q.model.Value()
value := q.q.tableModel.Value()
b = append(b, " ("...)
b = q.appendColumns(b, fields)
b = append(b, ") VALUES ("...)
if m, ok := q.q.model.(*sliceTableModel); ok {
if m, ok := q.q.tableModel.(*sliceTableModel); ok {
if m.sliceLen == 0 {
err = fmt.Errorf("pg: can't bulk-insert empty slice %s", value.Type())
return nil, err
... ... @@ -121,57 +187,46 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
}
}
b = append(b, ")"...)
}
if q.q.onConflict != nil {
b = append(b, " ON CONFLICT "...)
b, err = q.q.onConflict.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
return b, nil
}
if q.q.onConflictDoUpdate() {
if len(q.q.set) > 0 {
b, err = q.q.appendSet(fmter, b)
if err != nil {
return nil, err
}
} else {
fields, err := q.q.getDataFields()
if err != nil {
return nil, err
}
func (q *InsertQuery) appendMapColumnsValues(b []byte, m map[string]interface{}) []byte {
keys := make([]string, 0, len(m))
if len(fields) == 0 {
fields = q.q.model.Table().DataFields
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
b = q.appendSetExcluded(b, fields)
}
b = append(b, " ("...)
if len(q.q.updWhere) > 0 {
b = append(b, " WHERE "...)
b, err = q.q.appendUpdWhere(fmter, b)
if err != nil {
return nil, err
}
}
for i, k := range keys {
if i > 0 {
b = append(b, ", "...)
}
b = types.AppendIdent(b, k, 1)
}
if len(q.q.returning) > 0 {
b, err = q.q.appendReturning(fmter, b)
if err != nil {
return nil, err
b = append(b, ") VALUES ("...)
for i, k := range keys {
if i > 0 {
b = append(b, ", "...)
}
if q.placeholder {
b = append(b, '?')
} else {
b = types.Append(b, m[k], 1)
}
} else if len(q.returningFields) > 0 {
b = appendReturningFields(b, q.returningFields)
}
return b, q.q.stickyErr
b = append(b, ")"...)
return b
}
func (q *insertQuery) appendValues(
func (q *InsertQuery) appendValues(
fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value,
) (_ []byte, err error) {
for i, f := range fields {
... ... @@ -214,7 +269,7 @@ func (q *insertQuery) appendValues(
return b, nil
}
func (q *insertQuery) appendSliceValues(
func (q *InsertQuery) appendSliceValues(
fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value,
) (_ []byte, err error) {
if q.placeholder {
... ... @@ -247,7 +302,7 @@ func (q *insertQuery) appendSliceValues(
return b, nil
}
func (q *insertQuery) addReturningField(field *Field) {
func (q *InsertQuery) addReturningField(field *Field) {
if len(q.q.returning) > 0 {
return
}
... ... @@ -259,7 +314,7 @@ func (q *insertQuery) addReturningField(field *Field) {
q.returningFields = append(q.returningFields, field)
}
func (q *insertQuery) appendSetExcluded(b []byte, fields []*Field) []byte {
func (q *InsertQuery) appendSetExcluded(b []byte, fields []*Field) []byte {
b = append(b, " SET "...)
for i, f := range fields {
if i > 0 {
... ... @@ -272,7 +327,7 @@ func (q *insertQuery) appendSetExcluded(b []byte, fields []*Field) []byte {
return b
}
func (q *insertQuery) appendColumns(b []byte, fields []*Field) []byte {
func (q *InsertQuery) appendColumns(b []byte, fields []*Field) []byte {
b = appendColumns(b, "", fields)
for i, v := range q.q.extraValues {
if i > 0 || len(fields) > 0 {
... ...
... ... @@ -64,16 +64,16 @@ func (j *join) manyQuery(q *Query) (*Query, error) {
baseTable := j.BaseModel.Table()
var where []byte
if len(j.Rel.FKs) > 1 {
if len(j.Rel.JoinFKs) > 1 {
where = append(where, '(')
}
where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.FKs)
if len(j.Rel.FKs) > 1 {
where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.JoinFKs)
if len(j.Rel.JoinFKs) > 1 {
where = append(where, ')')
}
where = append(where, " IN ("...)
where = appendChildValues(
where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.FKValues)
where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.BaseFKs)
where = append(where, ")"...)
q = q.Where(internal.BytesToString(where))
... ... @@ -126,7 +126,7 @@ func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) {
join = append(join, " AS "...)
join = append(join, j.Rel.M2MTableAlias...)
join = append(join, " ON ("...)
for i, col := range j.Rel.BaseFKs {
for i, col := range j.Rel.M2MBaseFKs {
if i > 0 {
join = append(join, ", "...)
}
... ... @@ -140,10 +140,7 @@ func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) {
q = q.Join(internal.BytesToString(join))
joinTable := j.JoinModel.Table()
for i, col := range j.Rel.JoinFKs {
if i >= len(joinTable.PKs) {
break
}
for i, col := range j.Rel.M2MJoinFKs {
pk := joinTable.PKs[i]
q = q.Where("?.? = ?.?",
joinTable.Alias, pk.Column,
... ... @@ -242,7 +239,7 @@ func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []b
isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag)
b = append(b, "LEFT JOIN "...)
b = fmter.FormatQuery(b, string(j.JoinModel.Table().FullNameForSelects))
b = fmter.FormatQuery(b, string(j.JoinModel.Table().SQLNameForSelects))
b = append(b, " AS "...)
b = j.appendAlias(b)
... ... @@ -252,38 +249,22 @@ func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []b
b = append(b, '(')
}
if len(j.Rel.FKs) > 1 {
if len(j.Rel.BaseFKs) > 1 {
b = append(b, '(')
}
if j.Rel.Type == HasOneRelation {
for i, fk := range j.Rel.FKs {
for i, baseFK := range j.Rel.BaseFKs {
if i > 0 {
b = append(b, " AND "...)
}
b = j.appendAlias(b)
b = append(b, '.')
b = append(b, j.Rel.JoinTable.PKs[i].Column...)
b = append(b, j.Rel.JoinFKs[i].Column...)
b = append(b, " = "...)
b = j.appendBaseAlias(b)
b = append(b, '.')
b = append(b, fk.Column...)
}
} else {
baseTable := j.BaseModel.Table()
for i, fk := range j.Rel.FKs {
if i > 0 {
b = append(b, " AND "...)
}
b = j.appendAlias(b)
b = append(b, '.')
b = append(b, fk.Column...)
b = append(b, " = "...)
b = j.appendBaseAlias(b)
b = append(b, '.')
b = append(b, baseTable.PKs[i].Column...)
}
b = append(b, baseFK.Column...)
}
if len(j.Rel.FKs) > 1 {
if len(j.Rel.BaseFKs) > 1 {
b = append(b, ')')
}
... ...
... ... @@ -31,6 +31,7 @@ type HooklessModel interface {
type Model interface {
HooklessModel
AfterScanHook
AfterSelectHook
BeforeInsertHook
... ... @@ -43,45 +44,90 @@ type Model interface {
AfterDeleteHook
}
func NewModel(values ...interface{}) (Model, error) {
func NewModel(value interface{}) (Model, error) {
return newModel(value, false)
}
func newScanModel(values []interface{}) (Model, error) {
if len(values) > 1 {
return Scan(values...), nil
}
return newModel(values[0], true)
}
v0 := values[0]
switch v0 := v0.(type) {
func newModel(value interface{}, scan bool) (Model, error) {
switch value := value.(type) {
case Model:
return v0, nil
return value, nil
case HooklessModel:
return newModelWithHookStubs(v0), nil
return newModelWithHookStubs(value), nil
case types.ValueScanner, sql.Scanner:
return Scan(v0), nil
if !scan {
return nil, fmt.Errorf("pg: Model(unsupported %T)", value)
}
return Scan(value), nil
}
v := reflect.ValueOf(v0)
v := reflect.ValueOf(value)
if !v.IsValid() {
return nil, errModelNil
}
if v.Kind() != reflect.Ptr {
return nil, fmt.Errorf("pg: Model(non-pointer %T)", v0)
return nil, fmt.Errorf("pg: Model(non-pointer %T)", value)
}
if v.IsNil() {
typ := v.Type().Elem()
if typ.Kind() == reflect.Struct {
return newStructTableModel(GetTable(typ)), nil
}
return nil, errModelNil
}
v = v.Elem()
if v.Kind() == reflect.Interface {
if !v.IsNil() {
v = v.Elem()
if v.Kind() != reflect.Ptr {
return nil, fmt.Errorf("pg: Model(non-pointer %s)", v.Type().String())
}
}
}
switch v.Kind() {
case reflect.Struct:
if v.Type() != timeType {
return newStructTableModelValue(v), nil
}
case reflect.Slice:
typ := v.Type()
elemType := indirectType(typ.Elem())
if elemType.Kind() == reflect.Struct && elemType != timeType {
elemType := sliceElemType(v)
switch elemType.Kind() {
case reflect.Struct:
if elemType != timeType {
return newSliceTableModel(v, elemType), nil
}
case reflect.Map:
if err := validMap(elemType); err != nil {
return nil, err
}
slicePtr := v.Addr().Interface().(*[]map[string]interface{})
return newMapSliceModel(slicePtr), nil
}
return newSliceModel(v, elemType), nil
case reflect.Map:
typ := v.Type()
if err := validMap(typ); err != nil {
return nil, err
}
mapPtr := v.Addr().Interface().(*map[string]interface{})
return newMapModel(mapPtr), nil
}
return Scan(v0), nil
if !scan {
return nil, fmt.Errorf("pg: Model(unsupported %T)", value)
}
return Scan(value), nil
}
type modelWithHookStubs struct {
... ... @@ -94,3 +140,11 @@ func newModelWithHookStubs(m HooklessModel) Model {
HooklessModel: m,
}
}
func validMap(typ reflect.Type) error {
if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface {
return fmt.Errorf("pg: Model(unsupported %s, expected *map[string]interface{})",
typ.String())
}
return nil
}
... ...
... ... @@ -22,6 +22,6 @@ func (m Discard) AddColumnScanner(ColumnScanner) error {
return nil
}
func (m Discard) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error {
func (m Discard) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
return nil
}
... ...
package orm
import (
"github.com/go-pg/pg/v10/types"
)
type mapModel struct {
hookStubs
ptr *map[string]interface{}
m map[string]interface{}
}
var _ Model = (*mapModel)(nil)
func newMapModel(ptr *map[string]interface{}) *mapModel {
model := &mapModel{
ptr: ptr,
}
if ptr != nil {
model.m = *ptr
}
return model
}
func (m *mapModel) Init() error {
return nil
}
func (m *mapModel) NextColumnScanner() ColumnScanner {
if m.m == nil {
m.m = make(map[string]interface{})
*m.ptr = m.m
}
return m
}
func (m mapModel) AddColumnScanner(ColumnScanner) error {
return nil
}
func (m *mapModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
val, err := types.ReadColumnValue(col, rd, n)
if err != nil {
return err
}
m.m[col.Name] = val
return nil
}
func (mapModel) useQueryOne() bool {
return true
}
... ...
package orm
type mapSliceModel struct {
mapModel
slice *[]map[string]interface{}
}
var _ Model = (*mapSliceModel)(nil)
func newMapSliceModel(ptr *[]map[string]interface{}) *mapSliceModel {
return &mapSliceModel{
slice: ptr,
}
}
func (m *mapSliceModel) Init() error {
slice := *m.slice
if len(slice) > 0 {
*m.slice = slice[:0]
}
return nil
}
func (m *mapSliceModel) NextColumnScanner() ColumnScanner {
slice := *m.slice
if len(slice) == cap(slice) {
m.mapModel.m = make(map[string]interface{})
*m.slice = append(slice, m.mapModel.m) //nolint:gocritic
return m
}
slice = slice[:len(slice)+1]
el := slice[len(slice)-1]
if el != nil {
m.mapModel.m = el
} else {
el = make(map[string]interface{})
slice[len(slice)-1] = el
m.mapModel.m = el
}
*m.slice = slice
return m
}
func (mapSliceModel) useQueryOne() {} //nolint:unused
... ...
... ... @@ -29,12 +29,12 @@ func (m scanValuesModel) NextColumnScanner() ColumnScanner {
return m
}
func (m scanValuesModel) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error {
if colIdx >= len(m.values) {
func (m scanValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
if int(col.Index) >= len(m.values) {
return fmt.Errorf("pg: no Scan var for column index=%d name=%q",
colIdx, colName)
col.Index, col.Name)
}
return types.Scan(m.values[colIdx], rd, n)
return types.Scan(m.values[col.Index], rd, n)
}
//------------------------------------------------------------------------------
... ... @@ -60,10 +60,10 @@ func (m scanReflectValuesModel) NextColumnScanner() ColumnScanner {
return m
}
func (m scanReflectValuesModel) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error {
if colIdx >= len(m.values) {
func (m scanReflectValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
if int(col.Index) >= len(m.values) {
return fmt.Errorf("pg: no Scan var for column index=%d name=%q",
colIdx, colName)
col.Index, col.Name)
}
return types.ScanValue(m.values[colIdx], rd, n)
return types.ScanValue(m.values[col.Index], rd, n)
}
... ...
... ... @@ -34,7 +34,7 @@ func (m *sliceModel) NextColumnScanner() ColumnScanner {
return m
}
func (m *sliceModel) ScanColumn(colIdx int, _ string, rd types.Reader, n int) error {
func (m *sliceModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
if m.nextElem == nil {
m.nextElem = internal.MakeSliceNextElemFunc(m.slice)
}
... ...
... ... @@ -27,56 +27,8 @@ type TableModel interface {
Kind() reflect.Kind
Value() reflect.Value
setSoftDeleteField()
scanColumn(int, string, types.Reader, int) (bool, error)
}
func newTableModel(value interface{}) (TableModel, error) {
if value, ok := value.(TableModel); ok {
return value, nil
}
v := reflect.ValueOf(value)
if !v.IsValid() {
return nil, errModelNil
}
if v.Kind() != reflect.Ptr {
return nil, fmt.Errorf("pg: Model(non-pointer %T)", value)
}
if v.IsNil() {
typ := v.Type().Elem()
if typ.Kind() == reflect.Struct {
return newStructTableModel(GetTable(typ)), nil
}
return nil, errModelNil
}
v = v.Elem()
if v.Kind() == reflect.Interface {
if !v.IsNil() {
v = v.Elem()
if v.Kind() != reflect.Ptr {
return nil, fmt.Errorf("pg: Model(non-pointer %s)", v.Type().String())
}
}
}
return newTableModelValue(v)
}
func newTableModelValue(v reflect.Value) (TableModel, error) {
switch v.Kind() {
case reflect.Struct:
return newStructTableModelValue(v), nil
case reflect.Slice:
elemType := sliceElemType(v)
if elemType.Kind() == reflect.Struct {
return newSliceTableModel(v, elemType), nil
}
}
return nil, fmt.Errorf("pg: Model(unsupported %s)", v.Type())
setSoftDeleteField() error
scanColumn(types.ColumnInfo, types.Reader, int) (bool, error)
}
func newTableModelIndex(typ reflect.Type, root reflect.Value, index []int, rel *Relation) (TableModel, error) {
... ...
... ... @@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"github.com/go-pg/pg/v10/internal/pool"
"github.com/go-pg/pg/v10/types"
)
... ... @@ -60,7 +61,7 @@ func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error {
dstValues, ok := m.dstValues[string(buf)]
if !ok {
return fmt.Errorf(
"pg: relation=%q has no base %s with id=%q (check join conditions)",
"pg: relation=%q does not have base %s with id=%q (check join conditions)",
m.rel.Field.GoName, m.baseTable, buf)
}
... ... @@ -76,31 +77,35 @@ func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error {
}
func (m *m2mModel) modelIDMap(b []byte) ([]byte, error) {
for i, col := range m.rel.BaseFKs {
for i, col := range m.rel.M2MBaseFKs {
if i > 0 {
b = append(b, ',')
}
if s, ok := m.columns[col]; ok {
b = append(b, s...)
} else {
return nil, fmt.Errorf("pg: %s has no column=%q",
return nil, fmt.Errorf("pg: %s does not have column=%q",
m.sliceTableModel, col)
}
}
return b, nil
}
func (m *m2mModel) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error {
ok, err := m.sliceTableModel.scanColumn(colIdx, colName, rd, n)
if ok {
func (m *m2mModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
if n > 0 {
b, err := rd.ReadFullTemp()
if err != nil {
return err
}
tmp, err := rd.ReadFullTemp()
if err != nil {
return err
m.columns[col.Name] = string(b)
rd = pool.NewBytesReader(b)
} else {
m.columns[col.Name] = ""
}
m.columns[colName] = string(tmp)
if ok, err := m.sliceTableModel.scanColumn(col, rd, n); ok {
return err
}
return nil
}
... ...