作者 唐旭辉

更新依赖

正在显示 50 个修改的文件 包含 1484 行增加1662 行删除

要显示太多修改。

为保证性能只显示 50 of 50+ 个文件。

@@ -12,7 +12,7 @@ require ( @@ -12,7 +12,7 @@ require (
12 github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072 // indirect 12 github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072 // indirect
13 github.com/fatih/structs v1.1.0 // indirect 13 github.com/fatih/structs v1.1.0 // indirect
14 github.com/gavv/httpexpect v2.0.0+incompatible 14 github.com/gavv/httpexpect v2.0.0+incompatible
15 - github.com/go-pg/pg/v10 v10.0.0-beta.2 15 + github.com/go-pg/pg/v10 v10.7.3
16 github.com/google/go-querystring v1.0.0 // indirect 16 github.com/google/go-querystring v1.0.0 // indirect
17 github.com/gorilla/websocket v1.4.2 // indirect 17 github.com/gorilla/websocket v1.4.2 // indirect
18 github.com/imkira/go-interpol v1.1.0 // indirect 18 github.com/imkira/go-interpol v1.1.0 // indirect
@@ -20,8 +20,9 @@ require ( @@ -20,8 +20,9 @@ require (
20 github.com/linmadan/egglib-go v0.0.0-20191217144343-ca4539f95bf9 20 github.com/linmadan/egglib-go v0.0.0-20191217144343-ca4539f95bf9
21 github.com/mattn/go-colorable v0.1.6 // indirect 21 github.com/mattn/go-colorable v0.1.6 // indirect
22 github.com/moul/http2curl v1.0.0 // indirect 22 github.com/moul/http2curl v1.0.0 // indirect
23 - github.com/onsi/ginkgo v1.13.0  
24 - github.com/onsi/gomega v1.10.1 23 + github.com/onsi/ginkgo v1.14.2
  24 + github.com/onsi/gomega v1.10.3
  25 + github.com/sclevine/agouti v3.0.0+incompatible // indirect
25 github.com/sergi/go-diff v1.1.0 // indirect 26 github.com/sergi/go-diff v1.1.0 // indirect
26 github.com/shopspring/decimal v1.2.0 27 github.com/shopspring/decimal v1.2.0
27 github.com/smartystreets/goconvey v1.6.4 // indirect 28 github.com/smartystreets/goconvey v1.6.4 // indirect
@@ -47,7 +47,7 @@ func (repository *PartnerInfoRepository) Save(dm *domain.PartnerInfo) error { @@ -47,7 +47,7 @@ func (repository *PartnerInfoRepository) Save(dm *domain.PartnerInfo) error {
47 Remark: dm.Remark, 47 Remark: dm.Remark,
48 } 48 }
49 if m.Id == 0 { 49 if m.Id == 0 {
50 - err = tx.Insert(m) 50 + _, err = tx.Model(m).Insert()
51 dm.Partner.Id = m.Id 51 dm.Partner.Id = m.Id
52 if err != nil { 52 if err != nil {
53 return err 53 return err
此 diff 太大无法显示。
1 -# Compiled Object files, Static and Dynamic libs (Shared Objects)  
2 -*.o  
3 -*.a  
4 -*.so  
5 -  
6 -# Folders  
7 -_obj  
8 -_test  
9 -  
10 -# Architecture specific extensions/prefixes  
11 -*.[568vq]  
12 -[568vq].out  
13 -  
14 -*.cgo1.go  
15 -*.cgo2.c  
16 -_cgo_defun.c  
17 -_cgo_gotypes.go  
18 -_cgo_export.*  
19 -  
20 -_testmain.go  
21 -  
22 -*.exe  
23 -*.test  
24 -*.prof  
1 -The MIT License (MIT)  
2 -  
3 -Copyright (c) 2015 codemodus  
4 -  
5 -Permission is hereby granted, free of charge, to any person obtaining a copy  
6 -of this software and associated documentation files (the "Software"), to deal  
7 -in the Software without restriction, including without limitation the rights  
8 -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell  
9 -copies of the Software, and to permit persons to whom the Software is  
10 -furnished to do so, subject to the following conditions:  
11 -  
12 -The above copyright notice and this permission notice shall be included in all  
13 -copies or substantial portions of the Software.  
14 -  
15 -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR  
16 -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  
17 -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE  
18 -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER  
19 -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,  
20 -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE  
21 -SOFTWARE.  
22 -  
1 -# kace  
2 -  
3 - go get "github.com/codemodus/kace"  
4 -  
5 -Package kace provides common case conversion functions which take into  
6 -consideration common initialisms.  
7 -  
8 -## Usage  
9 -  
10 -```go  
11 -func Camel(s string) string  
12 -func Kebab(s string) string  
13 -func KebabUpper(s string) string  
14 -func Pascal(s string) string  
15 -func Snake(s string) string  
16 -func SnakeUpper(s string) string  
17 -type Kace  
18 - func New(initialisms map[string]bool) (*Kace, error)  
19 - func (k *Kace) Camel(s string) string  
20 - func (k *Kace) Kebab(s string) string  
21 - func (k *Kace) KebabUpper(s string) string  
22 - func (k *Kace) Pascal(s string) string  
23 - func (k *Kace) Snake(s string) string  
24 - func (k *Kace) SnakeUpper(s string) string  
25 -```  
26 -  
27 -### Setup  
28 -  
29 -```go  
30 -import (  
31 - "fmt"  
32 -  
33 - "github.com/codemodus/kace"  
34 -)  
35 -  
36 -func main() {  
37 - s := "this is a test sql."  
38 -  
39 - fmt.Println(kace.Camel(s))  
40 - fmt.Println(kace.Pascal(s))  
41 -  
42 - fmt.Println(kace.Snake(s))  
43 - fmt.Println(kace.SnakeUpper(s))  
44 -  
45 - fmt.Println(kace.Kebab(s))  
46 - fmt.Println(kace.KebabUpper(s))  
47 -  
48 - customInitialisms := map[string]bool{  
49 - "THIS": true,  
50 - }  
51 - k, err := kace.New(customInitialisms)  
52 - if err != nil {  
53 - // handle error  
54 - }  
55 -  
56 - fmt.Println(k.Camel(s))  
57 - fmt.Println(k.Pascal(s))  
58 -  
59 - fmt.Println(k.Snake(s))  
60 - fmt.Println(k.SnakeUpper(s))  
61 -  
62 - fmt.Println(k.Kebab(s))  
63 - fmt.Println(k.KebabUpper(s))  
64 -  
65 - // Output:  
66 - // thisIsATestSQL  
67 - // ThisIsATestSQL  
68 - // this_is_a_test_sql  
69 - // THIS_IS_A_TEST_SQL  
70 - // this-is-a-test-sql  
71 - // THIS-IS-A-TEST-SQL  
72 - // thisIsATestSql  
73 - // THISIsATestSql  
74 - // this_is_a_test_sql  
75 - // THIS_IS_A_TEST_SQL  
76 - // this-is-a-test-sql  
77 - // THIS-IS-A-TEST-SQL  
78 -}  
79 -```  
80 -  
81 -## More Info  
82 -  
83 -### TODO  
84 -  
85 -#### Test Trie  
86 -  
87 - Test the current trie.  
88 -  
89 -## Documentation  
90 -  
91 -View the [GoDoc](http://godoc.org/github.com/codemodus/kace)  
92 -  
93 -## Benchmarks  
94 -  
95 - benchmark iter time/iter bytes alloc allocs  
96 - --------- ---- --------- ----------- ------  
97 - BenchmarkCamel4 2000000 947.00 ns/op 112 B/op 3 allocs/op  
98 - BenchmarkSnake4 2000000 696.00 ns/op 128 B/op 2 allocs/op  
99 - BenchmarkSnakeUpper4 2000000 679.00 ns/op 128 B/op 2 allocs/op  
100 - BenchmarkKebab4 2000000 691.00 ns/op 128 B/op 2 allocs/op  
101 - BenchmarkKebabUpper4 2000000 677.00 ns/op 128 B/op 2 allocs/op  
1 -module github.com/codemodus/kace  
1 -// Package kace provides common case conversion functions which take into  
2 -// consideration common initialisms.  
3 -package kace  
4 -  
5 -import (  
6 - "fmt"  
7 - "strings"  
8 - "unicode"  
9 -  
10 - "github.com/codemodus/kace/ktrie"  
11 -)  
12 -  
13 -const (  
14 - kebabDelim = '-'  
15 - snakeDelim = '_'  
16 - none = rune(-1)  
17 -)  
18 -  
19 -var (  
20 - ciTrie *ktrie.KTrie  
21 -)  
22 -  
23 -func init() {  
24 - var err error  
25 - if ciTrie, err = ktrie.NewKTrie(ciMap); err != nil {  
26 - panic(err)  
27 - }  
28 -}  
29 -  
30 -// Camel returns a camelCased string.  
31 -func Camel(s string) string {  
32 - return camelCase(ciTrie, s, false)  
33 -}  
34 -  
35 -// Pascal returns a PascalCased string.  
36 -func Pascal(s string) string {  
37 - return camelCase(ciTrie, s, true)  
38 -}  
39 -  
40 -// Kebab returns a kebab-cased string with all lowercase letters.  
41 -func Kebab(s string) string {  
42 - return delimitedCase(s, kebabDelim, false)  
43 -}  
44 -  
45 -// KebabUpper returns a KEBAB-CASED string with all upper case letters.  
46 -func KebabUpper(s string) string {  
47 - return delimitedCase(s, kebabDelim, true)  
48 -}  
49 -  
50 -// Snake returns a snake_cased string with all lowercase letters.  
51 -func Snake(s string) string {  
52 - return delimitedCase(s, snakeDelim, false)  
53 -}  
54 -  
55 -// SnakeUpper returns a SNAKE_CASED string with all upper case letters.  
56 -func SnakeUpper(s string) string {  
57 - return delimitedCase(s, snakeDelim, true)  
58 -}  
59 -  
60 -// Kace provides common case conversion methods which take into  
61 -// consideration common initialisms set by the user.  
62 -type Kace struct {  
63 - t *ktrie.KTrie  
64 -}  
65 -  
66 -// New returns a pointer to an instance of kace loaded with a common  
67 -// initialsms trie based on the provided map. Before conversion to a  
68 -// trie, the provided map keys are all upper cased.  
69 -func New(initialisms map[string]bool) (*Kace, error) {  
70 - ci := initialisms  
71 - if ci == nil {  
72 - ci = map[string]bool{}  
73 - }  
74 -  
75 - ci = sanitizeCI(ci)  
76 -  
77 - t, err := ktrie.NewKTrie(ci)  
78 - if err != nil {  
79 - return nil, fmt.Errorf("kace: cannot create trie: %s", err)  
80 - }  
81 -  
82 - k := &Kace{  
83 - t: t,  
84 - }  
85 -  
86 - return k, nil  
87 -}  
88 -  
89 -// Camel returns a camelCased string.  
90 -func (k *Kace) Camel(s string) string {  
91 - return camelCase(k.t, s, false)  
92 -}  
93 -  
94 -// Pascal returns a PascalCased string.  
95 -func (k *Kace) Pascal(s string) string {  
96 - return camelCase(k.t, s, true)  
97 -}  
98 -  
99 -// Snake returns a snake_cased string with all lowercase letters.  
100 -func (k *Kace) Snake(s string) string {  
101 - return delimitedCase(s, snakeDelim, false)  
102 -}  
103 -  
104 -// SnakeUpper returns a SNAKE_CASED string with all upper case letters.  
105 -func (k *Kace) SnakeUpper(s string) string {  
106 - return delimitedCase(s, snakeDelim, true)  
107 -}  
108 -  
109 -// Kebab returns a kebab-cased string with all lowercase letters.  
110 -func (k *Kace) Kebab(s string) string {  
111 - return delimitedCase(s, kebabDelim, false)  
112 -}  
113 -  
114 -// KebabUpper returns a KEBAB-CASED string with all upper case letters.  
115 -func (k *Kace) KebabUpper(s string) string {  
116 - return delimitedCase(s, kebabDelim, true)  
117 -}  
118 -  
119 -func camelCase(t *ktrie.KTrie, s string, ucFirst bool) string {  
120 - rs := []rune(s)  
121 - offset := 0  
122 - prev := none  
123 -  
124 - for i := 0; i < len(rs); i++ {  
125 - r := rs[i]  
126 -  
127 - switch {  
128 - case unicode.IsLetter(r):  
129 - ucCurr := isToBeUpper(r, prev, ucFirst)  
130 -  
131 - if ucCurr || isSegmentStart(r, prev) {  
132 - prv, skip := updateRunes(rs, i, offset, t, ucCurr)  
133 - if skip > 0 {  
134 - i += skip - 1  
135 - prev = prv  
136 - continue  
137 - }  
138 - }  
139 -  
140 - prev = updateRune(rs, i, offset, ucCurr)  
141 - continue  
142 -  
143 - case unicode.IsNumber(r):  
144 - prev = updateRune(rs, i, offset, false)  
145 - continue  
146 -  
147 - default:  
148 - prev = r  
149 - offset--  
150 - }  
151 - }  
152 -  
153 - return string(rs[:len(rs)+offset])  
154 -}  
155 -  
156 -func isToBeUpper(curr, prev rune, ucFirst bool) bool {  
157 - if prev == none {  
158 - return ucFirst  
159 - }  
160 -  
161 - return isSegmentStart(curr, prev)  
162 -}  
163 -  
164 -func isSegmentStart(curr, prev rune) bool {  
165 - if !unicode.IsLetter(prev) || unicode.IsUpper(curr) && unicode.IsLower(prev) {  
166 - return true  
167 - }  
168 -  
169 - return false  
170 -}  
171 -  
172 -func updateRune(rs []rune, i, offset int, upper bool) rune {  
173 - r := rs[i]  
174 -  
175 - dest := i + offset  
176 - if dest < 0 || i > len(rs)-1 {  
177 - panic("this function has been used or designed incorrectly")  
178 - }  
179 -  
180 - fn := unicode.ToLower  
181 - if upper {  
182 - fn = unicode.ToUpper  
183 - }  
184 -  
185 - rs[dest] = fn(r)  
186 -  
187 - return r  
188 -}  
189 -  
190 -func updateRunes(rs []rune, i, offset int, t *ktrie.KTrie, upper bool) (rune, int) {  
191 - r := rs[i]  
192 - ns := nextSegment(rs, i)  
193 - ct := len(ns)  
194 -  
195 - if ct < t.MinDepth() || ct > t.MaxDepth() || !t.FindAsUpper(ns) {  
196 - return r, 0  
197 - }  
198 -  
199 - for j := i; j < i+ct; j++ {  
200 - r = updateRune(rs, j, offset, upper)  
201 - }  
202 -  
203 - return r, ct  
204 -}  
205 -  
206 -func nextSegment(rs []rune, i int) []rune {  
207 - for j := i; j < len(rs); j++ {  
208 - if !unicode.IsLetter(rs[j]) && !unicode.IsNumber(rs[j]) {  
209 - return rs[i:j]  
210 - }  
211 -  
212 - if j == len(rs)-1 {  
213 - return rs[i : j+1]  
214 - }  
215 - }  
216 -  
217 - return nil  
218 -}  
219 -  
220 -func delimitedCase(s string, delim rune, upper bool) string {  
221 - buf := make([]rune, 0, len(s)*2)  
222 -  
223 - for i := len(s); i > 0; i-- {  
224 - switch {  
225 - case unicode.IsLetter(rune(s[i-1])):  
226 - if i < len(s) && unicode.IsUpper(rune(s[i])) {  
227 - if i > 1 && unicode.IsLower(rune(s[i-1])) || i < len(s)-2 && unicode.IsLower(rune(s[i+1])) {  
228 - buf = append(buf, delim)  
229 - }  
230 - }  
231 -  
232 - buf = appendCased(buf, upper, rune(s[i-1]))  
233 -  
234 - case unicode.IsNumber(rune(s[i-1])):  
235 - if i == len(s) || i == 1 || unicode.IsNumber(rune(s[i])) {  
236 - buf = append(buf, rune(s[i-1]))  
237 - continue  
238 - }  
239 -  
240 - buf = append(buf, delim, rune(s[i-1]))  
241 -  
242 - default:  
243 - if i == len(s) {  
244 - continue  
245 - }  
246 -  
247 - buf = append(buf, delim)  
248 - }  
249 - }  
250 -  
251 - reverse(buf)  
252 -  
253 - return string(buf)  
254 -}  
255 -  
256 -func appendCased(rs []rune, upper bool, r rune) []rune {  
257 - if upper {  
258 - rs = append(rs, unicode.ToUpper(r))  
259 - return rs  
260 - }  
261 -  
262 - rs = append(rs, unicode.ToLower(r))  
263 -  
264 - return rs  
265 -}  
266 -  
267 -func reverse(s []rune) {  
268 - for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {  
269 - s[i], s[j] = s[j], s[i]  
270 - }  
271 -}  
272 -  
273 -var (  
274 - // github.com/golang/lint/blob/master/lint.go  
275 - ciMap = map[string]bool{  
276 - "ACL": true,  
277 - "API": true,  
278 - "ASCII": true,  
279 - "CPU": true,  
280 - "CSS": true,  
281 - "DNS": true,  
282 - "EOF": true,  
283 - "GUID": true,  
284 - "HTML": true,  
285 - "HTTP": true,  
286 - "HTTPS": true,  
287 - "ID": true,  
288 - "IP": true,  
289 - "JSON": true,  
290 - "LHS": true,  
291 - "QPS": true,  
292 - "RAM": true,  
293 - "RHS": true,  
294 - "RPC": true,  
295 - "SLA": true,  
296 - "SMTP": true,  
297 - "SQL": true,  
298 - "SSH": true,  
299 - "TCP": true,  
300 - "TLS": true,  
301 - "TTL": true,  
302 - "UDP": true,  
303 - "UI": true,  
304 - "UID": true,  
305 - "UUID": true,  
306 - "URI": true,  
307 - "URL": true,  
308 - "UTF8": true,  
309 - "VM": true,  
310 - "XML": true,  
311 - "XMPP": true,  
312 - "XSRF": true,  
313 - "XSS": true,  
314 - }  
315 -)  
316 -  
317 -func sanitizeCI(m map[string]bool) map[string]bool {  
318 - r := map[string]bool{}  
319 -  
320 - for k := range m {  
321 - fn := func(r rune) rune {  
322 - if !unicode.IsLetter(r) && !unicode.IsNumber(r) {  
323 - return -1  
324 - }  
325 - return r  
326 - }  
327 -  
328 - k = strings.Map(fn, k)  
329 - k = strings.ToUpper(k)  
330 -  
331 - if k == "" {  
332 - continue  
333 - }  
334 -  
335 - r[k] = true  
336 - }  
337 -  
338 - return r  
339 -}  
1 -package ktrie  
2 -  
3 -import "unicode"  
4 -  
5 -// KNode ...  
6 -type KNode struct {  
7 - val rune  
8 - end bool  
9 - links []*KNode  
10 -}  
11 -  
12 -// NewKNode ...  
13 -func NewKNode(val rune) *KNode {  
14 - return &KNode{  
15 - val: val,  
16 - links: make([]*KNode, 0),  
17 - }  
18 -}  
19 -  
20 -// Add ...  
21 -func (n *KNode) Add(rs []rune) {  
22 - cur := n  
23 -  
24 - for k, v := range rs {  
25 - link := cur.linkByVal(v)  
26 -  
27 - if link == nil {  
28 - link = NewKNode(v)  
29 - cur.links = append(cur.links, link)  
30 - }  
31 -  
32 - if k == len(rs)-1 {  
33 - link.end = true  
34 - }  
35 -  
36 - cur = link  
37 - }  
38 -}  
39 -  
40 -// Find ...  
41 -func (n *KNode) Find(rs []rune) bool {  
42 - cur := n  
43 -  
44 - for _, v := range rs {  
45 - cur = cur.linkByVal(v)  
46 -  
47 - if cur == nil {  
48 - return false  
49 - }  
50 - }  
51 -  
52 - return cur.end  
53 -}  
54 -  
55 -// FindAsUpper ...  
56 -func (n *KNode) FindAsUpper(rs []rune) bool {  
57 - cur := n  
58 -  
59 - for _, v := range rs {  
60 - cur = cur.linkByVal(unicode.ToUpper(v))  
61 -  
62 - if cur == nil {  
63 - return false  
64 - }  
65 - }  
66 -  
67 - return cur.end  
68 -}  
69 -  
70 -func (n *KNode) linkByVal(val rune) *KNode {  
71 - for _, v := range n.links {  
72 - if v.val == val {  
73 - return v  
74 - }  
75 - }  
76 -  
77 - return nil  
78 -}  
79 -  
80 -// KTrie ...  
81 -type KTrie struct {  
82 - *KNode  
83 -  
84 - maxDepth int  
85 - minDepth int  
86 -}  
87 -  
88 -// NewKTrie ...  
89 -func NewKTrie(data map[string]bool) (*KTrie, error) {  
90 - n := NewKNode(0)  
91 -  
92 - maxDepth := 0  
93 - minDepth := 9001  
94 -  
95 - for k := range data {  
96 - rs := []rune(k)  
97 - l := len(rs)  
98 -  
99 - n.Add(rs)  
100 -  
101 - if l > maxDepth {  
102 - maxDepth = l  
103 - }  
104 - if l < minDepth {  
105 - minDepth = l  
106 - }  
107 - }  
108 -  
109 - t := &KTrie{  
110 - maxDepth: maxDepth,  
111 - minDepth: minDepth,  
112 - KNode: n,  
113 - }  
114 -  
115 - return t, nil  
116 -}  
117 -  
118 -// MaxDepth ...  
119 -func (t *KTrie) MaxDepth() int {  
120 - return t.maxDepth  
121 -}  
122 -  
123 -// MinDepth ...  
124 -func (t *KTrie) MinDepth() int {  
125 - return t.minDepth  
126 -}  
@@ -11,3 +11,8 @@ linters: @@ -11,3 +11,8 @@ linters:
11 - wsl 11 - wsl
12 - funlen 12 - funlen
13 - godox 13 - godox
  14 + - goerr113
  15 + - exhaustive
  16 + - nestif
  17 + - gofumpt
  18 + - goconst
1 semi: false 1 semi: false
2 singleQuote: true 2 singleQuote: true
3 proseWrap: always 3 proseWrap: always
4 -printWidth: 80 4 +printWidth: 100
1 dist: xenial 1 dist: xenial
2 -sudo: false  
3 language: go 2 language: go
4 3
5 addons: 4 addons:
6 - postgresql: "9.6" 5 + postgresql: '9.6'
7 6
8 go: 7 go:
9 - - 1.13.x  
10 - 1.14.x 8 - 1.14.x
  9 + - 1.15.x
11 - tip 10 - tip
12 11
13 matrix: 12 matrix:
14 allow_failures: 13 allow_failures:
15 - go: tip 14 - go: tip
16 15
17 -env:  
18 - - GO111MODULE=on  
19 -  
20 go_import_path: github.com/go-pg/pg 16 go_import_path: github.com/go-pg/pg
21 17
22 before_install: 18 before_install:
23 - psql -U postgres -c "CREATE EXTENSION hstore" 19 - psql -U postgres -c "CREATE EXTENSION hstore"
24 - - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.24.0 20 + - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s --
  21 + -b $(go env GOPATH)/bin v1.28.3
1 # Changelog 1 # Changelog
2 2
3 -## v10 (unreleased) 3 +> :heart: [**Uptrace.dev** - distributed traces, logs, and errors in one place](https://uptrace.dev)
4 4
5 -- Added `pgext.OpenTemetryHook` that adds OpenTelemetry  
6 - [instrumentation](https://pg.uptrace.dev/tracing/).  
7 -- Added `pgext.DebugHook` that logs queries and errors.  
8 -- Added `db.Ping` to check if database is healthy.  
9 -- Changed `pg.QueryHook` to return temp byte slice to reduce memory usage.  
10 -- `,msgpack` struct tag marshals data in MessagePack format using  
11 - https://github.com/vmihailenco/msgpack  
12 -- Deprecated types and funcs are removed.  
13 -  
14 -## v9  
15 -  
16 -- `pg:",notnull"` is reworked. Now it means SQL `NOT NULL` constraint and  
17 - nothing more.  
18 -- Added `pg:",use_zero"` to prevent go-pg from converting Go zero values to SQL  
19 - `NULL`.  
20 -- UpdateNotNull is renamed to UpdateNotZero. As previously it omits zero Go  
21 - values, but it does not take in account if field is nullable or not.  
22 -- ORM supports DistinctOn.  
23 -- Hooks accept and return context.  
24 -- Client respects Context.Deadline when setting net.Conn deadline.  
25 -- Client listens on Context.Done while waiting for a connection from the pool  
26 - and returns an error when context is cancelled.  
27 -- `Query.Column` does not accept relation name any more. Use `Query.Relation`  
28 - instead which returns an error if relation does not exist.  
29 -- urlvalues package is removed in favor of https://github.com/go-pg/urlstruct.  
30 - You can also use struct based filters via `Query.WhereStruct`.  
31 -- `NewModel` and `AddModel` methods of `HooklessModel` interface were renamed to  
32 - `NextColumnScanner` and `AddColumnScanner` respectively.  
33 -- `types.F` and `pg.F` are deprecated in favor of `pg.Ident`.  
34 -- `types.Q` is deprecated in favor of `pg.Safe`.  
35 -- `pg.Q` is deprecated in favor of `pg.SafeQuery`.  
36 -- `TableName` field is deprecated in favor of `tableName`.  
37 -- Always use `pg:"..."` struct field tag instead of `sql:"..."`.  
38 -- `pg:",override"` is deprecated in favor of `pg:",inherit"`.  
39 -  
40 -## v8  
41 -  
42 -- Added `QueryContext`, `ExecContext`, and `ModelContext` which accept  
43 - `context.Context`. Queries are cancelled when context is cancelled.  
44 -- Model hooks are changed to accept `context.Context` as first argument.  
45 -- Fixed array and hstore parsers to handle multiple single quotes (#1235).  
46 -  
47 -## v7  
48 -  
49 -- DB.OnQueryProcessed is replaced with DB.AddQueryHook.  
50 -- Added WhereStruct.  
51 -- orm.Pager is moved to urlvalues.Pager. Pager.FromURLValues returns an error if  
52 - page or limit params can't be parsed.  
53 -  
54 -## v6.16  
55 -  
56 -- Read buffer is re-worked. Default read buffer is increased to 65kb.  
57 -  
58 -## v6.15  
59 -  
60 -- Added Options.MinIdleConns.  
61 -- Options.MaxAge renamed to Options.MaxConnAge.  
62 -- PoolStats.FreeConns is renamed to PoolStats.IdleConns.  
63 -- New hook BeforeSelectQuery.  
64 -- `,override` is renamed to `,inherit`.  
65 -- Dialer.KeepAlive is set to 5 minutes by default.  
66 -- Added support "scram-sha-256" authentication.  
67 -  
68 -## v6.14  
69 -  
70 -- Fields ignored with `sql:"-"` tag are no longer considered by ORM relation  
71 - detector.  
72 -  
73 -## v6.12  
74 -  
75 -- `Insert`, `Update`, and `Delete` can return `pg.ErrNoRows` and  
76 - `pg.ErrMultiRows` when `Returning` is used and model expects single row.  
77 -  
78 -## v6.11  
79 -  
80 -- `db.Model(&strct).Update()` and `db.Model(&strct).Delete()` no longer adds  
81 - WHERE condition based on primary key when there are no conditions. Instead you  
82 - should use `db.Update(&strct)` or `db.Model(&strct).WherePK().Update()`.  
83 -  
84 -## v6.10  
85 -  
86 -- `?Columns` is renamed to `?TableColumns`. `?Columns` is changed to produce  
87 - column names without table alias.  
88 -  
89 -## v6.9  
90 -  
91 -- `pg:"fk"` tag now accepts SQL names instead of Go names, e.g.  
92 - `pg:"fk:ParentId"` becomes `pg:"fk:parent_id"`. Old code should continue  
93 - working in most cases, but it is strongly advised to start using new  
94 - convention.  
95 -- uint and uint64 SQL type is changed from decimal to bigint according to the  
96 - lesser of two evils principle. Use `sql:"type:decimal"` to get old behavior.  
97 -  
98 -## v6.8  
99 -  
100 -- `CreateTable` no longer adds ON DELETE hook by default. To get old behavior  
101 - users should add `sql:"on_delete:CASCADE"` tag on foreign key field.  
102 -  
103 -## v6  
104 -  
105 -- `types.Result` is renamed to `orm.Result`.  
106 -- Added `OnQueryProcessed` event that can be used to log / report queries  
107 - timing. Query logger is removed.  
108 -- `orm.URLValues` is renamed to `orm.URLFilters`. It no longer adds ORDER  
109 - clause.  
110 -- `orm.Pager` is renamed to `orm.Pagination`.  
111 -- Support for net.IP and net.IPNet.  
112 -- Support for context.Context.  
113 -- Bulk/multi updates.  
114 -- Query.WhereGroup for enclosing conditions in parentheses.  
115 -  
116 -## v5  
117 -  
118 -- All fields are nullable by default. `,null` tag is replaced with `,notnull`.  
119 -- `Result.Affected` renamed to `Result.RowsAffected`.  
120 -- Added `Result.RowsReturned`.  
121 -- `Create` renamed to `Insert`, `BeforeCreate` to `BeforeInsert`, `AfterCreate`  
122 - to `AfterInsert`.  
123 -- Indexed placeholders support, e.g. `db.Exec("SELECT ?0 + ?0", 1)`.  
124 -- Named placeholders are evaluated when query is executed.  
125 -- Added Update and Delete hooks.  
126 -- Order reworked to quote column names. OrderExpr added to bypass Order quoting  
127 - restrictions.  
128 -- Group reworked to quote column names. GroupExpr added to bypass Group quoting  
129 - restrictions.  
130 -  
131 -## v4  
132 -  
133 -- `Options.Host` and `Options.Port` merged into `Options.Addr`.  
134 -- Added `Options.MaxRetries`. Now queries are not retried by default.  
135 -- `LoadInto` renamed to `Scan`, `ColumnLoader` renamed to `ColumnScanner`,  
136 - LoadColumn renamed to ScanColumn, `NewRecord() interface{}` changed to  
137 - `NewModel() ColumnScanner`, `AppendQuery(dst []byte) []byte` changed to  
138 - `AppendValue(dst []byte, quote bool) ([]byte, error)`.  
139 -- Structs, maps and slices are marshalled to JSON by default.  
140 -- Added support for scanning slices, .e.g. scanning `[]int`.  
141 -- Added object relational mapping. 5 +See https://pg.uptrace.dev/changelog/
1 all: 1 all:
2 - go test ./...  
3 - go test ./... -short -race  
4 - go test ./... -run=NONE -bench=. -benchmem 2 + TZ= go test ./...
  3 + TZ= go test ./... -short -race
  4 + TZ= go test ./... -run=NONE -bench=. -benchmem
5 env GOOS=linux GOARCH=386 go test ./... 5 env GOOS=linux GOARCH=386 go test ./...
  6 + go vet
6 golangci-lint run 7 golangci-lint run
  8 +
  9 +.PHONY: cleanTest
  10 +cleanTest:
  11 + docker rm -fv pg || true
  12 +
  13 +.PHONY: pre-test
  14 +pre-test: cleanTest
  15 + docker run -d --name pg -p 5432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust postgres:9.6
  16 + sleep 10
  17 + docker exec pg psql -U postgres -c "CREATE EXTENSION hstore"
  18 +
  19 +.PHONY: test
  20 +test: pre-test
  21 + TZ= PGSSLMODE=disable go test ./... -v
1 # PostgreSQL client and ORM for Golang 1 # PostgreSQL client and ORM for Golang
2 2
3 -[![Build Status](https://travis-ci.org/go-pg/pg.svg?branch=master)](https://travis-ci.org/go-pg/pg)  
4 -[![GoDoc](https://godoc.org/github.com/go-pg/pg?status.svg)](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc) 3 +[![Build Status](https://travis-ci.org/go-pg/pg.svg?branch=v10)](https://travis-ci.org/go-pg/pg)
  4 +[![PkgGoDev](https://pkg.go.dev/badge/github.com/go-pg/pg/v10)](https://pkg.go.dev/github.com/go-pg/pg/v10)
  5 +[![Documentation](https://img.shields.io/badge/pg-documentation-informational)](https://pg.uptrace.dev/)
  6 +[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj)
5 7
6 -- [Docs](https://pg.uptrace.dev) 8 +> :heart: [**Uptrace.dev** - distributed traces, logs, and errors in one place](https://uptrace.dev)
  9 +
  10 +- Join [Discord](https://discord.gg/rWtp5Aj) to ask questions.
  11 +- [Documentation](https://pg.uptrace.dev)
7 - [Reference](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc) 12 - [Reference](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc)
8 - [Examples](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#pkg-examples) 13 - [Examples](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#pkg-examples)
  14 +- Example projects:
  15 + - [treemux](https://github.com/uptrace/go-treemux-realworld-example-app)
  16 + - [gin](https://github.com/gogjango/gjango)
  17 + - [go-kit](https://github.com/Tsovak/rest-api-demo)
  18 + - [aah framework](https://github.com/kieusonlam/golamapi)
  19 +- [GraphQL Tutorial on YouTube](https://www.youtube.com/playlist?list=PLzQWIQOqeUSNwXcneWYJHUREAIucJ5UZn).
9 20
10 ## Ecosystem 21 ## Ecosystem
11 22
12 - Migrations by [vmihailenco](https://github.com/go-pg/migrations) and 23 - Migrations by [vmihailenco](https://github.com/go-pg/migrations) and
13 [robinjoseph08](https://github.com/robinjoseph08/go-pg-migrations). 24 [robinjoseph08](https://github.com/robinjoseph08/go-pg-migrations).
  25 +- [Genna - cli tool for generating go-pg models](https://github.com/dizzyfool/genna).
  26 +- [urlstruct](https://github.com/go-pg/urlstruct) to decode `url.Values` into structs.
14 - [Sharding](https://github.com/go-pg/sharding). 27 - [Sharding](https://github.com/go-pg/sharding).
15 -- [Model generator from SQL tables](https://github.com/dizzyfool/genna).  
16 -- [urlstruct](https://github.com/go-pg/urlstruct) to decode `url.Values` into  
17 - structs.  
18 -  
19 -## Sponsors  
20 -  
21 -- [**Uptrace.dev** - distributed traces and metrics](https://uptrace.dev)  
22 28
23 ## Features 29 ## Features
24 30
@@ -26,71 +32,200 @@ @@ -26,71 +32,200 @@
26 - sql.NullBool, sql.NullString, sql.NullInt64, sql.NullFloat64 and 32 - sql.NullBool, sql.NullString, sql.NullInt64, sql.NullFloat64 and
27 [pg.NullTime](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#NullTime). 33 [pg.NullTime](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#NullTime).
28 - [sql.Scanner](http://golang.org/pkg/database/sql/#Scanner) and 34 - [sql.Scanner](http://golang.org/pkg/database/sql/#Scanner) and
29 - [sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer)  
30 - interfaces. 35 + [sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer) interfaces.
31 - Structs, maps and arrays are marshalled as JSON by default. 36 - Structs, maps and arrays are marshalled as JSON by default.
32 - PostgreSQL multidimensional Arrays using 37 - PostgreSQL multidimensional Arrays using
33 [array tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-PostgresArrayStructTag) 38 [array tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-PostgresArrayStructTag)
34 - and  
35 - [Array wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Array). 39 + and [Array wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Array).
36 - Hstore using 40 - Hstore using
37 [hstore tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HstoreStructTag) 41 [hstore tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HstoreStructTag)
38 - and  
39 - [Hstore wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Hstore). 42 + and [Hstore wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Hstore).
40 - [Composite types](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CompositeType). 43 - [Composite types](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CompositeType).
41 -- All struct fields are nullable by default and zero values (empty string, 0,  
42 - zero time, empty map or slice, nil ptr) are marshalled as SQL `NULL`.  
43 - `pg:",notnull"` is used to add SQL `NOT NULL` constraint and `pg:",use_zero"`  
44 - to allow Go zero values. 44 +- All struct fields are nullable by default and zero values (empty string, 0, zero time, empty map
  45 + or slice, nil ptr) are marshalled as SQL `NULL`. `pg:",notnull"` is used to add SQL `NOT NULL`
  46 + constraint and `pg:",use_zero"` to allow Go zero values.
45 - [Transactions](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Begin). 47 - [Transactions](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Begin).
46 - [Prepared statements](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Prepare). 48 - [Prepared statements](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Prepare).
47 -- [Notifications](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Listener)  
48 - using `LISTEN` and `NOTIFY`.  
49 -- [Copying data](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CopyFrom)  
50 - using `COPY FROM` and `COPY TO`.  
51 -- [Timeouts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#Options) and  
52 - canceling queries using context.Context. 49 +- [Notifications](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Listener) using
  50 + `LISTEN` and `NOTIFY`.
  51 +- [Copying data](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CopyFrom) using
  52 + `COPY FROM` and `COPY TO`.
  53 +- [Timeouts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#Options) and canceling queries using
  54 + context.Context.
53 - Automatic connection pooling with 55 - Automatic connection pooling with
54 - [circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern)  
55 - support. 56 + [circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support.
56 - Queries retry on network errors. 57 - Queries retry on network errors.
57 - Working with models using 58 - Working with models using
58 - [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model) and  
59 - [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Query). 59 + [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model) and
  60 + [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Query).
60 - Scanning variables using 61 - Scanning variables using
61 - [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-SomeColumnsIntoVars) 62 + [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectSomeColumnsIntoVars)
62 and [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Scan). 63 and [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Scan).
63 -- [SelectOrInsert](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-SelectOrInsert) 64 +- [SelectOrInsert](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertSelectOrInsert)
64 using on-conflict. 65 using on-conflict.
65 -- [INSERT ... ON CONFLICT DO UPDATE](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-OnConflictDoUpdate) 66 +- [INSERT ... ON CONFLICT DO UPDATE](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertOnConflictDoUpdate)
66 using ORM. 67 using ORM.
67 - Bulk/batch 68 - Bulk/batch
68 - [inserts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-BulkInsert),  
69 - [updates](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Update-BulkUpdate),  
70 - and  
71 - [deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Delete-BulkDelete). 69 + [inserts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkInsert),
  70 + [updates](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkUpdate), and
  71 + [deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkDelete).
72 - Common table expressions using 72 - Common table expressions using
73 - [WITH](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-With)  
74 - and  
75 - [WrapWith](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-WrapWith).  
76 -- [CountEstimate](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CountEstimate) 73 + [WITH](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWith) and
  74 + [WrapWith](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWrapWith).
  75 +- [CountEstimate](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CountEstimate)
77 using `EXPLAIN` to get 76 using `EXPLAIN` to get
78 [estimated number of matching rows](https://wiki.postgresql.org/wiki/Count_estimate). 77 [estimated number of matching rows](https://wiki.postgresql.org/wiki/Count_estimate).
79 - ORM supports 78 - ORM supports
80 - [has one](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HasOne),  
81 - [belongs to](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-BelongsTo),  
82 - [has many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HasMany),  
83 - and  
84 - [many to many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-ManyToMany) 79 + [has one](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasOne),
  80 + [belongs to](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BelongsTo),
  81 + [has many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasMany), and
  82 + [many to many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ManyToMany)
85 with composite/multi-column primary keys. 83 with composite/multi-column primary keys.
86 -- [Soft deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-SoftDelete).  
87 -- [Creating tables from structs](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CreateTable).  
88 -- [ForEach](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-ForEach)  
89 - that calls a function for each row returned by the query without loading all  
90 - rows into the memory. 84 +- [Soft deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SoftDelete).
  85 +- [Creating tables from structs](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CreateTable).
  86 +- [ForEach](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ForEach) that calls
  87 + a function for each row returned by the query without loading all rows into the memory.
91 - Works with PgBouncer in transaction pooling mode. 88 - Works with PgBouncer in transaction pooling mode.
92 89
  90 +## Installation
  91 +
  92 +go-pg supports 2 last Go versions and requires a Go version with
  93 +[modules](https://github.com/golang/go/wiki/Modules) support. So make sure to initialize a Go
  94 +module:
  95 +
  96 +```shell
  97 +go mod init github.com/my/repo
  98 +```
  99 +
  100 +And then install go-pg (note _v10_ in the import; omitting it is a popular mistake):
  101 +
  102 +```shell
  103 +go get github.com/go-pg/pg/v10
  104 +```
  105 +
  106 +## Quickstart
  107 +
  108 +```go
  109 +package pg_test
  110 +
  111 +import (
  112 + "fmt"
  113 +
  114 + "github.com/go-pg/pg/v10"
  115 + "github.com/go-pg/pg/v10/orm"
  116 +)
  117 +
  118 +type User struct {
  119 + Id int64
  120 + Name string
  121 + Emails []string
  122 +}
  123 +
  124 +func (u User) String() string {
  125 + return fmt.Sprintf("User<%d %s %v>", u.Id, u.Name, u.Emails)
  126 +}
  127 +
  128 +type Story struct {
  129 + Id int64
  130 + Title string
  131 + AuthorId int64
  132 + Author *User `pg:"rel:has-one"`
  133 +}
  134 +
  135 +func (s Story) String() string {
  136 + return fmt.Sprintf("Story<%d %s %s>", s.Id, s.Title, s.Author)
  137 +}
  138 +
  139 +func ExampleDB_Model() {
  140 + db := pg.Connect(&pg.Options{
  141 + User: "postgres",
  142 + })
  143 + defer db.Close()
  144 +
  145 + err := createSchema(db)
  146 + if err != nil {
  147 + panic(err)
  148 + }
  149 +
  150 + user1 := &User{
  151 + Name: "admin",
  152 + Emails: []string{"admin1@admin", "admin2@admin"},
  153 + }
  154 + _, err = db.Model(user1).Insert()
  155 + if err != nil {
  156 + panic(err)
  157 + }
  158 +
  159 + _, err = db.Model(&User{
  160 + Name: "root",
  161 + Emails: []string{"root1@root", "root2@root"},
  162 + }).Insert()
  163 + if err != nil {
  164 + panic(err)
  165 + }
  166 +
  167 + story1 := &Story{
  168 + Title: "Cool story",
  169 + AuthorId: user1.Id,
  170 + }
  171 + _, err = db.Model(story1).Insert()
  172 + if err != nil {
  173 + panic(err)
  174 + }
  175 +
  176 + // Select user by primary key.
  177 + user := &User{Id: user1.Id}
  178 + err = db.Model(user).WherePK().Select()
  179 + if err != nil {
  180 + panic(err)
  181 + }
  182 +
  183 + // Select all users.
  184 + var users []User
  185 + err = db.Model(&users).Select()
  186 + if err != nil {
  187 + panic(err)
  188 + }
  189 +
  190 + // Select story and associated author in one query.
  191 + story := new(Story)
  192 + err = db.Model(story).
  193 + Relation("Author").
  194 + Where("story.id = ?", story1.Id).
  195 + Select()
  196 + if err != nil {
  197 + panic(err)
  198 + }
  199 +
  200 + fmt.Println(user)
  201 + fmt.Println(users)
  202 + fmt.Println(story)
  203 + // Output: User<1 admin [admin1@admin admin2@admin]>
  204 + // [User<1 admin [admin1@admin admin2@admin]> User<2 root [root1@root root2@root]>]
  205 + // Story<1 Cool story User<1 admin [admin1@admin admin2@admin]>>
  206 +}
  207 +
  208 +// createSchema creates database schema for User and Story models.
  209 +func createSchema(db *pg.DB) error {
  210 + models := []interface{}{
  211 + (*User)(nil),
  212 + (*Story)(nil),
  213 + }
  214 +
  215 + for _, model := range models {
  216 + err := db.Model(model).CreateTable(&orm.CreateTableOptions{
  217 + Temp: true,
  218 + })
  219 + if err != nil {
  220 + return err
  221 + }
  222 + }
  223 + return nil
  224 +}
  225 +```
  226 +
93 ## See also 227 ## See also
94 228
  229 +- [Fast and flexible HTTP router](https://github.com/vmihailenco/treemux)
95 - [Golang msgpack](https://github.com/vmihailenco/msgpack) 230 - [Golang msgpack](https://github.com/vmihailenco/msgpack)
96 - [Golang message task queue](https://github.com/vmihailenco/taskq) 231 - [Golang message task queue](https://github.com/vmihailenco/taskq)
@@ -5,12 +5,13 @@ import ( @@ -5,12 +5,13 @@ import (
5 "io" 5 "io"
6 "time" 6 "time"
7 7
8 - "go.opentelemetry.io/otel/api/kv"  
9 - "go.opentelemetry.io/otel/api/trace" 8 + "go.opentelemetry.io/otel/label"
  9 + "go.opentelemetry.io/otel/trace"
10 10
11 "github.com/go-pg/pg/v10/internal" 11 "github.com/go-pg/pg/v10/internal"
12 "github.com/go-pg/pg/v10/internal/pool" 12 "github.com/go-pg/pg/v10/internal/pool"
13 "github.com/go-pg/pg/v10/orm" 13 "github.com/go-pg/pg/v10/orm"
  14 + "github.com/go-pg/pg/v10/types"
14 ) 15 )
15 16
16 type baseDB struct { 17 type baseDB struct {
@@ -83,14 +84,14 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) { @@ -83,14 +84,14 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) {
83 return cn, nil 84 return cn, nil
84 } 85 }
85 86
86 - err = internal.WithSpan(ctx, "init_conn", func(ctx context.Context, span trace.Span) error { 87 + err = internal.WithSpan(ctx, "pg.init_conn", func(ctx context.Context, span trace.Span) error {
87 return db.initConn(ctx, cn) 88 return db.initConn(ctx, cn)
88 }) 89 })
89 if err != nil { 90 if err != nil {
90 - db.pool.Remove(cn, err)  
91 - // It is safe to reset SingleConnPool if conn can't be initialized.  
92 - if p, ok := db.pool.(*pool.SingleConnPool); ok {  
93 - _ = p.Reset() 91 + db.pool.Remove(ctx, cn, err)
  92 + // It is safe to reset StickyConnPool if conn can't be initialized.
  93 + if p, ok := db.pool.(*pool.StickyConnPool); ok {
  94 + _ = p.Reset(ctx)
94 } 95 }
95 if err := internal.Unwrap(err); err != nil { 96 if err := internal.Unwrap(err); err != nil {
96 return nil, err 97 return nil, err
@@ -101,45 +102,44 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) { @@ -101,45 +102,44 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) {
101 return cn, nil 102 return cn, nil
102 } 103 }
103 104
104 -func (db *baseDB) initConn(c context.Context, cn *pool.Conn) error { 105 +func (db *baseDB) initConn(ctx context.Context, cn *pool.Conn) error {
105 if cn.Inited { 106 if cn.Inited {
106 return nil 107 return nil
107 } 108 }
108 cn.Inited = true 109 cn.Inited = true
109 110
110 if db.opt.TLSConfig != nil { 111 if db.opt.TLSConfig != nil {
111 - err := db.enableSSL(c, cn, db.opt.TLSConfig) 112 + err := db.enableSSL(ctx, cn, db.opt.TLSConfig)
112 if err != nil { 113 if err != nil {
113 return err 114 return err
114 } 115 }
115 } 116 }
116 117
117 - err := db.startup(c, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName) 118 + err := db.startup(ctx, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName)
118 if err != nil { 119 if err != nil {
119 return err 120 return err
120 } 121 }
121 122
122 if db.opt.OnConnect != nil { 123 if db.opt.OnConnect != nil {
123 - p := pool.NewSingleConnPool(nil)  
124 - p.SetConn(cn)  
125 - return db.opt.OnConnect(newConn(c, db.withPool(p))) 124 + p := pool.NewSingleConnPool(db.pool, cn)
  125 + return db.opt.OnConnect(ctx, newConn(ctx, db.withPool(p)))
126 } 126 }
127 127
128 return nil 128 return nil
129 } 129 }
130 130
131 -func (db *baseDB) releaseConn(cn *pool.Conn, err error) { 131 +func (db *baseDB) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
132 if isBadConn(err, false) { 132 if isBadConn(err, false) {
133 - db.pool.Remove(cn, err) 133 + db.pool.Remove(ctx, cn, err)
134 } else { 134 } else {
135 - db.pool.Put(cn) 135 + db.pool.Put(ctx, cn)
136 } 136 }
137 } 137 }
138 138
139 func (db *baseDB) withConn( 139 func (db *baseDB) withConn(
140 ctx context.Context, fn func(context.Context, *pool.Conn) error, 140 ctx context.Context, fn func(context.Context, *pool.Conn) error,
141 ) error { 141 ) error {
142 - return internal.WithSpan(ctx, "with_conn", func(ctx context.Context, span trace.Span) error { 142 + return internal.WithSpan(ctx, "pg.with_conn", func(ctx context.Context, span trace.Span) error {
143 cn, err := db.getConn(ctx) 143 cn, err := db.getConn(ctx)
144 if err != nil { 144 if err != nil {
145 return err 145 return err
@@ -154,7 +154,7 @@ func (db *baseDB) withConn( @@ -154,7 +154,7 @@ func (db *baseDB) withConn(
154 case <-ctx.Done(): 154 case <-ctx.Done():
155 err := db.cancelRequest(cn.ProcessID, cn.SecretKey) 155 err := db.cancelRequest(cn.ProcessID, cn.SecretKey)
156 if err != nil { 156 if err != nil {
157 - internal.Logger.Printf("cancelRequest failed: %s", err) 157 + internal.Logger.Printf(ctx, "cancelRequest failed: %s", err)
158 } 158 }
159 // Signal end of conn use. 159 // Signal end of conn use.
160 fnDone <- struct{}{} 160 fnDone <- struct{}{}
@@ -169,7 +169,7 @@ func (db *baseDB) withConn( @@ -169,7 +169,7 @@ func (db *baseDB) withConn(
169 case fnDone <- struct{}{}: // signal fn finish, skip cancel goroutine 169 case fnDone <- struct{}{}: // signal fn finish, skip cancel goroutine
170 } 170 }
171 } 171 }
172 - db.releaseConn(cn, err) 172 + db.releaseConn(ctx, cn, err)
173 }() 173 }()
174 174
175 err = fn(ctx, cn) 175 err = fn(ctx, cn)
@@ -179,9 +179,12 @@ func (db *baseDB) withConn( @@ -179,9 +179,12 @@ func (db *baseDB) withConn(
179 179
180 func (db *baseDB) shouldRetry(err error) bool { 180 func (db *baseDB) shouldRetry(err error) bool {
181 switch err { 181 switch err {
  182 + case io.EOF, io.ErrUnexpectedEOF:
  183 + return true
182 case nil, context.Canceled, context.DeadlineExceeded: 184 case nil, context.Canceled, context.DeadlineExceeded:
183 return false 185 return false
184 } 186 }
  187 +
185 if pgerr, ok := err.(Error); ok { 188 if pgerr, ok := err.(Error); ok {
186 switch pgerr.Field('C') { 189 switch pgerr.Field('C') {
187 case "40001", // serialization_failure 190 case "40001", // serialization_failure
@@ -194,7 +197,12 @@ func (db *baseDB) shouldRetry(err error) bool { @@ -194,7 +197,12 @@ func (db *baseDB) shouldRetry(err error) bool {
194 return false 197 return false
195 } 198 }
196 } 199 }
197 - return isNetworkError(err) 200 +
  201 + if _, ok := err.(timeoutError); ok {
  202 + return true
  203 + }
  204 +
  205 + return false
198 } 206 }
199 207
200 // Close closes the database client, releasing any open resources. 208 // Close closes the database client, releasing any open resources.
@@ -233,9 +241,9 @@ func (db *baseDB) exec(ctx context.Context, query interface{}, params ...interfa @@ -233,9 +241,9 @@ func (db *baseDB) exec(ctx context.Context, query interface{}, params ...interfa
233 for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ { 241 for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ {
234 attempt := attempt 242 attempt := attempt
235 243
236 - lastErr = internal.WithSpan(ctx, "exec", func(ctx context.Context, span trace.Span) error { 244 + lastErr = internal.WithSpan(ctx, "pg.exec", func(ctx context.Context, span trace.Span) error {
237 if attempt > 0 { 245 if attempt > 0 {
238 - span.SetAttributes(kv.Int("retry", attempt)) 246 + span.SetAttributes(label.Int("retry", attempt))
239 247
240 if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil { 248 if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil {
241 return err 249 return err
@@ -311,9 +319,9 @@ func (db *baseDB) query(ctx context.Context, model, query interface{}, params .. @@ -311,9 +319,9 @@ func (db *baseDB) query(ctx context.Context, model, query interface{}, params ..
311 for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ { 319 for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ {
312 attempt := attempt 320 attempt := attempt
313 321
314 - lastErr = internal.WithSpan(ctx, "query", func(ctx context.Context, span trace.Span) error { 322 + lastErr = internal.WithSpan(ctx, "pg.query", func(ctx context.Context, span trace.Span) error {
315 if attempt > 0 { 323 if attempt > 0 {
316 - span.SetAttributes(kv.Int("retry", attempt)) 324 + span.SetAttributes(label.Int("retry", attempt))
317 325
318 if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil { 326 if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil {
319 return err 327 return err
@@ -373,7 +381,7 @@ func (db *baseDB) CopyFrom(r io.Reader, query interface{}, params ...interface{} @@ -373,7 +381,7 @@ func (db *baseDB) CopyFrom(r io.Reader, query interface{}, params ...interface{}
373 return res, err 381 return res, err
374 } 382 }
375 383
376 -// TODO: don't get/put conn in the pool 384 +// TODO: don't get/put conn in the pool.
377 func (db *baseDB) copyFrom( 385 func (db *baseDB) copyFrom(
378 ctx context.Context, cn *pool.Conn, r io.Reader, query interface{}, params ...interface{}, 386 ctx context.Context, cn *pool.Conn, r io.Reader, query interface{}, params ...interface{},
379 ) (res Result, err error) { 387 ) (res Result, err error) {
@@ -396,6 +404,7 @@ func (db *baseDB) copyFrom( @@ -396,6 +404,7 @@ func (db *baseDB) copyFrom(
396 return nil, err 404 return nil, err
397 } 405 }
398 406
  407 + // Note that afterQuery uses the err.
399 defer func() { 408 defer func() {
400 if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil { 409 if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil {
401 err = afterQueryErr 410 err = afterQueryErr
@@ -434,7 +443,7 @@ func (db *baseDB) copyFrom( @@ -434,7 +443,7 @@ func (db *baseDB) copyFrom(
434 return nil, err 443 return nil, err
435 } 444 }
436 445
437 - err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.BufReader) error { 446 + err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
438 res, err = readReadyForQuery(rd) 447 res, err = readReadyForQuery(rd)
439 return err 448 return err
440 }) 449 })
@@ -456,7 +465,7 @@ func (db *baseDB) CopyTo(w io.Writer, query interface{}, params ...interface{}) @@ -456,7 +465,7 @@ func (db *baseDB) CopyTo(w io.Writer, query interface{}, params ...interface{})
456 } 465 }
457 466
458 func (db *baseDB) copyTo( 467 func (db *baseDB) copyTo(
459 - c context.Context, cn *pool.Conn, w io.Writer, query interface{}, params ...interface{}, 468 + ctx context.Context, cn *pool.Conn, w io.Writer, query interface{}, params ...interface{},
460 ) (res Result, err error) { 469 ) (res Result, err error) {
461 var evt *QueryEvent 470 var evt *QueryEvent
462 471
@@ -472,25 +481,26 @@ func (db *baseDB) copyTo( @@ -472,25 +481,26 @@ func (db *baseDB) copyTo(
472 model, _ = params[len(params)-1].(orm.TableModel) 481 model, _ = params[len(params)-1].(orm.TableModel)
473 } 482 }
474 483
475 - c, evt, err = db.beforeQuery(c, db.db, model, query, params, wb.Query()) 484 + ctx, evt, err = db.beforeQuery(ctx, db.db, model, query, params, wb.Query())
476 if err != nil { 485 if err != nil {
477 return nil, err 486 return nil, err
478 } 487 }
479 488
  489 + // Note that afterQuery uses the err.
480 defer func() { 490 defer func() {
481 - if afterQueryErr := db.afterQuery(c, evt, res, err); afterQueryErr != nil { 491 + if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil {
482 err = afterQueryErr 492 err = afterQueryErr
483 } 493 }
484 }() 494 }()
485 495
486 - err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { 496 + err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
487 return writeQueryMsg(wb, db.fmter, query, params...) 497 return writeQueryMsg(wb, db.fmter, query, params...)
488 }) 498 })
489 if err != nil { 499 if err != nil {
490 return nil, err 500 return nil, err
491 } 501 }
492 502
493 - err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { 503 + err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
494 err := readCopyOutResponse(rd) 504 err := readCopyOutResponse(rd)
495 if err != nil { 505 if err != nil {
496 return err 506 return err
@@ -522,52 +532,6 @@ func (db *baseDB) ModelContext(c context.Context, model ...interface{}) *orm.Que @@ -522,52 +532,6 @@ func (db *baseDB) ModelContext(c context.Context, model ...interface{}) *orm.Que
522 return orm.NewQueryContext(c, db.db, model...) 532 return orm.NewQueryContext(c, db.db, model...)
523 } 533 }
524 534
525 -// Select selects the model by primary key.  
526 -func (db *baseDB) Select(model interface{}) error {  
527 - return orm.Select(db.db, model)  
528 -}  
529 -  
530 -// Insert inserts the model updating primary keys if they are empty.  
531 -func (db *baseDB) Insert(model ...interface{}) error {  
532 - return orm.Insert(db.db, model...)  
533 -}  
534 -  
535 -// Update updates the model by primary key.  
536 -func (db *baseDB) Update(model interface{}) error {  
537 - return orm.Update(db.db, model)  
538 -}  
539 -  
540 -// Delete deletes the model by primary key.  
541 -func (db *baseDB) Delete(model interface{}) error {  
542 - return orm.Delete(db.db, model)  
543 -}  
544 -  
545 -// Delete forces delete of the model with deleted_at column.  
546 -func (db *baseDB) ForceDelete(model interface{}) error {  
547 - return orm.ForceDelete(db.db, model)  
548 -}  
549 -  
550 -// CreateTable creates table for the model. It recognizes following field tags:  
551 -// - notnull - sets NOT NULL constraint.  
552 -// - unique - sets UNIQUE constraint.  
553 -// - default:value - sets default value.  
554 -func (db *baseDB) CreateTable(model interface{}, opt *orm.CreateTableOptions) error {  
555 - return orm.CreateTable(db.db, model, opt)  
556 -}  
557 -  
558 -// DropTable drops table for the model.  
559 -func (db *baseDB) DropTable(model interface{}, opt *orm.DropTableOptions) error {  
560 - return orm.DropTable(db.db, model, opt)  
561 -}  
562 -  
563 -func (db *baseDB) CreateComposite(model interface{}, opt *orm.CreateCompositeOptions) error {  
564 - return orm.CreateComposite(db.db, model, opt)  
565 -}  
566 -  
567 -func (db *baseDB) DropComposite(model interface{}, opt *orm.DropCompositeOptions) error {  
568 - return orm.DropComposite(db.db, model, opt)  
569 -}  
570 -  
571 func (db *baseDB) Formatter() orm.QueryFormatter { 535 func (db *baseDB) Formatter() orm.QueryFormatter {
572 return db.fmter 536 return db.fmter
573 } 537 }
@@ -597,7 +561,7 @@ func (db *baseDB) simpleQuery( @@ -597,7 +561,7 @@ func (db *baseDB) simpleQuery(
597 } 561 }
598 562
599 var res *result 563 var res *result
600 - if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { 564 + if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
601 var err error 565 var err error
602 res, err = readSimpleQuery(rd) 566 res, err = readSimpleQuery(rd)
603 return err 567 return err
@@ -616,7 +580,7 @@ func (db *baseDB) simpleQueryData( @@ -616,7 +580,7 @@ func (db *baseDB) simpleQueryData(
616 } 580 }
617 581
618 var res *result 582 var res *result
619 - if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { 583 + if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
620 var err error 584 var err error
621 res, err = readSimpleQueryData(c, rd, model) 585 res, err = readSimpleQueryData(c, rd, model)
622 return err 586 return err
@@ -631,12 +595,12 @@ func (db *baseDB) simpleQueryData( @@ -631,12 +595,12 @@ func (db *baseDB) simpleQueryData(
631 // executions. Multiple queries or executions may be run concurrently 595 // executions. Multiple queries or executions may be run concurrently
632 // from the returned statement. 596 // from the returned statement.
633 func (db *baseDB) Prepare(q string) (*Stmt, error) { 597 func (db *baseDB) Prepare(q string) (*Stmt, error) {
634 - return prepareStmt(db.withPool(pool.NewSingleConnPool(db.pool)), q) 598 + return prepareStmt(db.withPool(pool.NewStickyConnPool(db.pool)), q)
635 } 599 }
636 600
637 func (db *baseDB) prepare( 601 func (db *baseDB) prepare(
638 c context.Context, cn *pool.Conn, q string, 602 c context.Context, cn *pool.Conn, q string,
639 -) (string, [][]byte, error) { 603 +) (string, []types.ColumnInfo, error) {
640 name := cn.NextID() 604 name := cn.NextID()
641 err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { 605 err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
642 writeParseDescribeSyncMsg(wb, name, q) 606 writeParseDescribeSyncMsg(wb, name, q)
@@ -646,8 +610,8 @@ func (db *baseDB) prepare( @@ -646,8 +610,8 @@ func (db *baseDB) prepare(
646 return "", nil, err 610 return "", nil, err
647 } 611 }
648 612
649 - var columns [][]byte  
650 - err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { 613 + var columns []types.ColumnInfo
  614 + err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
651 columns, err = readParseDescribeSync(rd) 615 columns, err = readParseDescribeSync(rd)
652 return err 616 return err
653 }) 617 })
@@ -75,12 +75,12 @@ func (db *DB) WithParam(param string, value interface{}) *DB { @@ -75,12 +75,12 @@ func (db *DB) WithParam(param string, value interface{}) *DB {
75 } 75 }
76 76
77 // Listen listens for notifications sent with NOTIFY command. 77 // Listen listens for notifications sent with NOTIFY command.
78 -func (db *DB) Listen(channels ...string) *Listener { 78 +func (db *DB) Listen(ctx context.Context, channels ...string) *Listener {
79 ln := &Listener{ 79 ln := &Listener{
80 db: db, 80 db: db,
81 } 81 }
82 ln.init() 82 ln.init()
83 - _ = ln.Listen(channels...) 83 + _ = ln.Listen(ctx, channels...)
84 return ln 84 return ln
85 } 85 }
86 86
@@ -105,7 +105,7 @@ var _ orm.DB = (*Conn)(nil) @@ -105,7 +105,7 @@ var _ orm.DB = (*Conn)(nil)
105 // Every Conn must be returned to the database pool after use by 105 // Every Conn must be returned to the database pool after use by
106 // calling Conn.Close. 106 // calling Conn.Close.
107 func (db *DB) Conn() *Conn { 107 func (db *DB) Conn() *Conn {
108 - return newConn(db.ctx, db.baseDB.withPool(pool.NewSingleConnPool(db.pool))) 108 + return newConn(db.ctx, db.baseDB.withPool(pool.NewStickyConnPool(db.pool)))
109 } 109 }
110 110
111 func newConn(ctx context.Context, baseDB *baseDB) *Conn { 111 func newConn(ctx context.Context, baseDB *baseDB) *Conn {
1 package pg 1 package pg
2 2
3 import ( 3 import (
4 - "io"  
5 "net" 4 "net"
6 5
7 "github.com/go-pg/pg/v10/internal" 6 "github.com/go-pg/pg/v10/internal"
@@ -22,10 +21,10 @@ var ErrMultiRows = internal.ErrMultiRows @@ -22,10 +21,10 @@ var ErrMultiRows = internal.ErrMultiRows
22 type Error interface { 21 type Error interface {
23 error 22 error
24 23
25 - // Field returns a string value associated with an error code. 24 + // Field returns a string value associated with an error field.
26 // 25 //
27 // https://www.postgresql.org/docs/10/static/protocol-error-fields.html 26 // https://www.postgresql.org/docs/10/static/protocol-error-fields.html
28 - Field(byte) string 27 + Field(field byte) string
29 28
30 // IntegrityViolation reports whether an error is a part of 29 // IntegrityViolation reports whether an error is a part of
31 // Integrity Constraint Violation class of errors. 30 // Integrity Constraint Violation class of errors.
@@ -43,21 +42,19 @@ func isBadConn(err error, allowTimeout bool) bool { @@ -43,21 +42,19 @@ func isBadConn(err error, allowTimeout bool) bool {
43 if _, ok := err.(internal.Error); ok { 42 if _, ok := err.(internal.Error); ok {
44 return false 43 return false
45 } 44 }
46 - if pgErr, ok := err.(Error); ok && pgErr.Field('S') != "FATAL" {  
47 - return false 45 + if pgErr, ok := err.(Error); ok {
  46 + return pgErr.Field('S') == "FATAL"
48 } 47 }
49 if allowTimeout { 48 if allowTimeout {
50 if netErr, ok := err.(net.Error); ok && netErr.Timeout() { 49 if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
51 - return false 50 + return !netErr.Temporary()
52 } 51 }
53 } 52 }
54 return true 53 return true
55 } 54 }
56 55
57 -func isNetworkError(err error) bool {  
58 - if err == io.EOF {  
59 - return true  
60 - }  
61 - _, ok := err.(net.Error)  
62 - return ok 56 +//------------------------------------------------------------------------------
  57 +
  58 +type timeoutError interface {
  59 + Timeout() bool
63 } 60 }
@@ -3,25 +3,24 @@ module github.com/go-pg/pg/v10 @@ -3,25 +3,24 @@ module github.com/go-pg/pg/v10
3 go 1.11 3 go 1.11
4 4
5 require ( 5 require (
6 - github.com/go-pg/pg/v9 v9.1.6 // indirect  
7 - github.com/go-pg/urlstruct v0.4.0  
8 - github.com/go-pg/zerochecker v0.1.1  
9 - github.com/golang/protobuf v1.4.2 // indirect 6 + github.com/go-pg/zerochecker v0.2.0
  7 + github.com/golang/protobuf v1.4.3 // indirect
10 github.com/jinzhu/inflection v1.0.0 8 github.com/jinzhu/inflection v1.0.0
11 - github.com/onsi/ginkgo v1.10.1  
12 - github.com/onsi/gomega v1.7.0  
13 - github.com/segmentio/encoding v0.1.13 9 + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
  10 + github.com/onsi/ginkgo v1.14.2
  11 + github.com/onsi/gomega v1.10.3
  12 + github.com/stretchr/testify v1.6.1
14 github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc 13 github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
15 github.com/vmihailenco/bufpool v0.1.11 14 github.com/vmihailenco/bufpool v0.1.11
16 - github.com/vmihailenco/msgpack/v5 v5.0.0-beta.1  
17 - github.com/vmihailenco/tagparser v0.1.1  
18 - go.opentelemetry.io/otel v0.6.0  
19 - golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 // indirect  
20 - golang.org/x/net v0.0.0-20200602114024-627f9648deb9 // indirect  
21 - golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 // indirect  
22 - google.golang.org/appengine v1.6.6 // indirect  
23 - google.golang.org/grpc v1.29.1  
24 - google.golang.org/protobuf v1.24.0 // indirect  
25 - gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 15 + github.com/vmihailenco/msgpack/v4 v4.3.11 // indirect
  16 + github.com/vmihailenco/msgpack/v5 v5.0.0
  17 + github.com/vmihailenco/tagparser v0.1.2
  18 + go.opentelemetry.io/otel v0.14.0
  19 + golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 // indirect
  20 + golang.org/x/net v0.0.0-20201110031124-69a78807bb2b // indirect
  21 + golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect
  22 + google.golang.org/appengine v1.6.7 // indirect
  23 + google.golang.org/protobuf v1.25.0 // indirect
  24 + gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f
26 mellium.im/sasl v0.2.1 25 mellium.im/sasl v0.2.1
27 ) 26 )
@@ -8,15 +8,17 @@ import ( @@ -8,15 +8,17 @@ import (
8 "github.com/go-pg/pg/v10/orm" 8 "github.com/go-pg/pg/v10/orm"
9 ) 9 )
10 10
11 -type BeforeScanHook = orm.BeforeScanHook  
12 -type AfterScanHook = orm.AfterScanHook  
13 -type AfterSelectHook = orm.AfterSelectHook  
14 -type BeforeInsertHook = orm.BeforeInsertHook  
15 -type AfterInsertHook = orm.AfterInsertHook  
16 -type BeforeUpdateHook = orm.BeforeUpdateHook  
17 -type AfterUpdateHook = orm.AfterUpdateHook  
18 -type BeforeDeleteHook = orm.BeforeDeleteHook  
19 -type AfterDeleteHook = orm.AfterDeleteHook 11 +type (
  12 + BeforeScanHook = orm.BeforeScanHook
  13 + AfterScanHook = orm.AfterScanHook
  14 + AfterSelectHook = orm.AfterSelectHook
  15 + BeforeInsertHook = orm.BeforeInsertHook
  16 + AfterInsertHook = orm.AfterInsertHook
  17 + BeforeUpdateHook = orm.BeforeUpdateHook
  18 + AfterUpdateHook = orm.AfterUpdateHook
  19 + BeforeDeleteHook = orm.BeforeDeleteHook
  20 + AfterDeleteHook = orm.AfterDeleteHook
  21 +)
20 22
21 //------------------------------------------------------------------------------ 23 //------------------------------------------------------------------------------
22 24
@@ -94,11 +96,14 @@ func (db *baseDB) beforeQuery( @@ -94,11 +96,14 @@ func (db *baseDB) beforeQuery(
94 fmtedQuery: fmtedQuery, 96 fmtedQuery: fmtedQuery,
95 } 97 }
96 98
97 - for _, hook := range db.queryHooks { 99 + for i, hook := range db.queryHooks {
98 var err error 100 var err error
99 ctx, err = hook.BeforeQuery(ctx, event) 101 ctx, err = hook.BeforeQuery(ctx, event)
100 if err != nil { 102 if err != nil {
101 - return nil, nil, err 103 + if err := db.afterQueryFromIndex(ctx, event, i); err != nil {
  104 + return ctx, nil, err
  105 + }
  106 + return ctx, nil, err
102 } 107 }
103 } 108 }
104 109
@@ -117,14 +122,15 @@ func (db *baseDB) afterQuery( @@ -117,14 +122,15 @@ func (db *baseDB) afterQuery(
117 122
118 event.Err = err 123 event.Err = err
119 event.Result = res 124 event.Result = res
  125 + return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1)
  126 +}
120 127
121 - for _, hook := range db.queryHooks {  
122 - err := hook.AfterQuery(ctx, event)  
123 - if err != nil { 128 +func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error {
  129 + for ; hookIndex >= 0; hookIndex-- {
  130 + if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil {
124 return err 131 return err
125 } 132 }
126 } 133 }
127 -  
128 return nil 134 return nil
129 } 135 }
130 136
@@ -4,8 +4,10 @@ import ( @@ -4,8 +4,10 @@ import (
4 "fmt" 4 "fmt"
5 ) 5 )
6 6
7 -var ErrNoRows = Errorf("pg: no rows in result set")  
8 -var ErrMultiRows = Errorf("pg: multiple rows in result set") 7 +var (
  8 + ErrNoRows = Errorf("pg: no rows in result set")
  9 + ErrMultiRows = Errorf("pg: multiple rows in result set")
  10 +)
9 11
10 type Error struct { 12 type Error struct {
11 s string 13 s string
@@ -8,20 +8,20 @@ import ( @@ -8,20 +8,20 @@ import (
8 "time" 8 "time"
9 ) 9 )
10 10
11 -// Retry backoff with jitter sleep to prevent overloaded conditions during intervals  
12 -// https://www.awsarchitectureblog.com/2015/03/backoff.html  
13 func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration { 11 func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
14 if retry < 0 { 12 if retry < 0 {
15 - retry = 0 13 + panic("not reached")
16 } 14 }
17 -  
18 - backoff := minBackoff << uint(retry)  
19 - if backoff > maxBackoff || backoff < minBackoff {  
20 - backoff = maxBackoff 15 + if minBackoff == 0 {
  16 + return 0
21 } 17 }
22 18
23 - if backoff == 0 {  
24 - return 0 19 + d := minBackoff << uint(retry)
  20 + d = minBackoff + time.Duration(rand.Int63n(int64(d)))
  21 +
  22 + if d > maxBackoff || d < minBackoff {
  23 + d = maxBackoff
25 } 24 }
26 - return time.Duration(rand.Int63n(int64(backoff))) 25 +
  26 + return d
27 } 27 }
1 package internal 1 package internal
2 2
3 import ( 3 import (
  4 + "context"
  5 + "fmt"
4 "log" 6 "log"
5 "os" 7 "os"
6 ) 8 )
7 9
8 -var Logger = log.New(os.Stderr, "pg: ", log.LstdFlags|log.Lshortfile) 10 +var Warn = log.New(os.Stderr, "WARN: pg: ", log.LstdFlags)
  11 +
  12 +var Deprecated = log.New(os.Stderr, "DEPRECATED: pg: ", log.LstdFlags)
  13 +
  14 +type Logging interface {
  15 + Printf(ctx context.Context, format string, v ...interface{})
  16 +}
  17 +
  18 +type logger struct {
  19 + log *log.Logger
  20 +}
  21 +
  22 +func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
  23 + _ = l.log.Output(2, fmt.Sprintf(format, v...))
  24 +}
  25 +
  26 +var Logger Logging = &logger{
  27 + log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
  28 +}
@@ -8,16 +8,15 @@ import ( @@ -8,16 +8,15 @@ import (
8 "time" 8 "time"
9 9
10 "github.com/go-pg/pg/v10/internal" 10 "github.com/go-pg/pg/v10/internal"
11 - "go.opentelemetry.io/otel/api/kv"  
12 - "go.opentelemetry.io/otel/api/trace" 11 + "go.opentelemetry.io/otel/label"
  12 + "go.opentelemetry.io/otel/trace"
13 ) 13 )
14 14
15 var noDeadline = time.Time{} 15 var noDeadline = time.Time{}
16 16
17 type Conn struct { 17 type Conn struct {
18 netConn net.Conn 18 netConn net.Conn
19 -  
20 - rd *BufReader 19 + rd *ReaderContext
21 20
22 ProcessID int32 21 ProcessID int32
23 SecretKey int32 22 SecretKey int32
@@ -31,8 +30,6 @@ type Conn struct { @@ -31,8 +30,6 @@ type Conn struct {
31 30
32 func NewConn(netConn net.Conn) *Conn { 31 func NewConn(netConn net.Conn) *Conn {
33 cn := &Conn{ 32 cn := &Conn{
34 - rd: NewBufReader(netConn),  
35 -  
36 createdAt: time.Now(), 33 createdAt: time.Now(),
37 } 34 }
38 cn.SetNetConn(netConn) 35 cn.SetNetConn(netConn)
@@ -55,7 +52,17 @@ func (cn *Conn) RemoteAddr() net.Addr { @@ -55,7 +52,17 @@ func (cn *Conn) RemoteAddr() net.Addr {
55 52
56 func (cn *Conn) SetNetConn(netConn net.Conn) { 53 func (cn *Conn) SetNetConn(netConn net.Conn) {
57 cn.netConn = netConn 54 cn.netConn = netConn
58 - cn.rd.Reset(netConn) 55 + if cn.rd != nil {
  56 + cn.rd.Reset(netConn)
  57 + }
  58 +}
  59 +
  60 +func (cn *Conn) LockReader() {
  61 + if cn.rd != nil {
  62 + panic("not reached")
  63 + }
  64 + cn.rd = NewReaderContext()
  65 + cn.rd.Reset(cn.netConn)
59 } 66 }
60 67
61 func (cn *Conn) NetConn() net.Conn { 68 func (cn *Conn) NetConn() net.Conn {
@@ -68,30 +75,44 @@ func (cn *Conn) NextID() string { @@ -68,30 +75,44 @@ func (cn *Conn) NextID() string {
68 } 75 }
69 76
70 func (cn *Conn) WithReader( 77 func (cn *Conn) WithReader(
71 - ctx context.Context, timeout time.Duration, fn func(rd *BufReader) error, 78 + ctx context.Context, timeout time.Duration, fn func(rd *ReaderContext) error,
72 ) error { 79 ) error {
73 - return internal.WithSpan(ctx, "with_reader", func(ctx context.Context, span trace.Span) error {  
74 - err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout))  
75 - if err != nil { 80 + return internal.WithSpan(ctx, "pg.with_reader", func(ctx context.Context, span trace.Span) error {
  81 + if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
  82 + span.RecordError(err)
76 return err 83 return err
77 } 84 }
78 85
79 - cn.rd.bytesRead = 0  
80 - err = fn(cn.rd)  
81 - span.SetAttributes(kv.Int64("net.read_bytes", cn.rd.bytesRead)) 86 + rd := cn.rd
  87 + if rd == nil {
  88 + rd = GetReaderContext()
  89 + defer PutReaderContext(rd)
82 90
83 - return err 91 + rd.Reset(cn.netConn)
  92 + }
  93 +
  94 + rd.bytesRead = 0
  95 +
  96 + if err := fn(rd); err != nil {
  97 + span.RecordError(err)
  98 + return err
  99 + }
  100 +
  101 + span.SetAttributes(label.Int64("net.read_bytes", rd.bytesRead))
  102 +
  103 + return nil
84 }) 104 })
85 } 105 }
86 106
87 func (cn *Conn) WithWriter( 107 func (cn *Conn) WithWriter(
88 ctx context.Context, timeout time.Duration, fn func(wb *WriteBuffer) error, 108 ctx context.Context, timeout time.Duration, fn func(wb *WriteBuffer) error,
89 ) error { 109 ) error {
90 - return internal.WithSpan(ctx, "with_writer", func(ctx context.Context, span trace.Span) error { 110 + return internal.WithSpan(ctx, "pg.with_writer", func(ctx context.Context, span trace.Span) error {
91 wb := GetWriteBuffer() 111 wb := GetWriteBuffer()
92 defer PutWriteBuffer(wb) 112 defer PutWriteBuffer(wb)
93 113
94 if err := fn(wb); err != nil { 114 if err := fn(wb); err != nil {
  115 + span.RecordError(err)
95 return err 116 return err
96 } 117 }
97 118
@@ -100,7 +121,7 @@ func (cn *Conn) WithWriter( @@ -100,7 +121,7 @@ func (cn *Conn) WithWriter(
100 } 121 }
101 122
102 func (cn *Conn) WriteBuffer(ctx context.Context, timeout time.Duration, wb *WriteBuffer) error { 123 func (cn *Conn) WriteBuffer(ctx context.Context, timeout time.Duration, wb *WriteBuffer) error {
103 - return internal.WithSpan(ctx, "with_writer", func(ctx context.Context, span trace.Span) error { 124 + return internal.WithSpan(ctx, "pg.with_writer", func(ctx context.Context, span trace.Span) error {
104 return cn.writeBuffer(ctx, span, timeout, wb) 125 return cn.writeBuffer(ctx, span, timeout, wb)
105 }) 126 })
106 } 127 }
@@ -111,14 +132,19 @@ func (cn *Conn) writeBuffer( @@ -111,14 +132,19 @@ func (cn *Conn) writeBuffer(
111 timeout time.Duration, 132 timeout time.Duration,
112 wb *WriteBuffer, 133 wb *WriteBuffer,
113 ) error { 134 ) error {
114 - err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout))  
115 - if err != nil { 135 + if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
  136 + span.RecordError(err)
  137 + return err
  138 + }
  139 +
  140 + span.SetAttributes(label.Int("net.wrote_bytes", len(wb.Bytes)))
  141 +
  142 + if _, err := cn.netConn.Write(wb.Bytes); err != nil {
  143 + span.RecordError(err)
116 return err 144 return err
117 } 145 }
118 146
119 - span.SetAttributes(kv.Int("net.wrote_bytes", len(wb.Bytes)))  
120 - _, err = cn.netConn.Write(wb.Bytes)  
121 - return err 147 + return nil
122 } 148 }
123 149
124 func (cn *Conn) Close() error { 150 func (cn *Conn) Close() error {
@@ -11,8 +11,10 @@ import ( @@ -11,8 +11,10 @@ import (
11 "github.com/go-pg/pg/v10/internal" 11 "github.com/go-pg/pg/v10/internal"
12 ) 12 )
13 13
14 -var ErrClosed = errors.New("pg: database is closed")  
15 -var ErrPoolTimeout = errors.New("pg: connection pool timeout") 14 +var (
  15 + ErrClosed = errors.New("pg: database is closed")
  16 + ErrPoolTimeout = errors.New("pg: connection pool timeout")
  17 +)
16 18
17 var timers = sync.Pool{ 19 var timers = sync.Pool{
18 New: func() interface{} { 20 New: func() interface{} {
@@ -38,8 +40,8 @@ type Pooler interface { @@ -38,8 +40,8 @@ type Pooler interface {
38 CloseConn(*Conn) error 40 CloseConn(*Conn) error
39 41
40 Get(context.Context) (*Conn, error) 42 Get(context.Context) (*Conn, error)
41 - Put(*Conn)  
42 - Remove(*Conn, error) 43 + Put(context.Context, *Conn)
  44 + Remove(context.Context, *Conn, error)
43 45
44 Len() int 46 Len() int
45 IdleLen() int 47 IdleLen() int
@@ -216,12 +218,12 @@ func (p *ConnPool) getLastDialError() error { @@ -216,12 +218,12 @@ func (p *ConnPool) getLastDialError() error {
216 } 218 }
217 219
218 // Get returns existed connection from the pool or creates a new one. 220 // Get returns existed connection from the pool or creates a new one.
219 -func (p *ConnPool) Get(c context.Context) (*Conn, error) { 221 +func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
220 if p.closed() { 222 if p.closed() {
221 return nil, ErrClosed 223 return nil, ErrClosed
222 } 224 }
223 225
224 - err := p.waitTurn(c) 226 + err := p.waitTurn(ctx)
225 if err != nil { 227 if err != nil {
226 return nil, err 228 return nil, err
227 } 229 }
@@ -246,7 +248,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) { @@ -246,7 +248,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) {
246 248
247 atomic.AddUint32(&p.stats.Misses, 1) 249 atomic.AddUint32(&p.stats.Misses, 1)
248 250
249 - newcn, err := p.newConn(c, true) 251 + newcn, err := p.newConn(ctx, true)
250 if err != nil { 252 if err != nil {
251 p.freeTurn() 253 p.freeTurn()
252 return nil, err 254 return nil, err
@@ -312,15 +314,9 @@ func (p *ConnPool) popIdle() *Conn { @@ -312,15 +314,9 @@ func (p *ConnPool) popIdle() *Conn {
312 return cn 314 return cn
313 } 315 }
314 316
315 -func (p *ConnPool) Put(cn *Conn) {  
316 - if cn.rd.Buffered() > 0 {  
317 - internal.Logger.Printf("Conn has unread data")  
318 - p.Remove(cn, BadConnError{})  
319 - return  
320 - }  
321 - 317 +func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
322 if !cn.pooled { 318 if !cn.pooled {
323 - p.Remove(cn, nil) 319 + p.Remove(ctx, cn, nil)
324 return 320 return
325 } 321 }
326 322
@@ -331,7 +327,7 @@ func (p *ConnPool) Put(cn *Conn) { @@ -331,7 +327,7 @@ func (p *ConnPool) Put(cn *Conn) {
331 p.freeTurn() 327 p.freeTurn()
332 } 328 }
333 329
334 -func (p *ConnPool) Remove(cn *Conn, reason error) { 330 +func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
335 p.removeConnWithLock(cn) 331 p.removeConnWithLock(cn)
336 p.freeTurn() 332 p.freeTurn()
337 _ = p.closeConn(cn) 333 _ = p.closeConn(cn)
@@ -446,7 +442,7 @@ func (p *ConnPool) reaper(frequency time.Duration) { @@ -446,7 +442,7 @@ func (p *ConnPool) reaper(frequency time.Duration) {
446 } 442 }
447 n, err := p.ReapStaleConns() 443 n, err := p.ReapStaleConns()
448 if err != nil { 444 if err != nil {
449 - internal.Logger.Printf("ReapStaleConns failed: %s", err) 445 + internal.Logger.Printf(context.TODO(), "ReapStaleConns failed: %s", err)
450 continue 446 continue
451 } 447 }
452 atomic.AddUint32(&p.stats.StaleConns, uint32(n)) 448 atomic.AddUint32(&p.stats.StaleConns, uint32(n))
1 package pool 1 package pool
2 2
3 -import (  
4 - "context"  
5 - "errors"  
6 - "fmt"  
7 - "sync/atomic"  
8 -)  
9 -  
10 -const (  
11 - stateDefault = 0  
12 - stateInited = 1  
13 - stateClosed = 2  
14 -)  
15 -  
16 -type BadConnError struct {  
17 - wrapped error  
18 -}  
19 -  
20 -var _ error = (*BadConnError)(nil)  
21 -  
22 -func (e BadConnError) Error() string {  
23 - s := "pg: Conn is in a bad state"  
24 - if e.wrapped != nil {  
25 - s += ": " + e.wrapped.Error()  
26 - }  
27 - return s  
28 -}  
29 -  
30 -func (e BadConnError) Unwrap() error {  
31 - return e.wrapped  
32 -} 3 +import "context"
33 4
34 type SingleConnPool struct { 5 type SingleConnPool struct {
35 - pool Pooler  
36 - level int32 // atomic  
37 -  
38 - state uint32 // atomic  
39 - ch chan *Conn  
40 -  
41 - _badConnError atomic.Value 6 + pool Pooler
  7 + cn *Conn
  8 + stickyErr error
42 } 9 }
43 10
44 var _ Pooler = (*SingleConnPool)(nil) 11 var _ Pooler = (*SingleConnPool)(nil)
45 12
46 -func NewSingleConnPool(pool Pooler) *SingleConnPool {  
47 - p, ok := pool.(*SingleConnPool)  
48 - if !ok {  
49 - p = &SingleConnPool{  
50 - pool: pool,  
51 - ch: make(chan *Conn, 1),  
52 - } 13 +func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool {
  14 + return &SingleConnPool{
  15 + pool: pool,
  16 + cn: cn,
53 } 17 }
54 - atomic.AddInt32(&p.level, 1)  
55 - return p  
56 } 18 }
57 19
58 -func (p *SingleConnPool) SetConn(cn *Conn) {  
59 - if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {  
60 - p.ch <- cn  
61 - } else {  
62 - panic("not reached")  
63 - }  
64 -}  
65 -  
66 -func (p *SingleConnPool) NewConn(c context.Context) (*Conn, error) {  
67 - return p.pool.NewConn(c) 20 +func (p *SingleConnPool) NewConn(ctx context.Context) (*Conn, error) {
  21 + return p.pool.NewConn(ctx)
68 } 22 }
69 23
70 func (p *SingleConnPool) CloseConn(cn *Conn) error { 24 func (p *SingleConnPool) CloseConn(cn *Conn) error {
71 return p.pool.CloseConn(cn) 25 return p.pool.CloseConn(cn)
72 } 26 }
73 27
74 -func (p *SingleConnPool) Get(c context.Context) (*Conn, error) {  
75 - // In worst case this races with Close which is not a very common operation.  
76 - for i := 0; i < 1000; i++ {  
77 - switch atomic.LoadUint32(&p.state) {  
78 - case stateDefault:  
79 - cn, err := p.pool.Get(c)  
80 - if err != nil {  
81 - return nil, err  
82 - }  
83 - if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {  
84 - return cn, nil  
85 - }  
86 - p.pool.Remove(cn, ErrClosed)  
87 - case stateInited:  
88 - if err := p.badConnError(); err != nil {  
89 - return nil, err  
90 - }  
91 - cn, ok := <-p.ch  
92 - if !ok {  
93 - return nil, ErrClosed  
94 - }  
95 - return cn, nil  
96 - case stateClosed:  
97 - return nil, ErrClosed  
98 - default:  
99 - panic("not reached")  
100 - } 28 +func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
  29 + if p.stickyErr != nil {
  30 + return nil, p.stickyErr
101 } 31 }
102 - return nil, fmt.Errorf("pg: SingleConnPool.Get: infinite loop") 32 + return p.cn, nil
103 } 33 }
104 34
105 -func (p *SingleConnPool) Put(cn *Conn) {  
106 - defer func() {  
107 - if recover() != nil {  
108 - p.freeConn(cn)  
109 - }  
110 - }()  
111 - p.ch <- cn  
112 -} 35 +func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {}
113 36
114 -func (p *SingleConnPool) freeConn(cn *Conn) {  
115 - if err := p.badConnError(); err != nil {  
116 - p.pool.Remove(cn, err)  
117 - } else {  
118 - p.pool.Put(cn)  
119 - } 37 +func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
  38 + p.cn = nil
  39 + p.stickyErr = reason
120 } 40 }
121 41
122 -func (p *SingleConnPool) Remove(cn *Conn, reason error) {  
123 - defer func() {  
124 - if recover() != nil {  
125 - p.pool.Remove(cn, ErrClosed)  
126 - }  
127 - }()  
128 - p._badConnError.Store(BadConnError{wrapped: reason})  
129 - p.ch <- cn 42 +func (p *SingleConnPool) Close() error {
  43 + p.cn = nil
  44 + p.stickyErr = ErrClosed
  45 + return nil
130 } 46 }
131 47
132 func (p *SingleConnPool) Len() int { 48 func (p *SingleConnPool) Len() int {
133 - switch atomic.LoadUint32(&p.state) {  
134 - case stateDefault:  
135 - return 0  
136 - case stateInited:  
137 - return 1  
138 - case stateClosed:  
139 - return 0  
140 - default:  
141 - panic("not reached")  
142 - } 49 + return 0
143 } 50 }
144 51
145 func (p *SingleConnPool) IdleLen() int { 52 func (p *SingleConnPool) IdleLen() int {
146 - return len(p.ch) 53 + return 0
147 } 54 }
148 55
149 func (p *SingleConnPool) Stats() *Stats { 56 func (p *SingleConnPool) Stats() *Stats {
150 return &Stats{} 57 return &Stats{}
151 } 58 }
152 -  
153 -func (p *SingleConnPool) Close() error {  
154 - level := atomic.AddInt32(&p.level, -1)  
155 - if level > 0 {  
156 - return nil  
157 - }  
158 -  
159 - for i := 0; i < 1000; i++ {  
160 - state := atomic.LoadUint32(&p.state)  
161 - if state == stateClosed {  
162 - return ErrClosed  
163 - }  
164 - if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {  
165 - close(p.ch)  
166 - cn, ok := <-p.ch  
167 - if ok {  
168 - p.freeConn(cn)  
169 - }  
170 - return nil  
171 - }  
172 - }  
173 -  
174 - return errors.New("pg: SingleConnPool.Close: infinite loop")  
175 -}  
176 -  
177 -func (p *SingleConnPool) Reset() error {  
178 - if p.badConnError() == nil {  
179 - return nil  
180 - }  
181 -  
182 - select {  
183 - case cn, ok := <-p.ch:  
184 - if !ok {  
185 - return ErrClosed  
186 - }  
187 - p.pool.Remove(cn, ErrClosed)  
188 - p._badConnError.Store(BadConnError{wrapped: nil})  
189 - default:  
190 - return errors.New("pg: SingleConnPool does not have a Conn")  
191 - }  
192 -  
193 - if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {  
194 - state := atomic.LoadUint32(&p.state)  
195 - return fmt.Errorf("pg: invalid SingleConnPool state: %d", state)  
196 - }  
197 -  
198 - return nil  
199 -}  
200 -  
201 -func (p *SingleConnPool) badConnError() error {  
202 - if v := p._badConnError.Load(); v != nil {  
203 - err := v.(BadConnError)  
204 - if err.wrapped != nil {  
205 - return err  
206 - }  
207 - }  
208 - return nil  
209 -}  
  1 +package pool
  2 +
  3 +import (
  4 + "context"
  5 + "errors"
  6 + "fmt"
  7 + "sync/atomic"
  8 +)
  9 +
  10 +const (
  11 + stateDefault = 0
  12 + stateInited = 1
  13 + stateClosed = 2
  14 +)
  15 +
  16 +type BadConnError struct {
  17 + wrapped error
  18 +}
  19 +
  20 +var _ error = (*BadConnError)(nil)
  21 +
  22 +func (e BadConnError) Error() string {
  23 + s := "pg: Conn is in a bad state"
  24 + if e.wrapped != nil {
  25 + s += ": " + e.wrapped.Error()
  26 + }
  27 + return s
  28 +}
  29 +
  30 +func (e BadConnError) Unwrap() error {
  31 + return e.wrapped
  32 +}
  33 +
  34 +//------------------------------------------------------------------------------
  35 +
  36 +type StickyConnPool struct {
  37 + pool Pooler
  38 + shared int32 // atomic
  39 +
  40 + state uint32 // atomic
  41 + ch chan *Conn
  42 +
  43 + _badConnError atomic.Value
  44 +}
  45 +
  46 +var _ Pooler = (*StickyConnPool)(nil)
  47 +
  48 +func NewStickyConnPool(pool Pooler) *StickyConnPool {
  49 + p, ok := pool.(*StickyConnPool)
  50 + if !ok {
  51 + p = &StickyConnPool{
  52 + pool: pool,
  53 + ch: make(chan *Conn, 1),
  54 + }
  55 + }
  56 + atomic.AddInt32(&p.shared, 1)
  57 + return p
  58 +}
  59 +
  60 +func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) {
  61 + return p.pool.NewConn(ctx)
  62 +}
  63 +
  64 +func (p *StickyConnPool) CloseConn(cn *Conn) error {
  65 + return p.pool.CloseConn(cn)
  66 +}
  67 +
  68 +func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
  69 + // In worst case this races with Close which is not a very common operation.
  70 + for i := 0; i < 1000; i++ {
  71 + switch atomic.LoadUint32(&p.state) {
  72 + case stateDefault:
  73 + cn, err := p.pool.Get(ctx)
  74 + if err != nil {
  75 + return nil, err
  76 + }
  77 + if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
  78 + return cn, nil
  79 + }
  80 + p.pool.Remove(ctx, cn, ErrClosed)
  81 + case stateInited:
  82 + if err := p.badConnError(); err != nil {
  83 + return nil, err
  84 + }
  85 + cn, ok := <-p.ch
  86 + if !ok {
  87 + return nil, ErrClosed
  88 + }
  89 + return cn, nil
  90 + case stateClosed:
  91 + return nil, ErrClosed
  92 + default:
  93 + panic("not reached")
  94 + }
  95 + }
  96 + return nil, fmt.Errorf("pg: StickyConnPool.Get: infinite loop")
  97 +}
  98 +
  99 +func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) {
  100 + defer func() {
  101 + if recover() != nil {
  102 + p.freeConn(ctx, cn)
  103 + }
  104 + }()
  105 + p.ch <- cn
  106 +}
  107 +
  108 +func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) {
  109 + if err := p.badConnError(); err != nil {
  110 + p.pool.Remove(ctx, cn, err)
  111 + } else {
  112 + p.pool.Put(ctx, cn)
  113 + }
  114 +}
  115 +
  116 +func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
  117 + defer func() {
  118 + if recover() != nil {
  119 + p.pool.Remove(ctx, cn, ErrClosed)
  120 + }
  121 + }()
  122 + p._badConnError.Store(BadConnError{wrapped: reason})
  123 + p.ch <- cn
  124 +}
  125 +
  126 +func (p *StickyConnPool) Close() error {
  127 + if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
  128 + return nil
  129 + }
  130 +
  131 + for i := 0; i < 1000; i++ {
  132 + state := atomic.LoadUint32(&p.state)
  133 + if state == stateClosed {
  134 + return ErrClosed
  135 + }
  136 + if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
  137 + close(p.ch)
  138 + cn, ok := <-p.ch
  139 + if ok {
  140 + p.freeConn(context.TODO(), cn)
  141 + }
  142 + return nil
  143 + }
  144 + }
  145 +
  146 + return errors.New("pg: StickyConnPool.Close: infinite loop")
  147 +}
  148 +
  149 +func (p *StickyConnPool) Reset(ctx context.Context) error {
  150 + if p.badConnError() == nil {
  151 + return nil
  152 + }
  153 +
  154 + select {
  155 + case cn, ok := <-p.ch:
  156 + if !ok {
  157 + return ErrClosed
  158 + }
  159 + p.pool.Remove(ctx, cn, ErrClosed)
  160 + p._badConnError.Store(BadConnError{wrapped: nil})
  161 + default:
  162 + return errors.New("pg: StickyConnPool does not have a Conn")
  163 + }
  164 +
  165 + if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
  166 + state := atomic.LoadUint32(&p.state)
  167 + return fmt.Errorf("pg: invalid StickyConnPool state: %d", state)
  168 + }
  169 +
  170 + return nil
  171 +}
  172 +
  173 +func (p *StickyConnPool) badConnError() error {
  174 + if v := p._badConnError.Load(); v != nil {
  175 + err := v.(BadConnError)
  176 + if err.wrapped != nil {
  177 + return err
  178 + }
  179 + }
  180 + return nil
  181 +}
  182 +
  183 +func (p *StickyConnPool) Len() int {
  184 + switch atomic.LoadUint32(&p.state) {
  185 + case stateDefault:
  186 + return 0
  187 + case stateInited:
  188 + return 1
  189 + case stateClosed:
  190 + return 0
  191 + default:
  192 + panic("not reached")
  193 + }
  194 +}
  195 +
  196 +func (p *StickyConnPool) IdleLen() int {
  197 + return len(p.ch)
  198 +}
  199 +
  200 +func (p *StickyConnPool) Stats() *Stats {
  201 + return &Stats{}
  202 +}
1 package pool 1 package pool
2 2
  3 +import (
  4 + "sync"
  5 +)
  6 +
3 type Reader interface { 7 type Reader interface {
4 Buffered() int 8 Buffered() int
5 9
@@ -10,8 +14,67 @@ type Reader interface { @@ -10,8 +14,67 @@ type Reader interface {
10 ReadSlice(byte) ([]byte, error) 14 ReadSlice(byte) ([]byte, error)
11 Discard(int) (int, error) 15 Discard(int) (int, error)
12 16
13 - //ReadBytes(fn func(byte) bool) ([]byte, error)  
14 - //ReadN(int) ([]byte, error) 17 + // ReadBytes(fn func(byte) bool) ([]byte, error)
  18 + // ReadN(int) ([]byte, error)
15 ReadFull() ([]byte, error) 19 ReadFull() ([]byte, error)
16 ReadFullTemp() ([]byte, error) 20 ReadFullTemp() ([]byte, error)
17 } 21 }
  22 +
  23 +type ColumnInfo struct {
  24 + Index int16
  25 + DataType int32
  26 + Name string
  27 +}
  28 +
  29 +type ColumnAlloc struct {
  30 + columns []ColumnInfo
  31 +}
  32 +
  33 +func NewColumnAlloc() *ColumnAlloc {
  34 + return new(ColumnAlloc)
  35 +}
  36 +
  37 +func (c *ColumnAlloc) Reset() {
  38 + c.columns = c.columns[:0]
  39 +}
  40 +
  41 +func (c *ColumnAlloc) New(index int16, name []byte) *ColumnInfo {
  42 + c.columns = append(c.columns, ColumnInfo{
  43 + Index: index,
  44 + Name: string(name),
  45 + })
  46 + return &c.columns[len(c.columns)-1]
  47 +}
  48 +
  49 +func (c *ColumnAlloc) Columns() []ColumnInfo {
  50 + return c.columns
  51 +}
  52 +
  53 +type ReaderContext struct {
  54 + *BufReader
  55 + ColumnAlloc *ColumnAlloc
  56 +}
  57 +
  58 +func NewReaderContext() *ReaderContext {
  59 + const bufSize = 1 << 20 // 1mb
  60 + return &ReaderContext{
  61 + BufReader: NewBufReader(bufSize),
  62 + ColumnAlloc: NewColumnAlloc(),
  63 + }
  64 +}
  65 +
  66 +var readerPool = sync.Pool{
  67 + New: func() interface{} {
  68 + return NewReaderContext()
  69 + },
  70 +}
  71 +
  72 +func GetReaderContext() *ReaderContext {
  73 + rd := readerPool.Get().(*ReaderContext)
  74 + return rd
  75 +}
  76 +
  77 +func PutReaderContext(rd *ReaderContext) {
  78 + rd.ColumnAlloc.Reset()
  79 + readerPool.Put(rd)
  80 +}
@@ -10,11 +10,7 @@ import ( @@ -10,11 +10,7 @@ import (
10 "io" 10 "io"
11 ) 11 )
12 12
13 -const defaultBufSize = 65536  
14 -  
15 type BufReader struct { 13 type BufReader struct {
16 - Columns [][]byte  
17 -  
18 rd io.Reader // reader provided by the client 14 rd io.Reader // reader provided by the client
19 15
20 buf []byte 16 buf []byte
@@ -24,25 +20,24 @@ type BufReader struct { @@ -24,25 +20,24 @@ type BufReader struct {
24 err error 20 err error
25 21
26 available int // bytes available for reading 22 available int // bytes available for reading
27 - bytesRd BytesReader // reusable bytes reader 23 + brd BytesReader // reusable bytes reader
28 } 24 }
29 25
30 -func NewBufReader(rd io.Reader) *BufReader { 26 +func NewBufReader(bufSize int) *BufReader {
31 return &BufReader{ 27 return &BufReader{
32 - rd: rd,  
33 - buf: make([]byte, defaultBufSize), 28 + buf: make([]byte, bufSize),
34 available: -1, 29 available: -1,
35 } 30 }
36 } 31 }
37 32
38 func (b *BufReader) BytesReader(n int) *BytesReader { 33 func (b *BufReader) BytesReader(n int) *BytesReader {
39 - if b.Buffered() < n {  
40 - return nil 34 + if n == -1 {
  35 + n = 0
41 } 36 }
42 buf := b.buf[b.r : b.r+n] 37 buf := b.buf[b.r : b.r+n]
43 b.r += n 38 b.r += n
44 - b.bytesRd.Reset(buf)  
45 - return &b.bytesRd 39 + b.brd.Reset(buf)
  40 + return &b.brd
46 } 41 }
47 42
48 func (b *BufReader) SetAvailable(n int) { 43 func (b *BufReader) SetAvailable(n int) {
@@ -67,11 +62,11 @@ func (b *BufReader) Reset(rd io.Reader) { @@ -67,11 +62,11 @@ func (b *BufReader) Reset(rd io.Reader) {
67 62
68 // Buffered returns the number of bytes that can be read from the current buffer. 63 // Buffered returns the number of bytes that can be read from the current buffer.
69 func (b *BufReader) Buffered() int { 64 func (b *BufReader) Buffered() int {
70 - d := b.w - b.r  
71 - if b.available != -1 && d > b.available {  
72 - return b.available 65 + buffered := b.w - b.r
  66 + if b.available == -1 || buffered <= b.available {
  67 + return buffered
73 } 68 }
74 - return d 69 + return b.available
75 } 70 }
76 71
77 func (b *BufReader) Bytes() []byte { 72 func (b *BufReader) Bytes() []byte {
@@ -122,7 +117,7 @@ func (b *BufReader) fill() { @@ -122,7 +117,7 @@ func (b *BufReader) fill() {
122 // Read new data: try a limited number of times. 117 // Read new data: try a limited number of times.
123 const maxConsecutiveEmptyReads = 100 118 const maxConsecutiveEmptyReads = 100
124 for i := maxConsecutiveEmptyReads; i > 0; i-- { 119 for i := maxConsecutiveEmptyReads; i > 0; i-- {
125 - n, err := b.readDirectly(b.buf[b.w:]) 120 + n, err := b.read(b.buf[b.w:])
126 b.w += n 121 b.w += n
127 if err != nil { 122 if err != nil {
128 b.err = err 123 b.err = err
@@ -163,7 +158,7 @@ func (b *BufReader) Read(p []byte) (n int, err error) { @@ -163,7 +158,7 @@ func (b *BufReader) Read(p []byte) (n int, err error) {
163 if len(p) >= len(b.buf) { 158 if len(p) >= len(b.buf) {
164 // Large read, empty buffer. 159 // Large read, empty buffer.
165 // Read directly into p to avoid copy. 160 // Read directly into p to avoid copy.
166 - n, err = b.readDirectly(p) 161 + n, err = b.read(p)
167 if n > 0 { 162 if n > 0 {
168 b.changeAvailable(-n) 163 b.changeAvailable(-n)
169 b.lastByte = int(p[n-1]) 164 b.lastByte = int(p[n-1])
@@ -175,7 +170,7 @@ func (b *BufReader) Read(p []byte) (n int, err error) { @@ -175,7 +170,7 @@ func (b *BufReader) Read(p []byte) (n int, err error) {
175 // Do not use b.fill, which will loop. 170 // Do not use b.fill, which will loop.
176 b.r = 0 171 b.r = 0
177 b.w = 0 172 b.w = 0
178 - n, b.err = b.readDirectly(b.buf) 173 + n, b.err = b.read(b.buf)
179 if n == 0 { 174 if n == 0 {
180 return 0, b.readErr() 175 return 0, b.readErr()
181 } 176 }
@@ -259,7 +254,7 @@ func (b *BufReader) ReadBytes(fn func(byte) bool) (line []byte, err error) { @@ -259,7 +254,7 @@ func (b *BufReader) ReadBytes(fn func(byte) bool) (line []byte, err error) {
259 254
260 // Pending error? 255 // Pending error?
261 if b.err != nil { 256 if b.err != nil {
262 - line = b.flush() //nolint 257 + line = b.flush()
263 err = b.readErr() 258 err = b.readErr()
264 break 259 break
265 } 260 }
@@ -429,7 +424,7 @@ func (b *BufReader) ReadFullTemp() ([]byte, error) { @@ -429,7 +424,7 @@ func (b *BufReader) ReadFullTemp() ([]byte, error) {
429 return b.ReadFull() 424 return b.ReadFull()
430 } 425 }
431 426
432 -func (b *BufReader) readDirectly(buf []byte) (int, error) { 427 +func (b *BufReader) read(buf []byte) (int, error) {
433 n, err := b.rd.Read(buf) 428 n, err := b.rd.Read(buf)
434 b.bytesRead += int64(n) 429 b.bytesRead += int64(n)
435 return n, err 430 return n, err
@@ -6,20 +6,22 @@ import ( @@ -6,20 +6,22 @@ import (
6 "sync" 6 "sync"
7 ) 7 )
8 8
9 -var pool = sync.Pool{ 9 +const defaultBufSize = 65 << 10 // 65kb
  10 +
  11 +var wbPool = sync.Pool{
10 New: func() interface{} { 12 New: func() interface{} {
11 return NewWriteBuffer() 13 return NewWriteBuffer()
12 }, 14 },
13 } 15 }
14 16
15 func GetWriteBuffer() *WriteBuffer { 17 func GetWriteBuffer() *WriteBuffer {
16 - wb := pool.Get().(*WriteBuffer)  
17 - wb.Reset() 18 + wb := wbPool.Get().(*WriteBuffer)
18 return wb 19 return wb
19 } 20 }
20 21
21 func PutWriteBuffer(wb *WriteBuffer) { 22 func PutWriteBuffer(wb *WriteBuffer) {
22 - pool.Put(wb) 23 + wb.Reset()
  24 + wbPool.Put(wb)
23 } 25 }
24 26
25 type WriteBuffer struct { 27 type WriteBuffer struct {
@@ -39,10 +41,6 @@ func (buf *WriteBuffer) Reset() { @@ -39,10 +41,6 @@ func (buf *WriteBuffer) Reset() {
39 buf.Bytes = buf.Bytes[:0] 41 buf.Bytes = buf.Bytes[:0]
40 } 42 }
41 43
42 -func (buf *WriteBuffer) ResetBuffer(b []byte) {  
43 - buf.Bytes = b[:0]  
44 -}  
45 -  
46 func (buf *WriteBuffer) StartMessage(c byte) { 44 func (buf *WriteBuffer) StartMessage(c byte) {
47 if c == 0 { 45 if c == 0 {
48 buf.msgStart = len(buf.Bytes) 46 buf.msgStart = len(buf.Bytes)
@@ -5,12 +5,14 @@ import ( @@ -5,12 +5,14 @@ import (
5 "reflect" 5 "reflect"
6 "time" 6 "time"
7 7
8 - "go.opentelemetry.io/otel/api/global"  
9 - "go.opentelemetry.io/otel/api/trace" 8 + "go.opentelemetry.io/otel"
  9 + "go.opentelemetry.io/otel/trace"
10 ) 10 )
11 11
  12 +var tracer = otel.Tracer("github.com/go-pg/pg")
  13 +
12 func Sleep(ctx context.Context, dur time.Duration) error { 14 func Sleep(ctx context.Context, dur time.Duration) error {
13 - return WithSpan(ctx, "sleep", func(ctx context.Context, span trace.Span) error { 15 + return WithSpan(ctx, "time.Sleep", func(ctx context.Context, span trace.Span) error {
14 t := time.NewTimer(dur) 16 t := time.NewTimer(dur)
15 defer t.Stop() 17 defer t.Stop()
16 18
@@ -80,11 +82,11 @@ func WithSpan( @@ -80,11 +82,11 @@ func WithSpan(
80 name string, 82 name string,
81 fn func(context.Context, trace.Span) error, 83 fn func(context.Context, trace.Span) error,
82 ) error { 84 ) error {
83 - if !trace.SpanFromContext(ctx).IsRecording() {  
84 - return fn(ctx, trace.NoopSpan{}) 85 + if span := trace.SpanFromContext(ctx); !span.IsRecording() {
  86 + return fn(ctx, span)
85 } 87 }
86 88
87 - ctx, span := global.Tracer("go-pg").Start(ctx, name) 89 + ctx, span := tracer.Start(ctx, name)
88 defer span.End() 90 defer span.End()
89 91
90 return fn(ctx, span) 92 return fn(ctx, span)
@@ -15,8 +15,10 @@ import ( @@ -15,8 +15,10 @@ import (
15 15
16 const gopgChannel = "gopg:ping" 16 const gopgChannel = "gopg:ping"
17 17
18 -var errListenerClosed = errors.New("pg: listener is closed")  
19 -var errPingTimeout = errors.New("pg: ping timeout") 18 +var (
  19 + errListenerClosed = errors.New("pg: listener is closed")
  20 + errPingTimeout = errors.New("pg: ping timeout")
  21 +)
20 22
21 // Notification which is received with LISTEN command. 23 // Notification which is received with LISTEN command.
22 type Notification struct { 24 type Notification struct {
@@ -38,11 +40,14 @@ type Listener struct { @@ -38,11 +40,14 @@ type Listener struct {
38 closed bool 40 closed bool
39 41
40 chOnce sync.Once 42 chOnce sync.Once
41 - ch chan *Notification 43 + ch chan Notification
42 pingCh chan struct{} 44 pingCh chan struct{}
43 } 45 }
44 46
45 func (ln *Listener) String() string { 47 func (ln *Listener) String() string {
  48 + ln.mu.Lock()
  49 + defer ln.mu.Unlock()
  50 +
46 return fmt.Sprintf("Listener(%s)", strings.Join(ln.channels, ", ")) 51 return fmt.Sprintf("Listener(%s)", strings.Join(ln.channels, ", "))
47 } 52 }
48 53
@@ -50,9 +55,9 @@ func (ln *Listener) init() { @@ -50,9 +55,9 @@ func (ln *Listener) init() {
50 ln.exit = make(chan struct{}) 55 ln.exit = make(chan struct{})
51 } 56 }
52 57
53 -func (ln *Listener) connWithLock() (*pool.Conn, error) { 58 +func (ln *Listener) connWithLock(ctx context.Context) (*pool.Conn, error) {
54 ln.mu.Lock() 59 ln.mu.Lock()
55 - cn, err := ln.conn() 60 + cn, err := ln.conn(ctx)
56 ln.mu.Unlock() 61 ln.mu.Unlock()
57 62
58 switch err { 63 switch err {
@@ -64,12 +69,12 @@ func (ln *Listener) connWithLock() (*pool.Conn, error) { @@ -64,12 +69,12 @@ func (ln *Listener) connWithLock() (*pool.Conn, error) {
64 _ = ln.Close() 69 _ = ln.Close()
65 return nil, errListenerClosed 70 return nil, errListenerClosed
66 default: 71 default:
67 - internal.Logger.Printf("pg: Listen failed: %s", err) 72 + internal.Logger.Printf(ctx, "pg: Listen failed: %s", err)
68 return nil, err 73 return nil, err
69 } 74 }
70 } 75 }
71 76
72 -func (ln *Listener) conn() (*pool.Conn, error) { 77 +func (ln *Listener) conn(ctx context.Context) (*pool.Conn, error) {
73 if ln.closed { 78 if ln.closed {
74 return nil, errListenerClosed 79 return nil, errListenerClosed
75 } 80 }
@@ -78,21 +83,20 @@ func (ln *Listener) conn() (*pool.Conn, error) { @@ -78,21 +83,20 @@ func (ln *Listener) conn() (*pool.Conn, error) {
78 return ln.cn, nil 83 return ln.cn, nil
79 } 84 }
80 85
81 - c := context.TODO()  
82 -  
83 - cn, err := ln.db.pool.NewConn(c) 86 + cn, err := ln.db.pool.NewConn(ctx)
84 if err != nil { 87 if err != nil {
85 return nil, err 88 return nil, err
86 } 89 }
87 90
88 - err = ln.db.initConn(c, cn)  
89 - if err != nil { 91 + if err := ln.db.initConn(ctx, cn); err != nil {
90 _ = ln.db.pool.CloseConn(cn) 92 _ = ln.db.pool.CloseConn(cn)
91 return nil, err 93 return nil, err
92 } 94 }
93 95
  96 + cn.LockReader()
  97 +
94 if len(ln.channels) > 0 { 98 if len(ln.channels) > 0 {
95 - err := ln.listen(c, cn, ln.channels...) 99 + err := ln.listen(ctx, cn, ln.channels...)
96 if err != nil { 100 if err != nil {
97 _ = ln.db.pool.CloseConn(cn) 101 _ = ln.db.pool.CloseConn(cn)
98 return nil, err 102 return nil, err
@@ -103,19 +107,19 @@ func (ln *Listener) conn() (*pool.Conn, error) { @@ -103,19 +107,19 @@ func (ln *Listener) conn() (*pool.Conn, error) {
103 return cn, nil 107 return cn, nil
104 } 108 }
105 109
106 -func (ln *Listener) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { 110 +func (ln *Listener) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
107 ln.mu.Lock() 111 ln.mu.Lock()
108 if ln.cn == cn { 112 if ln.cn == cn {
109 if isBadConn(err, allowTimeout) { 113 if isBadConn(err, allowTimeout) {
110 - ln.reconnect(err) 114 + ln.reconnect(ctx, err)
111 } 115 }
112 } 116 }
113 ln.mu.Unlock() 117 ln.mu.Unlock()
114 } 118 }
115 119
116 -func (ln *Listener) reconnect(reason error) { 120 +func (ln *Listener) reconnect(ctx context.Context, reason error) {
117 _ = ln.closeTheCn(reason) 121 _ = ln.closeTheCn(reason)
118 - _, _ = ln.conn() 122 + _, _ = ln.conn(ctx)
119 } 123 }
120 124
121 func (ln *Listener) closeTheCn(reason error) error { 125 func (ln *Listener) closeTheCn(reason error) error {
@@ -123,7 +127,7 @@ func (ln *Listener) closeTheCn(reason error) error { @@ -123,7 +127,7 @@ func (ln *Listener) closeTheCn(reason error) error {
123 return nil 127 return nil
124 } 128 }
125 if !ln.closed { 129 if !ln.closed {
126 - internal.Logger.Printf("pg: discarding bad listener connection: %s", reason) 130 + internal.Logger.Printf(ln.db.ctx, "pg: discarding bad listener connection: %s", reason)
127 } 131 }
128 132
129 err := ln.db.pool.CloseConn(ln.cn) 133 err := ln.db.pool.CloseConn(ln.cn)
@@ -146,29 +150,60 @@ func (ln *Listener) Close() error { @@ -146,29 +150,60 @@ func (ln *Listener) Close() error {
146 } 150 }
147 151
148 // Listen starts listening for notifications on channels. 152 // Listen starts listening for notifications on channels.
149 -func (ln *Listener) Listen(channels ...string) error { 153 +func (ln *Listener) Listen(ctx context.Context, channels ...string) error {
150 // Always append channels so DB.Listen works correctly. 154 // Always append channels so DB.Listen works correctly.
  155 + ln.mu.Lock()
151 ln.channels = appendIfNotExists(ln.channels, channels...) 156 ln.channels = appendIfNotExists(ln.channels, channels...)
  157 + ln.mu.Unlock()
152 158
153 - cn, err := ln.connWithLock() 159 + cn, err := ln.connWithLock(ctx)
154 if err != nil { 160 if err != nil {
155 return err 161 return err
156 } 162 }
157 163
158 - err = ln.listen(context.TODO(), cn, channels...) 164 + if err := ln.listen(ctx, cn, channels...); err != nil {
  165 + ln.releaseConn(ctx, cn, err, false)
  166 + return err
  167 + }
  168 +
  169 + return nil
  170 +}
  171 +
  172 +func (ln *Listener) listen(ctx context.Context, cn *pool.Conn, channels ...string) error {
  173 + err := cn.WithWriter(ctx, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
  174 + for _, channel := range channels {
  175 + if err := writeQueryMsg(wb, ln.db.fmter, "LISTEN ?", pgChan(channel)); err != nil {
  176 + return err
  177 + }
  178 + }
  179 + return nil
  180 + })
  181 + return err
  182 +}
  183 +
  184 +// Unlisten stops listening for notifications on channels.
  185 +func (ln *Listener) Unlisten(ctx context.Context, channels ...string) error {
  186 + ln.mu.Lock()
  187 + ln.channels = removeIfExists(ln.channels, channels...)
  188 + ln.mu.Unlock()
  189 +
  190 + cn, err := ln.conn(ctx)
159 if err != nil { 191 if err != nil {
160 - ln.releaseConn(cn, err, false) 192 + return err
  193 + }
  194 +
  195 + if err := ln.unlisten(ctx, cn, channels...); err != nil {
  196 + ln.releaseConn(ctx, cn, err, false)
161 return err 197 return err
162 } 198 }
163 199
164 return nil 200 return nil
165 } 201 }
166 202
167 -func (ln *Listener) listen(c context.Context, cn *pool.Conn, channels ...string) error {  
168 - err := cn.WithWriter(c, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { 203 +func (ln *Listener) unlisten(ctx context.Context, cn *pool.Conn, channels ...string) error {
  204 + err := cn.WithWriter(ctx, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
169 for _, channel := range channels { 205 for _, channel := range channels {
170 - err := writeQueryMsg(wb, ln.db.fmter, "LISTEN ?", pgChan(channel))  
171 - if err != nil { 206 + if err := writeQueryMsg(wb, ln.db.fmter, "UNLISTEN ?", pgChan(channel)); err != nil {
172 return err 207 return err
173 } 208 }
174 } 209 }
@@ -179,24 +214,26 @@ func (ln *Listener) listen(c context.Context, cn *pool.Conn, channels ...string) @@ -179,24 +214,26 @@ func (ln *Listener) listen(c context.Context, cn *pool.Conn, channels ...string)
179 214
180 // Receive indefinitely waits for a notification. This is low-level API 215 // Receive indefinitely waits for a notification. This is low-level API
181 // and in most cases Channel should be used instead. 216 // and in most cases Channel should be used instead.
182 -func (ln *Listener) Receive() (channel string, payload string, err error) {  
183 - return ln.ReceiveTimeout(0) 217 +func (ln *Listener) Receive(ctx context.Context) (channel string, payload string, err error) {
  218 + return ln.ReceiveTimeout(ctx, 0)
184 } 219 }
185 220
186 // ReceiveTimeout waits for a notification until timeout is reached. 221 // ReceiveTimeout waits for a notification until timeout is reached.
187 // This is low-level API and in most cases Channel should be used instead. 222 // This is low-level API and in most cases Channel should be used instead.
188 -func (ln *Listener) ReceiveTimeout(timeout time.Duration) (channel, payload string, err error) {  
189 - cn, err := ln.connWithLock() 223 +func (ln *Listener) ReceiveTimeout(
  224 + ctx context.Context, timeout time.Duration,
  225 +) (channel, payload string, err error) {
  226 + cn, err := ln.connWithLock(ctx)
190 if err != nil { 227 if err != nil {
191 return "", "", err 228 return "", "", err
192 } 229 }
193 230
194 - err = cn.WithReader(context.TODO(), timeout, func(rd *pool.BufReader) error { 231 + err = cn.WithReader(ctx, timeout, func(rd *pool.ReaderContext) error {
195 channel, payload, err = readNotification(rd) 232 channel, payload, err = readNotification(rd)
196 return err 233 return err
197 }) 234 })
198 if err != nil { 235 if err != nil {
199 - ln.releaseConn(cn, err, timeout > 0) 236 + ln.releaseConn(ctx, cn, err, timeout > 0)
200 return "", "", err 237 return "", "", err
201 } 238 }
202 239
@@ -208,17 +245,17 @@ func (ln *Listener) ReceiveTimeout(timeout time.Duration) (channel, payload stri @@ -208,17 +245,17 @@ func (ln *Listener) ReceiveTimeout(timeout time.Duration) (channel, payload stri
208 // 245 //
209 // The channel is closed with Listener. Receive* APIs can not be used 246 // The channel is closed with Listener. Receive* APIs can not be used
210 // after channel is created. 247 // after channel is created.
211 -func (ln *Listener) Channel() <-chan *Notification { 248 +func (ln *Listener) Channel() <-chan Notification {
212 return ln.channel(100) 249 return ln.channel(100)
213 } 250 }
214 251
215 // ChannelSize is like Channel, but creates a Go channel 252 // ChannelSize is like Channel, but creates a Go channel
216 // with specified buffer size. 253 // with specified buffer size.
217 -func (ln *Listener) ChannelSize(size int) <-chan *Notification { 254 +func (ln *Listener) ChannelSize(size int) <-chan Notification {
218 return ln.channel(size) 255 return ln.channel(size)
219 } 256 }
220 257
221 -func (ln *Listener) channel(size int) <-chan *Notification { 258 +func (ln *Listener) channel(size int) <-chan Notification {
222 ln.chOnce.Do(func() { 259 ln.chOnce.Do(func() {
223 ln.initChannel(size) 260 ln.initChannel(size)
224 }) 261 })
@@ -230,29 +267,33 @@ func (ln *Listener) channel(size int) <-chan *Notification { @@ -230,29 +267,33 @@ func (ln *Listener) channel(size int) <-chan *Notification {
230 } 267 }
231 268
232 func (ln *Listener) initChannel(size int) { 269 func (ln *Listener) initChannel(size int) {
233 - const timeout = 30 * time.Second 270 + const pingTimeout = time.Second
  271 + const chanSendTimeout = time.Minute
234 272
235 - _ = ln.Listen(gopgChannel) 273 + ctx := ln.db.ctx
  274 + _ = ln.Listen(ctx, gopgChannel)
236 275
237 - ln.ch = make(chan *Notification, size) 276 + ln.ch = make(chan Notification, size)
238 ln.pingCh = make(chan struct{}, 1) 277 ln.pingCh = make(chan struct{}, 1)
239 278
240 go func() { 279 go func() {
241 - timer := time.NewTimer(timeout) 280 + timer := time.NewTimer(time.Minute)
242 timer.Stop() 281 timer.Stop()
243 282
244 var errCount int 283 var errCount int
245 for { 284 for {
246 - channel, payload, err := ln.Receive() 285 + channel, payload, err := ln.Receive(ctx)
247 if err != nil { 286 if err != nil {
248 if err == errListenerClosed { 287 if err == errListenerClosed {
249 close(ln.ch) 288 close(ln.ch)
250 return 289 return
251 } 290 }
  291 +
252 if errCount > 0 { 292 if errCount > 0 {
253 - time.Sleep(ln.db.retryBackoff(errCount)) 293 + time.Sleep(500 * time.Millisecond)
254 } 294 }
255 errCount++ 295 errCount++
  296 +
256 continue 297 continue
257 } 298 }
258 299
@@ -268,28 +309,31 @@ func (ln *Listener) initChannel(size int) { @@ -268,28 +309,31 @@ func (ln *Listener) initChannel(size int) {
268 case gopgChannel: 309 case gopgChannel:
269 // ignore 310 // ignore
270 default: 311 default:
271 - timer.Reset(timeout) 312 + timer.Reset(chanSendTimeout)
272 select { 313 select {
273 - case ln.ch <- &Notification{channel, payload}: 314 + case ln.ch <- Notification{channel, payload}:
274 if !timer.Stop() { 315 if !timer.Stop() {
275 <-timer.C 316 <-timer.C
276 } 317 }
277 case <-timer.C: 318 case <-timer.C:
278 internal.Logger.Printf( 319 internal.Logger.Printf(
  320 + ctx,
279 "pg: %s channel is full for %s (notification is dropped)", 321 "pg: %s channel is full for %s (notification is dropped)",
280 - ln, timeout) 322 + ln,
  323 + chanSendTimeout,
  324 + )
281 } 325 }
282 } 326 }
283 } 327 }
284 }() 328 }()
285 329
286 go func() { 330 go func() {
287 - timer := time.NewTimer(timeout) 331 + timer := time.NewTimer(time.Minute)
288 timer.Stop() 332 timer.Stop()
289 333
290 healthy := true 334 healthy := true
291 for { 335 for {
292 - timer.Reset(timeout) 336 + timer.Reset(pingTimeout)
293 select { 337 select {
294 case <-ln.pingCh: 338 case <-ln.pingCh:
295 healthy = true 339 healthy = true
@@ -305,7 +349,7 @@ func (ln *Listener) initChannel(size int) { @@ -305,7 +349,7 @@ func (ln *Listener) initChannel(size int) {
305 pingErr = errPingTimeout 349 pingErr = errPingTimeout
306 } 350 }
307 ln.mu.Lock() 351 ln.mu.Lock()
308 - ln.reconnect(pingErr) 352 + ln.reconnect(ctx, pingErr)
309 ln.mu.Unlock() 353 ln.mu.Unlock()
310 } 354 }
311 case <-ln.exit: 355 case <-ln.exit:
@@ -333,6 +377,20 @@ loop: @@ -333,6 +377,20 @@ loop:
333 return ss 377 return ss
334 } 378 }
335 379
  380 +func removeIfExists(ss []string, es ...string) []string {
  381 + for _, e := range es {
  382 + for i, s := range ss {
  383 + if s == e {
  384 + last := len(ss) - 1
  385 + ss[i] = ss[last]
  386 + ss = ss[:last]
  387 + break
  388 + }
  389 + }
  390 + }
  391 + return ss
  392 +}
  393 +
336 type pgChan string 394 type pgChan string
337 395
338 var _ types.ValueAppender = pgChan("") 396 var _ types.ValueAppender = pgChan("")
@@ -19,6 +19,7 @@ import ( @@ -19,6 +19,7 @@ import (
19 "github.com/go-pg/pg/v10/types" 19 "github.com/go-pg/pg/v10/types"
20 ) 20 )
21 21
  22 +// https://www.postgresql.org/docs/current/protocol-message-formats.html
22 const ( 23 const (
23 commandCompleteMsg = 'C' 24 commandCompleteMsg = 'C'
24 errorResponseMsg = 'E' 25 errorResponseMsg = 'E'
@@ -84,7 +85,7 @@ func (db *baseDB) startup( @@ -84,7 +85,7 @@ func (db *baseDB) startup(
84 return err 85 return err
85 } 86 }
86 87
87 - return cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { 88 + return cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
88 for { 89 for {
89 typ, msgLen, err := readMessageType(rd) 90 typ, msgLen, err := readMessageType(rd)
90 if err != nil { 91 if err != nil {
@@ -137,7 +138,7 @@ func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Confi @@ -137,7 +138,7 @@ func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Confi
137 return err 138 return err
138 } 139 }
139 140
140 - err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { 141 + err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error {
141 c, err := rd.ReadByte() 142 c, err := rd.ReadByte()
142 if err != nil { 143 if err != nil {
143 return err 144 return err
@@ -156,7 +157,7 @@ func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Confi @@ -156,7 +157,7 @@ func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Confi
156 } 157 }
157 158
158 func (db *baseDB) auth( 159 func (db *baseDB) auth(
159 - c context.Context, cn *pool.Conn, rd *pool.BufReader, user, password string, 160 + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string,
160 ) error { 161 ) error {
161 num, err := readInt32(rd) 162 num, err := readInt32(rd)
162 if err != nil { 163 if err != nil {
@@ -178,7 +179,7 @@ func (db *baseDB) auth( @@ -178,7 +179,7 @@ func (db *baseDB) auth(
178 } 179 }
179 180
180 func (db *baseDB) authCleartext( 181 func (db *baseDB) authCleartext(
181 - c context.Context, cn *pool.Conn, rd *pool.BufReader, password string, 182 + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, password string,
182 ) error { 183 ) error {
183 err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { 184 err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error {
184 writePasswordMsg(wb, password) 185 writePasswordMsg(wb, password)
@@ -191,7 +192,7 @@ func (db *baseDB) authCleartext( @@ -191,7 +192,7 @@ func (db *baseDB) authCleartext(
191 } 192 }
192 193
193 func (db *baseDB) authMD5( 194 func (db *baseDB) authMD5(
194 - c context.Context, cn *pool.Conn, rd *pool.BufReader, user, password string, 195 + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string,
195 ) error { 196 ) error {
196 b, err := rd.ReadN(4) 197 b, err := rd.ReadN(4)
197 if err != nil { 198 if err != nil {
@@ -210,7 +211,7 @@ func (db *baseDB) authMD5( @@ -210,7 +211,7 @@ func (db *baseDB) authMD5(
210 return readAuthOK(rd) 211 return readAuthOK(rd)
211 } 212 }
212 213
213 -func readAuthOK(rd *pool.BufReader) error { 214 +func readAuthOK(rd *pool.ReaderContext) error {
214 c, _, err := readMessageType(rd) 215 c, _, err := readMessageType(rd)
215 if err != nil { 216 if err != nil {
216 return err 217 return err
@@ -238,7 +239,7 @@ func readAuthOK(rd *pool.BufReader) error { @@ -238,7 +239,7 @@ func readAuthOK(rd *pool.BufReader) error {
238 } 239 }
239 240
240 func (db *baseDB) authSASL( 241 func (db *baseDB) authSASL(
241 - c context.Context, cn *pool.Conn, rd *pool.BufReader, user, password string, 242 + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string,
242 ) error { 243 ) error {
243 s, err := readString(rd) 244 s, err := readString(rd)
244 if err != nil { 245 if err != nil {
@@ -332,7 +333,7 @@ func (db *baseDB) authSASL( @@ -332,7 +333,7 @@ func (db *baseDB) authSASL(
332 } 333 }
333 } 334 }
334 335
335 -func readAuthSASLFinal(rd *pool.BufReader, client *sasl.Negotiator) error { 336 +func readAuthSASLFinal(rd *pool.ReaderContext, client *sasl.Negotiator) error {
336 c, n, err := readMessageType(rd) 337 c, n, err := readMessageType(rd)
337 if err != nil { 338 if err != nil {
338 return err 339 return err
@@ -485,8 +486,8 @@ func writeParseDescribeSyncMsg(buf *pool.WriteBuffer, name, q string) { @@ -485,8 +486,8 @@ func writeParseDescribeSyncMsg(buf *pool.WriteBuffer, name, q string) {
485 writeSyncMsg(buf) 486 writeSyncMsg(buf)
486 } 487 }
487 488
488 -func readParseDescribeSync(rd *pool.BufReader) ([][]byte, error) {  
489 - var columns [][]byte 489 +func readParseDescribeSync(rd *pool.ReaderContext) ([]types.ColumnInfo, error) {
  490 + var columns []types.ColumnInfo
490 var firstErr error 491 var firstErr error
491 for { 492 for {
492 c, msgLen, err := readMessageType(rd) 493 c, msgLen, err := readMessageType(rd)
@@ -500,7 +501,7 @@ func readParseDescribeSync(rd *pool.BufReader) ([][]byte, error) { @@ -500,7 +501,7 @@ func readParseDescribeSync(rd *pool.BufReader) ([][]byte, error) {
500 return nil, err 501 return nil, err
501 } 502 }
502 case rowDescriptionMsg: // Response to the DESCRIBE message. 503 case rowDescriptionMsg: // Response to the DESCRIBE message.
503 - columns, err = readRowDescription(rd, nil) 504 + columns, err = readRowDescription(rd, pool.NewColumnAlloc())
504 if err != nil { 505 if err != nil {
505 return nil, err 506 return nil, err
506 } 507 }
@@ -582,7 +583,7 @@ func writeCloseMsg(buf *pool.WriteBuffer, name string) { @@ -582,7 +583,7 @@ func writeCloseMsg(buf *pool.WriteBuffer, name string) {
582 buf.FinishMessage() 583 buf.FinishMessage()
583 } 584 }
584 585
585 -func readCloseCompleteMsg(rd *pool.BufReader) error { 586 +func readCloseCompleteMsg(rd *pool.ReaderContext) error {
586 for { 587 for {
587 c, msgLen, err := readMessageType(rd) 588 c, msgLen, err := readMessageType(rd)
588 if err != nil { 589 if err != nil {
@@ -612,7 +613,7 @@ func readCloseCompleteMsg(rd *pool.BufReader) error { @@ -612,7 +613,7 @@ func readCloseCompleteMsg(rd *pool.BufReader) error {
612 } 613 }
613 } 614 }
614 615
615 -func readSimpleQuery(rd *pool.BufReader) (*result, error) { 616 +func readSimpleQuery(rd *pool.ReaderContext) (*result, error) {
616 var res result 617 var res result
617 var firstErr error 618 var firstErr error
618 for { 619 for {
@@ -675,7 +676,7 @@ func readSimpleQuery(rd *pool.BufReader) (*result, error) { @@ -675,7 +676,7 @@ func readSimpleQuery(rd *pool.BufReader) (*result, error) {
675 } 676 }
676 } 677 }
677 678
678 -func readExtQuery(rd *pool.BufReader) (*result, error) { 679 +func readExtQuery(rd *pool.ReaderContext) (*result, error) {
679 var res result 680 var res result
680 var firstErr error 681 var firstErr error
681 for { 682 for {
@@ -739,42 +740,47 @@ func readExtQuery(rd *pool.BufReader) (*result, error) { @@ -739,42 +740,47 @@ func readExtQuery(rd *pool.BufReader) (*result, error) {
739 } 740 }
740 } 741 }
741 742
742 -func readRowDescription(rd *pool.BufReader, columns [][]byte) ([][]byte, error) {  
743 - colNum, err := readInt16(rd) 743 +func readRowDescription(
  744 + rd *pool.ReaderContext, columnAlloc *pool.ColumnAlloc,
  745 +) ([]types.ColumnInfo, error) {
  746 + numCol, err := readInt16(rd)
744 if err != nil { 747 if err != nil {
745 return nil, err 748 return nil, err
746 } 749 }
747 750
748 - columns = setByteSliceLen(columns, int(colNum))  
749 - for i := 0; i < int(colNum); i++ { 751 + for i := 0; i < int(numCol); i++ {
750 b, err := rd.ReadSlice(0) 752 b, err := rd.ReadSlice(0)
751 if err != nil { 753 if err != nil {
752 return nil, err 754 return nil, err
753 } 755 }
754 - columns[i] = append(columns[i][:0], b[:len(b)-1]...)  
755 756
756 - _, err = rd.ReadN(18)  
757 - if err != nil { 757 + col := columnAlloc.New(int16(i), b[:len(b)-1])
  758 +
  759 + if _, err := rd.ReadN(6); err != nil {
758 return nil, err 760 return nil, err
759 } 761 }
760 - }  
761 762
762 - return columns, nil  
763 -} 763 + dataType, err := readInt32(rd)
  764 + if err != nil {
  765 + return nil, err
  766 + }
  767 + col.DataType = dataType
764 768
765 -func setByteSliceLen(b [][]byte, n int) [][]byte {  
766 - if n <= cap(b) {  
767 - return b[:n] 769 + if _, err := rd.ReadN(8); err != nil {
  770 + return nil, err
  771 + }
768 } 772 }
769 - b = b[:cap(b)]  
770 - b = append(b, make([][]byte, n-cap(b))...)  
771 - return b 773 +
  774 + return columnAlloc.Columns(), nil
772 } 775 }
773 776
774 func readDataRow( 777 func readDataRow(
775 - ctx context.Context, rd *pool.BufReader, scanner orm.ColumnScanner, columns [][]byte, 778 + ctx context.Context,
  779 + rd *pool.ReaderContext,
  780 + columns []types.ColumnInfo,
  781 + scanner orm.ColumnScanner,
776 ) error { 782 ) error {
777 - colNum, err := readInt16(rd) 783 + numCol, err := readInt16(rd)
778 if err != nil { 784 if err != nil {
779 return err 785 return err
780 } 786 }
@@ -787,35 +793,28 @@ func readDataRow( @@ -787,35 +793,28 @@ func readDataRow(
787 793
788 var firstErr error 794 var firstErr error
789 795
790 - for colIdx := int16(0); colIdx < colNum; colIdx++ { 796 + for colIdx := int16(0); colIdx < numCol; colIdx++ {
791 n, err := readInt32(rd) 797 n, err := readInt32(rd)
792 if err != nil { 798 if err != nil {
793 return err 799 return err
794 } 800 }
795 801
796 - column := internal.BytesToString(columns[colIdx])  
797 var colRd types.Reader 802 var colRd types.Reader
798 - if n >= 0 {  
799 - bytesRd := rd.BytesReader(int(n))  
800 - if bytesRd != nil {  
801 - colRd = bytesRd  
802 - } else {  
803 - rd.SetAvailable(int(n))  
804 - colRd = rd  
805 - } 803 + if int(n) <= rd.Buffered() {
  804 + colRd = rd.BytesReader(int(n))
806 } else { 805 } else {
807 - colRd = rd.BytesReader(0) 806 + rd.SetAvailable(int(n))
  807 + colRd = rd
808 } 808 }
809 809
810 - err = scanner.ScanColumn(int(colIdx), column, colRd, int(n))  
811 - if err != nil && firstErr == nil { 810 + column := columns[colIdx]
  811 + if err := scanner.ScanColumn(column, colRd, int(n)); err != nil && firstErr == nil {
812 firstErr = internal.Errorf(err.Error()) 812 firstErr = internal.Errorf(err.Error())
813 } 813 }
814 814
815 if rd == colRd { 815 if rd == colRd {
816 if rd.Available() > 0 { 816 if rd.Available() > 0 {
817 - _, err = rd.Discard(rd.Available())  
818 - if err != nil && firstErr == nil { 817 + if _, err := rd.Discard(rd.Available()); err != nil && firstErr == nil {
819 firstErr = err 818 firstErr = err
820 } 819 }
821 } 820 }
@@ -841,8 +840,9 @@ func newModel(mod interface{}) (orm.Model, error) { @@ -841,8 +840,9 @@ func newModel(mod interface{}) (orm.Model, error) {
841 } 840 }
842 841
843 func readSimpleQueryData( 842 func readSimpleQueryData(
844 - ctx context.Context, rd *pool.BufReader, mod interface{}, 843 + ctx context.Context, rd *pool.ReaderContext, mod interface{},
845 ) (*result, error) { 844 ) (*result, error) {
  845 + var columns []types.ColumnInfo
846 var res result 846 var res result
847 var firstErr error 847 var firstErr error
848 for { 848 for {
@@ -853,7 +853,7 @@ func readSimpleQueryData( @@ -853,7 +853,7 @@ func readSimpleQueryData(
853 853
854 switch c { 854 switch c {
855 case rowDescriptionMsg: 855 case rowDescriptionMsg:
856 - rd.Columns, err = readRowDescription(rd, rd.Columns[:0]) 856 + columns, err = readRowDescription(rd, rd.ColumnAlloc)
857 if err != nil { 857 if err != nil {
858 return nil, err 858 return nil, err
859 } 859 }
@@ -870,7 +870,7 @@ func readSimpleQueryData( @@ -870,7 +870,7 @@ func readSimpleQueryData(
870 } 870 }
871 case dataRowMsg: 871 case dataRowMsg:
872 scanner := res.model.NextColumnScanner() 872 scanner := res.model.NextColumnScanner()
873 - if err := readDataRow(ctx, rd, scanner, rd.Columns); err != nil { 873 + if err := readDataRow(ctx, rd, columns, scanner); err != nil {
874 if firstErr == nil { 874 if firstErr == nil {
875 firstErr = err 875 firstErr = err
876 } 876 }
@@ -925,7 +925,7 @@ func readSimpleQueryData( @@ -925,7 +925,7 @@ func readSimpleQueryData(
925 } 925 }
926 926
927 func readExtQueryData( 927 func readExtQueryData(
928 - ctx context.Context, rd *pool.BufReader, mod interface{}, columns [][]byte, 928 + ctx context.Context, rd *pool.ReaderContext, mod interface{}, columns []types.ColumnInfo,
929 ) (*result, error) { 929 ) (*result, error) {
930 var res result 930 var res result
931 var firstErr error 931 var firstErr error
@@ -954,7 +954,7 @@ func readExtQueryData( @@ -954,7 +954,7 @@ func readExtQueryData(
954 } 954 }
955 955
956 scanner := res.model.NextColumnScanner() 956 scanner := res.model.NextColumnScanner()
957 - if err := readDataRow(ctx, rd, scanner, columns); err != nil { 957 + if err := readDataRow(ctx, rd, columns, scanner); err != nil {
958 if firstErr == nil { 958 if firstErr == nil {
959 firstErr = err 959 firstErr = err
960 } 960 }
@@ -1004,7 +1004,7 @@ func readExtQueryData( @@ -1004,7 +1004,7 @@ func readExtQueryData(
1004 } 1004 }
1005 } 1005 }
1006 1006
1007 -func readCopyInResponse(rd *pool.BufReader) error { 1007 +func readCopyInResponse(rd *pool.ReaderContext) error {
1008 var firstErr error 1008 var firstErr error
1009 for { 1009 for {
1010 c, msgLen, err := readMessageType(rd) 1010 c, msgLen, err := readMessageType(rd)
@@ -1044,7 +1044,7 @@ func readCopyInResponse(rd *pool.BufReader) error { @@ -1044,7 +1044,7 @@ func readCopyInResponse(rd *pool.BufReader) error {
1044 } 1044 }
1045 } 1045 }
1046 1046
1047 -func readCopyOutResponse(rd *pool.BufReader) error { 1047 +func readCopyOutResponse(rd *pool.ReaderContext) error {
1048 var firstErr error 1048 var firstErr error
1049 for { 1049 for {
1050 c, msgLen, err := readMessageType(rd) 1050 c, msgLen, err := readMessageType(rd)
@@ -1084,7 +1084,7 @@ func readCopyOutResponse(rd *pool.BufReader) error { @@ -1084,7 +1084,7 @@ func readCopyOutResponse(rd *pool.BufReader) error {
1084 } 1084 }
1085 } 1085 }
1086 1086
1087 -func readCopyData(rd *pool.BufReader, w io.Writer) (*result, error) { 1087 +func readCopyData(rd *pool.ReaderContext, w io.Writer) (*result, error) {
1088 var res result 1088 var res result
1089 var firstErr error 1089 var firstErr error
1090 for { 1090 for {
@@ -1162,7 +1162,7 @@ func writeCopyDone(buf *pool.WriteBuffer) { @@ -1162,7 +1162,7 @@ func writeCopyDone(buf *pool.WriteBuffer) {
1162 buf.FinishMessage() 1162 buf.FinishMessage()
1163 } 1163 }
1164 1164
1165 -func readReadyForQuery(rd *pool.BufReader) (*result, error) { 1165 +func readReadyForQuery(rd *pool.ReaderContext) (*result, error) {
1166 var res result 1166 var res result
1167 var firstErr error 1167 var firstErr error
1168 for { 1168 for {
@@ -1211,7 +1211,7 @@ func readReadyForQuery(rd *pool.BufReader) (*result, error) { @@ -1211,7 +1211,7 @@ func readReadyForQuery(rd *pool.BufReader) (*result, error) {
1211 } 1211 }
1212 } 1212 }
1213 1213
1214 -func readNotification(rd *pool.BufReader) (channel, payload string, err error) { 1214 +func readNotification(rd *pool.ReaderContext) (channel, payload string, err error) {
1215 for { 1215 for {
1216 c, msgLen, err := readMessageType(rd) 1216 c, msgLen, err := readMessageType(rd)
1217 if err != nil { 1217 if err != nil {
@@ -1269,17 +1269,17 @@ func terminateConn(cn *pool.Conn) error { @@ -1269,17 +1269,17 @@ func terminateConn(cn *pool.Conn) error {
1269 1269
1270 //------------------------------------------------------------------------------ 1270 //------------------------------------------------------------------------------
1271 1271
1272 -func logNotice(rd *pool.BufReader, msgLen int) error { 1272 +func logNotice(rd *pool.ReaderContext, msgLen int) error {
1273 _, err := rd.ReadN(msgLen) 1273 _, err := rd.ReadN(msgLen)
1274 return err 1274 return err
1275 } 1275 }
1276 1276
1277 -func logParameterStatus(rd *pool.BufReader, msgLen int) error { 1277 +func logParameterStatus(rd *pool.ReaderContext, msgLen int) error {
1278 _, err := rd.ReadN(msgLen) 1278 _, err := rd.ReadN(msgLen)
1279 return err 1279 return err
1280 } 1280 }
1281 1281
1282 -func readInt16(rd *pool.BufReader) (int16, error) { 1282 +func readInt16(rd *pool.ReaderContext) (int16, error) {
1283 b, err := rd.ReadN(2) 1283 b, err := rd.ReadN(2)
1284 if err != nil { 1284 if err != nil {
1285 return 0, err 1285 return 0, err
@@ -1287,7 +1287,7 @@ func readInt16(rd *pool.BufReader) (int16, error) { @@ -1287,7 +1287,7 @@ func readInt16(rd *pool.BufReader) (int16, error) {
1287 return int16(binary.BigEndian.Uint16(b)), nil 1287 return int16(binary.BigEndian.Uint16(b)), nil
1288 } 1288 }
1289 1289
1290 -func readInt32(rd *pool.BufReader) (int32, error) { 1290 +func readInt32(rd *pool.ReaderContext) (int32, error) {
1291 b, err := rd.ReadN(4) 1291 b, err := rd.ReadN(4)
1292 if err != nil { 1292 if err != nil {
1293 return 0, err 1293 return 0, err
@@ -1295,7 +1295,7 @@ func readInt32(rd *pool.BufReader) (int32, error) { @@ -1295,7 +1295,7 @@ func readInt32(rd *pool.BufReader) (int32, error) {
1295 return int32(binary.BigEndian.Uint32(b)), nil 1295 return int32(binary.BigEndian.Uint32(b)), nil
1296 } 1296 }
1297 1297
1298 -func readString(rd *pool.BufReader) (string, error) { 1298 +func readString(rd *pool.ReaderContext) (string, error) {
1299 b, err := rd.ReadSlice(0) 1299 b, err := rd.ReadSlice(0)
1300 if err != nil { 1300 if err != nil {
1301 return "", err 1301 return "", err
@@ -1303,7 +1303,7 @@ func readString(rd *pool.BufReader) (string, error) { @@ -1303,7 +1303,7 @@ func readString(rd *pool.BufReader) (string, error) {
1303 return string(b[:len(b)-1]), nil 1303 return string(b[:len(b)-1]), nil
1304 } 1304 }
1305 1305
1306 -func readError(rd *pool.BufReader) (error, error) { 1306 +func readError(rd *pool.ReaderContext) (error, error) {
1307 m := make(map[byte]string) 1307 m := make(map[byte]string)
1308 for { 1308 for {
1309 c, err := rd.ReadByte() 1309 c, err := rd.ReadByte()
@@ -1322,7 +1322,7 @@ func readError(rd *pool.BufReader) (error, error) { @@ -1322,7 +1322,7 @@ func readError(rd *pool.BufReader) (error, error) {
1322 return internal.NewPGError(m), nil 1322 return internal.NewPGError(m), nil
1323 } 1323 }
1324 1324
1325 -func readMessageType(rd *pool.BufReader) (byte, int, error) { 1325 +func readMessageType(rd *pool.ReaderContext) (byte, int, error) {
1326 c, err := rd.ReadByte() 1326 c, err := rd.ReadByte()
1327 if err != nil { 1327 if err != nil {
1328 return 0, 0, err 1328 return 0, 0, err
@@ -13,7 +13,8 @@ import ( @@ -13,7 +13,8 @@ import (
13 "strings" 13 "strings"
14 "time" 14 "time"
15 15
16 - "go.opentelemetry.io/otel/api/trace" 16 + "go.opentelemetry.io/otel/label"
  17 + "go.opentelemetry.io/otel/trace"
17 18
18 "github.com/go-pg/pg/v10/internal" 19 "github.com/go-pg/pg/v10/internal"
19 "github.com/go-pg/pg/v10/internal/pool" 20 "github.com/go-pg/pg/v10/internal/pool"
@@ -31,6 +32,10 @@ type Options struct { @@ -31,6 +32,10 @@ type Options struct {
31 // Network and Addr options. 32 // Network and Addr options.
32 Dialer func(ctx context.Context, network, addr string) (net.Conn, error) 33 Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
33 34
  35 + // Hook that is called after new connection is established
  36 + // and user is authenticated.
  37 + OnConnect func(ctx context.Context, cn *Conn) error
  38 +
34 User string 39 User string
35 Password string 40 Password string
36 Database string 41 Database string
@@ -53,10 +58,6 @@ type Options struct { @@ -53,10 +58,6 @@ type Options struct {
53 // with a timeout instead of blocking. 58 // with a timeout instead of blocking.
54 WriteTimeout time.Duration 59 WriteTimeout time.Duration
55 60
56 - // Hook that is called after new connection is established  
57 - // and user is authenticated.  
58 - OnConnect func(*Conn) error  
59 -  
60 // Maximum number of retries before giving up. 61 // Maximum number of retries before giving up.
61 // Default is to not retry failed queries. 62 // Default is to not retry failed queries.
62 MaxRetries int 63 MaxRetries int
@@ -110,6 +111,10 @@ func (opt *Options) init() { @@ -110,6 +111,10 @@ func (opt *Options) init() {
110 opt.Addr = "/var/run/postgresql/.s.PGSQL.5432" 111 opt.Addr = "/var/run/postgresql/.s.PGSQL.5432"
111 } 112 }
112 } 113 }
  114 +
  115 + if opt.DialTimeout == 0 {
  116 + opt.DialTimeout = 5 * time.Second
  117 + }
113 if opt.Dialer == nil { 118 if opt.Dialer == nil {
114 opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { 119 opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
115 netDialer := &net.Dialer{ 120 netDialer := &net.Dialer{
@@ -140,10 +145,6 @@ func (opt *Options) init() { @@ -140,10 +145,6 @@ func (opt *Options) init() {
140 } 145 }
141 } 146 }
142 147
143 - if opt.DialTimeout == 0 {  
144 - opt.DialTimeout = 5 * time.Second  
145 - }  
146 -  
147 if opt.IdleTimeout == 0 { 148 if opt.IdleTimeout == 0 {
148 opt.IdleTimeout = 5 * time.Minute 149 opt.IdleTimeout = 5 * time.Minute
149 } 150 }
@@ -262,9 +263,16 @@ func ParseURL(sURL string) (*Options, error) { @@ -262,9 +263,16 @@ func ParseURL(sURL string) (*Options, error) {
262 func (opt *Options) getDialer() func(context.Context) (net.Conn, error) { 263 func (opt *Options) getDialer() func(context.Context) (net.Conn, error) {
263 return func(ctx context.Context) (net.Conn, error) { 264 return func(ctx context.Context) (net.Conn, error) {
264 var conn net.Conn 265 var conn net.Conn
265 - err := internal.WithSpan(ctx, "dialer", func(ctx context.Context, span trace.Span) error { 266 + err := internal.WithSpan(ctx, "pg.dial", func(ctx context.Context, span trace.Span) error {
  267 + span.SetAttributes(
  268 + label.String("db.connection_string", opt.Addr),
  269 + )
  270 +
266 var err error 271 var err error
267 conn, err = opt.Dialer(ctx, opt.Network, opt.Addr) 272 conn, err = opt.Dialer(ctx, opt.Network, opt.Addr)
  273 + if err != nil {
  274 + span.RecordError(err)
  275 + }
268 return err 276 return err
269 }) 277 })
270 return conn, err 278 return conn, err
@@ -8,50 +8,62 @@ type CreateCompositeOptions struct { @@ -8,50 +8,62 @@ type CreateCompositeOptions struct {
8 Varchar int // replaces PostgreSQL data type `text` with `varchar(n)` 8 Varchar int // replaces PostgreSQL data type `text` with `varchar(n)`
9 } 9 }
10 10
11 -func CreateComposite(db DB, model interface{}, opt *CreateCompositeOptions) error {  
12 - q := NewQuery(db, model)  
13 - _, err := q.db.Exec(&createCompositeQuery{ 11 +type CreateCompositeQuery struct {
  12 + q *Query
  13 + opt *CreateCompositeOptions
  14 +}
  15 +
  16 +var (
  17 + _ QueryAppender = (*CreateCompositeQuery)(nil)
  18 + _ QueryCommand = (*CreateCompositeQuery)(nil)
  19 +)
  20 +
  21 +func NewCreateCompositeQuery(q *Query, opt *CreateCompositeOptions) *CreateCompositeQuery {
  22 + return &CreateCompositeQuery{
14 q: q, 23 q: q,
15 opt: opt, 24 opt: opt,
16 - })  
17 - return err 25 + }
18 } 26 }
19 27
20 -type createCompositeQuery struct {  
21 - q *Query  
22 - opt *CreateCompositeOptions 28 +func (q *CreateCompositeQuery) String() string {
  29 + b, err := q.AppendQuery(defaultFmter, nil)
  30 + if err != nil {
  31 + panic(err)
  32 + }
  33 + return string(b)
23 } 34 }
24 35
25 -var _ QueryAppender = (*createCompositeQuery)(nil)  
26 -var _ queryCommand = (*createCompositeQuery)(nil) 36 +func (q *CreateCompositeQuery) Operation() QueryOp {
  37 + return CreateCompositeOp
  38 +}
27 39
28 -func (q *createCompositeQuery) Clone() queryCommand {  
29 - return &createCompositeQuery{ 40 +func (q *CreateCompositeQuery) Clone() QueryCommand {
  41 + return &CreateCompositeQuery{
30 q: q.q.Clone(), 42 q: q.q.Clone(),
31 opt: q.opt, 43 opt: q.opt,
32 } 44 }
33 } 45 }
34 46
35 -func (q *createCompositeQuery) Query() *Query { 47 +func (q *CreateCompositeQuery) Query() *Query {
36 return q.q 48 return q.q
37 } 49 }
38 50
39 -func (q *createCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { 51 +func (q *CreateCompositeQuery) AppendTemplate(b []byte) ([]byte, error) {
40 return q.AppendQuery(dummyFormatter{}, b) 52 return q.AppendQuery(dummyFormatter{}, b)
41 } 53 }
42 54
43 -func (q *createCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { 55 +func (q *CreateCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) {
44 if q.q.stickyErr != nil { 56 if q.q.stickyErr != nil {
45 return nil, q.q.stickyErr 57 return nil, q.q.stickyErr
46 } 58 }
47 - if q.q.model == nil { 59 + if q.q.tableModel == nil {
48 return nil, errModelNil 60 return nil, errModelNil
49 } 61 }
50 62
51 - table := q.q.model.Table() 63 + table := q.q.tableModel.Table()
52 64
53 b = append(b, "CREATE TYPE "...) 65 b = append(b, "CREATE TYPE "...)
54 - b = append(b, q.q.model.Table().Alias...) 66 + b = append(b, table.Alias...)
55 b = append(b, " AS ("...) 67 b = append(b, " AS ("...)
56 68
57 for i, field := range table.Fields { 69 for i, field := range table.Fields {
@@ -5,43 +5,55 @@ type DropCompositeOptions struct { @@ -5,43 +5,55 @@ type DropCompositeOptions struct {
5 Cascade bool 5 Cascade bool
6 } 6 }
7 7
8 -func DropComposite(db DB, model interface{}, opt *DropCompositeOptions) error {  
9 - q := NewQuery(db, model)  
10 - _, err := q.db.Exec(&dropCompositeQuery{ 8 +type DropCompositeQuery struct {
  9 + q *Query
  10 + opt *DropCompositeOptions
  11 +}
  12 +
  13 +var (
  14 + _ QueryAppender = (*DropCompositeQuery)(nil)
  15 + _ QueryCommand = (*DropCompositeQuery)(nil)
  16 +)
  17 +
  18 +func NewDropCompositeQuery(q *Query, opt *DropCompositeOptions) *DropCompositeQuery {
  19 + return &DropCompositeQuery{
11 q: q, 20 q: q,
12 opt: opt, 21 opt: opt,
13 - })  
14 - return err 22 + }
15 } 23 }
16 24
17 -type dropCompositeQuery struct {  
18 - q *Query  
19 - opt *DropCompositeOptions 25 +func (q *DropCompositeQuery) String() string {
  26 + b, err := q.AppendQuery(defaultFmter, nil)
  27 + if err != nil {
  28 + panic(err)
  29 + }
  30 + return string(b)
20 } 31 }
21 32
22 -var _ QueryAppender = (*dropCompositeQuery)(nil)  
23 -var _ queryCommand = (*dropCompositeQuery)(nil) 33 +func (q *DropCompositeQuery) Operation() QueryOp {
  34 + return DropCompositeOp
  35 +}
24 36
25 -func (q *dropCompositeQuery) Clone() queryCommand {  
26 - return &dropCompositeQuery{ 37 +func (q *DropCompositeQuery) Clone() QueryCommand {
  38 + return &DropCompositeQuery{
27 q: q.q.Clone(), 39 q: q.q.Clone(),
28 opt: q.opt, 40 opt: q.opt,
29 } 41 }
30 } 42 }
31 43
32 -func (q *dropCompositeQuery) Query() *Query { 44 +func (q *DropCompositeQuery) Query() *Query {
33 return q.q 45 return q.q
34 } 46 }
35 47
36 -func (q *dropCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { 48 +func (q *DropCompositeQuery) AppendTemplate(b []byte) ([]byte, error) {
37 return q.AppendQuery(dummyFormatter{}, b) 49 return q.AppendQuery(dummyFormatter{}, b)
38 } 50 }
39 51
40 -func (q *dropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { 52 +func (q *DropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) {
41 if q.q.stickyErr != nil { 53 if q.q.stickyErr != nil {
42 return nil, q.q.stickyErr 54 return nil, q.q.stickyErr
43 } 55 }
44 - if q.q.model == nil { 56 + if q.q.tableModel == nil {
45 return nil, errModelNil 57 return nil, errModelNil
46 } 58 }
47 59
@@ -49,7 +61,7 @@ func (q *dropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte @@ -49,7 +61,7 @@ func (q *dropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte
49 if q.opt != nil && q.opt.IfExists { 61 if q.opt != nil && q.opt.IfExists {
50 b = append(b, "IF EXISTS "...) 62 b = append(b, "IF EXISTS "...)
51 } 63 }
52 - b = append(b, q.q.model.Table().Alias...) 64 + b = append(b, q.q.tableModel.Table().Alias...)
53 if q.opt != nil && q.opt.Cascade { 65 if q.opt != nil && q.opt.Cascade {
54 b = append(b, " CASCADE"...) 66 b = append(b, " CASCADE"...)
55 } 67 }
1 package orm 1 package orm
2 2
3 import ( 3 import (
4 - "github.com/go-pg/pg/v10/internal"  
5 -)  
6 -  
7 -// Delete deletes a given model from the db  
8 -func Delete(db DB, model interface{}) error {  
9 - res, err := NewQuery(db, model).WherePK().Delete()  
10 - if err != nil {  
11 - return err  
12 - }  
13 - return internal.AssertOneRow(res.RowsAffected())  
14 -} 4 + "reflect"
15 5
16 -// ForceDelete force deletes a given model from the db  
17 -func ForceDelete(db DB, model interface{}) error {  
18 - res, err := NewQuery(db, model).WherePK().ForceDelete()  
19 - if err != nil {  
20 - return err  
21 - }  
22 - return internal.AssertOneRow(res.RowsAffected())  
23 -} 6 + "github.com/go-pg/pg/v10/types"
  7 +)
24 8
25 -type deleteQuery struct { 9 +type DeleteQuery struct {
26 q *Query 10 q *Query
27 placeholder bool 11 placeholder bool
28 } 12 }
29 13
30 -var _ QueryAppender = (*deleteQuery)(nil)  
31 -var _ queryCommand = (*deleteQuery)(nil) 14 +var (
  15 + _ QueryAppender = (*DeleteQuery)(nil)
  16 + _ QueryCommand = (*DeleteQuery)(nil)
  17 +)
32 18
33 -func newDeleteQuery(q *Query) *deleteQuery {  
34 - return &deleteQuery{ 19 +func NewDeleteQuery(q *Query) *DeleteQuery {
  20 + return &DeleteQuery{
35 q: q, 21 q: q,
36 } 22 }
37 } 23 }
38 24
39 -func (q *deleteQuery) Operation() string { 25 +func (q *DeleteQuery) String() string {
  26 + b, err := q.AppendQuery(defaultFmter, nil)
  27 + if err != nil {
  28 + panic(err)
  29 + }
  30 + return string(b)
  31 +}
  32 +
  33 +func (q *DeleteQuery) Operation() QueryOp {
40 return DeleteOp 34 return DeleteOp
41 } 35 }
42 36
43 -func (q *deleteQuery) Clone() queryCommand {  
44 - return &deleteQuery{ 37 +func (q *DeleteQuery) Clone() QueryCommand {
  38 + return &DeleteQuery{
45 q: q.q.Clone(), 39 q: q.q.Clone(),
46 placeholder: q.placeholder, 40 placeholder: q.placeholder,
47 } 41 }
48 } 42 }
49 43
50 -func (q *deleteQuery) Query() *Query { 44 +func (q *DeleteQuery) Query() *Query {
51 return q.q 45 return q.q
52 } 46 }
53 47
54 -func (q *deleteQuery) AppendTemplate(b []byte) ([]byte, error) {  
55 - cp := q.Clone().(*deleteQuery) 48 +func (q *DeleteQuery) AppendTemplate(b []byte) ([]byte, error) {
  49 + cp := q.Clone().(*DeleteQuery)
56 cp.placeholder = true 50 cp.placeholder = true
57 return cp.AppendQuery(dummyFormatter{}, b) 51 return cp.AppendQuery(dummyFormatter{}, b)
58 } 52 }
59 53
60 -func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { 54 +func (q *DeleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) {
61 if q.q.stickyErr != nil { 55 if q.q.stickyErr != nil {
62 return nil, q.q.stickyErr 56 return nil, q.q.stickyErr
63 } 57 }
@@ -84,7 +78,8 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err @@ -84,7 +78,8 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
84 } 78 }
85 79
86 b = append(b, " WHERE "...) 80 b = append(b, " WHERE "...)
87 - value := q.q.model.Value() 81 + value := q.q.tableModel.Value()
  82 +
88 if q.q.isSliceModelWithData() { 83 if q.q.isSliceModelWithData() {
89 if len(q.q.where) > 0 { 84 if len(q.q.where) > 0 {
90 b, err = q.q.appendWhere(fmter, b) 85 b, err = q.q.appendWhere(fmter, b)
@@ -92,7 +87,7 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err @@ -92,7 +87,7 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
92 return nil, err 87 return nil, err
93 } 88 }
94 } else { 89 } else {
95 - table := q.q.model.Table() 90 + table := q.q.tableModel.Table()
96 err = table.checkPKs() 91 err = table.checkPKs()
97 if err != nil { 92 if err != nil {
98 return nil, err 93 return nil, err
@@ -116,3 +111,48 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err @@ -116,3 +111,48 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
116 111
117 return b, q.q.stickyErr 112 return b, q.q.stickyErr
118 } 113 }
  114 +
  115 +func appendColumnAndSliceValue(
  116 + fmter QueryFormatter, b []byte, slice reflect.Value, alias types.Safe, fields []*Field,
  117 +) []byte {
  118 + if len(fields) > 1 {
  119 + b = append(b, '(')
  120 + }
  121 + b = appendColumns(b, alias, fields)
  122 + if len(fields) > 1 {
  123 + b = append(b, ')')
  124 + }
  125 +
  126 + b = append(b, " IN ("...)
  127 +
  128 + isPlaceholder := isPlaceholderFormatter(fmter)
  129 + sliceLen := slice.Len()
  130 + for i := 0; i < sliceLen; i++ {
  131 + if i > 0 {
  132 + b = append(b, ", "...)
  133 + }
  134 +
  135 + el := indirect(slice.Index(i))
  136 +
  137 + if len(fields) > 1 {
  138 + b = append(b, '(')
  139 + }
  140 + for i, f := range fields {
  141 + if i > 0 {
  142 + b = append(b, ", "...)
  143 + }
  144 + if isPlaceholder {
  145 + b = append(b, '?')
  146 + } else {
  147 + b = f.AppendValue(b, el, 1)
  148 + }
  149 + }
  150 + if len(fields) > 1 {
  151 + b = append(b, ')')
  152 + }
  153 + }
  154 +
  155 + b = append(b, ')')
  156 +
  157 + return b
  158 +}
@@ -70,10 +70,10 @@ func (f *Field) Value(strct reflect.Value) reflect.Value { @@ -70,10 +70,10 @@ func (f *Field) Value(strct reflect.Value) reflect.Value {
70 } 70 }
71 71
72 func (f *Field) HasZeroValue(strct reflect.Value) bool { 72 func (f *Field) HasZeroValue(strct reflect.Value) bool {
73 - return f.hasZeroField(strct, f.Index) 73 + return f.hasZeroValue(strct, f.Index)
74 } 74 }
75 75
76 -func (f *Field) hasZeroField(v reflect.Value, index []int) bool { 76 +func (f *Field) hasZeroValue(v reflect.Value, index []int) bool {
77 for _, idx := range index { 77 for _, idx := range index {
78 if v.Kind() == reflect.Ptr { 78 if v.Kind() == reflect.Ptr {
79 if v.IsNil() { 79 if v.IsNil() {
@@ -106,10 +106,21 @@ func (f *Field) AppendValue(b []byte, strct reflect.Value, quote int) []byte { @@ -106,10 +106,21 @@ func (f *Field) AppendValue(b []byte, strct reflect.Value, quote int) []byte {
106 } 106 }
107 107
108 func (f *Field) ScanValue(strct reflect.Value, rd types.Reader, n int) error { 108 func (f *Field) ScanValue(strct reflect.Value, rd types.Reader, n int) error {
109 - fv := f.Value(strct)  
110 if f.scan == nil { 109 if f.scan == nil {
111 - return fmt.Errorf("pg: ScanValue(unsupported %s)", fv.Type()) 110 + return fmt.Errorf("pg: ScanValue(unsupported %s)", f.Type)
112 } 111 }
  112 +
  113 + var fv reflect.Value
  114 + if n == -1 {
  115 + var ok bool
  116 + fv, ok = fieldByIndex(strct, f.Index)
  117 + if !ok {
  118 + return nil
  119 + }
  120 + } else {
  121 + fv = fieldByIndexAlloc(strct, f.Index)
  122 + }
  123 +
113 return f.scan(fv, rd, n) 124 return f.scan(fv, rd, n)
114 } 125 }
115 126
@@ -26,8 +26,10 @@ type SafeQueryAppender struct { @@ -26,8 +26,10 @@ type SafeQueryAppender struct {
26 params []interface{} 26 params []interface{}
27 } 27 }
28 28
29 -var _ QueryAppender = (*SafeQueryAppender)(nil)  
30 -var _ types.ValueAppender = (*SafeQueryAppender)(nil) 29 +var (
  30 + _ QueryAppender = (*SafeQueryAppender)(nil)
  31 + _ types.ValueAppender = (*SafeQueryAppender)(nil)
  32 +)
31 33
32 //nolint 34 //nolint
33 func SafeQuery(query string, params ...interface{}) *SafeQueryAppender { 35 func SafeQuery(query string, params ...interface{}) *SafeQueryAppender {
@@ -57,8 +59,10 @@ type condGroupAppender struct { @@ -57,8 +59,10 @@ type condGroupAppender struct {
57 cond []queryWithSepAppender 59 cond []queryWithSepAppender
58 } 60 }
59 61
60 -var _ QueryAppender = (*condAppender)(nil)  
61 -var _ queryWithSepAppender = (*condAppender)(nil) 62 +var (
  63 + _ QueryAppender = (*condAppender)(nil)
  64 + _ queryWithSepAppender = (*condAppender)(nil)
  65 +)
62 66
63 func (q *condGroupAppender) AppendSep(b []byte) []byte { 67 func (q *condGroupAppender) AppendSep(b []byte) []byte {
64 return append(b, q.sep...) 68 return append(b, q.sep...)
@@ -87,8 +91,10 @@ type condAppender struct { @@ -87,8 +91,10 @@ type condAppender struct {
87 params []interface{} 91 params []interface{}
88 } 92 }
89 93
90 -var _ QueryAppender = (*condAppender)(nil)  
91 -var _ queryWithSepAppender = (*condAppender)(nil) 94 +var (
  95 + _ QueryAppender = (*condAppender)(nil)
  96 + _ queryWithSepAppender = (*condAppender)(nil)
  97 +)
92 98
93 func (q *condAppender) AppendSep(b []byte) []byte { 99 func (q *condAppender) AppendSep(b []byte) []byte {
94 return append(b, q.sep...) 100 return append(b, q.sep...)
@@ -192,9 +198,9 @@ func (f *Formatter) WithModel(model interface{}) *Formatter { @@ -192,9 +198,9 @@ func (f *Formatter) WithModel(model interface{}) *Formatter {
192 case TableModel: 198 case TableModel:
193 return f.WithTableModel(model) 199 return f.WithTableModel(model)
194 case *Query: 200 case *Query:
195 - return f.WithTableModel(model.model)  
196 - case queryCommand:  
197 - return f.WithTableModel(model.Query().model) 201 + return f.WithTableModel(model.tableModel)
  202 + case QueryCommand:
  203 + return f.WithTableModel(model.Query().tableModel)
198 default: 204 default:
199 panic(fmt.Errorf("pg: unsupported model %T", model)) 205 panic(fmt.Errorf("pg: unsupported model %T", model))
200 } 206 }
@@ -7,44 +7,51 @@ import ( @@ -7,44 +7,51 @@ import (
7 7
8 type hookStubs struct{} 8 type hookStubs struct{}
9 9
10 -var _ AfterSelectHook = (*hookStubs)(nil)  
11 -var _ BeforeInsertHook = (*hookStubs)(nil)  
12 -var _ AfterInsertHook = (*hookStubs)(nil)  
13 -var _ BeforeUpdateHook = (*hookStubs)(nil)  
14 -var _ AfterUpdateHook = (*hookStubs)(nil)  
15 -var _ BeforeDeleteHook = (*hookStubs)(nil)  
16 -var _ AfterDeleteHook = (*hookStubs)(nil)  
17 -  
18 -func (hookStubs) AfterSelect(c context.Context) error { 10 +var (
  11 + _ AfterScanHook = (*hookStubs)(nil)
  12 + _ AfterSelectHook = (*hookStubs)(nil)
  13 + _ BeforeInsertHook = (*hookStubs)(nil)
  14 + _ AfterInsertHook = (*hookStubs)(nil)
  15 + _ BeforeUpdateHook = (*hookStubs)(nil)
  16 + _ AfterUpdateHook = (*hookStubs)(nil)
  17 + _ BeforeDeleteHook = (*hookStubs)(nil)
  18 + _ AfterDeleteHook = (*hookStubs)(nil)
  19 +)
  20 +
  21 +func (hookStubs) AfterScan(ctx context.Context) error {
  22 + return nil
  23 +}
  24 +
  25 +func (hookStubs) AfterSelect(ctx context.Context) error {
19 return nil 26 return nil
20 } 27 }
21 28
22 -func (hookStubs) BeforeInsert(c context.Context) (context.Context, error) {  
23 - return c, nil 29 +func (hookStubs) BeforeInsert(ctx context.Context) (context.Context, error) {
  30 + return ctx, nil
24 } 31 }
25 32
26 -func (hookStubs) AfterInsert(c context.Context) error { 33 +func (hookStubs) AfterInsert(ctx context.Context) error {
27 return nil 34 return nil
28 } 35 }
29 36
30 -func (hookStubs) BeforeUpdate(c context.Context) (context.Context, error) {  
31 - return c, nil 37 +func (hookStubs) BeforeUpdate(ctx context.Context) (context.Context, error) {
  38 + return ctx, nil
32 } 39 }
33 40
34 -func (hookStubs) AfterUpdate(c context.Context) error { 41 +func (hookStubs) AfterUpdate(ctx context.Context) error {
35 return nil 42 return nil
36 } 43 }
37 44
38 -func (hookStubs) BeforeDelete(c context.Context) (context.Context, error) {  
39 - return c, nil 45 +func (hookStubs) BeforeDelete(ctx context.Context) (context.Context, error) {
  46 + return ctx, nil
40 } 47 }
41 48
42 -func (hookStubs) AfterDelete(c context.Context) error { 49 +func (hookStubs) AfterDelete(ctx context.Context) error {
43 return nil 50 return nil
44 } 51 }
45 52
46 func callHookSlice( 53 func callHookSlice(
47 - c context.Context, 54 + ctx context.Context,
48 slice reflect.Value, 55 slice reflect.Value,
49 ptr bool, 56 ptr bool,
50 hook func(context.Context, reflect.Value) (context.Context, error), 57 hook func(context.Context, reflect.Value) (context.Context, error),
@@ -58,16 +65,16 @@ func callHookSlice( @@ -58,16 +65,16 @@ func callHookSlice(
58 } 65 }
59 66
60 var err error 67 var err error
61 - c, err = hook(c, v) 68 + ctx, err = hook(ctx, v)
62 if err != nil && firstErr == nil { 69 if err != nil && firstErr == nil {
63 firstErr = err 70 firstErr = err
64 } 71 }
65 } 72 }
66 - return c, firstErr 73 + return ctx, firstErr
67 } 74 }
68 75
69 func callHookSlice2( 76 func callHookSlice2(
70 - c context.Context, 77 + ctx context.Context,
71 slice reflect.Value, 78 slice reflect.Value,
72 ptr bool, 79 ptr bool,
73 hook func(context.Context, reflect.Value) error, 80 hook func(context.Context, reflect.Value) error,
@@ -81,7 +88,7 @@ func callHookSlice2( @@ -81,7 +88,7 @@ func callHookSlice2(
81 v = v.Addr() 88 v = v.Addr()
82 } 89 }
83 90
84 - err := hook(c, v) 91 + err := hook(ctx, v)
85 if err != nil && firstErr == nil { 92 if err != nil && firstErr == nil {
86 firstErr = err 93 firstErr = err
87 } 94 }
@@ -98,8 +105,8 @@ type BeforeScanHook interface { @@ -98,8 +105,8 @@ type BeforeScanHook interface {
98 105
99 var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem() 106 var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem()
100 107
101 -func callBeforeScanHook(c context.Context, v reflect.Value) error {  
102 - return v.Interface().(BeforeScanHook).BeforeScan(c) 108 +func callBeforeScanHook(ctx context.Context, v reflect.Value) error {
  109 + return v.Interface().(BeforeScanHook).BeforeScan(ctx)
103 } 110 }
104 111
105 //------------------------------------------------------------------------------ 112 //------------------------------------------------------------------------------
@@ -110,8 +117,8 @@ type AfterScanHook interface { @@ -110,8 +117,8 @@ type AfterScanHook interface {
110 117
111 var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() 118 var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem()
112 119
113 -func callAfterScanHook(c context.Context, v reflect.Value) error {  
114 - return v.Interface().(AfterScanHook).AfterScan(c) 120 +func callAfterScanHook(ctx context.Context, v reflect.Value) error {
  121 + return v.Interface().(AfterScanHook).AfterScan(ctx)
115 } 122 }
116 123
117 //------------------------------------------------------------------------------ 124 //------------------------------------------------------------------------------
@@ -122,14 +129,14 @@ type AfterSelectHook interface { @@ -122,14 +129,14 @@ type AfterSelectHook interface {
122 129
123 var afterSelectHookType = reflect.TypeOf((*AfterSelectHook)(nil)).Elem() 130 var afterSelectHookType = reflect.TypeOf((*AfterSelectHook)(nil)).Elem()
124 131
125 -func callAfterSelectHook(c context.Context, v reflect.Value) error {  
126 - return v.Interface().(AfterSelectHook).AfterSelect(c) 132 +func callAfterSelectHook(ctx context.Context, v reflect.Value) error {
  133 + return v.Interface().(AfterSelectHook).AfterSelect(ctx)
127 } 134 }
128 135
129 func callAfterSelectHookSlice( 136 func callAfterSelectHookSlice(
130 - c context.Context, slice reflect.Value, ptr bool, 137 + ctx context.Context, slice reflect.Value, ptr bool,
131 ) error { 138 ) error {
132 - return callHookSlice2(c, slice, ptr, callAfterSelectHook) 139 + return callHookSlice2(ctx, slice, ptr, callAfterSelectHook)
133 } 140 }
134 141
135 //------------------------------------------------------------------------------ 142 //------------------------------------------------------------------------------
@@ -140,14 +147,14 @@ type BeforeInsertHook interface { @@ -140,14 +147,14 @@ type BeforeInsertHook interface {
140 147
141 var beforeInsertHookType = reflect.TypeOf((*BeforeInsertHook)(nil)).Elem() 148 var beforeInsertHookType = reflect.TypeOf((*BeforeInsertHook)(nil)).Elem()
142 149
143 -func callBeforeInsertHook(c context.Context, v reflect.Value) (context.Context, error) {  
144 - return v.Interface().(BeforeInsertHook).BeforeInsert(c) 150 +func callBeforeInsertHook(ctx context.Context, v reflect.Value) (context.Context, error) {
  151 + return v.Interface().(BeforeInsertHook).BeforeInsert(ctx)
145 } 152 }
146 153
147 func callBeforeInsertHookSlice( 154 func callBeforeInsertHookSlice(
148 - c context.Context, slice reflect.Value, ptr bool, 155 + ctx context.Context, slice reflect.Value, ptr bool,
149 ) (context.Context, error) { 156 ) (context.Context, error) {
150 - return callHookSlice(c, slice, ptr, callBeforeInsertHook) 157 + return callHookSlice(ctx, slice, ptr, callBeforeInsertHook)
151 } 158 }
152 159
153 //------------------------------------------------------------------------------ 160 //------------------------------------------------------------------------------
@@ -158,14 +165,14 @@ type AfterInsertHook interface { @@ -158,14 +165,14 @@ type AfterInsertHook interface {
158 165
159 var afterInsertHookType = reflect.TypeOf((*AfterInsertHook)(nil)).Elem() 166 var afterInsertHookType = reflect.TypeOf((*AfterInsertHook)(nil)).Elem()
160 167
161 -func callAfterInsertHook(c context.Context, v reflect.Value) error {  
162 - return v.Interface().(AfterInsertHook).AfterInsert(c) 168 +func callAfterInsertHook(ctx context.Context, v reflect.Value) error {
  169 + return v.Interface().(AfterInsertHook).AfterInsert(ctx)
163 } 170 }
164 171
165 func callAfterInsertHookSlice( 172 func callAfterInsertHookSlice(
166 - c context.Context, slice reflect.Value, ptr bool, 173 + ctx context.Context, slice reflect.Value, ptr bool,
167 ) error { 174 ) error {
168 - return callHookSlice2(c, slice, ptr, callAfterInsertHook) 175 + return callHookSlice2(ctx, slice, ptr, callAfterInsertHook)
169 } 176 }
170 177
171 //------------------------------------------------------------------------------ 178 //------------------------------------------------------------------------------
@@ -176,14 +183,14 @@ type BeforeUpdateHook interface { @@ -176,14 +183,14 @@ type BeforeUpdateHook interface {
176 183
177 var beforeUpdateHookType = reflect.TypeOf((*BeforeUpdateHook)(nil)).Elem() 184 var beforeUpdateHookType = reflect.TypeOf((*BeforeUpdateHook)(nil)).Elem()
178 185
179 -func callBeforeUpdateHook(c context.Context, v reflect.Value) (context.Context, error) {  
180 - return v.Interface().(BeforeUpdateHook).BeforeUpdate(c) 186 +func callBeforeUpdateHook(ctx context.Context, v reflect.Value) (context.Context, error) {
  187 + return v.Interface().(BeforeUpdateHook).BeforeUpdate(ctx)
181 } 188 }
182 189
183 func callBeforeUpdateHookSlice( 190 func callBeforeUpdateHookSlice(
184 - c context.Context, slice reflect.Value, ptr bool, 191 + ctx context.Context, slice reflect.Value, ptr bool,
185 ) (context.Context, error) { 192 ) (context.Context, error) {
186 - return callHookSlice(c, slice, ptr, callBeforeUpdateHook) 193 + return callHookSlice(ctx, slice, ptr, callBeforeUpdateHook)
187 } 194 }
188 195
189 //------------------------------------------------------------------------------ 196 //------------------------------------------------------------------------------
@@ -194,14 +201,14 @@ type AfterUpdateHook interface { @@ -194,14 +201,14 @@ type AfterUpdateHook interface {
194 201
195 var afterUpdateHookType = reflect.TypeOf((*AfterUpdateHook)(nil)).Elem() 202 var afterUpdateHookType = reflect.TypeOf((*AfterUpdateHook)(nil)).Elem()
196 203
197 -func callAfterUpdateHook(c context.Context, v reflect.Value) error {  
198 - return v.Interface().(AfterUpdateHook).AfterUpdate(c) 204 +func callAfterUpdateHook(ctx context.Context, v reflect.Value) error {
  205 + return v.Interface().(AfterUpdateHook).AfterUpdate(ctx)
199 } 206 }
200 207
201 func callAfterUpdateHookSlice( 208 func callAfterUpdateHookSlice(
202 - c context.Context, slice reflect.Value, ptr bool, 209 + ctx context.Context, slice reflect.Value, ptr bool,
203 ) error { 210 ) error {
204 - return callHookSlice2(c, slice, ptr, callAfterUpdateHook) 211 + return callHookSlice2(ctx, slice, ptr, callAfterUpdateHook)
205 } 212 }
206 213
207 //------------------------------------------------------------------------------ 214 //------------------------------------------------------------------------------
@@ -212,14 +219,14 @@ type BeforeDeleteHook interface { @@ -212,14 +219,14 @@ type BeforeDeleteHook interface {
212 219
213 var beforeDeleteHookType = reflect.TypeOf((*BeforeDeleteHook)(nil)).Elem() 220 var beforeDeleteHookType = reflect.TypeOf((*BeforeDeleteHook)(nil)).Elem()
214 221
215 -func callBeforeDeleteHook(c context.Context, v reflect.Value) (context.Context, error) {  
216 - return v.Interface().(BeforeDeleteHook).BeforeDelete(c) 222 +func callBeforeDeleteHook(ctx context.Context, v reflect.Value) (context.Context, error) {
  223 + return v.Interface().(BeforeDeleteHook).BeforeDelete(ctx)
217 } 224 }
218 225
219 func callBeforeDeleteHookSlice( 226 func callBeforeDeleteHookSlice(
220 - c context.Context, slice reflect.Value, ptr bool, 227 + ctx context.Context, slice reflect.Value, ptr bool,
221 ) (context.Context, error) { 228 ) (context.Context, error) {
222 - return callHookSlice(c, slice, ptr, callBeforeDeleteHook) 229 + return callHookSlice(ctx, slice, ptr, callBeforeDeleteHook)
223 } 230 }
224 231
225 //------------------------------------------------------------------------------ 232 //------------------------------------------------------------------------------
@@ -230,12 +237,12 @@ type AfterDeleteHook interface { @@ -230,12 +237,12 @@ type AfterDeleteHook interface {
230 237
231 var afterDeleteHookType = reflect.TypeOf((*AfterDeleteHook)(nil)).Elem() 238 var afterDeleteHookType = reflect.TypeOf((*AfterDeleteHook)(nil)).Elem()
232 239
233 -func callAfterDeleteHook(c context.Context, v reflect.Value) error {  
234 - return v.Interface().(AfterDeleteHook).AfterDelete(c) 240 +func callAfterDeleteHook(ctx context.Context, v reflect.Value) error {
  241 + return v.Interface().(AfterDeleteHook).AfterDelete(ctx)
235 } 242 }
236 243
237 func callAfterDeleteHookSlice( 244 func callAfterDeleteHookSlice(
238 - c context.Context, slice reflect.Value, ptr bool, 245 + ctx context.Context, slice reflect.Value, ptr bool,
239 ) error { 246 ) error {
240 - return callHookSlice2(c, slice, ptr, callAfterDeleteHook) 247 + return callHookSlice2(ctx, slice, ptr, callAfterDeleteHook)
241 } 248 }
@@ -3,55 +3,59 @@ package orm @@ -3,55 +3,59 @@ package orm
3 import ( 3 import (
4 "fmt" 4 "fmt"
5 "reflect" 5 "reflect"
  6 + "sort"
6 7
7 "github.com/go-pg/pg/v10/types" 8 "github.com/go-pg/pg/v10/types"
8 ) 9 )
9 10
10 -func Insert(db DB, model ...interface{}) error {  
11 - _, err := NewQuery(db, model...).Insert()  
12 - return err  
13 -}  
14 -  
15 -type insertQuery struct { 11 +type InsertQuery struct {
16 q *Query 12 q *Query
17 returningFields []*Field 13 returningFields []*Field
18 placeholder bool 14 placeholder bool
19 } 15 }
20 16
21 -var _ queryCommand = (*insertQuery)(nil) 17 +var _ QueryCommand = (*InsertQuery)(nil)
22 18
23 -func newInsertQuery(q *Query) *insertQuery {  
24 - return &insertQuery{ 19 +func NewInsertQuery(q *Query) *InsertQuery {
  20 + return &InsertQuery{
25 q: q, 21 q: q,
26 } 22 }
27 } 23 }
28 24
29 -func (q *insertQuery) Operation() string { 25 +func (q *InsertQuery) String() string {
  26 + b, err := q.AppendQuery(defaultFmter, nil)
  27 + if err != nil {
  28 + panic(err)
  29 + }
  30 + return string(b)
  31 +}
  32 +
  33 +func (q *InsertQuery) Operation() QueryOp {
30 return InsertOp 34 return InsertOp
31 } 35 }
32 36
33 -func (q *insertQuery) Clone() queryCommand {  
34 - return &insertQuery{ 37 +func (q *InsertQuery) Clone() QueryCommand {
  38 + return &InsertQuery{
35 q: q.q.Clone(), 39 q: q.q.Clone(),
36 placeholder: q.placeholder, 40 placeholder: q.placeholder,
37 } 41 }
38 } 42 }
39 43
40 -func (q *insertQuery) Query() *Query { 44 +func (q *InsertQuery) Query() *Query {
41 return q.q 45 return q.q
42 } 46 }
43 47
44 -var _ TemplateAppender = (*insertQuery)(nil) 48 +var _ TemplateAppender = (*InsertQuery)(nil)
45 49
46 -func (q *insertQuery) AppendTemplate(b []byte) ([]byte, error) {  
47 - cp := q.Clone().(*insertQuery) 50 +func (q *InsertQuery) AppendTemplate(b []byte) ([]byte, error) {
  51 + cp := q.Clone().(*InsertQuery)
48 cp.placeholder = true 52 cp.placeholder = true
49 return cp.AppendQuery(dummyFormatter{}, b) 53 return cp.AppendQuery(dummyFormatter{}, b)
50 } 54 }
51 55
52 -var _ QueryAppender = (*insertQuery)(nil) 56 +var _ QueryAppender = (*InsertQuery)(nil)
53 57
54 -func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { 58 +func (q *InsertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) {
55 if q.q.stickyErr != nil { 59 if q.q.stickyErr != nil {
56 return nil, q.q.stickyErr 60 return nil, q.q.stickyErr
57 } 61 }
@@ -73,54 +77,9 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err @@ -73,54 +77,9 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
73 return nil, err 77 return nil, err
74 } 78 }
75 79
76 - if q.q.hasMultiTables() {  
77 - if q.q.columns != nil {  
78 - b = append(b, " ("...)  
79 - b, err = q.q.appendColumns(fmter, b)  
80 - if err != nil {  
81 - return nil, err  
82 - }  
83 - b = append(b, ")"...)  
84 - }  
85 - b = append(b, " SELECT * FROM "...)  
86 - b, err = q.q.appendOtherTables(fmter, b)  
87 - if err != nil {  
88 - return nil, err  
89 - }  
90 - } else {  
91 - if !q.q.hasModel() {  
92 - return nil, errModelNil  
93 - }  
94 -  
95 - fields, err := q.q.getFields()  
96 - if err != nil {  
97 - return nil, err  
98 - }  
99 -  
100 - if len(fields) == 0 {  
101 - fields = q.q.model.Table().Fields  
102 - }  
103 - value := q.q.model.Value()  
104 -  
105 - b = append(b, " ("...)  
106 - b = q.appendColumns(b, fields)  
107 - b = append(b, ") VALUES ("...)  
108 - if m, ok := q.q.model.(*sliceTableModel); ok {  
109 - if m.sliceLen == 0 {  
110 - err = fmt.Errorf("pg: can't bulk-insert empty slice %s", value.Type())  
111 - return nil, err  
112 - }  
113 - b, err = q.appendSliceValues(fmter, b, fields, value)  
114 - if err != nil {  
115 - return nil, err  
116 - }  
117 - } else {  
118 - b, err = q.appendValues(fmter, b, fields, value)  
119 - if err != nil {  
120 - return nil, err  
121 - }  
122 - }  
123 - b = append(b, ")"...) 80 + b, err = q.appendColumnsValues(fmter, b)
  81 + if err != nil {
  82 + return nil, err
124 } 83 }
125 84
126 if q.q.onConflict != nil { 85 if q.q.onConflict != nil {
@@ -143,7 +102,7 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err @@ -143,7 +102,7 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
143 } 102 }
144 103
145 if len(fields) == 0 { 104 if len(fields) == 0 {
146 - fields = q.q.model.Table().DataFields 105 + fields = q.q.tableModel.Table().DataFields
147 } 106 }
148 107
149 b = q.appendSetExcluded(b, fields) 108 b = q.appendSetExcluded(b, fields)
@@ -171,7 +130,103 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err @@ -171,7 +130,103 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err
171 return b, q.q.stickyErr 130 return b, q.q.stickyErr
172 } 131 }
173 132
174 -func (q *insertQuery) appendValues( 133 +func (q *InsertQuery) appendColumnsValues(fmter QueryFormatter, b []byte) (_ []byte, err error) {
  134 + if q.q.hasMultiTables() {
  135 + if q.q.columns != nil {
  136 + b = append(b, " ("...)
  137 + b, err = q.q.appendColumns(fmter, b)
  138 + if err != nil {
  139 + return nil, err
  140 + }
  141 + b = append(b, ")"...)
  142 + }
  143 +
  144 + b = append(b, " SELECT * FROM "...)
  145 + b, err = q.q.appendOtherTables(fmter, b)
  146 + if err != nil {
  147 + return nil, err
  148 + }
  149 +
  150 + return b, nil
  151 + }
  152 +
  153 + if m, ok := q.q.model.(*mapModel); ok {
  154 + return q.appendMapColumnsValues(b, m.m), nil
  155 + }
  156 +
  157 + if !q.q.hasTableModel() {
  158 + return nil, errModelNil
  159 + }
  160 +
  161 + fields, err := q.q.getFields()
  162 + if err != nil {
  163 + return nil, err
  164 + }
  165 +
  166 + if len(fields) == 0 {
  167 + fields = q.q.tableModel.Table().Fields
  168 + }
  169 + value := q.q.tableModel.Value()
  170 +
  171 + b = append(b, " ("...)
  172 + b = q.appendColumns(b, fields)
  173 + b = append(b, ") VALUES ("...)
  174 + if m, ok := q.q.tableModel.(*sliceTableModel); ok {
  175 + if m.sliceLen == 0 {
  176 + err = fmt.Errorf("pg: can't bulk-insert empty slice %s", value.Type())
  177 + return nil, err
  178 + }
  179 + b, err = q.appendSliceValues(fmter, b, fields, value)
  180 + if err != nil {
  181 + return nil, err
  182 + }
  183 + } else {
  184 + b, err = q.appendValues(fmter, b, fields, value)
  185 + if err != nil {
  186 + return nil, err
  187 + }
  188 + }
  189 + b = append(b, ")"...)
  190 +
  191 + return b, nil
  192 +}
  193 +
  194 +func (q *InsertQuery) appendMapColumnsValues(b []byte, m map[string]interface{}) []byte {
  195 + keys := make([]string, 0, len(m))
  196 +
  197 + for k := range m {
  198 + keys = append(keys, k)
  199 + }
  200 + sort.Strings(keys)
  201 +
  202 + b = append(b, " ("...)
  203 +
  204 + for i, k := range keys {
  205 + if i > 0 {
  206 + b = append(b, ", "...)
  207 + }
  208 + b = types.AppendIdent(b, k, 1)
  209 + }
  210 +
  211 + b = append(b, ") VALUES ("...)
  212 +
  213 + for i, k := range keys {
  214 + if i > 0 {
  215 + b = append(b, ", "...)
  216 + }
  217 + if q.placeholder {
  218 + b = append(b, '?')
  219 + } else {
  220 + b = types.Append(b, m[k], 1)
  221 + }
  222 + }
  223 +
  224 + b = append(b, ")"...)
  225 +
  226 + return b
  227 +}
  228 +
  229 +func (q *InsertQuery) appendValues(
175 fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value, 230 fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value,
176 ) (_ []byte, err error) { 231 ) (_ []byte, err error) {
177 for i, f := range fields { 232 for i, f := range fields {
@@ -214,7 +269,7 @@ func (q *insertQuery) appendValues( @@ -214,7 +269,7 @@ func (q *insertQuery) appendValues(
214 return b, nil 269 return b, nil
215 } 270 }
216 271
217 -func (q *insertQuery) appendSliceValues( 272 +func (q *InsertQuery) appendSliceValues(
218 fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value, 273 fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value,
219 ) (_ []byte, err error) { 274 ) (_ []byte, err error) {
220 if q.placeholder { 275 if q.placeholder {
@@ -247,7 +302,7 @@ func (q *insertQuery) appendSliceValues( @@ -247,7 +302,7 @@ func (q *insertQuery) appendSliceValues(
247 return b, nil 302 return b, nil
248 } 303 }
249 304
250 -func (q *insertQuery) addReturningField(field *Field) { 305 +func (q *InsertQuery) addReturningField(field *Field) {
251 if len(q.q.returning) > 0 { 306 if len(q.q.returning) > 0 {
252 return 307 return
253 } 308 }
@@ -259,7 +314,7 @@ func (q *insertQuery) addReturningField(field *Field) { @@ -259,7 +314,7 @@ func (q *insertQuery) addReturningField(field *Field) {
259 q.returningFields = append(q.returningFields, field) 314 q.returningFields = append(q.returningFields, field)
260 } 315 }
261 316
262 -func (q *insertQuery) appendSetExcluded(b []byte, fields []*Field) []byte { 317 +func (q *InsertQuery) appendSetExcluded(b []byte, fields []*Field) []byte {
263 b = append(b, " SET "...) 318 b = append(b, " SET "...)
264 for i, f := range fields { 319 for i, f := range fields {
265 if i > 0 { 320 if i > 0 {
@@ -272,7 +327,7 @@ func (q *insertQuery) appendSetExcluded(b []byte, fields []*Field) []byte { @@ -272,7 +327,7 @@ func (q *insertQuery) appendSetExcluded(b []byte, fields []*Field) []byte {
272 return b 327 return b
273 } 328 }
274 329
275 -func (q *insertQuery) appendColumns(b []byte, fields []*Field) []byte { 330 +func (q *InsertQuery) appendColumns(b []byte, fields []*Field) []byte {
276 b = appendColumns(b, "", fields) 331 b = appendColumns(b, "", fields)
277 for i, v := range q.q.extraValues { 332 for i, v := range q.q.extraValues {
278 if i > 0 || len(fields) > 0 { 333 if i > 0 || len(fields) > 0 {
@@ -64,16 +64,16 @@ func (j *join) manyQuery(q *Query) (*Query, error) { @@ -64,16 +64,16 @@ func (j *join) manyQuery(q *Query) (*Query, error) {
64 64
65 baseTable := j.BaseModel.Table() 65 baseTable := j.BaseModel.Table()
66 var where []byte 66 var where []byte
67 - if len(j.Rel.FKs) > 1 { 67 + if len(j.Rel.JoinFKs) > 1 {
68 where = append(where, '(') 68 where = append(where, '(')
69 } 69 }
70 - where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.FKs)  
71 - if len(j.Rel.FKs) > 1 { 70 + where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.JoinFKs)
  71 + if len(j.Rel.JoinFKs) > 1 {
72 where = append(where, ')') 72 where = append(where, ')')
73 } 73 }
74 where = append(where, " IN ("...) 74 where = append(where, " IN ("...)
75 where = appendChildValues( 75 where = appendChildValues(
76 - where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.FKValues) 76 + where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.BaseFKs)
77 where = append(where, ")"...) 77 where = append(where, ")"...)
78 q = q.Where(internal.BytesToString(where)) 78 q = q.Where(internal.BytesToString(where))
79 79
@@ -126,7 +126,7 @@ func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) { @@ -126,7 +126,7 @@ func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) {
126 join = append(join, " AS "...) 126 join = append(join, " AS "...)
127 join = append(join, j.Rel.M2MTableAlias...) 127 join = append(join, j.Rel.M2MTableAlias...)
128 join = append(join, " ON ("...) 128 join = append(join, " ON ("...)
129 - for i, col := range j.Rel.BaseFKs { 129 + for i, col := range j.Rel.M2MBaseFKs {
130 if i > 0 { 130 if i > 0 {
131 join = append(join, ", "...) 131 join = append(join, ", "...)
132 } 132 }
@@ -140,10 +140,7 @@ func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) { @@ -140,10 +140,7 @@ func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) {
140 q = q.Join(internal.BytesToString(join)) 140 q = q.Join(internal.BytesToString(join))
141 141
142 joinTable := j.JoinModel.Table() 142 joinTable := j.JoinModel.Table()
143 - for i, col := range j.Rel.JoinFKs {  
144 - if i >= len(joinTable.PKs) {  
145 - break  
146 - } 143 + for i, col := range j.Rel.M2MJoinFKs {
147 pk := joinTable.PKs[i] 144 pk := joinTable.PKs[i]
148 q = q.Where("?.? = ?.?", 145 q = q.Where("?.? = ?.?",
149 joinTable.Alias, pk.Column, 146 joinTable.Alias, pk.Column,
@@ -242,7 +239,7 @@ func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []b @@ -242,7 +239,7 @@ func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []b
242 isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) 239 isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag)
243 240
244 b = append(b, "LEFT JOIN "...) 241 b = append(b, "LEFT JOIN "...)
245 - b = fmter.FormatQuery(b, string(j.JoinModel.Table().FullNameForSelects)) 242 + b = fmter.FormatQuery(b, string(j.JoinModel.Table().SQLNameForSelects))
246 b = append(b, " AS "...) 243 b = append(b, " AS "...)
247 b = j.appendAlias(b) 244 b = j.appendAlias(b)
248 245
@@ -252,38 +249,22 @@ func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []b @@ -252,38 +249,22 @@ func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []b
252 b = append(b, '(') 249 b = append(b, '(')
253 } 250 }
254 251
255 - if len(j.Rel.FKs) > 1 { 252 + if len(j.Rel.BaseFKs) > 1 {
256 b = append(b, '(') 253 b = append(b, '(')
257 } 254 }
258 - if j.Rel.Type == HasOneRelation {  
259 - for i, fk := range j.Rel.FKs {  
260 - if i > 0 {  
261 - b = append(b, " AND "...)  
262 - }  
263 - b = j.appendAlias(b)  
264 - b = append(b, '.')  
265 - b = append(b, j.Rel.JoinTable.PKs[i].Column...)  
266 - b = append(b, " = "...)  
267 - b = j.appendBaseAlias(b)  
268 - b = append(b, '.')  
269 - b = append(b, fk.Column...)  
270 - }  
271 - } else {  
272 - baseTable := j.BaseModel.Table()  
273 - for i, fk := range j.Rel.FKs {  
274 - if i > 0 {  
275 - b = append(b, " AND "...)  
276 - }  
277 - b = j.appendAlias(b)  
278 - b = append(b, '.')  
279 - b = append(b, fk.Column...)  
280 - b = append(b, " = "...)  
281 - b = j.appendBaseAlias(b)  
282 - b = append(b, '.')  
283 - b = append(b, baseTable.PKs[i].Column...) 255 + for i, baseFK := range j.Rel.BaseFKs {
  256 + if i > 0 {
  257 + b = append(b, " AND "...)
284 } 258 }
  259 + b = j.appendAlias(b)
  260 + b = append(b, '.')
  261 + b = append(b, j.Rel.JoinFKs[i].Column...)
  262 + b = append(b, " = "...)
  263 + b = j.appendBaseAlias(b)
  264 + b = append(b, '.')
  265 + b = append(b, baseFK.Column...)
285 } 266 }
286 - if len(j.Rel.FKs) > 1 { 267 + if len(j.Rel.BaseFKs) > 1 {
287 b = append(b, ')') 268 b = append(b, ')')
288 } 269 }
289 270
@@ -31,6 +31,7 @@ type HooklessModel interface { @@ -31,6 +31,7 @@ type HooklessModel interface {
31 type Model interface { 31 type Model interface {
32 HooklessModel 32 HooklessModel
33 33
  34 + AfterScanHook
34 AfterSelectHook 35 AfterSelectHook
35 36
36 BeforeInsertHook 37 BeforeInsertHook
@@ -43,45 +44,90 @@ type Model interface { @@ -43,45 +44,90 @@ type Model interface {
43 AfterDeleteHook 44 AfterDeleteHook
44 } 45 }
45 46
46 -func NewModel(values ...interface{}) (Model, error) { 47 +func NewModel(value interface{}) (Model, error) {
  48 + return newModel(value, false)
  49 +}
  50 +
  51 +func newScanModel(values []interface{}) (Model, error) {
47 if len(values) > 1 { 52 if len(values) > 1 {
48 return Scan(values...), nil 53 return Scan(values...), nil
49 } 54 }
  55 + return newModel(values[0], true)
  56 +}
50 57
51 - v0 := values[0]  
52 - switch v0 := v0.(type) { 58 +func newModel(value interface{}, scan bool) (Model, error) {
  59 + switch value := value.(type) {
53 case Model: 60 case Model:
54 - return v0, nil 61 + return value, nil
55 case HooklessModel: 62 case HooklessModel:
56 - return newModelWithHookStubs(v0), nil 63 + return newModelWithHookStubs(value), nil
57 case types.ValueScanner, sql.Scanner: 64 case types.ValueScanner, sql.Scanner:
58 - return Scan(v0), nil 65 + if !scan {
  66 + return nil, fmt.Errorf("pg: Model(unsupported %T)", value)
  67 + }
  68 + return Scan(value), nil
59 } 69 }
60 70
61 - v := reflect.ValueOf(v0) 71 + v := reflect.ValueOf(value)
62 if !v.IsValid() { 72 if !v.IsValid() {
63 return nil, errModelNil 73 return nil, errModelNil
64 } 74 }
65 if v.Kind() != reflect.Ptr { 75 if v.Kind() != reflect.Ptr {
66 - return nil, fmt.Errorf("pg: Model(non-pointer %T)", v0) 76 + return nil, fmt.Errorf("pg: Model(non-pointer %T)", value)
  77 + }
  78 +
  79 + if v.IsNil() {
  80 + typ := v.Type().Elem()
  81 + if typ.Kind() == reflect.Struct {
  82 + return newStructTableModel(GetTable(typ)), nil
  83 + }
  84 + return nil, errModelNil
67 } 85 }
  86 +
68 v = v.Elem() 87 v = v.Elem()
69 88
  89 + if v.Kind() == reflect.Interface {
  90 + if !v.IsNil() {
  91 + v = v.Elem()
  92 + if v.Kind() != reflect.Ptr {
  93 + return nil, fmt.Errorf("pg: Model(non-pointer %s)", v.Type().String())
  94 + }
  95 + }
  96 + }
  97 +
70 switch v.Kind() { 98 switch v.Kind() {
71 case reflect.Struct: 99 case reflect.Struct:
72 if v.Type() != timeType { 100 if v.Type() != timeType {
73 return newStructTableModelValue(v), nil 101 return newStructTableModelValue(v), nil
74 } 102 }
75 case reflect.Slice: 103 case reflect.Slice:
76 - typ := v.Type()  
77 - elemType := indirectType(typ.Elem())  
78 - if elemType.Kind() == reflect.Struct && elemType != timeType {  
79 - return newSliceTableModel(v, elemType), nil 104 + elemType := sliceElemType(v)
  105 + switch elemType.Kind() {
  106 + case reflect.Struct:
  107 + if elemType != timeType {
  108 + return newSliceTableModel(v, elemType), nil
  109 + }
  110 + case reflect.Map:
  111 + if err := validMap(elemType); err != nil {
  112 + return nil, err
  113 + }
  114 + slicePtr := v.Addr().Interface().(*[]map[string]interface{})
  115 + return newMapSliceModel(slicePtr), nil
80 } 116 }
81 return newSliceModel(v, elemType), nil 117 return newSliceModel(v, elemType), nil
  118 + case reflect.Map:
  119 + typ := v.Type()
  120 + if err := validMap(typ); err != nil {
  121 + return nil, err
  122 + }
  123 + mapPtr := v.Addr().Interface().(*map[string]interface{})
  124 + return newMapModel(mapPtr), nil
82 } 125 }
83 126
84 - return Scan(v0), nil 127 + if !scan {
  128 + return nil, fmt.Errorf("pg: Model(unsupported %T)", value)
  129 + }
  130 + return Scan(value), nil
85 } 131 }
86 132
87 type modelWithHookStubs struct { 133 type modelWithHookStubs struct {
@@ -94,3 +140,11 @@ func newModelWithHookStubs(m HooklessModel) Model { @@ -94,3 +140,11 @@ func newModelWithHookStubs(m HooklessModel) Model {
94 HooklessModel: m, 140 HooklessModel: m,
95 } 141 }
96 } 142 }
  143 +
  144 +func validMap(typ reflect.Type) error {
  145 + if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface {
  146 + return fmt.Errorf("pg: Model(unsupported %s, expected *map[string]interface{})",
  147 + typ.String())
  148 + }
  149 + return nil
  150 +}
@@ -22,6 +22,6 @@ func (m Discard) AddColumnScanner(ColumnScanner) error { @@ -22,6 +22,6 @@ func (m Discard) AddColumnScanner(ColumnScanner) error {
22 return nil 22 return nil
23 } 23 }
24 24
25 -func (m Discard) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error { 25 +func (m Discard) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
26 return nil 26 return nil
27 } 27 }
  1 +package orm
  2 +
  3 +import (
  4 + "github.com/go-pg/pg/v10/types"
  5 +)
  6 +
  7 +type mapModel struct {
  8 + hookStubs
  9 + ptr *map[string]interface{}
  10 + m map[string]interface{}
  11 +}
  12 +
  13 +var _ Model = (*mapModel)(nil)
  14 +
  15 +func newMapModel(ptr *map[string]interface{}) *mapModel {
  16 + model := &mapModel{
  17 + ptr: ptr,
  18 + }
  19 + if ptr != nil {
  20 + model.m = *ptr
  21 + }
  22 + return model
  23 +}
  24 +
  25 +func (m *mapModel) Init() error {
  26 + return nil
  27 +}
  28 +
  29 +func (m *mapModel) NextColumnScanner() ColumnScanner {
  30 + if m.m == nil {
  31 + m.m = make(map[string]interface{})
  32 + *m.ptr = m.m
  33 + }
  34 + return m
  35 +}
  36 +
  37 +func (m mapModel) AddColumnScanner(ColumnScanner) error {
  38 + return nil
  39 +}
  40 +
  41 +func (m *mapModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
  42 + val, err := types.ReadColumnValue(col, rd, n)
  43 + if err != nil {
  44 + return err
  45 + }
  46 +
  47 + m.m[col.Name] = val
  48 + return nil
  49 +}
  50 +
  51 +func (mapModel) useQueryOne() bool {
  52 + return true
  53 +}
  1 +package orm
  2 +
  3 +type mapSliceModel struct {
  4 + mapModel
  5 + slice *[]map[string]interface{}
  6 +}
  7 +
  8 +var _ Model = (*mapSliceModel)(nil)
  9 +
  10 +func newMapSliceModel(ptr *[]map[string]interface{}) *mapSliceModel {
  11 + return &mapSliceModel{
  12 + slice: ptr,
  13 + }
  14 +}
  15 +
  16 +func (m *mapSliceModel) Init() error {
  17 + slice := *m.slice
  18 + if len(slice) > 0 {
  19 + *m.slice = slice[:0]
  20 + }
  21 + return nil
  22 +}
  23 +
  24 +func (m *mapSliceModel) NextColumnScanner() ColumnScanner {
  25 + slice := *m.slice
  26 + if len(slice) == cap(slice) {
  27 + m.mapModel.m = make(map[string]interface{})
  28 + *m.slice = append(slice, m.mapModel.m) //nolint:gocritic
  29 + return m
  30 + }
  31 +
  32 + slice = slice[:len(slice)+1]
  33 + el := slice[len(slice)-1]
  34 + if el != nil {
  35 + m.mapModel.m = el
  36 + } else {
  37 + el = make(map[string]interface{})
  38 + slice[len(slice)-1] = el
  39 + m.mapModel.m = el
  40 + }
  41 + *m.slice = slice
  42 + return m
  43 +}
  44 +
  45 +func (mapSliceModel) useQueryOne() {} //nolint:unused
@@ -29,12 +29,12 @@ func (m scanValuesModel) NextColumnScanner() ColumnScanner { @@ -29,12 +29,12 @@ func (m scanValuesModel) NextColumnScanner() ColumnScanner {
29 return m 29 return m
30 } 30 }
31 31
32 -func (m scanValuesModel) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error {  
33 - if colIdx >= len(m.values) { 32 +func (m scanValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
  33 + if int(col.Index) >= len(m.values) {
34 return fmt.Errorf("pg: no Scan var for column index=%d name=%q", 34 return fmt.Errorf("pg: no Scan var for column index=%d name=%q",
35 - colIdx, colName) 35 + col.Index, col.Name)
36 } 36 }
37 - return types.Scan(m.values[colIdx], rd, n) 37 + return types.Scan(m.values[col.Index], rd, n)
38 } 38 }
39 39
40 //------------------------------------------------------------------------------ 40 //------------------------------------------------------------------------------
@@ -60,10 +60,10 @@ func (m scanReflectValuesModel) NextColumnScanner() ColumnScanner { @@ -60,10 +60,10 @@ func (m scanReflectValuesModel) NextColumnScanner() ColumnScanner {
60 return m 60 return m
61 } 61 }
62 62
63 -func (m scanReflectValuesModel) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error {  
64 - if colIdx >= len(m.values) { 63 +func (m scanReflectValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
  64 + if int(col.Index) >= len(m.values) {
65 return fmt.Errorf("pg: no Scan var for column index=%d name=%q", 65 return fmt.Errorf("pg: no Scan var for column index=%d name=%q",
66 - colIdx, colName) 66 + col.Index, col.Name)
67 } 67 }
68 - return types.ScanValue(m.values[colIdx], rd, n) 68 + return types.ScanValue(m.values[col.Index], rd, n)
69 } 69 }
@@ -34,7 +34,7 @@ func (m *sliceModel) NextColumnScanner() ColumnScanner { @@ -34,7 +34,7 @@ func (m *sliceModel) NextColumnScanner() ColumnScanner {
34 return m 34 return m
35 } 35 }
36 36
37 -func (m *sliceModel) ScanColumn(colIdx int, _ string, rd types.Reader, n int) error { 37 +func (m *sliceModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
38 if m.nextElem == nil { 38 if m.nextElem == nil {
39 m.nextElem = internal.MakeSliceNextElemFunc(m.slice) 39 m.nextElem = internal.MakeSliceNextElemFunc(m.slice)
40 } 40 }
@@ -27,56 +27,8 @@ type TableModel interface { @@ -27,56 +27,8 @@ type TableModel interface {
27 Kind() reflect.Kind 27 Kind() reflect.Kind
28 Value() reflect.Value 28 Value() reflect.Value
29 29
30 - setSoftDeleteField()  
31 - scanColumn(int, string, types.Reader, int) (bool, error)  
32 -}  
33 -  
34 -func newTableModel(value interface{}) (TableModel, error) {  
35 - if value, ok := value.(TableModel); ok {  
36 - return value, nil  
37 - }  
38 -  
39 - v := reflect.ValueOf(value)  
40 - if !v.IsValid() {  
41 - return nil, errModelNil  
42 - }  
43 - if v.Kind() != reflect.Ptr {  
44 - return nil, fmt.Errorf("pg: Model(non-pointer %T)", value)  
45 - }  
46 -  
47 - if v.IsNil() {  
48 - typ := v.Type().Elem()  
49 - if typ.Kind() == reflect.Struct {  
50 - return newStructTableModel(GetTable(typ)), nil  
51 - }  
52 - return nil, errModelNil  
53 - }  
54 -  
55 - v = v.Elem()  
56 - if v.Kind() == reflect.Interface {  
57 - if !v.IsNil() {  
58 - v = v.Elem()  
59 - if v.Kind() != reflect.Ptr {  
60 - return nil, fmt.Errorf("pg: Model(non-pointer %s)", v.Type().String())  
61 - }  
62 - }  
63 - }  
64 -  
65 - return newTableModelValue(v)  
66 -}  
67 -  
68 -func newTableModelValue(v reflect.Value) (TableModel, error) {  
69 - switch v.Kind() {  
70 - case reflect.Struct:  
71 - return newStructTableModelValue(v), nil  
72 - case reflect.Slice:  
73 - elemType := sliceElemType(v)  
74 - if elemType.Kind() == reflect.Struct {  
75 - return newSliceTableModel(v, elemType), nil  
76 - }  
77 - }  
78 -  
79 - return nil, fmt.Errorf("pg: Model(unsupported %s)", v.Type()) 30 + setSoftDeleteField() error
  31 + scanColumn(types.ColumnInfo, types.Reader, int) (bool, error)
80 } 32 }
81 33
82 func newTableModelIndex(typ reflect.Type, root reflect.Value, index []int, rel *Relation) (TableModel, error) { 34 func newTableModelIndex(typ reflect.Type, root reflect.Value, index []int, rel *Relation) (TableModel, error) {
@@ -4,6 +4,7 @@ import ( @@ -4,6 +4,7 @@ import (
4 "fmt" 4 "fmt"
5 "reflect" 5 "reflect"
6 6
  7 + "github.com/go-pg/pg/v10/internal/pool"
7 "github.com/go-pg/pg/v10/types" 8 "github.com/go-pg/pg/v10/types"
8 ) 9 )
9 10
@@ -60,7 +61,7 @@ func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error { @@ -60,7 +61,7 @@ func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error {
60 dstValues, ok := m.dstValues[string(buf)] 61 dstValues, ok := m.dstValues[string(buf)]
61 if !ok { 62 if !ok {
62 return fmt.Errorf( 63 return fmt.Errorf(
63 - "pg: relation=%q has no base %s with id=%q (check join conditions)", 64 + "pg: relation=%q does not have base %s with id=%q (check join conditions)",
64 m.rel.Field.GoName, m.baseTable, buf) 65 m.rel.Field.GoName, m.baseTable, buf)
65 } 66 }
66 67
@@ -76,31 +77,35 @@ func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error { @@ -76,31 +77,35 @@ func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error {
76 } 77 }
77 78
78 func (m *m2mModel) modelIDMap(b []byte) ([]byte, error) { 79 func (m *m2mModel) modelIDMap(b []byte) ([]byte, error) {
79 - for i, col := range m.rel.BaseFKs { 80 + for i, col := range m.rel.M2MBaseFKs {
80 if i > 0 { 81 if i > 0 {
81 b = append(b, ',') 82 b = append(b, ',')
82 } 83 }
83 if s, ok := m.columns[col]; ok { 84 if s, ok := m.columns[col]; ok {
84 b = append(b, s...) 85 b = append(b, s...)
85 } else { 86 } else {
86 - return nil, fmt.Errorf("pg: %s has no column=%q", 87 + return nil, fmt.Errorf("pg: %s does not have column=%q",
87 m.sliceTableModel, col) 88 m.sliceTableModel, col)
88 } 89 }
89 } 90 }
90 return b, nil 91 return b, nil
91 } 92 }
92 93
93 -func (m *m2mModel) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error {  
94 - ok, err := m.sliceTableModel.scanColumn(colIdx, colName, rd, n)  
95 - if ok {  
96 - return err 94 +func (m *m2mModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error {
  95 + if n > 0 {
  96 + b, err := rd.ReadFullTemp()
  97 + if err != nil {
  98 + return err
  99 + }
  100 +
  101 + m.columns[col.Name] = string(b)
  102 + rd = pool.NewBytesReader(b)
  103 + } else {
  104 + m.columns[col.Name] = ""
97 } 105 }
98 106
99 - tmp, err := rd.ReadFullTemp()  
100 - if err != nil { 107 + if ok, err := m.sliceTableModel.scanColumn(col, rd, n); ok {
101 return err 108 return err
102 } 109 }
103 -  
104 - m.columns[colName] = string(tmp)  
105 return nil 110 return nil
106 } 111 }