正在显示
50 个修改的文件
包含
1484 行增加
和
1662 行删除
| @@ -12,7 +12,7 @@ require ( | @@ -12,7 +12,7 @@ require ( | ||
| 12 | github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072 // indirect | 12 | github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072 // indirect |
| 13 | github.com/fatih/structs v1.1.0 // indirect | 13 | github.com/fatih/structs v1.1.0 // indirect |
| 14 | github.com/gavv/httpexpect v2.0.0+incompatible | 14 | github.com/gavv/httpexpect v2.0.0+incompatible |
| 15 | - github.com/go-pg/pg/v10 v10.0.0-beta.2 | 15 | + github.com/go-pg/pg/v10 v10.7.3 |
| 16 | github.com/google/go-querystring v1.0.0 // indirect | 16 | github.com/google/go-querystring v1.0.0 // indirect |
| 17 | github.com/gorilla/websocket v1.4.2 // indirect | 17 | github.com/gorilla/websocket v1.4.2 // indirect |
| 18 | github.com/imkira/go-interpol v1.1.0 // indirect | 18 | github.com/imkira/go-interpol v1.1.0 // indirect |
| @@ -20,8 +20,9 @@ require ( | @@ -20,8 +20,9 @@ require ( | ||
| 20 | github.com/linmadan/egglib-go v0.0.0-20191217144343-ca4539f95bf9 | 20 | github.com/linmadan/egglib-go v0.0.0-20191217144343-ca4539f95bf9 |
| 21 | github.com/mattn/go-colorable v0.1.6 // indirect | 21 | github.com/mattn/go-colorable v0.1.6 // indirect |
| 22 | github.com/moul/http2curl v1.0.0 // indirect | 22 | github.com/moul/http2curl v1.0.0 // indirect |
| 23 | - github.com/onsi/ginkgo v1.13.0 | ||
| 24 | - github.com/onsi/gomega v1.10.1 | 23 | + github.com/onsi/ginkgo v1.14.2 |
| 24 | + github.com/onsi/gomega v1.10.3 | ||
| 25 | + github.com/sclevine/agouti v3.0.0+incompatible // indirect | ||
| 25 | github.com/sergi/go-diff v1.1.0 // indirect | 26 | github.com/sergi/go-diff v1.1.0 // indirect |
| 26 | github.com/shopspring/decimal v1.2.0 | 27 | github.com/shopspring/decimal v1.2.0 |
| 27 | github.com/smartystreets/goconvey v1.6.4 // indirect | 28 | github.com/smartystreets/goconvey v1.6.4 // indirect |
| @@ -47,7 +47,7 @@ func (repository *PartnerInfoRepository) Save(dm *domain.PartnerInfo) error { | @@ -47,7 +47,7 @@ func (repository *PartnerInfoRepository) Save(dm *domain.PartnerInfo) error { | ||
| 47 | Remark: dm.Remark, | 47 | Remark: dm.Remark, |
| 48 | } | 48 | } |
| 49 | if m.Id == 0 { | 49 | if m.Id == 0 { |
| 50 | - err = tx.Insert(m) | 50 | + _, err = tx.Model(m).Insert() |
| 51 | dm.Partner.Id = m.Id | 51 | dm.Partner.Id = m.Id |
| 52 | if err != nil { | 52 | if err != nil { |
| 53 | return err | 53 | return err |
tags
已删除
100644 → 0
此 diff 太大无法显示。
| 1 | -# Compiled Object files, Static and Dynamic libs (Shared Objects) | ||
| 2 | -*.o | ||
| 3 | -*.a | ||
| 4 | -*.so | ||
| 5 | - | ||
| 6 | -# Folders | ||
| 7 | -_obj | ||
| 8 | -_test | ||
| 9 | - | ||
| 10 | -# Architecture specific extensions/prefixes | ||
| 11 | -*.[568vq] | ||
| 12 | -[568vq].out | ||
| 13 | - | ||
| 14 | -*.cgo1.go | ||
| 15 | -*.cgo2.c | ||
| 16 | -_cgo_defun.c | ||
| 17 | -_cgo_gotypes.go | ||
| 18 | -_cgo_export.* | ||
| 19 | - | ||
| 20 | -_testmain.go | ||
| 21 | - | ||
| 22 | -*.exe | ||
| 23 | -*.test | ||
| 24 | -*.prof |
vendor/github.com/codemodus/kace/LICENSE
已删除
100644 → 0
| 1 | -The MIT License (MIT) | ||
| 2 | - | ||
| 3 | -Copyright (c) 2015 codemodus | ||
| 4 | - | ||
| 5 | -Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| 6 | -of this software and associated documentation files (the "Software"), to deal | ||
| 7 | -in the Software without restriction, including without limitation the rights | ||
| 8 | -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| 9 | -copies of the Software, and to permit persons to whom the Software is | ||
| 10 | -furnished to do so, subject to the following conditions: | ||
| 11 | - | ||
| 12 | -The above copyright notice and this permission notice shall be included in all | ||
| 13 | -copies or substantial portions of the Software. | ||
| 14 | - | ||
| 15 | -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| 16 | -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| 17 | -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| 18 | -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| 19 | -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| 20 | -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| 21 | -SOFTWARE. | ||
| 22 | - |
| 1 | -# kace | ||
| 2 | - | ||
| 3 | - go get "github.com/codemodus/kace" | ||
| 4 | - | ||
| 5 | -Package kace provides common case conversion functions which take into | ||
| 6 | -consideration common initialisms. | ||
| 7 | - | ||
| 8 | -## Usage | ||
| 9 | - | ||
| 10 | -```go | ||
| 11 | -func Camel(s string) string | ||
| 12 | -func Kebab(s string) string | ||
| 13 | -func KebabUpper(s string) string | ||
| 14 | -func Pascal(s string) string | ||
| 15 | -func Snake(s string) string | ||
| 16 | -func SnakeUpper(s string) string | ||
| 17 | -type Kace | ||
| 18 | - func New(initialisms map[string]bool) (*Kace, error) | ||
| 19 | - func (k *Kace) Camel(s string) string | ||
| 20 | - func (k *Kace) Kebab(s string) string | ||
| 21 | - func (k *Kace) KebabUpper(s string) string | ||
| 22 | - func (k *Kace) Pascal(s string) string | ||
| 23 | - func (k *Kace) Snake(s string) string | ||
| 24 | - func (k *Kace) SnakeUpper(s string) string | ||
| 25 | -``` | ||
| 26 | - | ||
| 27 | -### Setup | ||
| 28 | - | ||
| 29 | -```go | ||
| 30 | -import ( | ||
| 31 | - "fmt" | ||
| 32 | - | ||
| 33 | - "github.com/codemodus/kace" | ||
| 34 | -) | ||
| 35 | - | ||
| 36 | -func main() { | ||
| 37 | - s := "this is a test sql." | ||
| 38 | - | ||
| 39 | - fmt.Println(kace.Camel(s)) | ||
| 40 | - fmt.Println(kace.Pascal(s)) | ||
| 41 | - | ||
| 42 | - fmt.Println(kace.Snake(s)) | ||
| 43 | - fmt.Println(kace.SnakeUpper(s)) | ||
| 44 | - | ||
| 45 | - fmt.Println(kace.Kebab(s)) | ||
| 46 | - fmt.Println(kace.KebabUpper(s)) | ||
| 47 | - | ||
| 48 | - customInitialisms := map[string]bool{ | ||
| 49 | - "THIS": true, | ||
| 50 | - } | ||
| 51 | - k, err := kace.New(customInitialisms) | ||
| 52 | - if err != nil { | ||
| 53 | - // handle error | ||
| 54 | - } | ||
| 55 | - | ||
| 56 | - fmt.Println(k.Camel(s)) | ||
| 57 | - fmt.Println(k.Pascal(s)) | ||
| 58 | - | ||
| 59 | - fmt.Println(k.Snake(s)) | ||
| 60 | - fmt.Println(k.SnakeUpper(s)) | ||
| 61 | - | ||
| 62 | - fmt.Println(k.Kebab(s)) | ||
| 63 | - fmt.Println(k.KebabUpper(s)) | ||
| 64 | - | ||
| 65 | - // Output: | ||
| 66 | - // thisIsATestSQL | ||
| 67 | - // ThisIsATestSQL | ||
| 68 | - // this_is_a_test_sql | ||
| 69 | - // THIS_IS_A_TEST_SQL | ||
| 70 | - // this-is-a-test-sql | ||
| 71 | - // THIS-IS-A-TEST-SQL | ||
| 72 | - // thisIsATestSql | ||
| 73 | - // THISIsATestSql | ||
| 74 | - // this_is_a_test_sql | ||
| 75 | - // THIS_IS_A_TEST_SQL | ||
| 76 | - // this-is-a-test-sql | ||
| 77 | - // THIS-IS-A-TEST-SQL | ||
| 78 | -} | ||
| 79 | -``` | ||
| 80 | - | ||
| 81 | -## More Info | ||
| 82 | - | ||
| 83 | -### TODO | ||
| 84 | - | ||
| 85 | -#### Test Trie | ||
| 86 | - | ||
| 87 | - Test the current trie. | ||
| 88 | - | ||
| 89 | -## Documentation | ||
| 90 | - | ||
| 91 | -View the [GoDoc](http://godoc.org/github.com/codemodus/kace) | ||
| 92 | - | ||
| 93 | -## Benchmarks | ||
| 94 | - | ||
| 95 | - benchmark iter time/iter bytes alloc allocs | ||
| 96 | - --------- ---- --------- ----------- ------ | ||
| 97 | - BenchmarkCamel4 2000000 947.00 ns/op 112 B/op 3 allocs/op | ||
| 98 | - BenchmarkSnake4 2000000 696.00 ns/op 128 B/op 2 allocs/op | ||
| 99 | - BenchmarkSnakeUpper4 2000000 679.00 ns/op 128 B/op 2 allocs/op | ||
| 100 | - BenchmarkKebab4 2000000 691.00 ns/op 128 B/op 2 allocs/op | ||
| 101 | - BenchmarkKebabUpper4 2000000 677.00 ns/op 128 B/op 2 allocs/op |
vendor/github.com/codemodus/kace/go.mod
已删除
100644 → 0
| 1 | -module github.com/codemodus/kace |
vendor/github.com/codemodus/kace/kace.go
已删除
100644 → 0
| 1 | -// Package kace provides common case conversion functions which take into | ||
| 2 | -// consideration common initialisms. | ||
| 3 | -package kace | ||
| 4 | - | ||
| 5 | -import ( | ||
| 6 | - "fmt" | ||
| 7 | - "strings" | ||
| 8 | - "unicode" | ||
| 9 | - | ||
| 10 | - "github.com/codemodus/kace/ktrie" | ||
| 11 | -) | ||
| 12 | - | ||
| 13 | -const ( | ||
| 14 | - kebabDelim = '-' | ||
| 15 | - snakeDelim = '_' | ||
| 16 | - none = rune(-1) | ||
| 17 | -) | ||
| 18 | - | ||
| 19 | -var ( | ||
| 20 | - ciTrie *ktrie.KTrie | ||
| 21 | -) | ||
| 22 | - | ||
| 23 | -func init() { | ||
| 24 | - var err error | ||
| 25 | - if ciTrie, err = ktrie.NewKTrie(ciMap); err != nil { | ||
| 26 | - panic(err) | ||
| 27 | - } | ||
| 28 | -} | ||
| 29 | - | ||
| 30 | -// Camel returns a camelCased string. | ||
| 31 | -func Camel(s string) string { | ||
| 32 | - return camelCase(ciTrie, s, false) | ||
| 33 | -} | ||
| 34 | - | ||
| 35 | -// Pascal returns a PascalCased string. | ||
| 36 | -func Pascal(s string) string { | ||
| 37 | - return camelCase(ciTrie, s, true) | ||
| 38 | -} | ||
| 39 | - | ||
| 40 | -// Kebab returns a kebab-cased string with all lowercase letters. | ||
| 41 | -func Kebab(s string) string { | ||
| 42 | - return delimitedCase(s, kebabDelim, false) | ||
| 43 | -} | ||
| 44 | - | ||
| 45 | -// KebabUpper returns a KEBAB-CASED string with all upper case letters. | ||
| 46 | -func KebabUpper(s string) string { | ||
| 47 | - return delimitedCase(s, kebabDelim, true) | ||
| 48 | -} | ||
| 49 | - | ||
| 50 | -// Snake returns a snake_cased string with all lowercase letters. | ||
| 51 | -func Snake(s string) string { | ||
| 52 | - return delimitedCase(s, snakeDelim, false) | ||
| 53 | -} | ||
| 54 | - | ||
| 55 | -// SnakeUpper returns a SNAKE_CASED string with all upper case letters. | ||
| 56 | -func SnakeUpper(s string) string { | ||
| 57 | - return delimitedCase(s, snakeDelim, true) | ||
| 58 | -} | ||
| 59 | - | ||
| 60 | -// Kace provides common case conversion methods which take into | ||
| 61 | -// consideration common initialisms set by the user. | ||
| 62 | -type Kace struct { | ||
| 63 | - t *ktrie.KTrie | ||
| 64 | -} | ||
| 65 | - | ||
| 66 | -// New returns a pointer to an instance of kace loaded with a common | ||
| 67 | -// initialsms trie based on the provided map. Before conversion to a | ||
| 68 | -// trie, the provided map keys are all upper cased. | ||
| 69 | -func New(initialisms map[string]bool) (*Kace, error) { | ||
| 70 | - ci := initialisms | ||
| 71 | - if ci == nil { | ||
| 72 | - ci = map[string]bool{} | ||
| 73 | - } | ||
| 74 | - | ||
| 75 | - ci = sanitizeCI(ci) | ||
| 76 | - | ||
| 77 | - t, err := ktrie.NewKTrie(ci) | ||
| 78 | - if err != nil { | ||
| 79 | - return nil, fmt.Errorf("kace: cannot create trie: %s", err) | ||
| 80 | - } | ||
| 81 | - | ||
| 82 | - k := &Kace{ | ||
| 83 | - t: t, | ||
| 84 | - } | ||
| 85 | - | ||
| 86 | - return k, nil | ||
| 87 | -} | ||
| 88 | - | ||
| 89 | -// Camel returns a camelCased string. | ||
| 90 | -func (k *Kace) Camel(s string) string { | ||
| 91 | - return camelCase(k.t, s, false) | ||
| 92 | -} | ||
| 93 | - | ||
| 94 | -// Pascal returns a PascalCased string. | ||
| 95 | -func (k *Kace) Pascal(s string) string { | ||
| 96 | - return camelCase(k.t, s, true) | ||
| 97 | -} | ||
| 98 | - | ||
| 99 | -// Snake returns a snake_cased string with all lowercase letters. | ||
| 100 | -func (k *Kace) Snake(s string) string { | ||
| 101 | - return delimitedCase(s, snakeDelim, false) | ||
| 102 | -} | ||
| 103 | - | ||
| 104 | -// SnakeUpper returns a SNAKE_CASED string with all upper case letters. | ||
| 105 | -func (k *Kace) SnakeUpper(s string) string { | ||
| 106 | - return delimitedCase(s, snakeDelim, true) | ||
| 107 | -} | ||
| 108 | - | ||
| 109 | -// Kebab returns a kebab-cased string with all lowercase letters. | ||
| 110 | -func (k *Kace) Kebab(s string) string { | ||
| 111 | - return delimitedCase(s, kebabDelim, false) | ||
| 112 | -} | ||
| 113 | - | ||
| 114 | -// KebabUpper returns a KEBAB-CASED string with all upper case letters. | ||
| 115 | -func (k *Kace) KebabUpper(s string) string { | ||
| 116 | - return delimitedCase(s, kebabDelim, true) | ||
| 117 | -} | ||
| 118 | - | ||
| 119 | -func camelCase(t *ktrie.KTrie, s string, ucFirst bool) string { | ||
| 120 | - rs := []rune(s) | ||
| 121 | - offset := 0 | ||
| 122 | - prev := none | ||
| 123 | - | ||
| 124 | - for i := 0; i < len(rs); i++ { | ||
| 125 | - r := rs[i] | ||
| 126 | - | ||
| 127 | - switch { | ||
| 128 | - case unicode.IsLetter(r): | ||
| 129 | - ucCurr := isToBeUpper(r, prev, ucFirst) | ||
| 130 | - | ||
| 131 | - if ucCurr || isSegmentStart(r, prev) { | ||
| 132 | - prv, skip := updateRunes(rs, i, offset, t, ucCurr) | ||
| 133 | - if skip > 0 { | ||
| 134 | - i += skip - 1 | ||
| 135 | - prev = prv | ||
| 136 | - continue | ||
| 137 | - } | ||
| 138 | - } | ||
| 139 | - | ||
| 140 | - prev = updateRune(rs, i, offset, ucCurr) | ||
| 141 | - continue | ||
| 142 | - | ||
| 143 | - case unicode.IsNumber(r): | ||
| 144 | - prev = updateRune(rs, i, offset, false) | ||
| 145 | - continue | ||
| 146 | - | ||
| 147 | - default: | ||
| 148 | - prev = r | ||
| 149 | - offset-- | ||
| 150 | - } | ||
| 151 | - } | ||
| 152 | - | ||
| 153 | - return string(rs[:len(rs)+offset]) | ||
| 154 | -} | ||
| 155 | - | ||
| 156 | -func isToBeUpper(curr, prev rune, ucFirst bool) bool { | ||
| 157 | - if prev == none { | ||
| 158 | - return ucFirst | ||
| 159 | - } | ||
| 160 | - | ||
| 161 | - return isSegmentStart(curr, prev) | ||
| 162 | -} | ||
| 163 | - | ||
| 164 | -func isSegmentStart(curr, prev rune) bool { | ||
| 165 | - if !unicode.IsLetter(prev) || unicode.IsUpper(curr) && unicode.IsLower(prev) { | ||
| 166 | - return true | ||
| 167 | - } | ||
| 168 | - | ||
| 169 | - return false | ||
| 170 | -} | ||
| 171 | - | ||
| 172 | -func updateRune(rs []rune, i, offset int, upper bool) rune { | ||
| 173 | - r := rs[i] | ||
| 174 | - | ||
| 175 | - dest := i + offset | ||
| 176 | - if dest < 0 || i > len(rs)-1 { | ||
| 177 | - panic("this function has been used or designed incorrectly") | ||
| 178 | - } | ||
| 179 | - | ||
| 180 | - fn := unicode.ToLower | ||
| 181 | - if upper { | ||
| 182 | - fn = unicode.ToUpper | ||
| 183 | - } | ||
| 184 | - | ||
| 185 | - rs[dest] = fn(r) | ||
| 186 | - | ||
| 187 | - return r | ||
| 188 | -} | ||
| 189 | - | ||
| 190 | -func updateRunes(rs []rune, i, offset int, t *ktrie.KTrie, upper bool) (rune, int) { | ||
| 191 | - r := rs[i] | ||
| 192 | - ns := nextSegment(rs, i) | ||
| 193 | - ct := len(ns) | ||
| 194 | - | ||
| 195 | - if ct < t.MinDepth() || ct > t.MaxDepth() || !t.FindAsUpper(ns) { | ||
| 196 | - return r, 0 | ||
| 197 | - } | ||
| 198 | - | ||
| 199 | - for j := i; j < i+ct; j++ { | ||
| 200 | - r = updateRune(rs, j, offset, upper) | ||
| 201 | - } | ||
| 202 | - | ||
| 203 | - return r, ct | ||
| 204 | -} | ||
| 205 | - | ||
| 206 | -func nextSegment(rs []rune, i int) []rune { | ||
| 207 | - for j := i; j < len(rs); j++ { | ||
| 208 | - if !unicode.IsLetter(rs[j]) && !unicode.IsNumber(rs[j]) { | ||
| 209 | - return rs[i:j] | ||
| 210 | - } | ||
| 211 | - | ||
| 212 | - if j == len(rs)-1 { | ||
| 213 | - return rs[i : j+1] | ||
| 214 | - } | ||
| 215 | - } | ||
| 216 | - | ||
| 217 | - return nil | ||
| 218 | -} | ||
| 219 | - | ||
| 220 | -func delimitedCase(s string, delim rune, upper bool) string { | ||
| 221 | - buf := make([]rune, 0, len(s)*2) | ||
| 222 | - | ||
| 223 | - for i := len(s); i > 0; i-- { | ||
| 224 | - switch { | ||
| 225 | - case unicode.IsLetter(rune(s[i-1])): | ||
| 226 | - if i < len(s) && unicode.IsUpper(rune(s[i])) { | ||
| 227 | - if i > 1 && unicode.IsLower(rune(s[i-1])) || i < len(s)-2 && unicode.IsLower(rune(s[i+1])) { | ||
| 228 | - buf = append(buf, delim) | ||
| 229 | - } | ||
| 230 | - } | ||
| 231 | - | ||
| 232 | - buf = appendCased(buf, upper, rune(s[i-1])) | ||
| 233 | - | ||
| 234 | - case unicode.IsNumber(rune(s[i-1])): | ||
| 235 | - if i == len(s) || i == 1 || unicode.IsNumber(rune(s[i])) { | ||
| 236 | - buf = append(buf, rune(s[i-1])) | ||
| 237 | - continue | ||
| 238 | - } | ||
| 239 | - | ||
| 240 | - buf = append(buf, delim, rune(s[i-1])) | ||
| 241 | - | ||
| 242 | - default: | ||
| 243 | - if i == len(s) { | ||
| 244 | - continue | ||
| 245 | - } | ||
| 246 | - | ||
| 247 | - buf = append(buf, delim) | ||
| 248 | - } | ||
| 249 | - } | ||
| 250 | - | ||
| 251 | - reverse(buf) | ||
| 252 | - | ||
| 253 | - return string(buf) | ||
| 254 | -} | ||
| 255 | - | ||
| 256 | -func appendCased(rs []rune, upper bool, r rune) []rune { | ||
| 257 | - if upper { | ||
| 258 | - rs = append(rs, unicode.ToUpper(r)) | ||
| 259 | - return rs | ||
| 260 | - } | ||
| 261 | - | ||
| 262 | - rs = append(rs, unicode.ToLower(r)) | ||
| 263 | - | ||
| 264 | - return rs | ||
| 265 | -} | ||
| 266 | - | ||
| 267 | -func reverse(s []rune) { | ||
| 268 | - for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { | ||
| 269 | - s[i], s[j] = s[j], s[i] | ||
| 270 | - } | ||
| 271 | -} | ||
| 272 | - | ||
| 273 | -var ( | ||
| 274 | - // github.com/golang/lint/blob/master/lint.go | ||
| 275 | - ciMap = map[string]bool{ | ||
| 276 | - "ACL": true, | ||
| 277 | - "API": true, | ||
| 278 | - "ASCII": true, | ||
| 279 | - "CPU": true, | ||
| 280 | - "CSS": true, | ||
| 281 | - "DNS": true, | ||
| 282 | - "EOF": true, | ||
| 283 | - "GUID": true, | ||
| 284 | - "HTML": true, | ||
| 285 | - "HTTP": true, | ||
| 286 | - "HTTPS": true, | ||
| 287 | - "ID": true, | ||
| 288 | - "IP": true, | ||
| 289 | - "JSON": true, | ||
| 290 | - "LHS": true, | ||
| 291 | - "QPS": true, | ||
| 292 | - "RAM": true, | ||
| 293 | - "RHS": true, | ||
| 294 | - "RPC": true, | ||
| 295 | - "SLA": true, | ||
| 296 | - "SMTP": true, | ||
| 297 | - "SQL": true, | ||
| 298 | - "SSH": true, | ||
| 299 | - "TCP": true, | ||
| 300 | - "TLS": true, | ||
| 301 | - "TTL": true, | ||
| 302 | - "UDP": true, | ||
| 303 | - "UI": true, | ||
| 304 | - "UID": true, | ||
| 305 | - "UUID": true, | ||
| 306 | - "URI": true, | ||
| 307 | - "URL": true, | ||
| 308 | - "UTF8": true, | ||
| 309 | - "VM": true, | ||
| 310 | - "XML": true, | ||
| 311 | - "XMPP": true, | ||
| 312 | - "XSRF": true, | ||
| 313 | - "XSS": true, | ||
| 314 | - } | ||
| 315 | -) | ||
| 316 | - | ||
| 317 | -func sanitizeCI(m map[string]bool) map[string]bool { | ||
| 318 | - r := map[string]bool{} | ||
| 319 | - | ||
| 320 | - for k := range m { | ||
| 321 | - fn := func(r rune) rune { | ||
| 322 | - if !unicode.IsLetter(r) && !unicode.IsNumber(r) { | ||
| 323 | - return -1 | ||
| 324 | - } | ||
| 325 | - return r | ||
| 326 | - } | ||
| 327 | - | ||
| 328 | - k = strings.Map(fn, k) | ||
| 329 | - k = strings.ToUpper(k) | ||
| 330 | - | ||
| 331 | - if k == "" { | ||
| 332 | - continue | ||
| 333 | - } | ||
| 334 | - | ||
| 335 | - r[k] = true | ||
| 336 | - } | ||
| 337 | - | ||
| 338 | - return r | ||
| 339 | -} |
| 1 | -package ktrie | ||
| 2 | - | ||
| 3 | -import "unicode" | ||
| 4 | - | ||
| 5 | -// KNode ... | ||
| 6 | -type KNode struct { | ||
| 7 | - val rune | ||
| 8 | - end bool | ||
| 9 | - links []*KNode | ||
| 10 | -} | ||
| 11 | - | ||
| 12 | -// NewKNode ... | ||
| 13 | -func NewKNode(val rune) *KNode { | ||
| 14 | - return &KNode{ | ||
| 15 | - val: val, | ||
| 16 | - links: make([]*KNode, 0), | ||
| 17 | - } | ||
| 18 | -} | ||
| 19 | - | ||
| 20 | -// Add ... | ||
| 21 | -func (n *KNode) Add(rs []rune) { | ||
| 22 | - cur := n | ||
| 23 | - | ||
| 24 | - for k, v := range rs { | ||
| 25 | - link := cur.linkByVal(v) | ||
| 26 | - | ||
| 27 | - if link == nil { | ||
| 28 | - link = NewKNode(v) | ||
| 29 | - cur.links = append(cur.links, link) | ||
| 30 | - } | ||
| 31 | - | ||
| 32 | - if k == len(rs)-1 { | ||
| 33 | - link.end = true | ||
| 34 | - } | ||
| 35 | - | ||
| 36 | - cur = link | ||
| 37 | - } | ||
| 38 | -} | ||
| 39 | - | ||
| 40 | -// Find ... | ||
| 41 | -func (n *KNode) Find(rs []rune) bool { | ||
| 42 | - cur := n | ||
| 43 | - | ||
| 44 | - for _, v := range rs { | ||
| 45 | - cur = cur.linkByVal(v) | ||
| 46 | - | ||
| 47 | - if cur == nil { | ||
| 48 | - return false | ||
| 49 | - } | ||
| 50 | - } | ||
| 51 | - | ||
| 52 | - return cur.end | ||
| 53 | -} | ||
| 54 | - | ||
| 55 | -// FindAsUpper ... | ||
| 56 | -func (n *KNode) FindAsUpper(rs []rune) bool { | ||
| 57 | - cur := n | ||
| 58 | - | ||
| 59 | - for _, v := range rs { | ||
| 60 | - cur = cur.linkByVal(unicode.ToUpper(v)) | ||
| 61 | - | ||
| 62 | - if cur == nil { | ||
| 63 | - return false | ||
| 64 | - } | ||
| 65 | - } | ||
| 66 | - | ||
| 67 | - return cur.end | ||
| 68 | -} | ||
| 69 | - | ||
| 70 | -func (n *KNode) linkByVal(val rune) *KNode { | ||
| 71 | - for _, v := range n.links { | ||
| 72 | - if v.val == val { | ||
| 73 | - return v | ||
| 74 | - } | ||
| 75 | - } | ||
| 76 | - | ||
| 77 | - return nil | ||
| 78 | -} | ||
| 79 | - | ||
| 80 | -// KTrie ... | ||
| 81 | -type KTrie struct { | ||
| 82 | - *KNode | ||
| 83 | - | ||
| 84 | - maxDepth int | ||
| 85 | - minDepth int | ||
| 86 | -} | ||
| 87 | - | ||
| 88 | -// NewKTrie ... | ||
| 89 | -func NewKTrie(data map[string]bool) (*KTrie, error) { | ||
| 90 | - n := NewKNode(0) | ||
| 91 | - | ||
| 92 | - maxDepth := 0 | ||
| 93 | - minDepth := 9001 | ||
| 94 | - | ||
| 95 | - for k := range data { | ||
| 96 | - rs := []rune(k) | ||
| 97 | - l := len(rs) | ||
| 98 | - | ||
| 99 | - n.Add(rs) | ||
| 100 | - | ||
| 101 | - if l > maxDepth { | ||
| 102 | - maxDepth = l | ||
| 103 | - } | ||
| 104 | - if l < minDepth { | ||
| 105 | - minDepth = l | ||
| 106 | - } | ||
| 107 | - } | ||
| 108 | - | ||
| 109 | - t := &KTrie{ | ||
| 110 | - maxDepth: maxDepth, | ||
| 111 | - minDepth: minDepth, | ||
| 112 | - KNode: n, | ||
| 113 | - } | ||
| 114 | - | ||
| 115 | - return t, nil | ||
| 116 | -} | ||
| 117 | - | ||
| 118 | -// MaxDepth ... | ||
| 119 | -func (t *KTrie) MaxDepth() int { | ||
| 120 | - return t.maxDepth | ||
| 121 | -} | ||
| 122 | - | ||
| 123 | -// MinDepth ... | ||
| 124 | -func (t *KTrie) MinDepth() int { | ||
| 125 | - return t.minDepth | ||
| 126 | -} |
| 1 | dist: xenial | 1 | dist: xenial |
| 2 | -sudo: false | ||
| 3 | language: go | 2 | language: go |
| 4 | 3 | ||
| 5 | addons: | 4 | addons: |
| 6 | - postgresql: "9.6" | 5 | + postgresql: '9.6' |
| 7 | 6 | ||
| 8 | go: | 7 | go: |
| 9 | - - 1.13.x | ||
| 10 | - 1.14.x | 8 | - 1.14.x |
| 9 | + - 1.15.x | ||
| 11 | - tip | 10 | - tip |
| 12 | 11 | ||
| 13 | matrix: | 12 | matrix: |
| 14 | allow_failures: | 13 | allow_failures: |
| 15 | - go: tip | 14 | - go: tip |
| 16 | 15 | ||
| 17 | -env: | ||
| 18 | - - GO111MODULE=on | ||
| 19 | - | ||
| 20 | go_import_path: github.com/go-pg/pg | 16 | go_import_path: github.com/go-pg/pg |
| 21 | 17 | ||
| 22 | before_install: | 18 | before_install: |
| 23 | - psql -U postgres -c "CREATE EXTENSION hstore" | 19 | - psql -U postgres -c "CREATE EXTENSION hstore" |
| 24 | - - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.24.0 | 20 | + - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- |
| 21 | + -b $(go env GOPATH)/bin v1.28.3 |
| 1 | # Changelog | 1 | # Changelog |
| 2 | 2 | ||
| 3 | -## v10 (unreleased) | 3 | +> :heart: [**Uptrace.dev** - distributed traces, logs, and errors in one place](https://uptrace.dev) |
| 4 | 4 | ||
| 5 | -- Added `pgext.OpenTemetryHook` that adds OpenTelemetry | ||
| 6 | - [instrumentation](https://pg.uptrace.dev/tracing/). | ||
| 7 | -- Added `pgext.DebugHook` that logs queries and errors. | ||
| 8 | -- Added `db.Ping` to check if database is healthy. | ||
| 9 | -- Changed `pg.QueryHook` to return temp byte slice to reduce memory usage. | ||
| 10 | -- `,msgpack` struct tag marshals data in MessagePack format using | ||
| 11 | - https://github.com/vmihailenco/msgpack | ||
| 12 | -- Deprecated types and funcs are removed. | ||
| 13 | - | ||
| 14 | -## v9 | ||
| 15 | - | ||
| 16 | -- `pg:",notnull"` is reworked. Now it means SQL `NOT NULL` constraint and | ||
| 17 | - nothing more. | ||
| 18 | -- Added `pg:",use_zero"` to prevent go-pg from converting Go zero values to SQL | ||
| 19 | - `NULL`. | ||
| 20 | -- UpdateNotNull is renamed to UpdateNotZero. As previously it omits zero Go | ||
| 21 | - values, but it does not take in account if field is nullable or not. | ||
| 22 | -- ORM supports DistinctOn. | ||
| 23 | -- Hooks accept and return context. | ||
| 24 | -- Client respects Context.Deadline when setting net.Conn deadline. | ||
| 25 | -- Client listens on Context.Done while waiting for a connection from the pool | ||
| 26 | - and returns an error when context is cancelled. | ||
| 27 | -- `Query.Column` does not accept relation name any more. Use `Query.Relation` | ||
| 28 | - instead which returns an error if relation does not exist. | ||
| 29 | -- urlvalues package is removed in favor of https://github.com/go-pg/urlstruct. | ||
| 30 | - You can also use struct based filters via `Query.WhereStruct`. | ||
| 31 | -- `NewModel` and `AddModel` methods of `HooklessModel` interface were renamed to | ||
| 32 | - `NextColumnScanner` and `AddColumnScanner` respectively. | ||
| 33 | -- `types.F` and `pg.F` are deprecated in favor of `pg.Ident`. | ||
| 34 | -- `types.Q` is deprecated in favor of `pg.Safe`. | ||
| 35 | -- `pg.Q` is deprecated in favor of `pg.SafeQuery`. | ||
| 36 | -- `TableName` field is deprecated in favor of `tableName`. | ||
| 37 | -- Always use `pg:"..."` struct field tag instead of `sql:"..."`. | ||
| 38 | -- `pg:",override"` is deprecated in favor of `pg:",inherit"`. | ||
| 39 | - | ||
| 40 | -## v8 | ||
| 41 | - | ||
| 42 | -- Added `QueryContext`, `ExecContext`, and `ModelContext` which accept | ||
| 43 | - `context.Context`. Queries are cancelled when context is cancelled. | ||
| 44 | -- Model hooks are changed to accept `context.Context` as first argument. | ||
| 45 | -- Fixed array and hstore parsers to handle multiple single quotes (#1235). | ||
| 46 | - | ||
| 47 | -## v7 | ||
| 48 | - | ||
| 49 | -- DB.OnQueryProcessed is replaced with DB.AddQueryHook. | ||
| 50 | -- Added WhereStruct. | ||
| 51 | -- orm.Pager is moved to urlvalues.Pager. Pager.FromURLValues returns an error if | ||
| 52 | - page or limit params can't be parsed. | ||
| 53 | - | ||
| 54 | -## v6.16 | ||
| 55 | - | ||
| 56 | -- Read buffer is re-worked. Default read buffer is increased to 65kb. | ||
| 57 | - | ||
| 58 | -## v6.15 | ||
| 59 | - | ||
| 60 | -- Added Options.MinIdleConns. | ||
| 61 | -- Options.MaxAge renamed to Options.MaxConnAge. | ||
| 62 | -- PoolStats.FreeConns is renamed to PoolStats.IdleConns. | ||
| 63 | -- New hook BeforeSelectQuery. | ||
| 64 | -- `,override` is renamed to `,inherit`. | ||
| 65 | -- Dialer.KeepAlive is set to 5 minutes by default. | ||
| 66 | -- Added support "scram-sha-256" authentication. | ||
| 67 | - | ||
| 68 | -## v6.14 | ||
| 69 | - | ||
| 70 | -- Fields ignored with `sql:"-"` tag are no longer considered by ORM relation | ||
| 71 | - detector. | ||
| 72 | - | ||
| 73 | -## v6.12 | ||
| 74 | - | ||
| 75 | -- `Insert`, `Update`, and `Delete` can return `pg.ErrNoRows` and | ||
| 76 | - `pg.ErrMultiRows` when `Returning` is used and model expects single row. | ||
| 77 | - | ||
| 78 | -## v6.11 | ||
| 79 | - | ||
| 80 | -- `db.Model(&strct).Update()` and `db.Model(&strct).Delete()` no longer adds | ||
| 81 | - WHERE condition based on primary key when there are no conditions. Instead you | ||
| 82 | - should use `db.Update(&strct)` or `db.Model(&strct).WherePK().Update()`. | ||
| 83 | - | ||
| 84 | -## v6.10 | ||
| 85 | - | ||
| 86 | -- `?Columns` is renamed to `?TableColumns`. `?Columns` is changed to produce | ||
| 87 | - column names without table alias. | ||
| 88 | - | ||
| 89 | -## v6.9 | ||
| 90 | - | ||
| 91 | -- `pg:"fk"` tag now accepts SQL names instead of Go names, e.g. | ||
| 92 | - `pg:"fk:ParentId"` becomes `pg:"fk:parent_id"`. Old code should continue | ||
| 93 | - working in most cases, but it is strongly advised to start using new | ||
| 94 | - convention. | ||
| 95 | -- uint and uint64 SQL type is changed from decimal to bigint according to the | ||
| 96 | - lesser of two evils principle. Use `sql:"type:decimal"` to get old behavior. | ||
| 97 | - | ||
| 98 | -## v6.8 | ||
| 99 | - | ||
| 100 | -- `CreateTable` no longer adds ON DELETE hook by default. To get old behavior | ||
| 101 | - users should add `sql:"on_delete:CASCADE"` tag on foreign key field. | ||
| 102 | - | ||
| 103 | -## v6 | ||
| 104 | - | ||
| 105 | -- `types.Result` is renamed to `orm.Result`. | ||
| 106 | -- Added `OnQueryProcessed` event that can be used to log / report queries | ||
| 107 | - timing. Query logger is removed. | ||
| 108 | -- `orm.URLValues` is renamed to `orm.URLFilters`. It no longer adds ORDER | ||
| 109 | - clause. | ||
| 110 | -- `orm.Pager` is renamed to `orm.Pagination`. | ||
| 111 | -- Support for net.IP and net.IPNet. | ||
| 112 | -- Support for context.Context. | ||
| 113 | -- Bulk/multi updates. | ||
| 114 | -- Query.WhereGroup for enclosing conditions in parentheses. | ||
| 115 | - | ||
| 116 | -## v5 | ||
| 117 | - | ||
| 118 | -- All fields are nullable by default. `,null` tag is replaced with `,notnull`. | ||
| 119 | -- `Result.Affected` renamed to `Result.RowsAffected`. | ||
| 120 | -- Added `Result.RowsReturned`. | ||
| 121 | -- `Create` renamed to `Insert`, `BeforeCreate` to `BeforeInsert`, `AfterCreate` | ||
| 122 | - to `AfterInsert`. | ||
| 123 | -- Indexed placeholders support, e.g. `db.Exec("SELECT ?0 + ?0", 1)`. | ||
| 124 | -- Named placeholders are evaluated when query is executed. | ||
| 125 | -- Added Update and Delete hooks. | ||
| 126 | -- Order reworked to quote column names. OrderExpr added to bypass Order quoting | ||
| 127 | - restrictions. | ||
| 128 | -- Group reworked to quote column names. GroupExpr added to bypass Group quoting | ||
| 129 | - restrictions. | ||
| 130 | - | ||
| 131 | -## v4 | ||
| 132 | - | ||
| 133 | -- `Options.Host` and `Options.Port` merged into `Options.Addr`. | ||
| 134 | -- Added `Options.MaxRetries`. Now queries are not retried by default. | ||
| 135 | -- `LoadInto` renamed to `Scan`, `ColumnLoader` renamed to `ColumnScanner`, | ||
| 136 | - LoadColumn renamed to ScanColumn, `NewRecord() interface{}` changed to | ||
| 137 | - `NewModel() ColumnScanner`, `AppendQuery(dst []byte) []byte` changed to | ||
| 138 | - `AppendValue(dst []byte, quote bool) ([]byte, error)`. | ||
| 139 | -- Structs, maps and slices are marshalled to JSON by default. | ||
| 140 | -- Added support for scanning slices, .e.g. scanning `[]int`. | ||
| 141 | -- Added object relational mapping. | 5 | +See https://pg.uptrace.dev/changelog/ |
| 1 | all: | 1 | all: |
| 2 | - go test ./... | ||
| 3 | - go test ./... -short -race | ||
| 4 | - go test ./... -run=NONE -bench=. -benchmem | 2 | + TZ= go test ./... |
| 3 | + TZ= go test ./... -short -race | ||
| 4 | + TZ= go test ./... -run=NONE -bench=. -benchmem | ||
| 5 | env GOOS=linux GOARCH=386 go test ./... | 5 | env GOOS=linux GOARCH=386 go test ./... |
| 6 | + go vet | ||
| 6 | golangci-lint run | 7 | golangci-lint run |
| 8 | + | ||
| 9 | +.PHONY: cleanTest | ||
| 10 | +cleanTest: | ||
| 11 | + docker rm -fv pg || true | ||
| 12 | + | ||
| 13 | +.PHONY: pre-test | ||
| 14 | +pre-test: cleanTest | ||
| 15 | + docker run -d --name pg -p 5432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust postgres:9.6 | ||
| 16 | + sleep 10 | ||
| 17 | + docker exec pg psql -U postgres -c "CREATE EXTENSION hstore" | ||
| 18 | + | ||
| 19 | +.PHONY: test | ||
| 20 | +test: pre-test | ||
| 21 | + TZ= PGSSLMODE=disable go test ./... -v |
| 1 | # PostgreSQL client and ORM for Golang | 1 | # PostgreSQL client and ORM for Golang |
| 2 | 2 | ||
| 3 | -[](https://travis-ci.org/go-pg/pg) | ||
| 4 | -[](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc) | 3 | +[](https://travis-ci.org/go-pg/pg) |
| 4 | +[](https://pkg.go.dev/github.com/go-pg/pg/v10) | ||
| 5 | +[](https://pg.uptrace.dev/) | ||
| 6 | +[](https://discord.gg/rWtp5Aj) | ||
| 5 | 7 | ||
| 6 | -- [Docs](https://pg.uptrace.dev) | 8 | +> :heart: [**Uptrace.dev** - distributed traces, logs, and errors in one place](https://uptrace.dev) |
| 9 | + | ||
| 10 | +- Join [Discord](https://discord.gg/rWtp5Aj) to ask questions. | ||
| 11 | +- [Documentation](https://pg.uptrace.dev) | ||
| 7 | - [Reference](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc) | 12 | - [Reference](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc) |
| 8 | - [Examples](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#pkg-examples) | 13 | - [Examples](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#pkg-examples) |
| 14 | +- Example projects: | ||
| 15 | + - [treemux](https://github.com/uptrace/go-treemux-realworld-example-app) | ||
| 16 | + - [gin](https://github.com/gogjango/gjango) | ||
| 17 | + - [go-kit](https://github.com/Tsovak/rest-api-demo) | ||
| 18 | + - [aah framework](https://github.com/kieusonlam/golamapi) | ||
| 19 | +- [GraphQL Tutorial on YouTube](https://www.youtube.com/playlist?list=PLzQWIQOqeUSNwXcneWYJHUREAIucJ5UZn). | ||
| 9 | 20 | ||
| 10 | ## Ecosystem | 21 | ## Ecosystem |
| 11 | 22 | ||
| 12 | - Migrations by [vmihailenco](https://github.com/go-pg/migrations) and | 23 | - Migrations by [vmihailenco](https://github.com/go-pg/migrations) and |
| 13 | [robinjoseph08](https://github.com/robinjoseph08/go-pg-migrations). | 24 | [robinjoseph08](https://github.com/robinjoseph08/go-pg-migrations). |
| 25 | +- [Genna - cli tool for generating go-pg models](https://github.com/dizzyfool/genna). | ||
| 26 | +- [urlstruct](https://github.com/go-pg/urlstruct) to decode `url.Values` into structs. | ||
| 14 | - [Sharding](https://github.com/go-pg/sharding). | 27 | - [Sharding](https://github.com/go-pg/sharding). |
| 15 | -- [Model generator from SQL tables](https://github.com/dizzyfool/genna). | ||
| 16 | -- [urlstruct](https://github.com/go-pg/urlstruct) to decode `url.Values` into | ||
| 17 | - structs. | ||
| 18 | - | ||
| 19 | -## Sponsors | ||
| 20 | - | ||
| 21 | -- [**Uptrace.dev** - distributed traces and metrics](https://uptrace.dev) | ||
| 22 | 28 | ||
| 23 | ## Features | 29 | ## Features |
| 24 | 30 | ||
| @@ -26,71 +32,200 @@ | @@ -26,71 +32,200 @@ | ||
| 26 | - sql.NullBool, sql.NullString, sql.NullInt64, sql.NullFloat64 and | 32 | - sql.NullBool, sql.NullString, sql.NullInt64, sql.NullFloat64 and |
| 27 | [pg.NullTime](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#NullTime). | 33 | [pg.NullTime](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#NullTime). |
| 28 | - [sql.Scanner](http://golang.org/pkg/database/sql/#Scanner) and | 34 | - [sql.Scanner](http://golang.org/pkg/database/sql/#Scanner) and |
| 29 | - [sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer) | ||
| 30 | - interfaces. | 35 | + [sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer) interfaces. |
| 31 | - Structs, maps and arrays are marshalled as JSON by default. | 36 | - Structs, maps and arrays are marshalled as JSON by default. |
| 32 | - PostgreSQL multidimensional Arrays using | 37 | - PostgreSQL multidimensional Arrays using |
| 33 | [array tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-PostgresArrayStructTag) | 38 | [array tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-PostgresArrayStructTag) |
| 34 | - and | ||
| 35 | - [Array wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Array). | 39 | + and [Array wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Array). |
| 36 | - Hstore using | 40 | - Hstore using |
| 37 | [hstore tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HstoreStructTag) | 41 | [hstore tag](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HstoreStructTag) |
| 38 | - and | ||
| 39 | - [Hstore wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Hstore). | 42 | + and [Hstore wrapper](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Hstore). |
| 40 | - [Composite types](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CompositeType). | 43 | - [Composite types](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CompositeType). |
| 41 | -- All struct fields are nullable by default and zero values (empty string, 0, | ||
| 42 | - zero time, empty map or slice, nil ptr) are marshalled as SQL `NULL`. | ||
| 43 | - `pg:",notnull"` is used to add SQL `NOT NULL` constraint and `pg:",use_zero"` | ||
| 44 | - to allow Go zero values. | 44 | +- All struct fields are nullable by default and zero values (empty string, 0, zero time, empty map |
| 45 | + or slice, nil ptr) are marshalled as SQL `NULL`. `pg:",notnull"` is used to add SQL `NOT NULL` | ||
| 46 | + constraint and `pg:",use_zero"` to allow Go zero values. | ||
| 45 | - [Transactions](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Begin). | 47 | - [Transactions](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Begin). |
| 46 | - [Prepared statements](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Prepare). | 48 | - [Prepared statements](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Prepare). |
| 47 | -- [Notifications](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Listener) | ||
| 48 | - using `LISTEN` and `NOTIFY`. | ||
| 49 | -- [Copying data](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CopyFrom) | ||
| 50 | - using `COPY FROM` and `COPY TO`. | ||
| 51 | -- [Timeouts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#Options) and | ||
| 52 | - canceling queries using context.Context. | 49 | +- [Notifications](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Listener) using |
| 50 | + `LISTEN` and `NOTIFY`. | ||
| 51 | +- [Copying data](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CopyFrom) using | ||
| 52 | + `COPY FROM` and `COPY TO`. | ||
| 53 | +- [Timeouts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#Options) and canceling queries using | ||
| 54 | + context.Context. | ||
| 53 | - Automatic connection pooling with | 55 | - Automatic connection pooling with |
| 54 | - [circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) | ||
| 55 | - support. | 56 | + [circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support. |
| 56 | - Queries retry on network errors. | 57 | - Queries retry on network errors. |
| 57 | - Working with models using | 58 | - Working with models using |
| 58 | - [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model) and | ||
| 59 | - [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Query). | 59 | + [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model) and |
| 60 | + [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Query). | ||
| 60 | - Scanning variables using | 61 | - Scanning variables using |
| 61 | - [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-SomeColumnsIntoVars) | 62 | + [ORM](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectSomeColumnsIntoVars) |
| 62 | and [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Scan). | 63 | and [SQL](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-Scan). |
| 63 | -- [SelectOrInsert](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-SelectOrInsert) | 64 | +- [SelectOrInsert](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertSelectOrInsert) |
| 64 | using on-conflict. | 65 | using on-conflict. |
| 65 | -- [INSERT ... ON CONFLICT DO UPDATE](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-OnConflictDoUpdate) | 66 | +- [INSERT ... ON CONFLICT DO UPDATE](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-InsertOnConflictDoUpdate) |
| 66 | using ORM. | 67 | using ORM. |
| 67 | - Bulk/batch | 68 | - Bulk/batch |
| 68 | - [inserts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Insert-BulkInsert), | ||
| 69 | - [updates](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Update-BulkUpdate), | ||
| 70 | - and | ||
| 71 | - [deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Delete-BulkDelete). | 69 | + [inserts](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkInsert), |
| 70 | + [updates](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkUpdate), and | ||
| 71 | + [deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BulkDelete). | ||
| 72 | - Common table expressions using | 72 | - Common table expressions using |
| 73 | - [WITH](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-With) | ||
| 74 | - and | ||
| 75 | - [WrapWith](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Select-WrapWith). | ||
| 76 | -- [CountEstimate](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-CountEstimate) | 73 | + [WITH](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWith) and |
| 74 | + [WrapWith](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SelectWrapWith). | ||
| 75 | +- [CountEstimate](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CountEstimate) | ||
| 77 | using `EXPLAIN` to get | 76 | using `EXPLAIN` to get |
| 78 | [estimated number of matching rows](https://wiki.postgresql.org/wiki/Count_estimate). | 77 | [estimated number of matching rows](https://wiki.postgresql.org/wiki/Count_estimate). |
| 79 | - ORM supports | 78 | - ORM supports |
| 80 | - [has one](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HasOne), | ||
| 81 | - [belongs to](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-BelongsTo), | ||
| 82 | - [has many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-HasMany), | ||
| 83 | - and | ||
| 84 | - [many to many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-ManyToMany) | 79 | + [has one](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasOne), |
| 80 | + [belongs to](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-BelongsTo), | ||
| 81 | + [has many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-HasMany), and | ||
| 82 | + [many to many](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ManyToMany) | ||
| 85 | with composite/multi-column primary keys. | 83 | with composite/multi-column primary keys. |
| 86 | -- [Soft deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-SoftDelete). | ||
| 87 | -- [Creating tables from structs](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-CreateTable). | ||
| 88 | -- [ForEach](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB-Model-ForEach) | ||
| 89 | - that calls a function for each row returned by the query without loading all | ||
| 90 | - rows into the memory. | 84 | +- [Soft deletes](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-SoftDelete). |
| 85 | +- [Creating tables from structs](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-CreateTable). | ||
| 86 | +- [ForEach](https://pkg.go.dev/github.com/go-pg/pg/v10?tab=doc#example-DB.Model-ForEach) that calls | ||
| 87 | + a function for each row returned by the query without loading all rows into the memory. | ||
| 91 | - Works with PgBouncer in transaction pooling mode. | 88 | - Works with PgBouncer in transaction pooling mode. |
| 92 | 89 | ||
| 90 | +## Installation | ||
| 91 | + | ||
| 92 | +go-pg supports 2 last Go versions and requires a Go version with | ||
| 93 | +[modules](https://github.com/golang/go/wiki/Modules) support. So make sure to initialize a Go | ||
| 94 | +module: | ||
| 95 | + | ||
| 96 | +```shell | ||
| 97 | +go mod init github.com/my/repo | ||
| 98 | +``` | ||
| 99 | + | ||
| 100 | +And then install go-pg (note _v10_ in the import; omitting it is a popular mistake): | ||
| 101 | + | ||
| 102 | +```shell | ||
| 103 | +go get github.com/go-pg/pg/v10 | ||
| 104 | +``` | ||
| 105 | + | ||
| 106 | +## Quickstart | ||
| 107 | + | ||
| 108 | +```go | ||
| 109 | +package pg_test | ||
| 110 | + | ||
| 111 | +import ( | ||
| 112 | + "fmt" | ||
| 113 | + | ||
| 114 | + "github.com/go-pg/pg/v10" | ||
| 115 | + "github.com/go-pg/pg/v10/orm" | ||
| 116 | +) | ||
| 117 | + | ||
| 118 | +type User struct { | ||
| 119 | + Id int64 | ||
| 120 | + Name string | ||
| 121 | + Emails []string | ||
| 122 | +} | ||
| 123 | + | ||
| 124 | +func (u User) String() string { | ||
| 125 | + return fmt.Sprintf("User<%d %s %v>", u.Id, u.Name, u.Emails) | ||
| 126 | +} | ||
| 127 | + | ||
| 128 | +type Story struct { | ||
| 129 | + Id int64 | ||
| 130 | + Title string | ||
| 131 | + AuthorId int64 | ||
| 132 | + Author *User `pg:"rel:has-one"` | ||
| 133 | +} | ||
| 134 | + | ||
| 135 | +func (s Story) String() string { | ||
| 136 | + return fmt.Sprintf("Story<%d %s %s>", s.Id, s.Title, s.Author) | ||
| 137 | +} | ||
| 138 | + | ||
| 139 | +func ExampleDB_Model() { | ||
| 140 | + db := pg.Connect(&pg.Options{ | ||
| 141 | + User: "postgres", | ||
| 142 | + }) | ||
| 143 | + defer db.Close() | ||
| 144 | + | ||
| 145 | + err := createSchema(db) | ||
| 146 | + if err != nil { | ||
| 147 | + panic(err) | ||
| 148 | + } | ||
| 149 | + | ||
| 150 | + user1 := &User{ | ||
| 151 | + Name: "admin", | ||
| 152 | + Emails: []string{"admin1@admin", "admin2@admin"}, | ||
| 153 | + } | ||
| 154 | + _, err = db.Model(user1).Insert() | ||
| 155 | + if err != nil { | ||
| 156 | + panic(err) | ||
| 157 | + } | ||
| 158 | + | ||
| 159 | + _, err = db.Model(&User{ | ||
| 160 | + Name: "root", | ||
| 161 | + Emails: []string{"root1@root", "root2@root"}, | ||
| 162 | + }).Insert() | ||
| 163 | + if err != nil { | ||
| 164 | + panic(err) | ||
| 165 | + } | ||
| 166 | + | ||
| 167 | + story1 := &Story{ | ||
| 168 | + Title: "Cool story", | ||
| 169 | + AuthorId: user1.Id, | ||
| 170 | + } | ||
| 171 | + _, err = db.Model(story1).Insert() | ||
| 172 | + if err != nil { | ||
| 173 | + panic(err) | ||
| 174 | + } | ||
| 175 | + | ||
| 176 | + // Select user by primary key. | ||
| 177 | + user := &User{Id: user1.Id} | ||
| 178 | + err = db.Model(user).WherePK().Select() | ||
| 179 | + if err != nil { | ||
| 180 | + panic(err) | ||
| 181 | + } | ||
| 182 | + | ||
| 183 | + // Select all users. | ||
| 184 | + var users []User | ||
| 185 | + err = db.Model(&users).Select() | ||
| 186 | + if err != nil { | ||
| 187 | + panic(err) | ||
| 188 | + } | ||
| 189 | + | ||
| 190 | + // Select story and associated author in one query. | ||
| 191 | + story := new(Story) | ||
| 192 | + err = db.Model(story). | ||
| 193 | + Relation("Author"). | ||
| 194 | + Where("story.id = ?", story1.Id). | ||
| 195 | + Select() | ||
| 196 | + if err != nil { | ||
| 197 | + panic(err) | ||
| 198 | + } | ||
| 199 | + | ||
| 200 | + fmt.Println(user) | ||
| 201 | + fmt.Println(users) | ||
| 202 | + fmt.Println(story) | ||
| 203 | + // Output: User<1 admin [admin1@admin admin2@admin]> | ||
| 204 | + // [User<1 admin [admin1@admin admin2@admin]> User<2 root [root1@root root2@root]>] | ||
| 205 | + // Story<1 Cool story User<1 admin [admin1@admin admin2@admin]>> | ||
| 206 | +} | ||
| 207 | + | ||
| 208 | +// createSchema creates database schema for User and Story models. | ||
| 209 | +func createSchema(db *pg.DB) error { | ||
| 210 | + models := []interface{}{ | ||
| 211 | + (*User)(nil), | ||
| 212 | + (*Story)(nil), | ||
| 213 | + } | ||
| 214 | + | ||
| 215 | + for _, model := range models { | ||
| 216 | + err := db.Model(model).CreateTable(&orm.CreateTableOptions{ | ||
| 217 | + Temp: true, | ||
| 218 | + }) | ||
| 219 | + if err != nil { | ||
| 220 | + return err | ||
| 221 | + } | ||
| 222 | + } | ||
| 223 | + return nil | ||
| 224 | +} | ||
| 225 | +``` | ||
| 226 | + | ||
| 93 | ## See also | 227 | ## See also |
| 94 | 228 | ||
| 229 | +- [Fast and flexible HTTP router](https://github.com/vmihailenco/treemux) | ||
| 95 | - [Golang msgpack](https://github.com/vmihailenco/msgpack) | 230 | - [Golang msgpack](https://github.com/vmihailenco/msgpack) |
| 96 | - [Golang message task queue](https://github.com/vmihailenco/taskq) | 231 | - [Golang message task queue](https://github.com/vmihailenco/taskq) |
| @@ -5,12 +5,13 @@ import ( | @@ -5,12 +5,13 @@ import ( | ||
| 5 | "io" | 5 | "io" |
| 6 | "time" | 6 | "time" |
| 7 | 7 | ||
| 8 | - "go.opentelemetry.io/otel/api/kv" | ||
| 9 | - "go.opentelemetry.io/otel/api/trace" | 8 | + "go.opentelemetry.io/otel/label" |
| 9 | + "go.opentelemetry.io/otel/trace" | ||
| 10 | 10 | ||
| 11 | "github.com/go-pg/pg/v10/internal" | 11 | "github.com/go-pg/pg/v10/internal" |
| 12 | "github.com/go-pg/pg/v10/internal/pool" | 12 | "github.com/go-pg/pg/v10/internal/pool" |
| 13 | "github.com/go-pg/pg/v10/orm" | 13 | "github.com/go-pg/pg/v10/orm" |
| 14 | + "github.com/go-pg/pg/v10/types" | ||
| 14 | ) | 15 | ) |
| 15 | 16 | ||
| 16 | type baseDB struct { | 17 | type baseDB struct { |
| @@ -83,14 +84,14 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) { | @@ -83,14 +84,14 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) { | ||
| 83 | return cn, nil | 84 | return cn, nil |
| 84 | } | 85 | } |
| 85 | 86 | ||
| 86 | - err = internal.WithSpan(ctx, "init_conn", func(ctx context.Context, span trace.Span) error { | 87 | + err = internal.WithSpan(ctx, "pg.init_conn", func(ctx context.Context, span trace.Span) error { |
| 87 | return db.initConn(ctx, cn) | 88 | return db.initConn(ctx, cn) |
| 88 | }) | 89 | }) |
| 89 | if err != nil { | 90 | if err != nil { |
| 90 | - db.pool.Remove(cn, err) | ||
| 91 | - // It is safe to reset SingleConnPool if conn can't be initialized. | ||
| 92 | - if p, ok := db.pool.(*pool.SingleConnPool); ok { | ||
| 93 | - _ = p.Reset() | 91 | + db.pool.Remove(ctx, cn, err) |
| 92 | + // It is safe to reset StickyConnPool if conn can't be initialized. | ||
| 93 | + if p, ok := db.pool.(*pool.StickyConnPool); ok { | ||
| 94 | + _ = p.Reset(ctx) | ||
| 94 | } | 95 | } |
| 95 | if err := internal.Unwrap(err); err != nil { | 96 | if err := internal.Unwrap(err); err != nil { |
| 96 | return nil, err | 97 | return nil, err |
| @@ -101,45 +102,44 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) { | @@ -101,45 +102,44 @@ func (db *baseDB) getConn(ctx context.Context) (*pool.Conn, error) { | ||
| 101 | return cn, nil | 102 | return cn, nil |
| 102 | } | 103 | } |
| 103 | 104 | ||
| 104 | -func (db *baseDB) initConn(c context.Context, cn *pool.Conn) error { | 105 | +func (db *baseDB) initConn(ctx context.Context, cn *pool.Conn) error { |
| 105 | if cn.Inited { | 106 | if cn.Inited { |
| 106 | return nil | 107 | return nil |
| 107 | } | 108 | } |
| 108 | cn.Inited = true | 109 | cn.Inited = true |
| 109 | 110 | ||
| 110 | if db.opt.TLSConfig != nil { | 111 | if db.opt.TLSConfig != nil { |
| 111 | - err := db.enableSSL(c, cn, db.opt.TLSConfig) | 112 | + err := db.enableSSL(ctx, cn, db.opt.TLSConfig) |
| 112 | if err != nil { | 113 | if err != nil { |
| 113 | return err | 114 | return err |
| 114 | } | 115 | } |
| 115 | } | 116 | } |
| 116 | 117 | ||
| 117 | - err := db.startup(c, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName) | 118 | + err := db.startup(ctx, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName) |
| 118 | if err != nil { | 119 | if err != nil { |
| 119 | return err | 120 | return err |
| 120 | } | 121 | } |
| 121 | 122 | ||
| 122 | if db.opt.OnConnect != nil { | 123 | if db.opt.OnConnect != nil { |
| 123 | - p := pool.NewSingleConnPool(nil) | ||
| 124 | - p.SetConn(cn) | ||
| 125 | - return db.opt.OnConnect(newConn(c, db.withPool(p))) | 124 | + p := pool.NewSingleConnPool(db.pool, cn) |
| 125 | + return db.opt.OnConnect(ctx, newConn(ctx, db.withPool(p))) | ||
| 126 | } | 126 | } |
| 127 | 127 | ||
| 128 | return nil | 128 | return nil |
| 129 | } | 129 | } |
| 130 | 130 | ||
| 131 | -func (db *baseDB) releaseConn(cn *pool.Conn, err error) { | 131 | +func (db *baseDB) releaseConn(ctx context.Context, cn *pool.Conn, err error) { |
| 132 | if isBadConn(err, false) { | 132 | if isBadConn(err, false) { |
| 133 | - db.pool.Remove(cn, err) | 133 | + db.pool.Remove(ctx, cn, err) |
| 134 | } else { | 134 | } else { |
| 135 | - db.pool.Put(cn) | 135 | + db.pool.Put(ctx, cn) |
| 136 | } | 136 | } |
| 137 | } | 137 | } |
| 138 | 138 | ||
| 139 | func (db *baseDB) withConn( | 139 | func (db *baseDB) withConn( |
| 140 | ctx context.Context, fn func(context.Context, *pool.Conn) error, | 140 | ctx context.Context, fn func(context.Context, *pool.Conn) error, |
| 141 | ) error { | 141 | ) error { |
| 142 | - return internal.WithSpan(ctx, "with_conn", func(ctx context.Context, span trace.Span) error { | 142 | + return internal.WithSpan(ctx, "pg.with_conn", func(ctx context.Context, span trace.Span) error { |
| 143 | cn, err := db.getConn(ctx) | 143 | cn, err := db.getConn(ctx) |
| 144 | if err != nil { | 144 | if err != nil { |
| 145 | return err | 145 | return err |
| @@ -154,7 +154,7 @@ func (db *baseDB) withConn( | @@ -154,7 +154,7 @@ func (db *baseDB) withConn( | ||
| 154 | case <-ctx.Done(): | 154 | case <-ctx.Done(): |
| 155 | err := db.cancelRequest(cn.ProcessID, cn.SecretKey) | 155 | err := db.cancelRequest(cn.ProcessID, cn.SecretKey) |
| 156 | if err != nil { | 156 | if err != nil { |
| 157 | - internal.Logger.Printf("cancelRequest failed: %s", err) | 157 | + internal.Logger.Printf(ctx, "cancelRequest failed: %s", err) |
| 158 | } | 158 | } |
| 159 | // Signal end of conn use. | 159 | // Signal end of conn use. |
| 160 | fnDone <- struct{}{} | 160 | fnDone <- struct{}{} |
| @@ -169,7 +169,7 @@ func (db *baseDB) withConn( | @@ -169,7 +169,7 @@ func (db *baseDB) withConn( | ||
| 169 | case fnDone <- struct{}{}: // signal fn finish, skip cancel goroutine | 169 | case fnDone <- struct{}{}: // signal fn finish, skip cancel goroutine |
| 170 | } | 170 | } |
| 171 | } | 171 | } |
| 172 | - db.releaseConn(cn, err) | 172 | + db.releaseConn(ctx, cn, err) |
| 173 | }() | 173 | }() |
| 174 | 174 | ||
| 175 | err = fn(ctx, cn) | 175 | err = fn(ctx, cn) |
| @@ -179,9 +179,12 @@ func (db *baseDB) withConn( | @@ -179,9 +179,12 @@ func (db *baseDB) withConn( | ||
| 179 | 179 | ||
| 180 | func (db *baseDB) shouldRetry(err error) bool { | 180 | func (db *baseDB) shouldRetry(err error) bool { |
| 181 | switch err { | 181 | switch err { |
| 182 | + case io.EOF, io.ErrUnexpectedEOF: | ||
| 183 | + return true | ||
| 182 | case nil, context.Canceled, context.DeadlineExceeded: | 184 | case nil, context.Canceled, context.DeadlineExceeded: |
| 183 | return false | 185 | return false |
| 184 | } | 186 | } |
| 187 | + | ||
| 185 | if pgerr, ok := err.(Error); ok { | 188 | if pgerr, ok := err.(Error); ok { |
| 186 | switch pgerr.Field('C') { | 189 | switch pgerr.Field('C') { |
| 187 | case "40001", // serialization_failure | 190 | case "40001", // serialization_failure |
| @@ -194,7 +197,12 @@ func (db *baseDB) shouldRetry(err error) bool { | @@ -194,7 +197,12 @@ func (db *baseDB) shouldRetry(err error) bool { | ||
| 194 | return false | 197 | return false |
| 195 | } | 198 | } |
| 196 | } | 199 | } |
| 197 | - return isNetworkError(err) | 200 | + |
| 201 | + if _, ok := err.(timeoutError); ok { | ||
| 202 | + return true | ||
| 203 | + } | ||
| 204 | + | ||
| 205 | + return false | ||
| 198 | } | 206 | } |
| 199 | 207 | ||
| 200 | // Close closes the database client, releasing any open resources. | 208 | // Close closes the database client, releasing any open resources. |
| @@ -233,9 +241,9 @@ func (db *baseDB) exec(ctx context.Context, query interface{}, params ...interfa | @@ -233,9 +241,9 @@ func (db *baseDB) exec(ctx context.Context, query interface{}, params ...interfa | ||
| 233 | for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ { | 241 | for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ { |
| 234 | attempt := attempt | 242 | attempt := attempt |
| 235 | 243 | ||
| 236 | - lastErr = internal.WithSpan(ctx, "exec", func(ctx context.Context, span trace.Span) error { | 244 | + lastErr = internal.WithSpan(ctx, "pg.exec", func(ctx context.Context, span trace.Span) error { |
| 237 | if attempt > 0 { | 245 | if attempt > 0 { |
| 238 | - span.SetAttributes(kv.Int("retry", attempt)) | 246 | + span.SetAttributes(label.Int("retry", attempt)) |
| 239 | 247 | ||
| 240 | if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil { | 248 | if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil { |
| 241 | return err | 249 | return err |
| @@ -311,9 +319,9 @@ func (db *baseDB) query(ctx context.Context, model, query interface{}, params .. | @@ -311,9 +319,9 @@ func (db *baseDB) query(ctx context.Context, model, query interface{}, params .. | ||
| 311 | for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ { | 319 | for attempt := 0; attempt <= db.opt.MaxRetries; attempt++ { |
| 312 | attempt := attempt | 320 | attempt := attempt |
| 313 | 321 | ||
| 314 | - lastErr = internal.WithSpan(ctx, "query", func(ctx context.Context, span trace.Span) error { | 322 | + lastErr = internal.WithSpan(ctx, "pg.query", func(ctx context.Context, span trace.Span) error { |
| 315 | if attempt > 0 { | 323 | if attempt > 0 { |
| 316 | - span.SetAttributes(kv.Int("retry", attempt)) | 324 | + span.SetAttributes(label.Int("retry", attempt)) |
| 317 | 325 | ||
| 318 | if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil { | 326 | if err := internal.Sleep(ctx, db.retryBackoff(attempt-1)); err != nil { |
| 319 | return err | 327 | return err |
| @@ -373,7 +381,7 @@ func (db *baseDB) CopyFrom(r io.Reader, query interface{}, params ...interface{} | @@ -373,7 +381,7 @@ func (db *baseDB) CopyFrom(r io.Reader, query interface{}, params ...interface{} | ||
| 373 | return res, err | 381 | return res, err |
| 374 | } | 382 | } |
| 375 | 383 | ||
| 376 | -// TODO: don't get/put conn in the pool | 384 | +// TODO: don't get/put conn in the pool. |
| 377 | func (db *baseDB) copyFrom( | 385 | func (db *baseDB) copyFrom( |
| 378 | ctx context.Context, cn *pool.Conn, r io.Reader, query interface{}, params ...interface{}, | 386 | ctx context.Context, cn *pool.Conn, r io.Reader, query interface{}, params ...interface{}, |
| 379 | ) (res Result, err error) { | 387 | ) (res Result, err error) { |
| @@ -396,6 +404,7 @@ func (db *baseDB) copyFrom( | @@ -396,6 +404,7 @@ func (db *baseDB) copyFrom( | ||
| 396 | return nil, err | 404 | return nil, err |
| 397 | } | 405 | } |
| 398 | 406 | ||
| 407 | + // Note that afterQuery uses the err. | ||
| 399 | defer func() { | 408 | defer func() { |
| 400 | if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil { | 409 | if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil { |
| 401 | err = afterQueryErr | 410 | err = afterQueryErr |
| @@ -434,7 +443,7 @@ func (db *baseDB) copyFrom( | @@ -434,7 +443,7 @@ func (db *baseDB) copyFrom( | ||
| 434 | return nil, err | 443 | return nil, err |
| 435 | } | 444 | } |
| 436 | 445 | ||
| 437 | - err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.BufReader) error { | 446 | + err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { |
| 438 | res, err = readReadyForQuery(rd) | 447 | res, err = readReadyForQuery(rd) |
| 439 | return err | 448 | return err |
| 440 | }) | 449 | }) |
| @@ -456,7 +465,7 @@ func (db *baseDB) CopyTo(w io.Writer, query interface{}, params ...interface{}) | @@ -456,7 +465,7 @@ func (db *baseDB) CopyTo(w io.Writer, query interface{}, params ...interface{}) | ||
| 456 | } | 465 | } |
| 457 | 466 | ||
| 458 | func (db *baseDB) copyTo( | 467 | func (db *baseDB) copyTo( |
| 459 | - c context.Context, cn *pool.Conn, w io.Writer, query interface{}, params ...interface{}, | 468 | + ctx context.Context, cn *pool.Conn, w io.Writer, query interface{}, params ...interface{}, |
| 460 | ) (res Result, err error) { | 469 | ) (res Result, err error) { |
| 461 | var evt *QueryEvent | 470 | var evt *QueryEvent |
| 462 | 471 | ||
| @@ -472,25 +481,26 @@ func (db *baseDB) copyTo( | @@ -472,25 +481,26 @@ func (db *baseDB) copyTo( | ||
| 472 | model, _ = params[len(params)-1].(orm.TableModel) | 481 | model, _ = params[len(params)-1].(orm.TableModel) |
| 473 | } | 482 | } |
| 474 | 483 | ||
| 475 | - c, evt, err = db.beforeQuery(c, db.db, model, query, params, wb.Query()) | 484 | + ctx, evt, err = db.beforeQuery(ctx, db.db, model, query, params, wb.Query()) |
| 476 | if err != nil { | 485 | if err != nil { |
| 477 | return nil, err | 486 | return nil, err |
| 478 | } | 487 | } |
| 479 | 488 | ||
| 489 | + // Note that afterQuery uses the err. | ||
| 480 | defer func() { | 490 | defer func() { |
| 481 | - if afterQueryErr := db.afterQuery(c, evt, res, err); afterQueryErr != nil { | 491 | + if afterQueryErr := db.afterQuery(ctx, evt, res, err); afterQueryErr != nil { |
| 482 | err = afterQueryErr | 492 | err = afterQueryErr |
| 483 | } | 493 | } |
| 484 | }() | 494 | }() |
| 485 | 495 | ||
| 486 | - err = cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { | 496 | + err = cn.WithWriter(ctx, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { |
| 487 | return writeQueryMsg(wb, db.fmter, query, params...) | 497 | return writeQueryMsg(wb, db.fmter, query, params...) |
| 488 | }) | 498 | }) |
| 489 | if err != nil { | 499 | if err != nil { |
| 490 | return nil, err | 500 | return nil, err |
| 491 | } | 501 | } |
| 492 | 502 | ||
| 493 | - err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { | 503 | + err = cn.WithReader(ctx, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { |
| 494 | err := readCopyOutResponse(rd) | 504 | err := readCopyOutResponse(rd) |
| 495 | if err != nil { | 505 | if err != nil { |
| 496 | return err | 506 | return err |
| @@ -522,52 +532,6 @@ func (db *baseDB) ModelContext(c context.Context, model ...interface{}) *orm.Que | @@ -522,52 +532,6 @@ func (db *baseDB) ModelContext(c context.Context, model ...interface{}) *orm.Que | ||
| 522 | return orm.NewQueryContext(c, db.db, model...) | 532 | return orm.NewQueryContext(c, db.db, model...) |
| 523 | } | 533 | } |
| 524 | 534 | ||
| 525 | -// Select selects the model by primary key. | ||
| 526 | -func (db *baseDB) Select(model interface{}) error { | ||
| 527 | - return orm.Select(db.db, model) | ||
| 528 | -} | ||
| 529 | - | ||
| 530 | -// Insert inserts the model updating primary keys if they are empty. | ||
| 531 | -func (db *baseDB) Insert(model ...interface{}) error { | ||
| 532 | - return orm.Insert(db.db, model...) | ||
| 533 | -} | ||
| 534 | - | ||
| 535 | -// Update updates the model by primary key. | ||
| 536 | -func (db *baseDB) Update(model interface{}) error { | ||
| 537 | - return orm.Update(db.db, model) | ||
| 538 | -} | ||
| 539 | - | ||
| 540 | -// Delete deletes the model by primary key. | ||
| 541 | -func (db *baseDB) Delete(model interface{}) error { | ||
| 542 | - return orm.Delete(db.db, model) | ||
| 543 | -} | ||
| 544 | - | ||
| 545 | -// Delete forces delete of the model with deleted_at column. | ||
| 546 | -func (db *baseDB) ForceDelete(model interface{}) error { | ||
| 547 | - return orm.ForceDelete(db.db, model) | ||
| 548 | -} | ||
| 549 | - | ||
| 550 | -// CreateTable creates table for the model. It recognizes following field tags: | ||
| 551 | -// - notnull - sets NOT NULL constraint. | ||
| 552 | -// - unique - sets UNIQUE constraint. | ||
| 553 | -// - default:value - sets default value. | ||
| 554 | -func (db *baseDB) CreateTable(model interface{}, opt *orm.CreateTableOptions) error { | ||
| 555 | - return orm.CreateTable(db.db, model, opt) | ||
| 556 | -} | ||
| 557 | - | ||
| 558 | -// DropTable drops table for the model. | ||
| 559 | -func (db *baseDB) DropTable(model interface{}, opt *orm.DropTableOptions) error { | ||
| 560 | - return orm.DropTable(db.db, model, opt) | ||
| 561 | -} | ||
| 562 | - | ||
| 563 | -func (db *baseDB) CreateComposite(model interface{}, opt *orm.CreateCompositeOptions) error { | ||
| 564 | - return orm.CreateComposite(db.db, model, opt) | ||
| 565 | -} | ||
| 566 | - | ||
| 567 | -func (db *baseDB) DropComposite(model interface{}, opt *orm.DropCompositeOptions) error { | ||
| 568 | - return orm.DropComposite(db.db, model, opt) | ||
| 569 | -} | ||
| 570 | - | ||
| 571 | func (db *baseDB) Formatter() orm.QueryFormatter { | 535 | func (db *baseDB) Formatter() orm.QueryFormatter { |
| 572 | return db.fmter | 536 | return db.fmter |
| 573 | } | 537 | } |
| @@ -597,7 +561,7 @@ func (db *baseDB) simpleQuery( | @@ -597,7 +561,7 @@ func (db *baseDB) simpleQuery( | ||
| 597 | } | 561 | } |
| 598 | 562 | ||
| 599 | var res *result | 563 | var res *result |
| 600 | - if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { | 564 | + if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { |
| 601 | var err error | 565 | var err error |
| 602 | res, err = readSimpleQuery(rd) | 566 | res, err = readSimpleQuery(rd) |
| 603 | return err | 567 | return err |
| @@ -616,7 +580,7 @@ func (db *baseDB) simpleQueryData( | @@ -616,7 +580,7 @@ func (db *baseDB) simpleQueryData( | ||
| 616 | } | 580 | } |
| 617 | 581 | ||
| 618 | var res *result | 582 | var res *result |
| 619 | - if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { | 583 | + if err := cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { |
| 620 | var err error | 584 | var err error |
| 621 | res, err = readSimpleQueryData(c, rd, model) | 585 | res, err = readSimpleQueryData(c, rd, model) |
| 622 | return err | 586 | return err |
| @@ -631,12 +595,12 @@ func (db *baseDB) simpleQueryData( | @@ -631,12 +595,12 @@ func (db *baseDB) simpleQueryData( | ||
| 631 | // executions. Multiple queries or executions may be run concurrently | 595 | // executions. Multiple queries or executions may be run concurrently |
| 632 | // from the returned statement. | 596 | // from the returned statement. |
| 633 | func (db *baseDB) Prepare(q string) (*Stmt, error) { | 597 | func (db *baseDB) Prepare(q string) (*Stmt, error) { |
| 634 | - return prepareStmt(db.withPool(pool.NewSingleConnPool(db.pool)), q) | 598 | + return prepareStmt(db.withPool(pool.NewStickyConnPool(db.pool)), q) |
| 635 | } | 599 | } |
| 636 | 600 | ||
| 637 | func (db *baseDB) prepare( | 601 | func (db *baseDB) prepare( |
| 638 | c context.Context, cn *pool.Conn, q string, | 602 | c context.Context, cn *pool.Conn, q string, |
| 639 | -) (string, [][]byte, error) { | 603 | +) (string, []types.ColumnInfo, error) { |
| 640 | name := cn.NextID() | 604 | name := cn.NextID() |
| 641 | err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { | 605 | err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { |
| 642 | writeParseDescribeSyncMsg(wb, name, q) | 606 | writeParseDescribeSyncMsg(wb, name, q) |
| @@ -646,8 +610,8 @@ func (db *baseDB) prepare( | @@ -646,8 +610,8 @@ func (db *baseDB) prepare( | ||
| 646 | return "", nil, err | 610 | return "", nil, err |
| 647 | } | 611 | } |
| 648 | 612 | ||
| 649 | - var columns [][]byte | ||
| 650 | - err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { | 613 | + var columns []types.ColumnInfo |
| 614 | + err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { | ||
| 651 | columns, err = readParseDescribeSync(rd) | 615 | columns, err = readParseDescribeSync(rd) |
| 652 | return err | 616 | return err |
| 653 | }) | 617 | }) |
| @@ -75,12 +75,12 @@ func (db *DB) WithParam(param string, value interface{}) *DB { | @@ -75,12 +75,12 @@ func (db *DB) WithParam(param string, value interface{}) *DB { | ||
| 75 | } | 75 | } |
| 76 | 76 | ||
| 77 | // Listen listens for notifications sent with NOTIFY command. | 77 | // Listen listens for notifications sent with NOTIFY command. |
| 78 | -func (db *DB) Listen(channels ...string) *Listener { | 78 | +func (db *DB) Listen(ctx context.Context, channels ...string) *Listener { |
| 79 | ln := &Listener{ | 79 | ln := &Listener{ |
| 80 | db: db, | 80 | db: db, |
| 81 | } | 81 | } |
| 82 | ln.init() | 82 | ln.init() |
| 83 | - _ = ln.Listen(channels...) | 83 | + _ = ln.Listen(ctx, channels...) |
| 84 | return ln | 84 | return ln |
| 85 | } | 85 | } |
| 86 | 86 | ||
| @@ -105,7 +105,7 @@ var _ orm.DB = (*Conn)(nil) | @@ -105,7 +105,7 @@ var _ orm.DB = (*Conn)(nil) | ||
| 105 | // Every Conn must be returned to the database pool after use by | 105 | // Every Conn must be returned to the database pool after use by |
| 106 | // calling Conn.Close. | 106 | // calling Conn.Close. |
| 107 | func (db *DB) Conn() *Conn { | 107 | func (db *DB) Conn() *Conn { |
| 108 | - return newConn(db.ctx, db.baseDB.withPool(pool.NewSingleConnPool(db.pool))) | 108 | + return newConn(db.ctx, db.baseDB.withPool(pool.NewStickyConnPool(db.pool))) |
| 109 | } | 109 | } |
| 110 | 110 | ||
| 111 | func newConn(ctx context.Context, baseDB *baseDB) *Conn { | 111 | func newConn(ctx context.Context, baseDB *baseDB) *Conn { |
| 1 | package pg | 1 | package pg |
| 2 | 2 | ||
| 3 | import ( | 3 | import ( |
| 4 | - "io" | ||
| 5 | "net" | 4 | "net" |
| 6 | 5 | ||
| 7 | "github.com/go-pg/pg/v10/internal" | 6 | "github.com/go-pg/pg/v10/internal" |
| @@ -22,10 +21,10 @@ var ErrMultiRows = internal.ErrMultiRows | @@ -22,10 +21,10 @@ var ErrMultiRows = internal.ErrMultiRows | ||
| 22 | type Error interface { | 21 | type Error interface { |
| 23 | error | 22 | error |
| 24 | 23 | ||
| 25 | - // Field returns a string value associated with an error code. | 24 | + // Field returns a string value associated with an error field. |
| 26 | // | 25 | // |
| 27 | // https://www.postgresql.org/docs/10/static/protocol-error-fields.html | 26 | // https://www.postgresql.org/docs/10/static/protocol-error-fields.html |
| 28 | - Field(byte) string | 27 | + Field(field byte) string |
| 29 | 28 | ||
| 30 | // IntegrityViolation reports whether an error is a part of | 29 | // IntegrityViolation reports whether an error is a part of |
| 31 | // Integrity Constraint Violation class of errors. | 30 | // Integrity Constraint Violation class of errors. |
| @@ -43,21 +42,19 @@ func isBadConn(err error, allowTimeout bool) bool { | @@ -43,21 +42,19 @@ func isBadConn(err error, allowTimeout bool) bool { | ||
| 43 | if _, ok := err.(internal.Error); ok { | 42 | if _, ok := err.(internal.Error); ok { |
| 44 | return false | 43 | return false |
| 45 | } | 44 | } |
| 46 | - if pgErr, ok := err.(Error); ok && pgErr.Field('S') != "FATAL" { | ||
| 47 | - return false | 45 | + if pgErr, ok := err.(Error); ok { |
| 46 | + return pgErr.Field('S') == "FATAL" | ||
| 48 | } | 47 | } |
| 49 | if allowTimeout { | 48 | if allowTimeout { |
| 50 | if netErr, ok := err.(net.Error); ok && netErr.Timeout() { | 49 | if netErr, ok := err.(net.Error); ok && netErr.Timeout() { |
| 51 | - return false | 50 | + return !netErr.Temporary() |
| 52 | } | 51 | } |
| 53 | } | 52 | } |
| 54 | return true | 53 | return true |
| 55 | } | 54 | } |
| 56 | 55 | ||
| 57 | -func isNetworkError(err error) bool { | ||
| 58 | - if err == io.EOF { | ||
| 59 | - return true | ||
| 60 | - } | ||
| 61 | - _, ok := err.(net.Error) | ||
| 62 | - return ok | 56 | +//------------------------------------------------------------------------------ |
| 57 | + | ||
| 58 | +type timeoutError interface { | ||
| 59 | + Timeout() bool | ||
| 63 | } | 60 | } |
| @@ -3,25 +3,24 @@ module github.com/go-pg/pg/v10 | @@ -3,25 +3,24 @@ module github.com/go-pg/pg/v10 | ||
| 3 | go 1.11 | 3 | go 1.11 |
| 4 | 4 | ||
| 5 | require ( | 5 | require ( |
| 6 | - github.com/go-pg/pg/v9 v9.1.6 // indirect | ||
| 7 | - github.com/go-pg/urlstruct v0.4.0 | ||
| 8 | - github.com/go-pg/zerochecker v0.1.1 | ||
| 9 | - github.com/golang/protobuf v1.4.2 // indirect | 6 | + github.com/go-pg/zerochecker v0.2.0 |
| 7 | + github.com/golang/protobuf v1.4.3 // indirect | ||
| 10 | github.com/jinzhu/inflection v1.0.0 | 8 | github.com/jinzhu/inflection v1.0.0 |
| 11 | - github.com/onsi/ginkgo v1.10.1 | ||
| 12 | - github.com/onsi/gomega v1.7.0 | ||
| 13 | - github.com/segmentio/encoding v0.1.13 | 9 | + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect |
| 10 | + github.com/onsi/ginkgo v1.14.2 | ||
| 11 | + github.com/onsi/gomega v1.10.3 | ||
| 12 | + github.com/stretchr/testify v1.6.1 | ||
| 14 | github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc | 13 | github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc |
| 15 | github.com/vmihailenco/bufpool v0.1.11 | 14 | github.com/vmihailenco/bufpool v0.1.11 |
| 16 | - github.com/vmihailenco/msgpack/v5 v5.0.0-beta.1 | ||
| 17 | - github.com/vmihailenco/tagparser v0.1.1 | ||
| 18 | - go.opentelemetry.io/otel v0.6.0 | ||
| 19 | - golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 // indirect | ||
| 20 | - golang.org/x/net v0.0.0-20200602114024-627f9648deb9 // indirect | ||
| 21 | - golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 // indirect | ||
| 22 | - google.golang.org/appengine v1.6.6 // indirect | ||
| 23 | - google.golang.org/grpc v1.29.1 | ||
| 24 | - google.golang.org/protobuf v1.24.0 // indirect | ||
| 25 | - gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 | 15 | + github.com/vmihailenco/msgpack/v4 v4.3.11 // indirect |
| 16 | + github.com/vmihailenco/msgpack/v5 v5.0.0 | ||
| 17 | + github.com/vmihailenco/tagparser v0.1.2 | ||
| 18 | + go.opentelemetry.io/otel v0.14.0 | ||
| 19 | + golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 // indirect | ||
| 20 | + golang.org/x/net v0.0.0-20201110031124-69a78807bb2b // indirect | ||
| 21 | + golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect | ||
| 22 | + google.golang.org/appengine v1.6.7 // indirect | ||
| 23 | + google.golang.org/protobuf v1.25.0 // indirect | ||
| 24 | + gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f | ||
| 26 | mellium.im/sasl v0.2.1 | 25 | mellium.im/sasl v0.2.1 |
| 27 | ) | 26 | ) |
| @@ -8,15 +8,17 @@ import ( | @@ -8,15 +8,17 @@ import ( | ||
| 8 | "github.com/go-pg/pg/v10/orm" | 8 | "github.com/go-pg/pg/v10/orm" |
| 9 | ) | 9 | ) |
| 10 | 10 | ||
| 11 | -type BeforeScanHook = orm.BeforeScanHook | ||
| 12 | -type AfterScanHook = orm.AfterScanHook | ||
| 13 | -type AfterSelectHook = orm.AfterSelectHook | ||
| 14 | -type BeforeInsertHook = orm.BeforeInsertHook | ||
| 15 | -type AfterInsertHook = orm.AfterInsertHook | ||
| 16 | -type BeforeUpdateHook = orm.BeforeUpdateHook | ||
| 17 | -type AfterUpdateHook = orm.AfterUpdateHook | ||
| 18 | -type BeforeDeleteHook = orm.BeforeDeleteHook | ||
| 19 | -type AfterDeleteHook = orm.AfterDeleteHook | 11 | +type ( |
| 12 | + BeforeScanHook = orm.BeforeScanHook | ||
| 13 | + AfterScanHook = orm.AfterScanHook | ||
| 14 | + AfterSelectHook = orm.AfterSelectHook | ||
| 15 | + BeforeInsertHook = orm.BeforeInsertHook | ||
| 16 | + AfterInsertHook = orm.AfterInsertHook | ||
| 17 | + BeforeUpdateHook = orm.BeforeUpdateHook | ||
| 18 | + AfterUpdateHook = orm.AfterUpdateHook | ||
| 19 | + BeforeDeleteHook = orm.BeforeDeleteHook | ||
| 20 | + AfterDeleteHook = orm.AfterDeleteHook | ||
| 21 | +) | ||
| 20 | 22 | ||
| 21 | //------------------------------------------------------------------------------ | 23 | //------------------------------------------------------------------------------ |
| 22 | 24 | ||
| @@ -94,11 +96,14 @@ func (db *baseDB) beforeQuery( | @@ -94,11 +96,14 @@ func (db *baseDB) beforeQuery( | ||
| 94 | fmtedQuery: fmtedQuery, | 96 | fmtedQuery: fmtedQuery, |
| 95 | } | 97 | } |
| 96 | 98 | ||
| 97 | - for _, hook := range db.queryHooks { | 99 | + for i, hook := range db.queryHooks { |
| 98 | var err error | 100 | var err error |
| 99 | ctx, err = hook.BeforeQuery(ctx, event) | 101 | ctx, err = hook.BeforeQuery(ctx, event) |
| 100 | if err != nil { | 102 | if err != nil { |
| 101 | - return nil, nil, err | 103 | + if err := db.afterQueryFromIndex(ctx, event, i); err != nil { |
| 104 | + return ctx, nil, err | ||
| 105 | + } | ||
| 106 | + return ctx, nil, err | ||
| 102 | } | 107 | } |
| 103 | } | 108 | } |
| 104 | 109 | ||
| @@ -117,14 +122,15 @@ func (db *baseDB) afterQuery( | @@ -117,14 +122,15 @@ func (db *baseDB) afterQuery( | ||
| 117 | 122 | ||
| 118 | event.Err = err | 123 | event.Err = err |
| 119 | event.Result = res | 124 | event.Result = res |
| 125 | + return db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1) | ||
| 126 | +} | ||
| 120 | 127 | ||
| 121 | - for _, hook := range db.queryHooks { | ||
| 122 | - err := hook.AfterQuery(ctx, event) | ||
| 123 | - if err != nil { | 128 | +func (db *baseDB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) error { |
| 129 | + for ; hookIndex >= 0; hookIndex-- { | ||
| 130 | + if err := db.queryHooks[hookIndex].AfterQuery(ctx, event); err != nil { | ||
| 124 | return err | 131 | return err |
| 125 | } | 132 | } |
| 126 | } | 133 | } |
| 127 | - | ||
| 128 | return nil | 134 | return nil |
| 129 | } | 135 | } |
| 130 | 136 |
| @@ -4,8 +4,10 @@ import ( | @@ -4,8 +4,10 @@ import ( | ||
| 4 | "fmt" | 4 | "fmt" |
| 5 | ) | 5 | ) |
| 6 | 6 | ||
| 7 | -var ErrNoRows = Errorf("pg: no rows in result set") | ||
| 8 | -var ErrMultiRows = Errorf("pg: multiple rows in result set") | 7 | +var ( |
| 8 | + ErrNoRows = Errorf("pg: no rows in result set") | ||
| 9 | + ErrMultiRows = Errorf("pg: multiple rows in result set") | ||
| 10 | +) | ||
| 9 | 11 | ||
| 10 | type Error struct { | 12 | type Error struct { |
| 11 | s string | 13 | s string |
| @@ -8,20 +8,20 @@ import ( | @@ -8,20 +8,20 @@ import ( | ||
| 8 | "time" | 8 | "time" |
| 9 | ) | 9 | ) |
| 10 | 10 | ||
| 11 | -// Retry backoff with jitter sleep to prevent overloaded conditions during intervals | ||
| 12 | -// https://www.awsarchitectureblog.com/2015/03/backoff.html | ||
| 13 | func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration { | 11 | func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration { |
| 14 | if retry < 0 { | 12 | if retry < 0 { |
| 15 | - retry = 0 | 13 | + panic("not reached") |
| 16 | } | 14 | } |
| 17 | - | ||
| 18 | - backoff := minBackoff << uint(retry) | ||
| 19 | - if backoff > maxBackoff || backoff < minBackoff { | ||
| 20 | - backoff = maxBackoff | 15 | + if minBackoff == 0 { |
| 16 | + return 0 | ||
| 21 | } | 17 | } |
| 22 | 18 | ||
| 23 | - if backoff == 0 { | ||
| 24 | - return 0 | 19 | + d := minBackoff << uint(retry) |
| 20 | + d = minBackoff + time.Duration(rand.Int63n(int64(d))) | ||
| 21 | + | ||
| 22 | + if d > maxBackoff || d < minBackoff { | ||
| 23 | + d = maxBackoff | ||
| 25 | } | 24 | } |
| 26 | - return time.Duration(rand.Int63n(int64(backoff))) | 25 | + |
| 26 | + return d | ||
| 27 | } | 27 | } |
| 1 | package internal | 1 | package internal |
| 2 | 2 | ||
| 3 | import ( | 3 | import ( |
| 4 | + "context" | ||
| 5 | + "fmt" | ||
| 4 | "log" | 6 | "log" |
| 5 | "os" | 7 | "os" |
| 6 | ) | 8 | ) |
| 7 | 9 | ||
| 8 | -var Logger = log.New(os.Stderr, "pg: ", log.LstdFlags|log.Lshortfile) | 10 | +var Warn = log.New(os.Stderr, "WARN: pg: ", log.LstdFlags) |
| 11 | + | ||
| 12 | +var Deprecated = log.New(os.Stderr, "DEPRECATED: pg: ", log.LstdFlags) | ||
| 13 | + | ||
| 14 | +type Logging interface { | ||
| 15 | + Printf(ctx context.Context, format string, v ...interface{}) | ||
| 16 | +} | ||
| 17 | + | ||
| 18 | +type logger struct { | ||
| 19 | + log *log.Logger | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) { | ||
| 23 | + _ = l.log.Output(2, fmt.Sprintf(format, v...)) | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +var Logger Logging = &logger{ | ||
| 27 | + log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile), | ||
| 28 | +} |
| @@ -8,16 +8,15 @@ import ( | @@ -8,16 +8,15 @@ import ( | ||
| 8 | "time" | 8 | "time" |
| 9 | 9 | ||
| 10 | "github.com/go-pg/pg/v10/internal" | 10 | "github.com/go-pg/pg/v10/internal" |
| 11 | - "go.opentelemetry.io/otel/api/kv" | ||
| 12 | - "go.opentelemetry.io/otel/api/trace" | 11 | + "go.opentelemetry.io/otel/label" |
| 12 | + "go.opentelemetry.io/otel/trace" | ||
| 13 | ) | 13 | ) |
| 14 | 14 | ||
| 15 | var noDeadline = time.Time{} | 15 | var noDeadline = time.Time{} |
| 16 | 16 | ||
| 17 | type Conn struct { | 17 | type Conn struct { |
| 18 | netConn net.Conn | 18 | netConn net.Conn |
| 19 | - | ||
| 20 | - rd *BufReader | 19 | + rd *ReaderContext |
| 21 | 20 | ||
| 22 | ProcessID int32 | 21 | ProcessID int32 |
| 23 | SecretKey int32 | 22 | SecretKey int32 |
| @@ -31,8 +30,6 @@ type Conn struct { | @@ -31,8 +30,6 @@ type Conn struct { | ||
| 31 | 30 | ||
| 32 | func NewConn(netConn net.Conn) *Conn { | 31 | func NewConn(netConn net.Conn) *Conn { |
| 33 | cn := &Conn{ | 32 | cn := &Conn{ |
| 34 | - rd: NewBufReader(netConn), | ||
| 35 | - | ||
| 36 | createdAt: time.Now(), | 33 | createdAt: time.Now(), |
| 37 | } | 34 | } |
| 38 | cn.SetNetConn(netConn) | 35 | cn.SetNetConn(netConn) |
| @@ -55,7 +52,17 @@ func (cn *Conn) RemoteAddr() net.Addr { | @@ -55,7 +52,17 @@ func (cn *Conn) RemoteAddr() net.Addr { | ||
| 55 | 52 | ||
| 56 | func (cn *Conn) SetNetConn(netConn net.Conn) { | 53 | func (cn *Conn) SetNetConn(netConn net.Conn) { |
| 57 | cn.netConn = netConn | 54 | cn.netConn = netConn |
| 58 | - cn.rd.Reset(netConn) | 55 | + if cn.rd != nil { |
| 56 | + cn.rd.Reset(netConn) | ||
| 57 | + } | ||
| 58 | +} | ||
| 59 | + | ||
| 60 | +func (cn *Conn) LockReader() { | ||
| 61 | + if cn.rd != nil { | ||
| 62 | + panic("not reached") | ||
| 63 | + } | ||
| 64 | + cn.rd = NewReaderContext() | ||
| 65 | + cn.rd.Reset(cn.netConn) | ||
| 59 | } | 66 | } |
| 60 | 67 | ||
| 61 | func (cn *Conn) NetConn() net.Conn { | 68 | func (cn *Conn) NetConn() net.Conn { |
| @@ -68,30 +75,44 @@ func (cn *Conn) NextID() string { | @@ -68,30 +75,44 @@ func (cn *Conn) NextID() string { | ||
| 68 | } | 75 | } |
| 69 | 76 | ||
| 70 | func (cn *Conn) WithReader( | 77 | func (cn *Conn) WithReader( |
| 71 | - ctx context.Context, timeout time.Duration, fn func(rd *BufReader) error, | 78 | + ctx context.Context, timeout time.Duration, fn func(rd *ReaderContext) error, |
| 72 | ) error { | 79 | ) error { |
| 73 | - return internal.WithSpan(ctx, "with_reader", func(ctx context.Context, span trace.Span) error { | ||
| 74 | - err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)) | ||
| 75 | - if err != nil { | 80 | + return internal.WithSpan(ctx, "pg.with_reader", func(ctx context.Context, span trace.Span) error { |
| 81 | + if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { | ||
| 82 | + span.RecordError(err) | ||
| 76 | return err | 83 | return err |
| 77 | } | 84 | } |
| 78 | 85 | ||
| 79 | - cn.rd.bytesRead = 0 | ||
| 80 | - err = fn(cn.rd) | ||
| 81 | - span.SetAttributes(kv.Int64("net.read_bytes", cn.rd.bytesRead)) | 86 | + rd := cn.rd |
| 87 | + if rd == nil { | ||
| 88 | + rd = GetReaderContext() | ||
| 89 | + defer PutReaderContext(rd) | ||
| 82 | 90 | ||
| 83 | - return err | 91 | + rd.Reset(cn.netConn) |
| 92 | + } | ||
| 93 | + | ||
| 94 | + rd.bytesRead = 0 | ||
| 95 | + | ||
| 96 | + if err := fn(rd); err != nil { | ||
| 97 | + span.RecordError(err) | ||
| 98 | + return err | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + span.SetAttributes(label.Int64("net.read_bytes", rd.bytesRead)) | ||
| 102 | + | ||
| 103 | + return nil | ||
| 84 | }) | 104 | }) |
| 85 | } | 105 | } |
| 86 | 106 | ||
| 87 | func (cn *Conn) WithWriter( | 107 | func (cn *Conn) WithWriter( |
| 88 | ctx context.Context, timeout time.Duration, fn func(wb *WriteBuffer) error, | 108 | ctx context.Context, timeout time.Duration, fn func(wb *WriteBuffer) error, |
| 89 | ) error { | 109 | ) error { |
| 90 | - return internal.WithSpan(ctx, "with_writer", func(ctx context.Context, span trace.Span) error { | 110 | + return internal.WithSpan(ctx, "pg.with_writer", func(ctx context.Context, span trace.Span) error { |
| 91 | wb := GetWriteBuffer() | 111 | wb := GetWriteBuffer() |
| 92 | defer PutWriteBuffer(wb) | 112 | defer PutWriteBuffer(wb) |
| 93 | 113 | ||
| 94 | if err := fn(wb); err != nil { | 114 | if err := fn(wb); err != nil { |
| 115 | + span.RecordError(err) | ||
| 95 | return err | 116 | return err |
| 96 | } | 117 | } |
| 97 | 118 | ||
| @@ -100,7 +121,7 @@ func (cn *Conn) WithWriter( | @@ -100,7 +121,7 @@ func (cn *Conn) WithWriter( | ||
| 100 | } | 121 | } |
| 101 | 122 | ||
| 102 | func (cn *Conn) WriteBuffer(ctx context.Context, timeout time.Duration, wb *WriteBuffer) error { | 123 | func (cn *Conn) WriteBuffer(ctx context.Context, timeout time.Duration, wb *WriteBuffer) error { |
| 103 | - return internal.WithSpan(ctx, "with_writer", func(ctx context.Context, span trace.Span) error { | 124 | + return internal.WithSpan(ctx, "pg.with_writer", func(ctx context.Context, span trace.Span) error { |
| 104 | return cn.writeBuffer(ctx, span, timeout, wb) | 125 | return cn.writeBuffer(ctx, span, timeout, wb) |
| 105 | }) | 126 | }) |
| 106 | } | 127 | } |
| @@ -111,14 +132,19 @@ func (cn *Conn) writeBuffer( | @@ -111,14 +132,19 @@ func (cn *Conn) writeBuffer( | ||
| 111 | timeout time.Duration, | 132 | timeout time.Duration, |
| 112 | wb *WriteBuffer, | 133 | wb *WriteBuffer, |
| 113 | ) error { | 134 | ) error { |
| 114 | - err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)) | ||
| 115 | - if err != nil { | 135 | + if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { |
| 136 | + span.RecordError(err) | ||
| 137 | + return err | ||
| 138 | + } | ||
| 139 | + | ||
| 140 | + span.SetAttributes(label.Int("net.wrote_bytes", len(wb.Bytes))) | ||
| 141 | + | ||
| 142 | + if _, err := cn.netConn.Write(wb.Bytes); err != nil { | ||
| 143 | + span.RecordError(err) | ||
| 116 | return err | 144 | return err |
| 117 | } | 145 | } |
| 118 | 146 | ||
| 119 | - span.SetAttributes(kv.Int("net.wrote_bytes", len(wb.Bytes))) | ||
| 120 | - _, err = cn.netConn.Write(wb.Bytes) | ||
| 121 | - return err | 147 | + return nil |
| 122 | } | 148 | } |
| 123 | 149 | ||
| 124 | func (cn *Conn) Close() error { | 150 | func (cn *Conn) Close() error { |
| @@ -11,8 +11,10 @@ import ( | @@ -11,8 +11,10 @@ import ( | ||
| 11 | "github.com/go-pg/pg/v10/internal" | 11 | "github.com/go-pg/pg/v10/internal" |
| 12 | ) | 12 | ) |
| 13 | 13 | ||
| 14 | -var ErrClosed = errors.New("pg: database is closed") | ||
| 15 | -var ErrPoolTimeout = errors.New("pg: connection pool timeout") | 14 | +var ( |
| 15 | + ErrClosed = errors.New("pg: database is closed") | ||
| 16 | + ErrPoolTimeout = errors.New("pg: connection pool timeout") | ||
| 17 | +) | ||
| 16 | 18 | ||
| 17 | var timers = sync.Pool{ | 19 | var timers = sync.Pool{ |
| 18 | New: func() interface{} { | 20 | New: func() interface{} { |
| @@ -38,8 +40,8 @@ type Pooler interface { | @@ -38,8 +40,8 @@ type Pooler interface { | ||
| 38 | CloseConn(*Conn) error | 40 | CloseConn(*Conn) error |
| 39 | 41 | ||
| 40 | Get(context.Context) (*Conn, error) | 42 | Get(context.Context) (*Conn, error) |
| 41 | - Put(*Conn) | ||
| 42 | - Remove(*Conn, error) | 43 | + Put(context.Context, *Conn) |
| 44 | + Remove(context.Context, *Conn, error) | ||
| 43 | 45 | ||
| 44 | Len() int | 46 | Len() int |
| 45 | IdleLen() int | 47 | IdleLen() int |
| @@ -216,12 +218,12 @@ func (p *ConnPool) getLastDialError() error { | @@ -216,12 +218,12 @@ func (p *ConnPool) getLastDialError() error { | ||
| 216 | } | 218 | } |
| 217 | 219 | ||
| 218 | // Get returns existed connection from the pool or creates a new one. | 220 | // Get returns existed connection from the pool or creates a new one. |
| 219 | -func (p *ConnPool) Get(c context.Context) (*Conn, error) { | 221 | +func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { |
| 220 | if p.closed() { | 222 | if p.closed() { |
| 221 | return nil, ErrClosed | 223 | return nil, ErrClosed |
| 222 | } | 224 | } |
| 223 | 225 | ||
| 224 | - err := p.waitTurn(c) | 226 | + err := p.waitTurn(ctx) |
| 225 | if err != nil { | 227 | if err != nil { |
| 226 | return nil, err | 228 | return nil, err |
| 227 | } | 229 | } |
| @@ -246,7 +248,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) { | @@ -246,7 +248,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) { | ||
| 246 | 248 | ||
| 247 | atomic.AddUint32(&p.stats.Misses, 1) | 249 | atomic.AddUint32(&p.stats.Misses, 1) |
| 248 | 250 | ||
| 249 | - newcn, err := p.newConn(c, true) | 251 | + newcn, err := p.newConn(ctx, true) |
| 250 | if err != nil { | 252 | if err != nil { |
| 251 | p.freeTurn() | 253 | p.freeTurn() |
| 252 | return nil, err | 254 | return nil, err |
| @@ -312,15 +314,9 @@ func (p *ConnPool) popIdle() *Conn { | @@ -312,15 +314,9 @@ func (p *ConnPool) popIdle() *Conn { | ||
| 312 | return cn | 314 | return cn |
| 313 | } | 315 | } |
| 314 | 316 | ||
| 315 | -func (p *ConnPool) Put(cn *Conn) { | ||
| 316 | - if cn.rd.Buffered() > 0 { | ||
| 317 | - internal.Logger.Printf("Conn has unread data") | ||
| 318 | - p.Remove(cn, BadConnError{}) | ||
| 319 | - return | ||
| 320 | - } | ||
| 321 | - | 317 | +func (p *ConnPool) Put(ctx context.Context, cn *Conn) { |
| 322 | if !cn.pooled { | 318 | if !cn.pooled { |
| 323 | - p.Remove(cn, nil) | 319 | + p.Remove(ctx, cn, nil) |
| 324 | return | 320 | return |
| 325 | } | 321 | } |
| 326 | 322 | ||
| @@ -331,7 +327,7 @@ func (p *ConnPool) Put(cn *Conn) { | @@ -331,7 +327,7 @@ func (p *ConnPool) Put(cn *Conn) { | ||
| 331 | p.freeTurn() | 327 | p.freeTurn() |
| 332 | } | 328 | } |
| 333 | 329 | ||
| 334 | -func (p *ConnPool) Remove(cn *Conn, reason error) { | 330 | +func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { |
| 335 | p.removeConnWithLock(cn) | 331 | p.removeConnWithLock(cn) |
| 336 | p.freeTurn() | 332 | p.freeTurn() |
| 337 | _ = p.closeConn(cn) | 333 | _ = p.closeConn(cn) |
| @@ -446,7 +442,7 @@ func (p *ConnPool) reaper(frequency time.Duration) { | @@ -446,7 +442,7 @@ func (p *ConnPool) reaper(frequency time.Duration) { | ||
| 446 | } | 442 | } |
| 447 | n, err := p.ReapStaleConns() | 443 | n, err := p.ReapStaleConns() |
| 448 | if err != nil { | 444 | if err != nil { |
| 449 | - internal.Logger.Printf("ReapStaleConns failed: %s", err) | 445 | + internal.Logger.Printf(context.TODO(), "ReapStaleConns failed: %s", err) |
| 450 | continue | 446 | continue |
| 451 | } | 447 | } |
| 452 | atomic.AddUint32(&p.stats.StaleConns, uint32(n)) | 448 | atomic.AddUint32(&p.stats.StaleConns, uint32(n)) |
| 1 | package pool | 1 | package pool |
| 2 | 2 | ||
| 3 | -import ( | ||
| 4 | - "context" | ||
| 5 | - "errors" | ||
| 6 | - "fmt" | ||
| 7 | - "sync/atomic" | ||
| 8 | -) | ||
| 9 | - | ||
| 10 | -const ( | ||
| 11 | - stateDefault = 0 | ||
| 12 | - stateInited = 1 | ||
| 13 | - stateClosed = 2 | ||
| 14 | -) | ||
| 15 | - | ||
| 16 | -type BadConnError struct { | ||
| 17 | - wrapped error | ||
| 18 | -} | ||
| 19 | - | ||
| 20 | -var _ error = (*BadConnError)(nil) | ||
| 21 | - | ||
| 22 | -func (e BadConnError) Error() string { | ||
| 23 | - s := "pg: Conn is in a bad state" | ||
| 24 | - if e.wrapped != nil { | ||
| 25 | - s += ": " + e.wrapped.Error() | ||
| 26 | - } | ||
| 27 | - return s | ||
| 28 | -} | ||
| 29 | - | ||
| 30 | -func (e BadConnError) Unwrap() error { | ||
| 31 | - return e.wrapped | ||
| 32 | -} | 3 | +import "context" |
| 33 | 4 | ||
| 34 | type SingleConnPool struct { | 5 | type SingleConnPool struct { |
| 35 | - pool Pooler | ||
| 36 | - level int32 // atomic | ||
| 37 | - | ||
| 38 | - state uint32 // atomic | ||
| 39 | - ch chan *Conn | ||
| 40 | - | ||
| 41 | - _badConnError atomic.Value | 6 | + pool Pooler |
| 7 | + cn *Conn | ||
| 8 | + stickyErr error | ||
| 42 | } | 9 | } |
| 43 | 10 | ||
| 44 | var _ Pooler = (*SingleConnPool)(nil) | 11 | var _ Pooler = (*SingleConnPool)(nil) |
| 45 | 12 | ||
| 46 | -func NewSingleConnPool(pool Pooler) *SingleConnPool { | ||
| 47 | - p, ok := pool.(*SingleConnPool) | ||
| 48 | - if !ok { | ||
| 49 | - p = &SingleConnPool{ | ||
| 50 | - pool: pool, | ||
| 51 | - ch: make(chan *Conn, 1), | ||
| 52 | - } | 13 | +func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool { |
| 14 | + return &SingleConnPool{ | ||
| 15 | + pool: pool, | ||
| 16 | + cn: cn, | ||
| 53 | } | 17 | } |
| 54 | - atomic.AddInt32(&p.level, 1) | ||
| 55 | - return p | ||
| 56 | } | 18 | } |
| 57 | 19 | ||
| 58 | -func (p *SingleConnPool) SetConn(cn *Conn) { | ||
| 59 | - if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { | ||
| 60 | - p.ch <- cn | ||
| 61 | - } else { | ||
| 62 | - panic("not reached") | ||
| 63 | - } | ||
| 64 | -} | ||
| 65 | - | ||
| 66 | -func (p *SingleConnPool) NewConn(c context.Context) (*Conn, error) { | ||
| 67 | - return p.pool.NewConn(c) | 20 | +func (p *SingleConnPool) NewConn(ctx context.Context) (*Conn, error) { |
| 21 | + return p.pool.NewConn(ctx) | ||
| 68 | } | 22 | } |
| 69 | 23 | ||
| 70 | func (p *SingleConnPool) CloseConn(cn *Conn) error { | 24 | func (p *SingleConnPool) CloseConn(cn *Conn) error { |
| 71 | return p.pool.CloseConn(cn) | 25 | return p.pool.CloseConn(cn) |
| 72 | } | 26 | } |
| 73 | 27 | ||
| 74 | -func (p *SingleConnPool) Get(c context.Context) (*Conn, error) { | ||
| 75 | - // In worst case this races with Close which is not a very common operation. | ||
| 76 | - for i := 0; i < 1000; i++ { | ||
| 77 | - switch atomic.LoadUint32(&p.state) { | ||
| 78 | - case stateDefault: | ||
| 79 | - cn, err := p.pool.Get(c) | ||
| 80 | - if err != nil { | ||
| 81 | - return nil, err | ||
| 82 | - } | ||
| 83 | - if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { | ||
| 84 | - return cn, nil | ||
| 85 | - } | ||
| 86 | - p.pool.Remove(cn, ErrClosed) | ||
| 87 | - case stateInited: | ||
| 88 | - if err := p.badConnError(); err != nil { | ||
| 89 | - return nil, err | ||
| 90 | - } | ||
| 91 | - cn, ok := <-p.ch | ||
| 92 | - if !ok { | ||
| 93 | - return nil, ErrClosed | ||
| 94 | - } | ||
| 95 | - return cn, nil | ||
| 96 | - case stateClosed: | ||
| 97 | - return nil, ErrClosed | ||
| 98 | - default: | ||
| 99 | - panic("not reached") | ||
| 100 | - } | 28 | +func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { |
| 29 | + if p.stickyErr != nil { | ||
| 30 | + return nil, p.stickyErr | ||
| 101 | } | 31 | } |
| 102 | - return nil, fmt.Errorf("pg: SingleConnPool.Get: infinite loop") | 32 | + return p.cn, nil |
| 103 | } | 33 | } |
| 104 | 34 | ||
| 105 | -func (p *SingleConnPool) Put(cn *Conn) { | ||
| 106 | - defer func() { | ||
| 107 | - if recover() != nil { | ||
| 108 | - p.freeConn(cn) | ||
| 109 | - } | ||
| 110 | - }() | ||
| 111 | - p.ch <- cn | ||
| 112 | -} | 35 | +func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {} |
| 113 | 36 | ||
| 114 | -func (p *SingleConnPool) freeConn(cn *Conn) { | ||
| 115 | - if err := p.badConnError(); err != nil { | ||
| 116 | - p.pool.Remove(cn, err) | ||
| 117 | - } else { | ||
| 118 | - p.pool.Put(cn) | ||
| 119 | - } | 37 | +func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) { |
| 38 | + p.cn = nil | ||
| 39 | + p.stickyErr = reason | ||
| 120 | } | 40 | } |
| 121 | 41 | ||
| 122 | -func (p *SingleConnPool) Remove(cn *Conn, reason error) { | ||
| 123 | - defer func() { | ||
| 124 | - if recover() != nil { | ||
| 125 | - p.pool.Remove(cn, ErrClosed) | ||
| 126 | - } | ||
| 127 | - }() | ||
| 128 | - p._badConnError.Store(BadConnError{wrapped: reason}) | ||
| 129 | - p.ch <- cn | 42 | +func (p *SingleConnPool) Close() error { |
| 43 | + p.cn = nil | ||
| 44 | + p.stickyErr = ErrClosed | ||
| 45 | + return nil | ||
| 130 | } | 46 | } |
| 131 | 47 | ||
| 132 | func (p *SingleConnPool) Len() int { | 48 | func (p *SingleConnPool) Len() int { |
| 133 | - switch atomic.LoadUint32(&p.state) { | ||
| 134 | - case stateDefault: | ||
| 135 | - return 0 | ||
| 136 | - case stateInited: | ||
| 137 | - return 1 | ||
| 138 | - case stateClosed: | ||
| 139 | - return 0 | ||
| 140 | - default: | ||
| 141 | - panic("not reached") | ||
| 142 | - } | 49 | + return 0 |
| 143 | } | 50 | } |
| 144 | 51 | ||
| 145 | func (p *SingleConnPool) IdleLen() int { | 52 | func (p *SingleConnPool) IdleLen() int { |
| 146 | - return len(p.ch) | 53 | + return 0 |
| 147 | } | 54 | } |
| 148 | 55 | ||
| 149 | func (p *SingleConnPool) Stats() *Stats { | 56 | func (p *SingleConnPool) Stats() *Stats { |
| 150 | return &Stats{} | 57 | return &Stats{} |
| 151 | } | 58 | } |
| 152 | - | ||
| 153 | -func (p *SingleConnPool) Close() error { | ||
| 154 | - level := atomic.AddInt32(&p.level, -1) | ||
| 155 | - if level > 0 { | ||
| 156 | - return nil | ||
| 157 | - } | ||
| 158 | - | ||
| 159 | - for i := 0; i < 1000; i++ { | ||
| 160 | - state := atomic.LoadUint32(&p.state) | ||
| 161 | - if state == stateClosed { | ||
| 162 | - return ErrClosed | ||
| 163 | - } | ||
| 164 | - if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { | ||
| 165 | - close(p.ch) | ||
| 166 | - cn, ok := <-p.ch | ||
| 167 | - if ok { | ||
| 168 | - p.freeConn(cn) | ||
| 169 | - } | ||
| 170 | - return nil | ||
| 171 | - } | ||
| 172 | - } | ||
| 173 | - | ||
| 174 | - return errors.New("pg: SingleConnPool.Close: infinite loop") | ||
| 175 | -} | ||
| 176 | - | ||
| 177 | -func (p *SingleConnPool) Reset() error { | ||
| 178 | - if p.badConnError() == nil { | ||
| 179 | - return nil | ||
| 180 | - } | ||
| 181 | - | ||
| 182 | - select { | ||
| 183 | - case cn, ok := <-p.ch: | ||
| 184 | - if !ok { | ||
| 185 | - return ErrClosed | ||
| 186 | - } | ||
| 187 | - p.pool.Remove(cn, ErrClosed) | ||
| 188 | - p._badConnError.Store(BadConnError{wrapped: nil}) | ||
| 189 | - default: | ||
| 190 | - return errors.New("pg: SingleConnPool does not have a Conn") | ||
| 191 | - } | ||
| 192 | - | ||
| 193 | - if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { | ||
| 194 | - state := atomic.LoadUint32(&p.state) | ||
| 195 | - return fmt.Errorf("pg: invalid SingleConnPool state: %d", state) | ||
| 196 | - } | ||
| 197 | - | ||
| 198 | - return nil | ||
| 199 | -} | ||
| 200 | - | ||
| 201 | -func (p *SingleConnPool) badConnError() error { | ||
| 202 | - if v := p._badConnError.Load(); v != nil { | ||
| 203 | - err := v.(BadConnError) | ||
| 204 | - if err.wrapped != nil { | ||
| 205 | - return err | ||
| 206 | - } | ||
| 207 | - } | ||
| 208 | - return nil | ||
| 209 | -} |
| 1 | +package pool | ||
| 2 | + | ||
| 3 | +import ( | ||
| 4 | + "context" | ||
| 5 | + "errors" | ||
| 6 | + "fmt" | ||
| 7 | + "sync/atomic" | ||
| 8 | +) | ||
| 9 | + | ||
| 10 | +const ( | ||
| 11 | + stateDefault = 0 | ||
| 12 | + stateInited = 1 | ||
| 13 | + stateClosed = 2 | ||
| 14 | +) | ||
| 15 | + | ||
| 16 | +type BadConnError struct { | ||
| 17 | + wrapped error | ||
| 18 | +} | ||
| 19 | + | ||
| 20 | +var _ error = (*BadConnError)(nil) | ||
| 21 | + | ||
| 22 | +func (e BadConnError) Error() string { | ||
| 23 | + s := "pg: Conn is in a bad state" | ||
| 24 | + if e.wrapped != nil { | ||
| 25 | + s += ": " + e.wrapped.Error() | ||
| 26 | + } | ||
| 27 | + return s | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +func (e BadConnError) Unwrap() error { | ||
| 31 | + return e.wrapped | ||
| 32 | +} | ||
| 33 | + | ||
| 34 | +//------------------------------------------------------------------------------ | ||
| 35 | + | ||
| 36 | +type StickyConnPool struct { | ||
| 37 | + pool Pooler | ||
| 38 | + shared int32 // atomic | ||
| 39 | + | ||
| 40 | + state uint32 // atomic | ||
| 41 | + ch chan *Conn | ||
| 42 | + | ||
| 43 | + _badConnError atomic.Value | ||
| 44 | +} | ||
| 45 | + | ||
| 46 | +var _ Pooler = (*StickyConnPool)(nil) | ||
| 47 | + | ||
| 48 | +func NewStickyConnPool(pool Pooler) *StickyConnPool { | ||
| 49 | + p, ok := pool.(*StickyConnPool) | ||
| 50 | + if !ok { | ||
| 51 | + p = &StickyConnPool{ | ||
| 52 | + pool: pool, | ||
| 53 | + ch: make(chan *Conn, 1), | ||
| 54 | + } | ||
| 55 | + } | ||
| 56 | + atomic.AddInt32(&p.shared, 1) | ||
| 57 | + return p | ||
| 58 | +} | ||
| 59 | + | ||
| 60 | +func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) { | ||
| 61 | + return p.pool.NewConn(ctx) | ||
| 62 | +} | ||
| 63 | + | ||
| 64 | +func (p *StickyConnPool) CloseConn(cn *Conn) error { | ||
| 65 | + return p.pool.CloseConn(cn) | ||
| 66 | +} | ||
| 67 | + | ||
| 68 | +func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { | ||
| 69 | + // In worst case this races with Close which is not a very common operation. | ||
| 70 | + for i := 0; i < 1000; i++ { | ||
| 71 | + switch atomic.LoadUint32(&p.state) { | ||
| 72 | + case stateDefault: | ||
| 73 | + cn, err := p.pool.Get(ctx) | ||
| 74 | + if err != nil { | ||
| 75 | + return nil, err | ||
| 76 | + } | ||
| 77 | + if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { | ||
| 78 | + return cn, nil | ||
| 79 | + } | ||
| 80 | + p.pool.Remove(ctx, cn, ErrClosed) | ||
| 81 | + case stateInited: | ||
| 82 | + if err := p.badConnError(); err != nil { | ||
| 83 | + return nil, err | ||
| 84 | + } | ||
| 85 | + cn, ok := <-p.ch | ||
| 86 | + if !ok { | ||
| 87 | + return nil, ErrClosed | ||
| 88 | + } | ||
| 89 | + return cn, nil | ||
| 90 | + case stateClosed: | ||
| 91 | + return nil, ErrClosed | ||
| 92 | + default: | ||
| 93 | + panic("not reached") | ||
| 94 | + } | ||
| 95 | + } | ||
| 96 | + return nil, fmt.Errorf("pg: StickyConnPool.Get: infinite loop") | ||
| 97 | +} | ||
| 98 | + | ||
| 99 | +func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) { | ||
| 100 | + defer func() { | ||
| 101 | + if recover() != nil { | ||
| 102 | + p.freeConn(ctx, cn) | ||
| 103 | + } | ||
| 104 | + }() | ||
| 105 | + p.ch <- cn | ||
| 106 | +} | ||
| 107 | + | ||
| 108 | +func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) { | ||
| 109 | + if err := p.badConnError(); err != nil { | ||
| 110 | + p.pool.Remove(ctx, cn, err) | ||
| 111 | + } else { | ||
| 112 | + p.pool.Put(ctx, cn) | ||
| 113 | + } | ||
| 114 | +} | ||
| 115 | + | ||
| 116 | +func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) { | ||
| 117 | + defer func() { | ||
| 118 | + if recover() != nil { | ||
| 119 | + p.pool.Remove(ctx, cn, ErrClosed) | ||
| 120 | + } | ||
| 121 | + }() | ||
| 122 | + p._badConnError.Store(BadConnError{wrapped: reason}) | ||
| 123 | + p.ch <- cn | ||
| 124 | +} | ||
| 125 | + | ||
| 126 | +func (p *StickyConnPool) Close() error { | ||
| 127 | + if shared := atomic.AddInt32(&p.shared, -1); shared > 0 { | ||
| 128 | + return nil | ||
| 129 | + } | ||
| 130 | + | ||
| 131 | + for i := 0; i < 1000; i++ { | ||
| 132 | + state := atomic.LoadUint32(&p.state) | ||
| 133 | + if state == stateClosed { | ||
| 134 | + return ErrClosed | ||
| 135 | + } | ||
| 136 | + if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { | ||
| 137 | + close(p.ch) | ||
| 138 | + cn, ok := <-p.ch | ||
| 139 | + if ok { | ||
| 140 | + p.freeConn(context.TODO(), cn) | ||
| 141 | + } | ||
| 142 | + return nil | ||
| 143 | + } | ||
| 144 | + } | ||
| 145 | + | ||
| 146 | + return errors.New("pg: StickyConnPool.Close: infinite loop") | ||
| 147 | +} | ||
| 148 | + | ||
| 149 | +func (p *StickyConnPool) Reset(ctx context.Context) error { | ||
| 150 | + if p.badConnError() == nil { | ||
| 151 | + return nil | ||
| 152 | + } | ||
| 153 | + | ||
| 154 | + select { | ||
| 155 | + case cn, ok := <-p.ch: | ||
| 156 | + if !ok { | ||
| 157 | + return ErrClosed | ||
| 158 | + } | ||
| 159 | + p.pool.Remove(ctx, cn, ErrClosed) | ||
| 160 | + p._badConnError.Store(BadConnError{wrapped: nil}) | ||
| 161 | + default: | ||
| 162 | + return errors.New("pg: StickyConnPool does not have a Conn") | ||
| 163 | + } | ||
| 164 | + | ||
| 165 | + if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { | ||
| 166 | + state := atomic.LoadUint32(&p.state) | ||
| 167 | + return fmt.Errorf("pg: invalid StickyConnPool state: %d", state) | ||
| 168 | + } | ||
| 169 | + | ||
| 170 | + return nil | ||
| 171 | +} | ||
| 172 | + | ||
| 173 | +func (p *StickyConnPool) badConnError() error { | ||
| 174 | + if v := p._badConnError.Load(); v != nil { | ||
| 175 | + err := v.(BadConnError) | ||
| 176 | + if err.wrapped != nil { | ||
| 177 | + return err | ||
| 178 | + } | ||
| 179 | + } | ||
| 180 | + return nil | ||
| 181 | +} | ||
| 182 | + | ||
| 183 | +func (p *StickyConnPool) Len() int { | ||
| 184 | + switch atomic.LoadUint32(&p.state) { | ||
| 185 | + case stateDefault: | ||
| 186 | + return 0 | ||
| 187 | + case stateInited: | ||
| 188 | + return 1 | ||
| 189 | + case stateClosed: | ||
| 190 | + return 0 | ||
| 191 | + default: | ||
| 192 | + panic("not reached") | ||
| 193 | + } | ||
| 194 | +} | ||
| 195 | + | ||
| 196 | +func (p *StickyConnPool) IdleLen() int { | ||
| 197 | + return len(p.ch) | ||
| 198 | +} | ||
| 199 | + | ||
| 200 | +func (p *StickyConnPool) Stats() *Stats { | ||
| 201 | + return &Stats{} | ||
| 202 | +} |
| 1 | package pool | 1 | package pool |
| 2 | 2 | ||
| 3 | +import ( | ||
| 4 | + "sync" | ||
| 5 | +) | ||
| 6 | + | ||
| 3 | type Reader interface { | 7 | type Reader interface { |
| 4 | Buffered() int | 8 | Buffered() int |
| 5 | 9 | ||
| @@ -10,8 +14,67 @@ type Reader interface { | @@ -10,8 +14,67 @@ type Reader interface { | ||
| 10 | ReadSlice(byte) ([]byte, error) | 14 | ReadSlice(byte) ([]byte, error) |
| 11 | Discard(int) (int, error) | 15 | Discard(int) (int, error) |
| 12 | 16 | ||
| 13 | - //ReadBytes(fn func(byte) bool) ([]byte, error) | ||
| 14 | - //ReadN(int) ([]byte, error) | 17 | + // ReadBytes(fn func(byte) bool) ([]byte, error) |
| 18 | + // ReadN(int) ([]byte, error) | ||
| 15 | ReadFull() ([]byte, error) | 19 | ReadFull() ([]byte, error) |
| 16 | ReadFullTemp() ([]byte, error) | 20 | ReadFullTemp() ([]byte, error) |
| 17 | } | 21 | } |
| 22 | + | ||
| 23 | +type ColumnInfo struct { | ||
| 24 | + Index int16 | ||
| 25 | + DataType int32 | ||
| 26 | + Name string | ||
| 27 | +} | ||
| 28 | + | ||
| 29 | +type ColumnAlloc struct { | ||
| 30 | + columns []ColumnInfo | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +func NewColumnAlloc() *ColumnAlloc { | ||
| 34 | + return new(ColumnAlloc) | ||
| 35 | +} | ||
| 36 | + | ||
| 37 | +func (c *ColumnAlloc) Reset() { | ||
| 38 | + c.columns = c.columns[:0] | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | +func (c *ColumnAlloc) New(index int16, name []byte) *ColumnInfo { | ||
| 42 | + c.columns = append(c.columns, ColumnInfo{ | ||
| 43 | + Index: index, | ||
| 44 | + Name: string(name), | ||
| 45 | + }) | ||
| 46 | + return &c.columns[len(c.columns)-1] | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +func (c *ColumnAlloc) Columns() []ColumnInfo { | ||
| 50 | + return c.columns | ||
| 51 | +} | ||
| 52 | + | ||
| 53 | +type ReaderContext struct { | ||
| 54 | + *BufReader | ||
| 55 | + ColumnAlloc *ColumnAlloc | ||
| 56 | +} | ||
| 57 | + | ||
| 58 | +func NewReaderContext() *ReaderContext { | ||
| 59 | + const bufSize = 1 << 20 // 1mb | ||
| 60 | + return &ReaderContext{ | ||
| 61 | + BufReader: NewBufReader(bufSize), | ||
| 62 | + ColumnAlloc: NewColumnAlloc(), | ||
| 63 | + } | ||
| 64 | +} | ||
| 65 | + | ||
| 66 | +var readerPool = sync.Pool{ | ||
| 67 | + New: func() interface{} { | ||
| 68 | + return NewReaderContext() | ||
| 69 | + }, | ||
| 70 | +} | ||
| 71 | + | ||
| 72 | +func GetReaderContext() *ReaderContext { | ||
| 73 | + rd := readerPool.Get().(*ReaderContext) | ||
| 74 | + return rd | ||
| 75 | +} | ||
| 76 | + | ||
| 77 | +func PutReaderContext(rd *ReaderContext) { | ||
| 78 | + rd.ColumnAlloc.Reset() | ||
| 79 | + readerPool.Put(rd) | ||
| 80 | +} |
| @@ -10,11 +10,7 @@ import ( | @@ -10,11 +10,7 @@ import ( | ||
| 10 | "io" | 10 | "io" |
| 11 | ) | 11 | ) |
| 12 | 12 | ||
| 13 | -const defaultBufSize = 65536 | ||
| 14 | - | ||
| 15 | type BufReader struct { | 13 | type BufReader struct { |
| 16 | - Columns [][]byte | ||
| 17 | - | ||
| 18 | rd io.Reader // reader provided by the client | 14 | rd io.Reader // reader provided by the client |
| 19 | 15 | ||
| 20 | buf []byte | 16 | buf []byte |
| @@ -24,25 +20,24 @@ type BufReader struct { | @@ -24,25 +20,24 @@ type BufReader struct { | ||
| 24 | err error | 20 | err error |
| 25 | 21 | ||
| 26 | available int // bytes available for reading | 22 | available int // bytes available for reading |
| 27 | - bytesRd BytesReader // reusable bytes reader | 23 | + brd BytesReader // reusable bytes reader |
| 28 | } | 24 | } |
| 29 | 25 | ||
| 30 | -func NewBufReader(rd io.Reader) *BufReader { | 26 | +func NewBufReader(bufSize int) *BufReader { |
| 31 | return &BufReader{ | 27 | return &BufReader{ |
| 32 | - rd: rd, | ||
| 33 | - buf: make([]byte, defaultBufSize), | 28 | + buf: make([]byte, bufSize), |
| 34 | available: -1, | 29 | available: -1, |
| 35 | } | 30 | } |
| 36 | } | 31 | } |
| 37 | 32 | ||
| 38 | func (b *BufReader) BytesReader(n int) *BytesReader { | 33 | func (b *BufReader) BytesReader(n int) *BytesReader { |
| 39 | - if b.Buffered() < n { | ||
| 40 | - return nil | 34 | + if n == -1 { |
| 35 | + n = 0 | ||
| 41 | } | 36 | } |
| 42 | buf := b.buf[b.r : b.r+n] | 37 | buf := b.buf[b.r : b.r+n] |
| 43 | b.r += n | 38 | b.r += n |
| 44 | - b.bytesRd.Reset(buf) | ||
| 45 | - return &b.bytesRd | 39 | + b.brd.Reset(buf) |
| 40 | + return &b.brd | ||
| 46 | } | 41 | } |
| 47 | 42 | ||
| 48 | func (b *BufReader) SetAvailable(n int) { | 43 | func (b *BufReader) SetAvailable(n int) { |
| @@ -67,11 +62,11 @@ func (b *BufReader) Reset(rd io.Reader) { | @@ -67,11 +62,11 @@ func (b *BufReader) Reset(rd io.Reader) { | ||
| 67 | 62 | ||
| 68 | // Buffered returns the number of bytes that can be read from the current buffer. | 63 | // Buffered returns the number of bytes that can be read from the current buffer. |
| 69 | func (b *BufReader) Buffered() int { | 64 | func (b *BufReader) Buffered() int { |
| 70 | - d := b.w - b.r | ||
| 71 | - if b.available != -1 && d > b.available { | ||
| 72 | - return b.available | 65 | + buffered := b.w - b.r |
| 66 | + if b.available == -1 || buffered <= b.available { | ||
| 67 | + return buffered | ||
| 73 | } | 68 | } |
| 74 | - return d | 69 | + return b.available |
| 75 | } | 70 | } |
| 76 | 71 | ||
| 77 | func (b *BufReader) Bytes() []byte { | 72 | func (b *BufReader) Bytes() []byte { |
| @@ -122,7 +117,7 @@ func (b *BufReader) fill() { | @@ -122,7 +117,7 @@ func (b *BufReader) fill() { | ||
| 122 | // Read new data: try a limited number of times. | 117 | // Read new data: try a limited number of times. |
| 123 | const maxConsecutiveEmptyReads = 100 | 118 | const maxConsecutiveEmptyReads = 100 |
| 124 | for i := maxConsecutiveEmptyReads; i > 0; i-- { | 119 | for i := maxConsecutiveEmptyReads; i > 0; i-- { |
| 125 | - n, err := b.readDirectly(b.buf[b.w:]) | 120 | + n, err := b.read(b.buf[b.w:]) |
| 126 | b.w += n | 121 | b.w += n |
| 127 | if err != nil { | 122 | if err != nil { |
| 128 | b.err = err | 123 | b.err = err |
| @@ -163,7 +158,7 @@ func (b *BufReader) Read(p []byte) (n int, err error) { | @@ -163,7 +158,7 @@ func (b *BufReader) Read(p []byte) (n int, err error) { | ||
| 163 | if len(p) >= len(b.buf) { | 158 | if len(p) >= len(b.buf) { |
| 164 | // Large read, empty buffer. | 159 | // Large read, empty buffer. |
| 165 | // Read directly into p to avoid copy. | 160 | // Read directly into p to avoid copy. |
| 166 | - n, err = b.readDirectly(p) | 161 | + n, err = b.read(p) |
| 167 | if n > 0 { | 162 | if n > 0 { |
| 168 | b.changeAvailable(-n) | 163 | b.changeAvailable(-n) |
| 169 | b.lastByte = int(p[n-1]) | 164 | b.lastByte = int(p[n-1]) |
| @@ -175,7 +170,7 @@ func (b *BufReader) Read(p []byte) (n int, err error) { | @@ -175,7 +170,7 @@ func (b *BufReader) Read(p []byte) (n int, err error) { | ||
| 175 | // Do not use b.fill, which will loop. | 170 | // Do not use b.fill, which will loop. |
| 176 | b.r = 0 | 171 | b.r = 0 |
| 177 | b.w = 0 | 172 | b.w = 0 |
| 178 | - n, b.err = b.readDirectly(b.buf) | 173 | + n, b.err = b.read(b.buf) |
| 179 | if n == 0 { | 174 | if n == 0 { |
| 180 | return 0, b.readErr() | 175 | return 0, b.readErr() |
| 181 | } | 176 | } |
| @@ -259,7 +254,7 @@ func (b *BufReader) ReadBytes(fn func(byte) bool) (line []byte, err error) { | @@ -259,7 +254,7 @@ func (b *BufReader) ReadBytes(fn func(byte) bool) (line []byte, err error) { | ||
| 259 | 254 | ||
| 260 | // Pending error? | 255 | // Pending error? |
| 261 | if b.err != nil { | 256 | if b.err != nil { |
| 262 | - line = b.flush() //nolint | 257 | + line = b.flush() |
| 263 | err = b.readErr() | 258 | err = b.readErr() |
| 264 | break | 259 | break |
| 265 | } | 260 | } |
| @@ -429,7 +424,7 @@ func (b *BufReader) ReadFullTemp() ([]byte, error) { | @@ -429,7 +424,7 @@ func (b *BufReader) ReadFullTemp() ([]byte, error) { | ||
| 429 | return b.ReadFull() | 424 | return b.ReadFull() |
| 430 | } | 425 | } |
| 431 | 426 | ||
| 432 | -func (b *BufReader) readDirectly(buf []byte) (int, error) { | 427 | +func (b *BufReader) read(buf []byte) (int, error) { |
| 433 | n, err := b.rd.Read(buf) | 428 | n, err := b.rd.Read(buf) |
| 434 | b.bytesRead += int64(n) | 429 | b.bytesRead += int64(n) |
| 435 | return n, err | 430 | return n, err |
| @@ -6,20 +6,22 @@ import ( | @@ -6,20 +6,22 @@ import ( | ||
| 6 | "sync" | 6 | "sync" |
| 7 | ) | 7 | ) |
| 8 | 8 | ||
| 9 | -var pool = sync.Pool{ | 9 | +const defaultBufSize = 65 << 10 // 65kb |
| 10 | + | ||
| 11 | +var wbPool = sync.Pool{ | ||
| 10 | New: func() interface{} { | 12 | New: func() interface{} { |
| 11 | return NewWriteBuffer() | 13 | return NewWriteBuffer() |
| 12 | }, | 14 | }, |
| 13 | } | 15 | } |
| 14 | 16 | ||
| 15 | func GetWriteBuffer() *WriteBuffer { | 17 | func GetWriteBuffer() *WriteBuffer { |
| 16 | - wb := pool.Get().(*WriteBuffer) | ||
| 17 | - wb.Reset() | 18 | + wb := wbPool.Get().(*WriteBuffer) |
| 18 | return wb | 19 | return wb |
| 19 | } | 20 | } |
| 20 | 21 | ||
| 21 | func PutWriteBuffer(wb *WriteBuffer) { | 22 | func PutWriteBuffer(wb *WriteBuffer) { |
| 22 | - pool.Put(wb) | 23 | + wb.Reset() |
| 24 | + wbPool.Put(wb) | ||
| 23 | } | 25 | } |
| 24 | 26 | ||
| 25 | type WriteBuffer struct { | 27 | type WriteBuffer struct { |
| @@ -39,10 +41,6 @@ func (buf *WriteBuffer) Reset() { | @@ -39,10 +41,6 @@ func (buf *WriteBuffer) Reset() { | ||
| 39 | buf.Bytes = buf.Bytes[:0] | 41 | buf.Bytes = buf.Bytes[:0] |
| 40 | } | 42 | } |
| 41 | 43 | ||
| 42 | -func (buf *WriteBuffer) ResetBuffer(b []byte) { | ||
| 43 | - buf.Bytes = b[:0] | ||
| 44 | -} | ||
| 45 | - | ||
| 46 | func (buf *WriteBuffer) StartMessage(c byte) { | 44 | func (buf *WriteBuffer) StartMessage(c byte) { |
| 47 | if c == 0 { | 45 | if c == 0 { |
| 48 | buf.msgStart = len(buf.Bytes) | 46 | buf.msgStart = len(buf.Bytes) |
| @@ -5,12 +5,14 @@ import ( | @@ -5,12 +5,14 @@ import ( | ||
| 5 | "reflect" | 5 | "reflect" |
| 6 | "time" | 6 | "time" |
| 7 | 7 | ||
| 8 | - "go.opentelemetry.io/otel/api/global" | ||
| 9 | - "go.opentelemetry.io/otel/api/trace" | 8 | + "go.opentelemetry.io/otel" |
| 9 | + "go.opentelemetry.io/otel/trace" | ||
| 10 | ) | 10 | ) |
| 11 | 11 | ||
| 12 | +var tracer = otel.Tracer("github.com/go-pg/pg") | ||
| 13 | + | ||
| 12 | func Sleep(ctx context.Context, dur time.Duration) error { | 14 | func Sleep(ctx context.Context, dur time.Duration) error { |
| 13 | - return WithSpan(ctx, "sleep", func(ctx context.Context, span trace.Span) error { | 15 | + return WithSpan(ctx, "time.Sleep", func(ctx context.Context, span trace.Span) error { |
| 14 | t := time.NewTimer(dur) | 16 | t := time.NewTimer(dur) |
| 15 | defer t.Stop() | 17 | defer t.Stop() |
| 16 | 18 | ||
| @@ -80,11 +82,11 @@ func WithSpan( | @@ -80,11 +82,11 @@ func WithSpan( | ||
| 80 | name string, | 82 | name string, |
| 81 | fn func(context.Context, trace.Span) error, | 83 | fn func(context.Context, trace.Span) error, |
| 82 | ) error { | 84 | ) error { |
| 83 | - if !trace.SpanFromContext(ctx).IsRecording() { | ||
| 84 | - return fn(ctx, trace.NoopSpan{}) | 85 | + if span := trace.SpanFromContext(ctx); !span.IsRecording() { |
| 86 | + return fn(ctx, span) | ||
| 85 | } | 87 | } |
| 86 | 88 | ||
| 87 | - ctx, span := global.Tracer("go-pg").Start(ctx, name) | 89 | + ctx, span := tracer.Start(ctx, name) |
| 88 | defer span.End() | 90 | defer span.End() |
| 89 | 91 | ||
| 90 | return fn(ctx, span) | 92 | return fn(ctx, span) |
| @@ -15,8 +15,10 @@ import ( | @@ -15,8 +15,10 @@ import ( | ||
| 15 | 15 | ||
| 16 | const gopgChannel = "gopg:ping" | 16 | const gopgChannel = "gopg:ping" |
| 17 | 17 | ||
| 18 | -var errListenerClosed = errors.New("pg: listener is closed") | ||
| 19 | -var errPingTimeout = errors.New("pg: ping timeout") | 18 | +var ( |
| 19 | + errListenerClosed = errors.New("pg: listener is closed") | ||
| 20 | + errPingTimeout = errors.New("pg: ping timeout") | ||
| 21 | +) | ||
| 20 | 22 | ||
| 21 | // Notification which is received with LISTEN command. | 23 | // Notification which is received with LISTEN command. |
| 22 | type Notification struct { | 24 | type Notification struct { |
| @@ -38,11 +40,14 @@ type Listener struct { | @@ -38,11 +40,14 @@ type Listener struct { | ||
| 38 | closed bool | 40 | closed bool |
| 39 | 41 | ||
| 40 | chOnce sync.Once | 42 | chOnce sync.Once |
| 41 | - ch chan *Notification | 43 | + ch chan Notification |
| 42 | pingCh chan struct{} | 44 | pingCh chan struct{} |
| 43 | } | 45 | } |
| 44 | 46 | ||
| 45 | func (ln *Listener) String() string { | 47 | func (ln *Listener) String() string { |
| 48 | + ln.mu.Lock() | ||
| 49 | + defer ln.mu.Unlock() | ||
| 50 | + | ||
| 46 | return fmt.Sprintf("Listener(%s)", strings.Join(ln.channels, ", ")) | 51 | return fmt.Sprintf("Listener(%s)", strings.Join(ln.channels, ", ")) |
| 47 | } | 52 | } |
| 48 | 53 | ||
| @@ -50,9 +55,9 @@ func (ln *Listener) init() { | @@ -50,9 +55,9 @@ func (ln *Listener) init() { | ||
| 50 | ln.exit = make(chan struct{}) | 55 | ln.exit = make(chan struct{}) |
| 51 | } | 56 | } |
| 52 | 57 | ||
| 53 | -func (ln *Listener) connWithLock() (*pool.Conn, error) { | 58 | +func (ln *Listener) connWithLock(ctx context.Context) (*pool.Conn, error) { |
| 54 | ln.mu.Lock() | 59 | ln.mu.Lock() |
| 55 | - cn, err := ln.conn() | 60 | + cn, err := ln.conn(ctx) |
| 56 | ln.mu.Unlock() | 61 | ln.mu.Unlock() |
| 57 | 62 | ||
| 58 | switch err { | 63 | switch err { |
| @@ -64,12 +69,12 @@ func (ln *Listener) connWithLock() (*pool.Conn, error) { | @@ -64,12 +69,12 @@ func (ln *Listener) connWithLock() (*pool.Conn, error) { | ||
| 64 | _ = ln.Close() | 69 | _ = ln.Close() |
| 65 | return nil, errListenerClosed | 70 | return nil, errListenerClosed |
| 66 | default: | 71 | default: |
| 67 | - internal.Logger.Printf("pg: Listen failed: %s", err) | 72 | + internal.Logger.Printf(ctx, "pg: Listen failed: %s", err) |
| 68 | return nil, err | 73 | return nil, err |
| 69 | } | 74 | } |
| 70 | } | 75 | } |
| 71 | 76 | ||
| 72 | -func (ln *Listener) conn() (*pool.Conn, error) { | 77 | +func (ln *Listener) conn(ctx context.Context) (*pool.Conn, error) { |
| 73 | if ln.closed { | 78 | if ln.closed { |
| 74 | return nil, errListenerClosed | 79 | return nil, errListenerClosed |
| 75 | } | 80 | } |
| @@ -78,21 +83,20 @@ func (ln *Listener) conn() (*pool.Conn, error) { | @@ -78,21 +83,20 @@ func (ln *Listener) conn() (*pool.Conn, error) { | ||
| 78 | return ln.cn, nil | 83 | return ln.cn, nil |
| 79 | } | 84 | } |
| 80 | 85 | ||
| 81 | - c := context.TODO() | ||
| 82 | - | ||
| 83 | - cn, err := ln.db.pool.NewConn(c) | 86 | + cn, err := ln.db.pool.NewConn(ctx) |
| 84 | if err != nil { | 87 | if err != nil { |
| 85 | return nil, err | 88 | return nil, err |
| 86 | } | 89 | } |
| 87 | 90 | ||
| 88 | - err = ln.db.initConn(c, cn) | ||
| 89 | - if err != nil { | 91 | + if err := ln.db.initConn(ctx, cn); err != nil { |
| 90 | _ = ln.db.pool.CloseConn(cn) | 92 | _ = ln.db.pool.CloseConn(cn) |
| 91 | return nil, err | 93 | return nil, err |
| 92 | } | 94 | } |
| 93 | 95 | ||
| 96 | + cn.LockReader() | ||
| 97 | + | ||
| 94 | if len(ln.channels) > 0 { | 98 | if len(ln.channels) > 0 { |
| 95 | - err := ln.listen(c, cn, ln.channels...) | 99 | + err := ln.listen(ctx, cn, ln.channels...) |
| 96 | if err != nil { | 100 | if err != nil { |
| 97 | _ = ln.db.pool.CloseConn(cn) | 101 | _ = ln.db.pool.CloseConn(cn) |
| 98 | return nil, err | 102 | return nil, err |
| @@ -103,19 +107,19 @@ func (ln *Listener) conn() (*pool.Conn, error) { | @@ -103,19 +107,19 @@ func (ln *Listener) conn() (*pool.Conn, error) { | ||
| 103 | return cn, nil | 107 | return cn, nil |
| 104 | } | 108 | } |
| 105 | 109 | ||
| 106 | -func (ln *Listener) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { | 110 | +func (ln *Listener) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) { |
| 107 | ln.mu.Lock() | 111 | ln.mu.Lock() |
| 108 | if ln.cn == cn { | 112 | if ln.cn == cn { |
| 109 | if isBadConn(err, allowTimeout) { | 113 | if isBadConn(err, allowTimeout) { |
| 110 | - ln.reconnect(err) | 114 | + ln.reconnect(ctx, err) |
| 111 | } | 115 | } |
| 112 | } | 116 | } |
| 113 | ln.mu.Unlock() | 117 | ln.mu.Unlock() |
| 114 | } | 118 | } |
| 115 | 119 | ||
| 116 | -func (ln *Listener) reconnect(reason error) { | 120 | +func (ln *Listener) reconnect(ctx context.Context, reason error) { |
| 117 | _ = ln.closeTheCn(reason) | 121 | _ = ln.closeTheCn(reason) |
| 118 | - _, _ = ln.conn() | 122 | + _, _ = ln.conn(ctx) |
| 119 | } | 123 | } |
| 120 | 124 | ||
| 121 | func (ln *Listener) closeTheCn(reason error) error { | 125 | func (ln *Listener) closeTheCn(reason error) error { |
| @@ -123,7 +127,7 @@ func (ln *Listener) closeTheCn(reason error) error { | @@ -123,7 +127,7 @@ func (ln *Listener) closeTheCn(reason error) error { | ||
| 123 | return nil | 127 | return nil |
| 124 | } | 128 | } |
| 125 | if !ln.closed { | 129 | if !ln.closed { |
| 126 | - internal.Logger.Printf("pg: discarding bad listener connection: %s", reason) | 130 | + internal.Logger.Printf(ln.db.ctx, "pg: discarding bad listener connection: %s", reason) |
| 127 | } | 131 | } |
| 128 | 132 | ||
| 129 | err := ln.db.pool.CloseConn(ln.cn) | 133 | err := ln.db.pool.CloseConn(ln.cn) |
| @@ -146,29 +150,60 @@ func (ln *Listener) Close() error { | @@ -146,29 +150,60 @@ func (ln *Listener) Close() error { | ||
| 146 | } | 150 | } |
| 147 | 151 | ||
| 148 | // Listen starts listening for notifications on channels. | 152 | // Listen starts listening for notifications on channels. |
| 149 | -func (ln *Listener) Listen(channels ...string) error { | 153 | +func (ln *Listener) Listen(ctx context.Context, channels ...string) error { |
| 150 | // Always append channels so DB.Listen works correctly. | 154 | // Always append channels so DB.Listen works correctly. |
| 155 | + ln.mu.Lock() | ||
| 151 | ln.channels = appendIfNotExists(ln.channels, channels...) | 156 | ln.channels = appendIfNotExists(ln.channels, channels...) |
| 157 | + ln.mu.Unlock() | ||
| 152 | 158 | ||
| 153 | - cn, err := ln.connWithLock() | 159 | + cn, err := ln.connWithLock(ctx) |
| 154 | if err != nil { | 160 | if err != nil { |
| 155 | return err | 161 | return err |
| 156 | } | 162 | } |
| 157 | 163 | ||
| 158 | - err = ln.listen(context.TODO(), cn, channels...) | 164 | + if err := ln.listen(ctx, cn, channels...); err != nil { |
| 165 | + ln.releaseConn(ctx, cn, err, false) | ||
| 166 | + return err | ||
| 167 | + } | ||
| 168 | + | ||
| 169 | + return nil | ||
| 170 | +} | ||
| 171 | + | ||
| 172 | +func (ln *Listener) listen(ctx context.Context, cn *pool.Conn, channels ...string) error { | ||
| 173 | + err := cn.WithWriter(ctx, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { | ||
| 174 | + for _, channel := range channels { | ||
| 175 | + if err := writeQueryMsg(wb, ln.db.fmter, "LISTEN ?", pgChan(channel)); err != nil { | ||
| 176 | + return err | ||
| 177 | + } | ||
| 178 | + } | ||
| 179 | + return nil | ||
| 180 | + }) | ||
| 181 | + return err | ||
| 182 | +} | ||
| 183 | + | ||
| 184 | +// Unlisten stops listening for notifications on channels. | ||
| 185 | +func (ln *Listener) Unlisten(ctx context.Context, channels ...string) error { | ||
| 186 | + ln.mu.Lock() | ||
| 187 | + ln.channels = removeIfExists(ln.channels, channels...) | ||
| 188 | + ln.mu.Unlock() | ||
| 189 | + | ||
| 190 | + cn, err := ln.conn(ctx) | ||
| 159 | if err != nil { | 191 | if err != nil { |
| 160 | - ln.releaseConn(cn, err, false) | 192 | + return err |
| 193 | + } | ||
| 194 | + | ||
| 195 | + if err := ln.unlisten(ctx, cn, channels...); err != nil { | ||
| 196 | + ln.releaseConn(ctx, cn, err, false) | ||
| 161 | return err | 197 | return err |
| 162 | } | 198 | } |
| 163 | 199 | ||
| 164 | return nil | 200 | return nil |
| 165 | } | 201 | } |
| 166 | 202 | ||
| 167 | -func (ln *Listener) listen(c context.Context, cn *pool.Conn, channels ...string) error { | ||
| 168 | - err := cn.WithWriter(c, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { | 203 | +func (ln *Listener) unlisten(ctx context.Context, cn *pool.Conn, channels ...string) error { |
| 204 | + err := cn.WithWriter(ctx, ln.db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { | ||
| 169 | for _, channel := range channels { | 205 | for _, channel := range channels { |
| 170 | - err := writeQueryMsg(wb, ln.db.fmter, "LISTEN ?", pgChan(channel)) | ||
| 171 | - if err != nil { | 206 | + if err := writeQueryMsg(wb, ln.db.fmter, "UNLISTEN ?", pgChan(channel)); err != nil { |
| 172 | return err | 207 | return err |
| 173 | } | 208 | } |
| 174 | } | 209 | } |
| @@ -179,24 +214,26 @@ func (ln *Listener) listen(c context.Context, cn *pool.Conn, channels ...string) | @@ -179,24 +214,26 @@ func (ln *Listener) listen(c context.Context, cn *pool.Conn, channels ...string) | ||
| 179 | 214 | ||
| 180 | // Receive indefinitely waits for a notification. This is low-level API | 215 | // Receive indefinitely waits for a notification. This is low-level API |
| 181 | // and in most cases Channel should be used instead. | 216 | // and in most cases Channel should be used instead. |
| 182 | -func (ln *Listener) Receive() (channel string, payload string, err error) { | ||
| 183 | - return ln.ReceiveTimeout(0) | 217 | +func (ln *Listener) Receive(ctx context.Context) (channel string, payload string, err error) { |
| 218 | + return ln.ReceiveTimeout(ctx, 0) | ||
| 184 | } | 219 | } |
| 185 | 220 | ||
| 186 | // ReceiveTimeout waits for a notification until timeout is reached. | 221 | // ReceiveTimeout waits for a notification until timeout is reached. |
| 187 | // This is low-level API and in most cases Channel should be used instead. | 222 | // This is low-level API and in most cases Channel should be used instead. |
| 188 | -func (ln *Listener) ReceiveTimeout(timeout time.Duration) (channel, payload string, err error) { | ||
| 189 | - cn, err := ln.connWithLock() | 223 | +func (ln *Listener) ReceiveTimeout( |
| 224 | + ctx context.Context, timeout time.Duration, | ||
| 225 | +) (channel, payload string, err error) { | ||
| 226 | + cn, err := ln.connWithLock(ctx) | ||
| 190 | if err != nil { | 227 | if err != nil { |
| 191 | return "", "", err | 228 | return "", "", err |
| 192 | } | 229 | } |
| 193 | 230 | ||
| 194 | - err = cn.WithReader(context.TODO(), timeout, func(rd *pool.BufReader) error { | 231 | + err = cn.WithReader(ctx, timeout, func(rd *pool.ReaderContext) error { |
| 195 | channel, payload, err = readNotification(rd) | 232 | channel, payload, err = readNotification(rd) |
| 196 | return err | 233 | return err |
| 197 | }) | 234 | }) |
| 198 | if err != nil { | 235 | if err != nil { |
| 199 | - ln.releaseConn(cn, err, timeout > 0) | 236 | + ln.releaseConn(ctx, cn, err, timeout > 0) |
| 200 | return "", "", err | 237 | return "", "", err |
| 201 | } | 238 | } |
| 202 | 239 | ||
| @@ -208,17 +245,17 @@ func (ln *Listener) ReceiveTimeout(timeout time.Duration) (channel, payload stri | @@ -208,17 +245,17 @@ func (ln *Listener) ReceiveTimeout(timeout time.Duration) (channel, payload stri | ||
| 208 | // | 245 | // |
| 209 | // The channel is closed with Listener. Receive* APIs can not be used | 246 | // The channel is closed with Listener. Receive* APIs can not be used |
| 210 | // after channel is created. | 247 | // after channel is created. |
| 211 | -func (ln *Listener) Channel() <-chan *Notification { | 248 | +func (ln *Listener) Channel() <-chan Notification { |
| 212 | return ln.channel(100) | 249 | return ln.channel(100) |
| 213 | } | 250 | } |
| 214 | 251 | ||
| 215 | // ChannelSize is like Channel, but creates a Go channel | 252 | // ChannelSize is like Channel, but creates a Go channel |
| 216 | // with specified buffer size. | 253 | // with specified buffer size. |
| 217 | -func (ln *Listener) ChannelSize(size int) <-chan *Notification { | 254 | +func (ln *Listener) ChannelSize(size int) <-chan Notification { |
| 218 | return ln.channel(size) | 255 | return ln.channel(size) |
| 219 | } | 256 | } |
| 220 | 257 | ||
| 221 | -func (ln *Listener) channel(size int) <-chan *Notification { | 258 | +func (ln *Listener) channel(size int) <-chan Notification { |
| 222 | ln.chOnce.Do(func() { | 259 | ln.chOnce.Do(func() { |
| 223 | ln.initChannel(size) | 260 | ln.initChannel(size) |
| 224 | }) | 261 | }) |
| @@ -230,29 +267,33 @@ func (ln *Listener) channel(size int) <-chan *Notification { | @@ -230,29 +267,33 @@ func (ln *Listener) channel(size int) <-chan *Notification { | ||
| 230 | } | 267 | } |
| 231 | 268 | ||
| 232 | func (ln *Listener) initChannel(size int) { | 269 | func (ln *Listener) initChannel(size int) { |
| 233 | - const timeout = 30 * time.Second | 270 | + const pingTimeout = time.Second |
| 271 | + const chanSendTimeout = time.Minute | ||
| 234 | 272 | ||
| 235 | - _ = ln.Listen(gopgChannel) | 273 | + ctx := ln.db.ctx |
| 274 | + _ = ln.Listen(ctx, gopgChannel) | ||
| 236 | 275 | ||
| 237 | - ln.ch = make(chan *Notification, size) | 276 | + ln.ch = make(chan Notification, size) |
| 238 | ln.pingCh = make(chan struct{}, 1) | 277 | ln.pingCh = make(chan struct{}, 1) |
| 239 | 278 | ||
| 240 | go func() { | 279 | go func() { |
| 241 | - timer := time.NewTimer(timeout) | 280 | + timer := time.NewTimer(time.Minute) |
| 242 | timer.Stop() | 281 | timer.Stop() |
| 243 | 282 | ||
| 244 | var errCount int | 283 | var errCount int |
| 245 | for { | 284 | for { |
| 246 | - channel, payload, err := ln.Receive() | 285 | + channel, payload, err := ln.Receive(ctx) |
| 247 | if err != nil { | 286 | if err != nil { |
| 248 | if err == errListenerClosed { | 287 | if err == errListenerClosed { |
| 249 | close(ln.ch) | 288 | close(ln.ch) |
| 250 | return | 289 | return |
| 251 | } | 290 | } |
| 291 | + | ||
| 252 | if errCount > 0 { | 292 | if errCount > 0 { |
| 253 | - time.Sleep(ln.db.retryBackoff(errCount)) | 293 | + time.Sleep(500 * time.Millisecond) |
| 254 | } | 294 | } |
| 255 | errCount++ | 295 | errCount++ |
| 296 | + | ||
| 256 | continue | 297 | continue |
| 257 | } | 298 | } |
| 258 | 299 | ||
| @@ -268,28 +309,31 @@ func (ln *Listener) initChannel(size int) { | @@ -268,28 +309,31 @@ func (ln *Listener) initChannel(size int) { | ||
| 268 | case gopgChannel: | 309 | case gopgChannel: |
| 269 | // ignore | 310 | // ignore |
| 270 | default: | 311 | default: |
| 271 | - timer.Reset(timeout) | 312 | + timer.Reset(chanSendTimeout) |
| 272 | select { | 313 | select { |
| 273 | - case ln.ch <- &Notification{channel, payload}: | 314 | + case ln.ch <- Notification{channel, payload}: |
| 274 | if !timer.Stop() { | 315 | if !timer.Stop() { |
| 275 | <-timer.C | 316 | <-timer.C |
| 276 | } | 317 | } |
| 277 | case <-timer.C: | 318 | case <-timer.C: |
| 278 | internal.Logger.Printf( | 319 | internal.Logger.Printf( |
| 320 | + ctx, | ||
| 279 | "pg: %s channel is full for %s (notification is dropped)", | 321 | "pg: %s channel is full for %s (notification is dropped)", |
| 280 | - ln, timeout) | 322 | + ln, |
| 323 | + chanSendTimeout, | ||
| 324 | + ) | ||
| 281 | } | 325 | } |
| 282 | } | 326 | } |
| 283 | } | 327 | } |
| 284 | }() | 328 | }() |
| 285 | 329 | ||
| 286 | go func() { | 330 | go func() { |
| 287 | - timer := time.NewTimer(timeout) | 331 | + timer := time.NewTimer(time.Minute) |
| 288 | timer.Stop() | 332 | timer.Stop() |
| 289 | 333 | ||
| 290 | healthy := true | 334 | healthy := true |
| 291 | for { | 335 | for { |
| 292 | - timer.Reset(timeout) | 336 | + timer.Reset(pingTimeout) |
| 293 | select { | 337 | select { |
| 294 | case <-ln.pingCh: | 338 | case <-ln.pingCh: |
| 295 | healthy = true | 339 | healthy = true |
| @@ -305,7 +349,7 @@ func (ln *Listener) initChannel(size int) { | @@ -305,7 +349,7 @@ func (ln *Listener) initChannel(size int) { | ||
| 305 | pingErr = errPingTimeout | 349 | pingErr = errPingTimeout |
| 306 | } | 350 | } |
| 307 | ln.mu.Lock() | 351 | ln.mu.Lock() |
| 308 | - ln.reconnect(pingErr) | 352 | + ln.reconnect(ctx, pingErr) |
| 309 | ln.mu.Unlock() | 353 | ln.mu.Unlock() |
| 310 | } | 354 | } |
| 311 | case <-ln.exit: | 355 | case <-ln.exit: |
| @@ -333,6 +377,20 @@ loop: | @@ -333,6 +377,20 @@ loop: | ||
| 333 | return ss | 377 | return ss |
| 334 | } | 378 | } |
| 335 | 379 | ||
| 380 | +func removeIfExists(ss []string, es ...string) []string { | ||
| 381 | + for _, e := range es { | ||
| 382 | + for i, s := range ss { | ||
| 383 | + if s == e { | ||
| 384 | + last := len(ss) - 1 | ||
| 385 | + ss[i] = ss[last] | ||
| 386 | + ss = ss[:last] | ||
| 387 | + break | ||
| 388 | + } | ||
| 389 | + } | ||
| 390 | + } | ||
| 391 | + return ss | ||
| 392 | +} | ||
| 393 | + | ||
| 336 | type pgChan string | 394 | type pgChan string |
| 337 | 395 | ||
| 338 | var _ types.ValueAppender = pgChan("") | 396 | var _ types.ValueAppender = pgChan("") |
| @@ -19,6 +19,7 @@ import ( | @@ -19,6 +19,7 @@ import ( | ||
| 19 | "github.com/go-pg/pg/v10/types" | 19 | "github.com/go-pg/pg/v10/types" |
| 20 | ) | 20 | ) |
| 21 | 21 | ||
| 22 | +// https://www.postgresql.org/docs/current/protocol-message-formats.html | ||
| 22 | const ( | 23 | const ( |
| 23 | commandCompleteMsg = 'C' | 24 | commandCompleteMsg = 'C' |
| 24 | errorResponseMsg = 'E' | 25 | errorResponseMsg = 'E' |
| @@ -84,7 +85,7 @@ func (db *baseDB) startup( | @@ -84,7 +85,7 @@ func (db *baseDB) startup( | ||
| 84 | return err | 85 | return err |
| 85 | } | 86 | } |
| 86 | 87 | ||
| 87 | - return cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { | 88 | + return cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { |
| 88 | for { | 89 | for { |
| 89 | typ, msgLen, err := readMessageType(rd) | 90 | typ, msgLen, err := readMessageType(rd) |
| 90 | if err != nil { | 91 | if err != nil { |
| @@ -137,7 +138,7 @@ func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Confi | @@ -137,7 +138,7 @@ func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Confi | ||
| 137 | return err | 138 | return err |
| 138 | } | 139 | } |
| 139 | 140 | ||
| 140 | - err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.BufReader) error { | 141 | + err = cn.WithReader(c, db.opt.ReadTimeout, func(rd *pool.ReaderContext) error { |
| 141 | c, err := rd.ReadByte() | 142 | c, err := rd.ReadByte() |
| 142 | if err != nil { | 143 | if err != nil { |
| 143 | return err | 144 | return err |
| @@ -156,7 +157,7 @@ func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Confi | @@ -156,7 +157,7 @@ func (db *baseDB) enableSSL(c context.Context, cn *pool.Conn, tlsConf *tls.Confi | ||
| 156 | } | 157 | } |
| 157 | 158 | ||
| 158 | func (db *baseDB) auth( | 159 | func (db *baseDB) auth( |
| 159 | - c context.Context, cn *pool.Conn, rd *pool.BufReader, user, password string, | 160 | + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, |
| 160 | ) error { | 161 | ) error { |
| 161 | num, err := readInt32(rd) | 162 | num, err := readInt32(rd) |
| 162 | if err != nil { | 163 | if err != nil { |
| @@ -178,7 +179,7 @@ func (db *baseDB) auth( | @@ -178,7 +179,7 @@ func (db *baseDB) auth( | ||
| 178 | } | 179 | } |
| 179 | 180 | ||
| 180 | func (db *baseDB) authCleartext( | 181 | func (db *baseDB) authCleartext( |
| 181 | - c context.Context, cn *pool.Conn, rd *pool.BufReader, password string, | 182 | + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, password string, |
| 182 | ) error { | 183 | ) error { |
| 183 | err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { | 184 | err := cn.WithWriter(c, db.opt.WriteTimeout, func(wb *pool.WriteBuffer) error { |
| 184 | writePasswordMsg(wb, password) | 185 | writePasswordMsg(wb, password) |
| @@ -191,7 +192,7 @@ func (db *baseDB) authCleartext( | @@ -191,7 +192,7 @@ func (db *baseDB) authCleartext( | ||
| 191 | } | 192 | } |
| 192 | 193 | ||
| 193 | func (db *baseDB) authMD5( | 194 | func (db *baseDB) authMD5( |
| 194 | - c context.Context, cn *pool.Conn, rd *pool.BufReader, user, password string, | 195 | + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, |
| 195 | ) error { | 196 | ) error { |
| 196 | b, err := rd.ReadN(4) | 197 | b, err := rd.ReadN(4) |
| 197 | if err != nil { | 198 | if err != nil { |
| @@ -210,7 +211,7 @@ func (db *baseDB) authMD5( | @@ -210,7 +211,7 @@ func (db *baseDB) authMD5( | ||
| 210 | return readAuthOK(rd) | 211 | return readAuthOK(rd) |
| 211 | } | 212 | } |
| 212 | 213 | ||
| 213 | -func readAuthOK(rd *pool.BufReader) error { | 214 | +func readAuthOK(rd *pool.ReaderContext) error { |
| 214 | c, _, err := readMessageType(rd) | 215 | c, _, err := readMessageType(rd) |
| 215 | if err != nil { | 216 | if err != nil { |
| 216 | return err | 217 | return err |
| @@ -238,7 +239,7 @@ func readAuthOK(rd *pool.BufReader) error { | @@ -238,7 +239,7 @@ func readAuthOK(rd *pool.BufReader) error { | ||
| 238 | } | 239 | } |
| 239 | 240 | ||
| 240 | func (db *baseDB) authSASL( | 241 | func (db *baseDB) authSASL( |
| 241 | - c context.Context, cn *pool.Conn, rd *pool.BufReader, user, password string, | 242 | + c context.Context, cn *pool.Conn, rd *pool.ReaderContext, user, password string, |
| 242 | ) error { | 243 | ) error { |
| 243 | s, err := readString(rd) | 244 | s, err := readString(rd) |
| 244 | if err != nil { | 245 | if err != nil { |
| @@ -332,7 +333,7 @@ func (db *baseDB) authSASL( | @@ -332,7 +333,7 @@ func (db *baseDB) authSASL( | ||
| 332 | } | 333 | } |
| 333 | } | 334 | } |
| 334 | 335 | ||
| 335 | -func readAuthSASLFinal(rd *pool.BufReader, client *sasl.Negotiator) error { | 336 | +func readAuthSASLFinal(rd *pool.ReaderContext, client *sasl.Negotiator) error { |
| 336 | c, n, err := readMessageType(rd) | 337 | c, n, err := readMessageType(rd) |
| 337 | if err != nil { | 338 | if err != nil { |
| 338 | return err | 339 | return err |
| @@ -485,8 +486,8 @@ func writeParseDescribeSyncMsg(buf *pool.WriteBuffer, name, q string) { | @@ -485,8 +486,8 @@ func writeParseDescribeSyncMsg(buf *pool.WriteBuffer, name, q string) { | ||
| 485 | writeSyncMsg(buf) | 486 | writeSyncMsg(buf) |
| 486 | } | 487 | } |
| 487 | 488 | ||
| 488 | -func readParseDescribeSync(rd *pool.BufReader) ([][]byte, error) { | ||
| 489 | - var columns [][]byte | 489 | +func readParseDescribeSync(rd *pool.ReaderContext) ([]types.ColumnInfo, error) { |
| 490 | + var columns []types.ColumnInfo | ||
| 490 | var firstErr error | 491 | var firstErr error |
| 491 | for { | 492 | for { |
| 492 | c, msgLen, err := readMessageType(rd) | 493 | c, msgLen, err := readMessageType(rd) |
| @@ -500,7 +501,7 @@ func readParseDescribeSync(rd *pool.BufReader) ([][]byte, error) { | @@ -500,7 +501,7 @@ func readParseDescribeSync(rd *pool.BufReader) ([][]byte, error) { | ||
| 500 | return nil, err | 501 | return nil, err |
| 501 | } | 502 | } |
| 502 | case rowDescriptionMsg: // Response to the DESCRIBE message. | 503 | case rowDescriptionMsg: // Response to the DESCRIBE message. |
| 503 | - columns, err = readRowDescription(rd, nil) | 504 | + columns, err = readRowDescription(rd, pool.NewColumnAlloc()) |
| 504 | if err != nil { | 505 | if err != nil { |
| 505 | return nil, err | 506 | return nil, err |
| 506 | } | 507 | } |
| @@ -582,7 +583,7 @@ func writeCloseMsg(buf *pool.WriteBuffer, name string) { | @@ -582,7 +583,7 @@ func writeCloseMsg(buf *pool.WriteBuffer, name string) { | ||
| 582 | buf.FinishMessage() | 583 | buf.FinishMessage() |
| 583 | } | 584 | } |
| 584 | 585 | ||
| 585 | -func readCloseCompleteMsg(rd *pool.BufReader) error { | 586 | +func readCloseCompleteMsg(rd *pool.ReaderContext) error { |
| 586 | for { | 587 | for { |
| 587 | c, msgLen, err := readMessageType(rd) | 588 | c, msgLen, err := readMessageType(rd) |
| 588 | if err != nil { | 589 | if err != nil { |
| @@ -612,7 +613,7 @@ func readCloseCompleteMsg(rd *pool.BufReader) error { | @@ -612,7 +613,7 @@ func readCloseCompleteMsg(rd *pool.BufReader) error { | ||
| 612 | } | 613 | } |
| 613 | } | 614 | } |
| 614 | 615 | ||
| 615 | -func readSimpleQuery(rd *pool.BufReader) (*result, error) { | 616 | +func readSimpleQuery(rd *pool.ReaderContext) (*result, error) { |
| 616 | var res result | 617 | var res result |
| 617 | var firstErr error | 618 | var firstErr error |
| 618 | for { | 619 | for { |
| @@ -675,7 +676,7 @@ func readSimpleQuery(rd *pool.BufReader) (*result, error) { | @@ -675,7 +676,7 @@ func readSimpleQuery(rd *pool.BufReader) (*result, error) { | ||
| 675 | } | 676 | } |
| 676 | } | 677 | } |
| 677 | 678 | ||
| 678 | -func readExtQuery(rd *pool.BufReader) (*result, error) { | 679 | +func readExtQuery(rd *pool.ReaderContext) (*result, error) { |
| 679 | var res result | 680 | var res result |
| 680 | var firstErr error | 681 | var firstErr error |
| 681 | for { | 682 | for { |
| @@ -739,42 +740,47 @@ func readExtQuery(rd *pool.BufReader) (*result, error) { | @@ -739,42 +740,47 @@ func readExtQuery(rd *pool.BufReader) (*result, error) { | ||
| 739 | } | 740 | } |
| 740 | } | 741 | } |
| 741 | 742 | ||
| 742 | -func readRowDescription(rd *pool.BufReader, columns [][]byte) ([][]byte, error) { | ||
| 743 | - colNum, err := readInt16(rd) | 743 | +func readRowDescription( |
| 744 | + rd *pool.ReaderContext, columnAlloc *pool.ColumnAlloc, | ||
| 745 | +) ([]types.ColumnInfo, error) { | ||
| 746 | + numCol, err := readInt16(rd) | ||
| 744 | if err != nil { | 747 | if err != nil { |
| 745 | return nil, err | 748 | return nil, err |
| 746 | } | 749 | } |
| 747 | 750 | ||
| 748 | - columns = setByteSliceLen(columns, int(colNum)) | ||
| 749 | - for i := 0; i < int(colNum); i++ { | 751 | + for i := 0; i < int(numCol); i++ { |
| 750 | b, err := rd.ReadSlice(0) | 752 | b, err := rd.ReadSlice(0) |
| 751 | if err != nil { | 753 | if err != nil { |
| 752 | return nil, err | 754 | return nil, err |
| 753 | } | 755 | } |
| 754 | - columns[i] = append(columns[i][:0], b[:len(b)-1]...) | ||
| 755 | 756 | ||
| 756 | - _, err = rd.ReadN(18) | ||
| 757 | - if err != nil { | 757 | + col := columnAlloc.New(int16(i), b[:len(b)-1]) |
| 758 | + | ||
| 759 | + if _, err := rd.ReadN(6); err != nil { | ||
| 758 | return nil, err | 760 | return nil, err |
| 759 | } | 761 | } |
| 760 | - } | ||
| 761 | 762 | ||
| 762 | - return columns, nil | ||
| 763 | -} | 763 | + dataType, err := readInt32(rd) |
| 764 | + if err != nil { | ||
| 765 | + return nil, err | ||
| 766 | + } | ||
| 767 | + col.DataType = dataType | ||
| 764 | 768 | ||
| 765 | -func setByteSliceLen(b [][]byte, n int) [][]byte { | ||
| 766 | - if n <= cap(b) { | ||
| 767 | - return b[:n] | 769 | + if _, err := rd.ReadN(8); err != nil { |
| 770 | + return nil, err | ||
| 771 | + } | ||
| 768 | } | 772 | } |
| 769 | - b = b[:cap(b)] | ||
| 770 | - b = append(b, make([][]byte, n-cap(b))...) | ||
| 771 | - return b | 773 | + |
| 774 | + return columnAlloc.Columns(), nil | ||
| 772 | } | 775 | } |
| 773 | 776 | ||
| 774 | func readDataRow( | 777 | func readDataRow( |
| 775 | - ctx context.Context, rd *pool.BufReader, scanner orm.ColumnScanner, columns [][]byte, | 778 | + ctx context.Context, |
| 779 | + rd *pool.ReaderContext, | ||
| 780 | + columns []types.ColumnInfo, | ||
| 781 | + scanner orm.ColumnScanner, | ||
| 776 | ) error { | 782 | ) error { |
| 777 | - colNum, err := readInt16(rd) | 783 | + numCol, err := readInt16(rd) |
| 778 | if err != nil { | 784 | if err != nil { |
| 779 | return err | 785 | return err |
| 780 | } | 786 | } |
| @@ -787,35 +793,28 @@ func readDataRow( | @@ -787,35 +793,28 @@ func readDataRow( | ||
| 787 | 793 | ||
| 788 | var firstErr error | 794 | var firstErr error |
| 789 | 795 | ||
| 790 | - for colIdx := int16(0); colIdx < colNum; colIdx++ { | 796 | + for colIdx := int16(0); colIdx < numCol; colIdx++ { |
| 791 | n, err := readInt32(rd) | 797 | n, err := readInt32(rd) |
| 792 | if err != nil { | 798 | if err != nil { |
| 793 | return err | 799 | return err |
| 794 | } | 800 | } |
| 795 | 801 | ||
| 796 | - column := internal.BytesToString(columns[colIdx]) | ||
| 797 | var colRd types.Reader | 802 | var colRd types.Reader |
| 798 | - if n >= 0 { | ||
| 799 | - bytesRd := rd.BytesReader(int(n)) | ||
| 800 | - if bytesRd != nil { | ||
| 801 | - colRd = bytesRd | ||
| 802 | - } else { | ||
| 803 | - rd.SetAvailable(int(n)) | ||
| 804 | - colRd = rd | ||
| 805 | - } | 803 | + if int(n) <= rd.Buffered() { |
| 804 | + colRd = rd.BytesReader(int(n)) | ||
| 806 | } else { | 805 | } else { |
| 807 | - colRd = rd.BytesReader(0) | 806 | + rd.SetAvailable(int(n)) |
| 807 | + colRd = rd | ||
| 808 | } | 808 | } |
| 809 | 809 | ||
| 810 | - err = scanner.ScanColumn(int(colIdx), column, colRd, int(n)) | ||
| 811 | - if err != nil && firstErr == nil { | 810 | + column := columns[colIdx] |
| 811 | + if err := scanner.ScanColumn(column, colRd, int(n)); err != nil && firstErr == nil { | ||
| 812 | firstErr = internal.Errorf(err.Error()) | 812 | firstErr = internal.Errorf(err.Error()) |
| 813 | } | 813 | } |
| 814 | 814 | ||
| 815 | if rd == colRd { | 815 | if rd == colRd { |
| 816 | if rd.Available() > 0 { | 816 | if rd.Available() > 0 { |
| 817 | - _, err = rd.Discard(rd.Available()) | ||
| 818 | - if err != nil && firstErr == nil { | 817 | + if _, err := rd.Discard(rd.Available()); err != nil && firstErr == nil { |
| 819 | firstErr = err | 818 | firstErr = err |
| 820 | } | 819 | } |
| 821 | } | 820 | } |
| @@ -841,8 +840,9 @@ func newModel(mod interface{}) (orm.Model, error) { | @@ -841,8 +840,9 @@ func newModel(mod interface{}) (orm.Model, error) { | ||
| 841 | } | 840 | } |
| 842 | 841 | ||
| 843 | func readSimpleQueryData( | 842 | func readSimpleQueryData( |
| 844 | - ctx context.Context, rd *pool.BufReader, mod interface{}, | 843 | + ctx context.Context, rd *pool.ReaderContext, mod interface{}, |
| 845 | ) (*result, error) { | 844 | ) (*result, error) { |
| 845 | + var columns []types.ColumnInfo | ||
| 846 | var res result | 846 | var res result |
| 847 | var firstErr error | 847 | var firstErr error |
| 848 | for { | 848 | for { |
| @@ -853,7 +853,7 @@ func readSimpleQueryData( | @@ -853,7 +853,7 @@ func readSimpleQueryData( | ||
| 853 | 853 | ||
| 854 | switch c { | 854 | switch c { |
| 855 | case rowDescriptionMsg: | 855 | case rowDescriptionMsg: |
| 856 | - rd.Columns, err = readRowDescription(rd, rd.Columns[:0]) | 856 | + columns, err = readRowDescription(rd, rd.ColumnAlloc) |
| 857 | if err != nil { | 857 | if err != nil { |
| 858 | return nil, err | 858 | return nil, err |
| 859 | } | 859 | } |
| @@ -870,7 +870,7 @@ func readSimpleQueryData( | @@ -870,7 +870,7 @@ func readSimpleQueryData( | ||
| 870 | } | 870 | } |
| 871 | case dataRowMsg: | 871 | case dataRowMsg: |
| 872 | scanner := res.model.NextColumnScanner() | 872 | scanner := res.model.NextColumnScanner() |
| 873 | - if err := readDataRow(ctx, rd, scanner, rd.Columns); err != nil { | 873 | + if err := readDataRow(ctx, rd, columns, scanner); err != nil { |
| 874 | if firstErr == nil { | 874 | if firstErr == nil { |
| 875 | firstErr = err | 875 | firstErr = err |
| 876 | } | 876 | } |
| @@ -925,7 +925,7 @@ func readSimpleQueryData( | @@ -925,7 +925,7 @@ func readSimpleQueryData( | ||
| 925 | } | 925 | } |
| 926 | 926 | ||
| 927 | func readExtQueryData( | 927 | func readExtQueryData( |
| 928 | - ctx context.Context, rd *pool.BufReader, mod interface{}, columns [][]byte, | 928 | + ctx context.Context, rd *pool.ReaderContext, mod interface{}, columns []types.ColumnInfo, |
| 929 | ) (*result, error) { | 929 | ) (*result, error) { |
| 930 | var res result | 930 | var res result |
| 931 | var firstErr error | 931 | var firstErr error |
| @@ -954,7 +954,7 @@ func readExtQueryData( | @@ -954,7 +954,7 @@ func readExtQueryData( | ||
| 954 | } | 954 | } |
| 955 | 955 | ||
| 956 | scanner := res.model.NextColumnScanner() | 956 | scanner := res.model.NextColumnScanner() |
| 957 | - if err := readDataRow(ctx, rd, scanner, columns); err != nil { | 957 | + if err := readDataRow(ctx, rd, columns, scanner); err != nil { |
| 958 | if firstErr == nil { | 958 | if firstErr == nil { |
| 959 | firstErr = err | 959 | firstErr = err |
| 960 | } | 960 | } |
| @@ -1004,7 +1004,7 @@ func readExtQueryData( | @@ -1004,7 +1004,7 @@ func readExtQueryData( | ||
| 1004 | } | 1004 | } |
| 1005 | } | 1005 | } |
| 1006 | 1006 | ||
| 1007 | -func readCopyInResponse(rd *pool.BufReader) error { | 1007 | +func readCopyInResponse(rd *pool.ReaderContext) error { |
| 1008 | var firstErr error | 1008 | var firstErr error |
| 1009 | for { | 1009 | for { |
| 1010 | c, msgLen, err := readMessageType(rd) | 1010 | c, msgLen, err := readMessageType(rd) |
| @@ -1044,7 +1044,7 @@ func readCopyInResponse(rd *pool.BufReader) error { | @@ -1044,7 +1044,7 @@ func readCopyInResponse(rd *pool.BufReader) error { | ||
| 1044 | } | 1044 | } |
| 1045 | } | 1045 | } |
| 1046 | 1046 | ||
| 1047 | -func readCopyOutResponse(rd *pool.BufReader) error { | 1047 | +func readCopyOutResponse(rd *pool.ReaderContext) error { |
| 1048 | var firstErr error | 1048 | var firstErr error |
| 1049 | for { | 1049 | for { |
| 1050 | c, msgLen, err := readMessageType(rd) | 1050 | c, msgLen, err := readMessageType(rd) |
| @@ -1084,7 +1084,7 @@ func readCopyOutResponse(rd *pool.BufReader) error { | @@ -1084,7 +1084,7 @@ func readCopyOutResponse(rd *pool.BufReader) error { | ||
| 1084 | } | 1084 | } |
| 1085 | } | 1085 | } |
| 1086 | 1086 | ||
| 1087 | -func readCopyData(rd *pool.BufReader, w io.Writer) (*result, error) { | 1087 | +func readCopyData(rd *pool.ReaderContext, w io.Writer) (*result, error) { |
| 1088 | var res result | 1088 | var res result |
| 1089 | var firstErr error | 1089 | var firstErr error |
| 1090 | for { | 1090 | for { |
| @@ -1162,7 +1162,7 @@ func writeCopyDone(buf *pool.WriteBuffer) { | @@ -1162,7 +1162,7 @@ func writeCopyDone(buf *pool.WriteBuffer) { | ||
| 1162 | buf.FinishMessage() | 1162 | buf.FinishMessage() |
| 1163 | } | 1163 | } |
| 1164 | 1164 | ||
| 1165 | -func readReadyForQuery(rd *pool.BufReader) (*result, error) { | 1165 | +func readReadyForQuery(rd *pool.ReaderContext) (*result, error) { |
| 1166 | var res result | 1166 | var res result |
| 1167 | var firstErr error | 1167 | var firstErr error |
| 1168 | for { | 1168 | for { |
| @@ -1211,7 +1211,7 @@ func readReadyForQuery(rd *pool.BufReader) (*result, error) { | @@ -1211,7 +1211,7 @@ func readReadyForQuery(rd *pool.BufReader) (*result, error) { | ||
| 1211 | } | 1211 | } |
| 1212 | } | 1212 | } |
| 1213 | 1213 | ||
| 1214 | -func readNotification(rd *pool.BufReader) (channel, payload string, err error) { | 1214 | +func readNotification(rd *pool.ReaderContext) (channel, payload string, err error) { |
| 1215 | for { | 1215 | for { |
| 1216 | c, msgLen, err := readMessageType(rd) | 1216 | c, msgLen, err := readMessageType(rd) |
| 1217 | if err != nil { | 1217 | if err != nil { |
| @@ -1269,17 +1269,17 @@ func terminateConn(cn *pool.Conn) error { | @@ -1269,17 +1269,17 @@ func terminateConn(cn *pool.Conn) error { | ||
| 1269 | 1269 | ||
| 1270 | //------------------------------------------------------------------------------ | 1270 | //------------------------------------------------------------------------------ |
| 1271 | 1271 | ||
| 1272 | -func logNotice(rd *pool.BufReader, msgLen int) error { | 1272 | +func logNotice(rd *pool.ReaderContext, msgLen int) error { |
| 1273 | _, err := rd.ReadN(msgLen) | 1273 | _, err := rd.ReadN(msgLen) |
| 1274 | return err | 1274 | return err |
| 1275 | } | 1275 | } |
| 1276 | 1276 | ||
| 1277 | -func logParameterStatus(rd *pool.BufReader, msgLen int) error { | 1277 | +func logParameterStatus(rd *pool.ReaderContext, msgLen int) error { |
| 1278 | _, err := rd.ReadN(msgLen) | 1278 | _, err := rd.ReadN(msgLen) |
| 1279 | return err | 1279 | return err |
| 1280 | } | 1280 | } |
| 1281 | 1281 | ||
| 1282 | -func readInt16(rd *pool.BufReader) (int16, error) { | 1282 | +func readInt16(rd *pool.ReaderContext) (int16, error) { |
| 1283 | b, err := rd.ReadN(2) | 1283 | b, err := rd.ReadN(2) |
| 1284 | if err != nil { | 1284 | if err != nil { |
| 1285 | return 0, err | 1285 | return 0, err |
| @@ -1287,7 +1287,7 @@ func readInt16(rd *pool.BufReader) (int16, error) { | @@ -1287,7 +1287,7 @@ func readInt16(rd *pool.BufReader) (int16, error) { | ||
| 1287 | return int16(binary.BigEndian.Uint16(b)), nil | 1287 | return int16(binary.BigEndian.Uint16(b)), nil |
| 1288 | } | 1288 | } |
| 1289 | 1289 | ||
| 1290 | -func readInt32(rd *pool.BufReader) (int32, error) { | 1290 | +func readInt32(rd *pool.ReaderContext) (int32, error) { |
| 1291 | b, err := rd.ReadN(4) | 1291 | b, err := rd.ReadN(4) |
| 1292 | if err != nil { | 1292 | if err != nil { |
| 1293 | return 0, err | 1293 | return 0, err |
| @@ -1295,7 +1295,7 @@ func readInt32(rd *pool.BufReader) (int32, error) { | @@ -1295,7 +1295,7 @@ func readInt32(rd *pool.BufReader) (int32, error) { | ||
| 1295 | return int32(binary.BigEndian.Uint32(b)), nil | 1295 | return int32(binary.BigEndian.Uint32(b)), nil |
| 1296 | } | 1296 | } |
| 1297 | 1297 | ||
| 1298 | -func readString(rd *pool.BufReader) (string, error) { | 1298 | +func readString(rd *pool.ReaderContext) (string, error) { |
| 1299 | b, err := rd.ReadSlice(0) | 1299 | b, err := rd.ReadSlice(0) |
| 1300 | if err != nil { | 1300 | if err != nil { |
| 1301 | return "", err | 1301 | return "", err |
| @@ -1303,7 +1303,7 @@ func readString(rd *pool.BufReader) (string, error) { | @@ -1303,7 +1303,7 @@ func readString(rd *pool.BufReader) (string, error) { | ||
| 1303 | return string(b[:len(b)-1]), nil | 1303 | return string(b[:len(b)-1]), nil |
| 1304 | } | 1304 | } |
| 1305 | 1305 | ||
| 1306 | -func readError(rd *pool.BufReader) (error, error) { | 1306 | +func readError(rd *pool.ReaderContext) (error, error) { |
| 1307 | m := make(map[byte]string) | 1307 | m := make(map[byte]string) |
| 1308 | for { | 1308 | for { |
| 1309 | c, err := rd.ReadByte() | 1309 | c, err := rd.ReadByte() |
| @@ -1322,7 +1322,7 @@ func readError(rd *pool.BufReader) (error, error) { | @@ -1322,7 +1322,7 @@ func readError(rd *pool.BufReader) (error, error) { | ||
| 1322 | return internal.NewPGError(m), nil | 1322 | return internal.NewPGError(m), nil |
| 1323 | } | 1323 | } |
| 1324 | 1324 | ||
| 1325 | -func readMessageType(rd *pool.BufReader) (byte, int, error) { | 1325 | +func readMessageType(rd *pool.ReaderContext) (byte, int, error) { |
| 1326 | c, err := rd.ReadByte() | 1326 | c, err := rd.ReadByte() |
| 1327 | if err != nil { | 1327 | if err != nil { |
| 1328 | return 0, 0, err | 1328 | return 0, 0, err |
| @@ -13,7 +13,8 @@ import ( | @@ -13,7 +13,8 @@ import ( | ||
| 13 | "strings" | 13 | "strings" |
| 14 | "time" | 14 | "time" |
| 15 | 15 | ||
| 16 | - "go.opentelemetry.io/otel/api/trace" | 16 | + "go.opentelemetry.io/otel/label" |
| 17 | + "go.opentelemetry.io/otel/trace" | ||
| 17 | 18 | ||
| 18 | "github.com/go-pg/pg/v10/internal" | 19 | "github.com/go-pg/pg/v10/internal" |
| 19 | "github.com/go-pg/pg/v10/internal/pool" | 20 | "github.com/go-pg/pg/v10/internal/pool" |
| @@ -31,6 +32,10 @@ type Options struct { | @@ -31,6 +32,10 @@ type Options struct { | ||
| 31 | // Network and Addr options. | 32 | // Network and Addr options. |
| 32 | Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | 33 | Dialer func(ctx context.Context, network, addr string) (net.Conn, error) |
| 33 | 34 | ||
| 35 | + // Hook that is called after new connection is established | ||
| 36 | + // and user is authenticated. | ||
| 37 | + OnConnect func(ctx context.Context, cn *Conn) error | ||
| 38 | + | ||
| 34 | User string | 39 | User string |
| 35 | Password string | 40 | Password string |
| 36 | Database string | 41 | Database string |
| @@ -53,10 +58,6 @@ type Options struct { | @@ -53,10 +58,6 @@ type Options struct { | ||
| 53 | // with a timeout instead of blocking. | 58 | // with a timeout instead of blocking. |
| 54 | WriteTimeout time.Duration | 59 | WriteTimeout time.Duration |
| 55 | 60 | ||
| 56 | - // Hook that is called after new connection is established | ||
| 57 | - // and user is authenticated. | ||
| 58 | - OnConnect func(*Conn) error | ||
| 59 | - | ||
| 60 | // Maximum number of retries before giving up. | 61 | // Maximum number of retries before giving up. |
| 61 | // Default is to not retry failed queries. | 62 | // Default is to not retry failed queries. |
| 62 | MaxRetries int | 63 | MaxRetries int |
| @@ -110,6 +111,10 @@ func (opt *Options) init() { | @@ -110,6 +111,10 @@ func (opt *Options) init() { | ||
| 110 | opt.Addr = "/var/run/postgresql/.s.PGSQL.5432" | 111 | opt.Addr = "/var/run/postgresql/.s.PGSQL.5432" |
| 111 | } | 112 | } |
| 112 | } | 113 | } |
| 114 | + | ||
| 115 | + if opt.DialTimeout == 0 { | ||
| 116 | + opt.DialTimeout = 5 * time.Second | ||
| 117 | + } | ||
| 113 | if opt.Dialer == nil { | 118 | if opt.Dialer == nil { |
| 114 | opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { | 119 | opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { |
| 115 | netDialer := &net.Dialer{ | 120 | netDialer := &net.Dialer{ |
| @@ -140,10 +145,6 @@ func (opt *Options) init() { | @@ -140,10 +145,6 @@ func (opt *Options) init() { | ||
| 140 | } | 145 | } |
| 141 | } | 146 | } |
| 142 | 147 | ||
| 143 | - if opt.DialTimeout == 0 { | ||
| 144 | - opt.DialTimeout = 5 * time.Second | ||
| 145 | - } | ||
| 146 | - | ||
| 147 | if opt.IdleTimeout == 0 { | 148 | if opt.IdleTimeout == 0 { |
| 148 | opt.IdleTimeout = 5 * time.Minute | 149 | opt.IdleTimeout = 5 * time.Minute |
| 149 | } | 150 | } |
| @@ -262,9 +263,16 @@ func ParseURL(sURL string) (*Options, error) { | @@ -262,9 +263,16 @@ func ParseURL(sURL string) (*Options, error) { | ||
| 262 | func (opt *Options) getDialer() func(context.Context) (net.Conn, error) { | 263 | func (opt *Options) getDialer() func(context.Context) (net.Conn, error) { |
| 263 | return func(ctx context.Context) (net.Conn, error) { | 264 | return func(ctx context.Context) (net.Conn, error) { |
| 264 | var conn net.Conn | 265 | var conn net.Conn |
| 265 | - err := internal.WithSpan(ctx, "dialer", func(ctx context.Context, span trace.Span) error { | 266 | + err := internal.WithSpan(ctx, "pg.dial", func(ctx context.Context, span trace.Span) error { |
| 267 | + span.SetAttributes( | ||
| 268 | + label.String("db.connection_string", opt.Addr), | ||
| 269 | + ) | ||
| 270 | + | ||
| 266 | var err error | 271 | var err error |
| 267 | conn, err = opt.Dialer(ctx, opt.Network, opt.Addr) | 272 | conn, err = opt.Dialer(ctx, opt.Network, opt.Addr) |
| 273 | + if err != nil { | ||
| 274 | + span.RecordError(err) | ||
| 275 | + } | ||
| 268 | return err | 276 | return err |
| 269 | }) | 277 | }) |
| 270 | return conn, err | 278 | return conn, err |
| @@ -8,50 +8,62 @@ type CreateCompositeOptions struct { | @@ -8,50 +8,62 @@ type CreateCompositeOptions struct { | ||
| 8 | Varchar int // replaces PostgreSQL data type `text` with `varchar(n)` | 8 | Varchar int // replaces PostgreSQL data type `text` with `varchar(n)` |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | -func CreateComposite(db DB, model interface{}, opt *CreateCompositeOptions) error { | ||
| 12 | - q := NewQuery(db, model) | ||
| 13 | - _, err := q.db.Exec(&createCompositeQuery{ | 11 | +type CreateCompositeQuery struct { |
| 12 | + q *Query | ||
| 13 | + opt *CreateCompositeOptions | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +var ( | ||
| 17 | + _ QueryAppender = (*CreateCompositeQuery)(nil) | ||
| 18 | + _ QueryCommand = (*CreateCompositeQuery)(nil) | ||
| 19 | +) | ||
| 20 | + | ||
| 21 | +func NewCreateCompositeQuery(q *Query, opt *CreateCompositeOptions) *CreateCompositeQuery { | ||
| 22 | + return &CreateCompositeQuery{ | ||
| 14 | q: q, | 23 | q: q, |
| 15 | opt: opt, | 24 | opt: opt, |
| 16 | - }) | ||
| 17 | - return err | 25 | + } |
| 18 | } | 26 | } |
| 19 | 27 | ||
| 20 | -type createCompositeQuery struct { | ||
| 21 | - q *Query | ||
| 22 | - opt *CreateCompositeOptions | 28 | +func (q *CreateCompositeQuery) String() string { |
| 29 | + b, err := q.AppendQuery(defaultFmter, nil) | ||
| 30 | + if err != nil { | ||
| 31 | + panic(err) | ||
| 32 | + } | ||
| 33 | + return string(b) | ||
| 23 | } | 34 | } |
| 24 | 35 | ||
| 25 | -var _ QueryAppender = (*createCompositeQuery)(nil) | ||
| 26 | -var _ queryCommand = (*createCompositeQuery)(nil) | 36 | +func (q *CreateCompositeQuery) Operation() QueryOp { |
| 37 | + return CreateCompositeOp | ||
| 38 | +} | ||
| 27 | 39 | ||
| 28 | -func (q *createCompositeQuery) Clone() queryCommand { | ||
| 29 | - return &createCompositeQuery{ | 40 | +func (q *CreateCompositeQuery) Clone() QueryCommand { |
| 41 | + return &CreateCompositeQuery{ | ||
| 30 | q: q.q.Clone(), | 42 | q: q.q.Clone(), |
| 31 | opt: q.opt, | 43 | opt: q.opt, |
| 32 | } | 44 | } |
| 33 | } | 45 | } |
| 34 | 46 | ||
| 35 | -func (q *createCompositeQuery) Query() *Query { | 47 | +func (q *CreateCompositeQuery) Query() *Query { |
| 36 | return q.q | 48 | return q.q |
| 37 | } | 49 | } |
| 38 | 50 | ||
| 39 | -func (q *createCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { | 51 | +func (q *CreateCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { |
| 40 | return q.AppendQuery(dummyFormatter{}, b) | 52 | return q.AppendQuery(dummyFormatter{}, b) |
| 41 | } | 53 | } |
| 42 | 54 | ||
| 43 | -func (q *createCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { | 55 | +func (q *CreateCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { |
| 44 | if q.q.stickyErr != nil { | 56 | if q.q.stickyErr != nil { |
| 45 | return nil, q.q.stickyErr | 57 | return nil, q.q.stickyErr |
| 46 | } | 58 | } |
| 47 | - if q.q.model == nil { | 59 | + if q.q.tableModel == nil { |
| 48 | return nil, errModelNil | 60 | return nil, errModelNil |
| 49 | } | 61 | } |
| 50 | 62 | ||
| 51 | - table := q.q.model.Table() | 63 | + table := q.q.tableModel.Table() |
| 52 | 64 | ||
| 53 | b = append(b, "CREATE TYPE "...) | 65 | b = append(b, "CREATE TYPE "...) |
| 54 | - b = append(b, q.q.model.Table().Alias...) | 66 | + b = append(b, table.Alias...) |
| 55 | b = append(b, " AS ("...) | 67 | b = append(b, " AS ("...) |
| 56 | 68 | ||
| 57 | for i, field := range table.Fields { | 69 | for i, field := range table.Fields { |
| @@ -5,43 +5,55 @@ type DropCompositeOptions struct { | @@ -5,43 +5,55 @@ type DropCompositeOptions struct { | ||
| 5 | Cascade bool | 5 | Cascade bool |
| 6 | } | 6 | } |
| 7 | 7 | ||
| 8 | -func DropComposite(db DB, model interface{}, opt *DropCompositeOptions) error { | ||
| 9 | - q := NewQuery(db, model) | ||
| 10 | - _, err := q.db.Exec(&dropCompositeQuery{ | 8 | +type DropCompositeQuery struct { |
| 9 | + q *Query | ||
| 10 | + opt *DropCompositeOptions | ||
| 11 | +} | ||
| 12 | + | ||
| 13 | +var ( | ||
| 14 | + _ QueryAppender = (*DropCompositeQuery)(nil) | ||
| 15 | + _ QueryCommand = (*DropCompositeQuery)(nil) | ||
| 16 | +) | ||
| 17 | + | ||
| 18 | +func NewDropCompositeQuery(q *Query, opt *DropCompositeOptions) *DropCompositeQuery { | ||
| 19 | + return &DropCompositeQuery{ | ||
| 11 | q: q, | 20 | q: q, |
| 12 | opt: opt, | 21 | opt: opt, |
| 13 | - }) | ||
| 14 | - return err | 22 | + } |
| 15 | } | 23 | } |
| 16 | 24 | ||
| 17 | -type dropCompositeQuery struct { | ||
| 18 | - q *Query | ||
| 19 | - opt *DropCompositeOptions | 25 | +func (q *DropCompositeQuery) String() string { |
| 26 | + b, err := q.AppendQuery(defaultFmter, nil) | ||
| 27 | + if err != nil { | ||
| 28 | + panic(err) | ||
| 29 | + } | ||
| 30 | + return string(b) | ||
| 20 | } | 31 | } |
| 21 | 32 | ||
| 22 | -var _ QueryAppender = (*dropCompositeQuery)(nil) | ||
| 23 | -var _ queryCommand = (*dropCompositeQuery)(nil) | 33 | +func (q *DropCompositeQuery) Operation() QueryOp { |
| 34 | + return DropCompositeOp | ||
| 35 | +} | ||
| 24 | 36 | ||
| 25 | -func (q *dropCompositeQuery) Clone() queryCommand { | ||
| 26 | - return &dropCompositeQuery{ | 37 | +func (q *DropCompositeQuery) Clone() QueryCommand { |
| 38 | + return &DropCompositeQuery{ | ||
| 27 | q: q.q.Clone(), | 39 | q: q.q.Clone(), |
| 28 | opt: q.opt, | 40 | opt: q.opt, |
| 29 | } | 41 | } |
| 30 | } | 42 | } |
| 31 | 43 | ||
| 32 | -func (q *dropCompositeQuery) Query() *Query { | 44 | +func (q *DropCompositeQuery) Query() *Query { |
| 33 | return q.q | 45 | return q.q |
| 34 | } | 46 | } |
| 35 | 47 | ||
| 36 | -func (q *dropCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { | 48 | +func (q *DropCompositeQuery) AppendTemplate(b []byte) ([]byte, error) { |
| 37 | return q.AppendQuery(dummyFormatter{}, b) | 49 | return q.AppendQuery(dummyFormatter{}, b) |
| 38 | } | 50 | } |
| 39 | 51 | ||
| 40 | -func (q *dropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { | 52 | +func (q *DropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte, error) { |
| 41 | if q.q.stickyErr != nil { | 53 | if q.q.stickyErr != nil { |
| 42 | return nil, q.q.stickyErr | 54 | return nil, q.q.stickyErr |
| 43 | } | 55 | } |
| 44 | - if q.q.model == nil { | 56 | + if q.q.tableModel == nil { |
| 45 | return nil, errModelNil | 57 | return nil, errModelNil |
| 46 | } | 58 | } |
| 47 | 59 | ||
| @@ -49,7 +61,7 @@ func (q *dropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte | @@ -49,7 +61,7 @@ func (q *dropCompositeQuery) AppendQuery(fmter QueryFormatter, b []byte) ([]byte | ||
| 49 | if q.opt != nil && q.opt.IfExists { | 61 | if q.opt != nil && q.opt.IfExists { |
| 50 | b = append(b, "IF EXISTS "...) | 62 | b = append(b, "IF EXISTS "...) |
| 51 | } | 63 | } |
| 52 | - b = append(b, q.q.model.Table().Alias...) | 64 | + b = append(b, q.q.tableModel.Table().Alias...) |
| 53 | if q.opt != nil && q.opt.Cascade { | 65 | if q.opt != nil && q.opt.Cascade { |
| 54 | b = append(b, " CASCADE"...) | 66 | b = append(b, " CASCADE"...) |
| 55 | } | 67 | } |
| 1 | package orm | 1 | package orm |
| 2 | 2 | ||
| 3 | import ( | 3 | import ( |
| 4 | - "github.com/go-pg/pg/v10/internal" | ||
| 5 | -) | ||
| 6 | - | ||
| 7 | -// Delete deletes a given model from the db | ||
| 8 | -func Delete(db DB, model interface{}) error { | ||
| 9 | - res, err := NewQuery(db, model).WherePK().Delete() | ||
| 10 | - if err != nil { | ||
| 11 | - return err | ||
| 12 | - } | ||
| 13 | - return internal.AssertOneRow(res.RowsAffected()) | ||
| 14 | -} | 4 | + "reflect" |
| 15 | 5 | ||
| 16 | -// ForceDelete force deletes a given model from the db | ||
| 17 | -func ForceDelete(db DB, model interface{}) error { | ||
| 18 | - res, err := NewQuery(db, model).WherePK().ForceDelete() | ||
| 19 | - if err != nil { | ||
| 20 | - return err | ||
| 21 | - } | ||
| 22 | - return internal.AssertOneRow(res.RowsAffected()) | ||
| 23 | -} | 6 | + "github.com/go-pg/pg/v10/types" |
| 7 | +) | ||
| 24 | 8 | ||
| 25 | -type deleteQuery struct { | 9 | +type DeleteQuery struct { |
| 26 | q *Query | 10 | q *Query |
| 27 | placeholder bool | 11 | placeholder bool |
| 28 | } | 12 | } |
| 29 | 13 | ||
| 30 | -var _ QueryAppender = (*deleteQuery)(nil) | ||
| 31 | -var _ queryCommand = (*deleteQuery)(nil) | 14 | +var ( |
| 15 | + _ QueryAppender = (*DeleteQuery)(nil) | ||
| 16 | + _ QueryCommand = (*DeleteQuery)(nil) | ||
| 17 | +) | ||
| 32 | 18 | ||
| 33 | -func newDeleteQuery(q *Query) *deleteQuery { | ||
| 34 | - return &deleteQuery{ | 19 | +func NewDeleteQuery(q *Query) *DeleteQuery { |
| 20 | + return &DeleteQuery{ | ||
| 35 | q: q, | 21 | q: q, |
| 36 | } | 22 | } |
| 37 | } | 23 | } |
| 38 | 24 | ||
| 39 | -func (q *deleteQuery) Operation() string { | 25 | +func (q *DeleteQuery) String() string { |
| 26 | + b, err := q.AppendQuery(defaultFmter, nil) | ||
| 27 | + if err != nil { | ||
| 28 | + panic(err) | ||
| 29 | + } | ||
| 30 | + return string(b) | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +func (q *DeleteQuery) Operation() QueryOp { | ||
| 40 | return DeleteOp | 34 | return DeleteOp |
| 41 | } | 35 | } |
| 42 | 36 | ||
| 43 | -func (q *deleteQuery) Clone() queryCommand { | ||
| 44 | - return &deleteQuery{ | 37 | +func (q *DeleteQuery) Clone() QueryCommand { |
| 38 | + return &DeleteQuery{ | ||
| 45 | q: q.q.Clone(), | 39 | q: q.q.Clone(), |
| 46 | placeholder: q.placeholder, | 40 | placeholder: q.placeholder, |
| 47 | } | 41 | } |
| 48 | } | 42 | } |
| 49 | 43 | ||
| 50 | -func (q *deleteQuery) Query() *Query { | 44 | +func (q *DeleteQuery) Query() *Query { |
| 51 | return q.q | 45 | return q.q |
| 52 | } | 46 | } |
| 53 | 47 | ||
| 54 | -func (q *deleteQuery) AppendTemplate(b []byte) ([]byte, error) { | ||
| 55 | - cp := q.Clone().(*deleteQuery) | 48 | +func (q *DeleteQuery) AppendTemplate(b []byte) ([]byte, error) { |
| 49 | + cp := q.Clone().(*DeleteQuery) | ||
| 56 | cp.placeholder = true | 50 | cp.placeholder = true |
| 57 | return cp.AppendQuery(dummyFormatter{}, b) | 51 | return cp.AppendQuery(dummyFormatter{}, b) |
| 58 | } | 52 | } |
| 59 | 53 | ||
| 60 | -func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { | 54 | +func (q *DeleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { |
| 61 | if q.q.stickyErr != nil { | 55 | if q.q.stickyErr != nil { |
| 62 | return nil, q.q.stickyErr | 56 | return nil, q.q.stickyErr |
| 63 | } | 57 | } |
| @@ -84,7 +78,8 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | @@ -84,7 +78,8 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | ||
| 84 | } | 78 | } |
| 85 | 79 | ||
| 86 | b = append(b, " WHERE "...) | 80 | b = append(b, " WHERE "...) |
| 87 | - value := q.q.model.Value() | 81 | + value := q.q.tableModel.Value() |
| 82 | + | ||
| 88 | if q.q.isSliceModelWithData() { | 83 | if q.q.isSliceModelWithData() { |
| 89 | if len(q.q.where) > 0 { | 84 | if len(q.q.where) > 0 { |
| 90 | b, err = q.q.appendWhere(fmter, b) | 85 | b, err = q.q.appendWhere(fmter, b) |
| @@ -92,7 +87,7 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | @@ -92,7 +87,7 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | ||
| 92 | return nil, err | 87 | return nil, err |
| 93 | } | 88 | } |
| 94 | } else { | 89 | } else { |
| 95 | - table := q.q.model.Table() | 90 | + table := q.q.tableModel.Table() |
| 96 | err = table.checkPKs() | 91 | err = table.checkPKs() |
| 97 | if err != nil { | 92 | if err != nil { |
| 98 | return nil, err | 93 | return nil, err |
| @@ -116,3 +111,48 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | @@ -116,3 +111,48 @@ func (q *deleteQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | ||
| 116 | 111 | ||
| 117 | return b, q.q.stickyErr | 112 | return b, q.q.stickyErr |
| 118 | } | 113 | } |
| 114 | + | ||
| 115 | +func appendColumnAndSliceValue( | ||
| 116 | + fmter QueryFormatter, b []byte, slice reflect.Value, alias types.Safe, fields []*Field, | ||
| 117 | +) []byte { | ||
| 118 | + if len(fields) > 1 { | ||
| 119 | + b = append(b, '(') | ||
| 120 | + } | ||
| 121 | + b = appendColumns(b, alias, fields) | ||
| 122 | + if len(fields) > 1 { | ||
| 123 | + b = append(b, ')') | ||
| 124 | + } | ||
| 125 | + | ||
| 126 | + b = append(b, " IN ("...) | ||
| 127 | + | ||
| 128 | + isPlaceholder := isPlaceholderFormatter(fmter) | ||
| 129 | + sliceLen := slice.Len() | ||
| 130 | + for i := 0; i < sliceLen; i++ { | ||
| 131 | + if i > 0 { | ||
| 132 | + b = append(b, ", "...) | ||
| 133 | + } | ||
| 134 | + | ||
| 135 | + el := indirect(slice.Index(i)) | ||
| 136 | + | ||
| 137 | + if len(fields) > 1 { | ||
| 138 | + b = append(b, '(') | ||
| 139 | + } | ||
| 140 | + for i, f := range fields { | ||
| 141 | + if i > 0 { | ||
| 142 | + b = append(b, ", "...) | ||
| 143 | + } | ||
| 144 | + if isPlaceholder { | ||
| 145 | + b = append(b, '?') | ||
| 146 | + } else { | ||
| 147 | + b = f.AppendValue(b, el, 1) | ||
| 148 | + } | ||
| 149 | + } | ||
| 150 | + if len(fields) > 1 { | ||
| 151 | + b = append(b, ')') | ||
| 152 | + } | ||
| 153 | + } | ||
| 154 | + | ||
| 155 | + b = append(b, ')') | ||
| 156 | + | ||
| 157 | + return b | ||
| 158 | +} |
| @@ -70,10 +70,10 @@ func (f *Field) Value(strct reflect.Value) reflect.Value { | @@ -70,10 +70,10 @@ func (f *Field) Value(strct reflect.Value) reflect.Value { | ||
| 70 | } | 70 | } |
| 71 | 71 | ||
| 72 | func (f *Field) HasZeroValue(strct reflect.Value) bool { | 72 | func (f *Field) HasZeroValue(strct reflect.Value) bool { |
| 73 | - return f.hasZeroField(strct, f.Index) | 73 | + return f.hasZeroValue(strct, f.Index) |
| 74 | } | 74 | } |
| 75 | 75 | ||
| 76 | -func (f *Field) hasZeroField(v reflect.Value, index []int) bool { | 76 | +func (f *Field) hasZeroValue(v reflect.Value, index []int) bool { |
| 77 | for _, idx := range index { | 77 | for _, idx := range index { |
| 78 | if v.Kind() == reflect.Ptr { | 78 | if v.Kind() == reflect.Ptr { |
| 79 | if v.IsNil() { | 79 | if v.IsNil() { |
| @@ -106,10 +106,21 @@ func (f *Field) AppendValue(b []byte, strct reflect.Value, quote int) []byte { | @@ -106,10 +106,21 @@ func (f *Field) AppendValue(b []byte, strct reflect.Value, quote int) []byte { | ||
| 106 | } | 106 | } |
| 107 | 107 | ||
| 108 | func (f *Field) ScanValue(strct reflect.Value, rd types.Reader, n int) error { | 108 | func (f *Field) ScanValue(strct reflect.Value, rd types.Reader, n int) error { |
| 109 | - fv := f.Value(strct) | ||
| 110 | if f.scan == nil { | 109 | if f.scan == nil { |
| 111 | - return fmt.Errorf("pg: ScanValue(unsupported %s)", fv.Type()) | 110 | + return fmt.Errorf("pg: ScanValue(unsupported %s)", f.Type) |
| 112 | } | 111 | } |
| 112 | + | ||
| 113 | + var fv reflect.Value | ||
| 114 | + if n == -1 { | ||
| 115 | + var ok bool | ||
| 116 | + fv, ok = fieldByIndex(strct, f.Index) | ||
| 117 | + if !ok { | ||
| 118 | + return nil | ||
| 119 | + } | ||
| 120 | + } else { | ||
| 121 | + fv = fieldByIndexAlloc(strct, f.Index) | ||
| 122 | + } | ||
| 123 | + | ||
| 113 | return f.scan(fv, rd, n) | 124 | return f.scan(fv, rd, n) |
| 114 | } | 125 | } |
| 115 | 126 |
| @@ -26,8 +26,10 @@ type SafeQueryAppender struct { | @@ -26,8 +26,10 @@ type SafeQueryAppender struct { | ||
| 26 | params []interface{} | 26 | params []interface{} |
| 27 | } | 27 | } |
| 28 | 28 | ||
| 29 | -var _ QueryAppender = (*SafeQueryAppender)(nil) | ||
| 30 | -var _ types.ValueAppender = (*SafeQueryAppender)(nil) | 29 | +var ( |
| 30 | + _ QueryAppender = (*SafeQueryAppender)(nil) | ||
| 31 | + _ types.ValueAppender = (*SafeQueryAppender)(nil) | ||
| 32 | +) | ||
| 31 | 33 | ||
| 32 | //nolint | 34 | //nolint |
| 33 | func SafeQuery(query string, params ...interface{}) *SafeQueryAppender { | 35 | func SafeQuery(query string, params ...interface{}) *SafeQueryAppender { |
| @@ -57,8 +59,10 @@ type condGroupAppender struct { | @@ -57,8 +59,10 @@ type condGroupAppender struct { | ||
| 57 | cond []queryWithSepAppender | 59 | cond []queryWithSepAppender |
| 58 | } | 60 | } |
| 59 | 61 | ||
| 60 | -var _ QueryAppender = (*condAppender)(nil) | ||
| 61 | -var _ queryWithSepAppender = (*condAppender)(nil) | 62 | +var ( |
| 63 | + _ QueryAppender = (*condAppender)(nil) | ||
| 64 | + _ queryWithSepAppender = (*condAppender)(nil) | ||
| 65 | +) | ||
| 62 | 66 | ||
| 63 | func (q *condGroupAppender) AppendSep(b []byte) []byte { | 67 | func (q *condGroupAppender) AppendSep(b []byte) []byte { |
| 64 | return append(b, q.sep...) | 68 | return append(b, q.sep...) |
| @@ -87,8 +91,10 @@ type condAppender struct { | @@ -87,8 +91,10 @@ type condAppender struct { | ||
| 87 | params []interface{} | 91 | params []interface{} |
| 88 | } | 92 | } |
| 89 | 93 | ||
| 90 | -var _ QueryAppender = (*condAppender)(nil) | ||
| 91 | -var _ queryWithSepAppender = (*condAppender)(nil) | 94 | +var ( |
| 95 | + _ QueryAppender = (*condAppender)(nil) | ||
| 96 | + _ queryWithSepAppender = (*condAppender)(nil) | ||
| 97 | +) | ||
| 92 | 98 | ||
| 93 | func (q *condAppender) AppendSep(b []byte) []byte { | 99 | func (q *condAppender) AppendSep(b []byte) []byte { |
| 94 | return append(b, q.sep...) | 100 | return append(b, q.sep...) |
| @@ -192,9 +198,9 @@ func (f *Formatter) WithModel(model interface{}) *Formatter { | @@ -192,9 +198,9 @@ func (f *Formatter) WithModel(model interface{}) *Formatter { | ||
| 192 | case TableModel: | 198 | case TableModel: |
| 193 | return f.WithTableModel(model) | 199 | return f.WithTableModel(model) |
| 194 | case *Query: | 200 | case *Query: |
| 195 | - return f.WithTableModel(model.model) | ||
| 196 | - case queryCommand: | ||
| 197 | - return f.WithTableModel(model.Query().model) | 201 | + return f.WithTableModel(model.tableModel) |
| 202 | + case QueryCommand: | ||
| 203 | + return f.WithTableModel(model.Query().tableModel) | ||
| 198 | default: | 204 | default: |
| 199 | panic(fmt.Errorf("pg: unsupported model %T", model)) | 205 | panic(fmt.Errorf("pg: unsupported model %T", model)) |
| 200 | } | 206 | } |
| @@ -7,44 +7,51 @@ import ( | @@ -7,44 +7,51 @@ import ( | ||
| 7 | 7 | ||
| 8 | type hookStubs struct{} | 8 | type hookStubs struct{} |
| 9 | 9 | ||
| 10 | -var _ AfterSelectHook = (*hookStubs)(nil) | ||
| 11 | -var _ BeforeInsertHook = (*hookStubs)(nil) | ||
| 12 | -var _ AfterInsertHook = (*hookStubs)(nil) | ||
| 13 | -var _ BeforeUpdateHook = (*hookStubs)(nil) | ||
| 14 | -var _ AfterUpdateHook = (*hookStubs)(nil) | ||
| 15 | -var _ BeforeDeleteHook = (*hookStubs)(nil) | ||
| 16 | -var _ AfterDeleteHook = (*hookStubs)(nil) | ||
| 17 | - | ||
| 18 | -func (hookStubs) AfterSelect(c context.Context) error { | 10 | +var ( |
| 11 | + _ AfterScanHook = (*hookStubs)(nil) | ||
| 12 | + _ AfterSelectHook = (*hookStubs)(nil) | ||
| 13 | + _ BeforeInsertHook = (*hookStubs)(nil) | ||
| 14 | + _ AfterInsertHook = (*hookStubs)(nil) | ||
| 15 | + _ BeforeUpdateHook = (*hookStubs)(nil) | ||
| 16 | + _ AfterUpdateHook = (*hookStubs)(nil) | ||
| 17 | + _ BeforeDeleteHook = (*hookStubs)(nil) | ||
| 18 | + _ AfterDeleteHook = (*hookStubs)(nil) | ||
| 19 | +) | ||
| 20 | + | ||
| 21 | +func (hookStubs) AfterScan(ctx context.Context) error { | ||
| 22 | + return nil | ||
| 23 | +} | ||
| 24 | + | ||
| 25 | +func (hookStubs) AfterSelect(ctx context.Context) error { | ||
| 19 | return nil | 26 | return nil |
| 20 | } | 27 | } |
| 21 | 28 | ||
| 22 | -func (hookStubs) BeforeInsert(c context.Context) (context.Context, error) { | ||
| 23 | - return c, nil | 29 | +func (hookStubs) BeforeInsert(ctx context.Context) (context.Context, error) { |
| 30 | + return ctx, nil | ||
| 24 | } | 31 | } |
| 25 | 32 | ||
| 26 | -func (hookStubs) AfterInsert(c context.Context) error { | 33 | +func (hookStubs) AfterInsert(ctx context.Context) error { |
| 27 | return nil | 34 | return nil |
| 28 | } | 35 | } |
| 29 | 36 | ||
| 30 | -func (hookStubs) BeforeUpdate(c context.Context) (context.Context, error) { | ||
| 31 | - return c, nil | 37 | +func (hookStubs) BeforeUpdate(ctx context.Context) (context.Context, error) { |
| 38 | + return ctx, nil | ||
| 32 | } | 39 | } |
| 33 | 40 | ||
| 34 | -func (hookStubs) AfterUpdate(c context.Context) error { | 41 | +func (hookStubs) AfterUpdate(ctx context.Context) error { |
| 35 | return nil | 42 | return nil |
| 36 | } | 43 | } |
| 37 | 44 | ||
| 38 | -func (hookStubs) BeforeDelete(c context.Context) (context.Context, error) { | ||
| 39 | - return c, nil | 45 | +func (hookStubs) BeforeDelete(ctx context.Context) (context.Context, error) { |
| 46 | + return ctx, nil | ||
| 40 | } | 47 | } |
| 41 | 48 | ||
| 42 | -func (hookStubs) AfterDelete(c context.Context) error { | 49 | +func (hookStubs) AfterDelete(ctx context.Context) error { |
| 43 | return nil | 50 | return nil |
| 44 | } | 51 | } |
| 45 | 52 | ||
| 46 | func callHookSlice( | 53 | func callHookSlice( |
| 47 | - c context.Context, | 54 | + ctx context.Context, |
| 48 | slice reflect.Value, | 55 | slice reflect.Value, |
| 49 | ptr bool, | 56 | ptr bool, |
| 50 | hook func(context.Context, reflect.Value) (context.Context, error), | 57 | hook func(context.Context, reflect.Value) (context.Context, error), |
| @@ -58,16 +65,16 @@ func callHookSlice( | @@ -58,16 +65,16 @@ func callHookSlice( | ||
| 58 | } | 65 | } |
| 59 | 66 | ||
| 60 | var err error | 67 | var err error |
| 61 | - c, err = hook(c, v) | 68 | + ctx, err = hook(ctx, v) |
| 62 | if err != nil && firstErr == nil { | 69 | if err != nil && firstErr == nil { |
| 63 | firstErr = err | 70 | firstErr = err |
| 64 | } | 71 | } |
| 65 | } | 72 | } |
| 66 | - return c, firstErr | 73 | + return ctx, firstErr |
| 67 | } | 74 | } |
| 68 | 75 | ||
| 69 | func callHookSlice2( | 76 | func callHookSlice2( |
| 70 | - c context.Context, | 77 | + ctx context.Context, |
| 71 | slice reflect.Value, | 78 | slice reflect.Value, |
| 72 | ptr bool, | 79 | ptr bool, |
| 73 | hook func(context.Context, reflect.Value) error, | 80 | hook func(context.Context, reflect.Value) error, |
| @@ -81,7 +88,7 @@ func callHookSlice2( | @@ -81,7 +88,7 @@ func callHookSlice2( | ||
| 81 | v = v.Addr() | 88 | v = v.Addr() |
| 82 | } | 89 | } |
| 83 | 90 | ||
| 84 | - err := hook(c, v) | 91 | + err := hook(ctx, v) |
| 85 | if err != nil && firstErr == nil { | 92 | if err != nil && firstErr == nil { |
| 86 | firstErr = err | 93 | firstErr = err |
| 87 | } | 94 | } |
| @@ -98,8 +105,8 @@ type BeforeScanHook interface { | @@ -98,8 +105,8 @@ type BeforeScanHook interface { | ||
| 98 | 105 | ||
| 99 | var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem() | 106 | var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem() |
| 100 | 107 | ||
| 101 | -func callBeforeScanHook(c context.Context, v reflect.Value) error { | ||
| 102 | - return v.Interface().(BeforeScanHook).BeforeScan(c) | 108 | +func callBeforeScanHook(ctx context.Context, v reflect.Value) error { |
| 109 | + return v.Interface().(BeforeScanHook).BeforeScan(ctx) | ||
| 103 | } | 110 | } |
| 104 | 111 | ||
| 105 | //------------------------------------------------------------------------------ | 112 | //------------------------------------------------------------------------------ |
| @@ -110,8 +117,8 @@ type AfterScanHook interface { | @@ -110,8 +117,8 @@ type AfterScanHook interface { | ||
| 110 | 117 | ||
| 111 | var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() | 118 | var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem() |
| 112 | 119 | ||
| 113 | -func callAfterScanHook(c context.Context, v reflect.Value) error { | ||
| 114 | - return v.Interface().(AfterScanHook).AfterScan(c) | 120 | +func callAfterScanHook(ctx context.Context, v reflect.Value) error { |
| 121 | + return v.Interface().(AfterScanHook).AfterScan(ctx) | ||
| 115 | } | 122 | } |
| 116 | 123 | ||
| 117 | //------------------------------------------------------------------------------ | 124 | //------------------------------------------------------------------------------ |
| @@ -122,14 +129,14 @@ type AfterSelectHook interface { | @@ -122,14 +129,14 @@ type AfterSelectHook interface { | ||
| 122 | 129 | ||
| 123 | var afterSelectHookType = reflect.TypeOf((*AfterSelectHook)(nil)).Elem() | 130 | var afterSelectHookType = reflect.TypeOf((*AfterSelectHook)(nil)).Elem() |
| 124 | 131 | ||
| 125 | -func callAfterSelectHook(c context.Context, v reflect.Value) error { | ||
| 126 | - return v.Interface().(AfterSelectHook).AfterSelect(c) | 132 | +func callAfterSelectHook(ctx context.Context, v reflect.Value) error { |
| 133 | + return v.Interface().(AfterSelectHook).AfterSelect(ctx) | ||
| 127 | } | 134 | } |
| 128 | 135 | ||
| 129 | func callAfterSelectHookSlice( | 136 | func callAfterSelectHookSlice( |
| 130 | - c context.Context, slice reflect.Value, ptr bool, | 137 | + ctx context.Context, slice reflect.Value, ptr bool, |
| 131 | ) error { | 138 | ) error { |
| 132 | - return callHookSlice2(c, slice, ptr, callAfterSelectHook) | 139 | + return callHookSlice2(ctx, slice, ptr, callAfterSelectHook) |
| 133 | } | 140 | } |
| 134 | 141 | ||
| 135 | //------------------------------------------------------------------------------ | 142 | //------------------------------------------------------------------------------ |
| @@ -140,14 +147,14 @@ type BeforeInsertHook interface { | @@ -140,14 +147,14 @@ type BeforeInsertHook interface { | ||
| 140 | 147 | ||
| 141 | var beforeInsertHookType = reflect.TypeOf((*BeforeInsertHook)(nil)).Elem() | 148 | var beforeInsertHookType = reflect.TypeOf((*BeforeInsertHook)(nil)).Elem() |
| 142 | 149 | ||
| 143 | -func callBeforeInsertHook(c context.Context, v reflect.Value) (context.Context, error) { | ||
| 144 | - return v.Interface().(BeforeInsertHook).BeforeInsert(c) | 150 | +func callBeforeInsertHook(ctx context.Context, v reflect.Value) (context.Context, error) { |
| 151 | + return v.Interface().(BeforeInsertHook).BeforeInsert(ctx) | ||
| 145 | } | 152 | } |
| 146 | 153 | ||
| 147 | func callBeforeInsertHookSlice( | 154 | func callBeforeInsertHookSlice( |
| 148 | - c context.Context, slice reflect.Value, ptr bool, | 155 | + ctx context.Context, slice reflect.Value, ptr bool, |
| 149 | ) (context.Context, error) { | 156 | ) (context.Context, error) { |
| 150 | - return callHookSlice(c, slice, ptr, callBeforeInsertHook) | 157 | + return callHookSlice(ctx, slice, ptr, callBeforeInsertHook) |
| 151 | } | 158 | } |
| 152 | 159 | ||
| 153 | //------------------------------------------------------------------------------ | 160 | //------------------------------------------------------------------------------ |
| @@ -158,14 +165,14 @@ type AfterInsertHook interface { | @@ -158,14 +165,14 @@ type AfterInsertHook interface { | ||
| 158 | 165 | ||
| 159 | var afterInsertHookType = reflect.TypeOf((*AfterInsertHook)(nil)).Elem() | 166 | var afterInsertHookType = reflect.TypeOf((*AfterInsertHook)(nil)).Elem() |
| 160 | 167 | ||
| 161 | -func callAfterInsertHook(c context.Context, v reflect.Value) error { | ||
| 162 | - return v.Interface().(AfterInsertHook).AfterInsert(c) | 168 | +func callAfterInsertHook(ctx context.Context, v reflect.Value) error { |
| 169 | + return v.Interface().(AfterInsertHook).AfterInsert(ctx) | ||
| 163 | } | 170 | } |
| 164 | 171 | ||
| 165 | func callAfterInsertHookSlice( | 172 | func callAfterInsertHookSlice( |
| 166 | - c context.Context, slice reflect.Value, ptr bool, | 173 | + ctx context.Context, slice reflect.Value, ptr bool, |
| 167 | ) error { | 174 | ) error { |
| 168 | - return callHookSlice2(c, slice, ptr, callAfterInsertHook) | 175 | + return callHookSlice2(ctx, slice, ptr, callAfterInsertHook) |
| 169 | } | 176 | } |
| 170 | 177 | ||
| 171 | //------------------------------------------------------------------------------ | 178 | //------------------------------------------------------------------------------ |
| @@ -176,14 +183,14 @@ type BeforeUpdateHook interface { | @@ -176,14 +183,14 @@ type BeforeUpdateHook interface { | ||
| 176 | 183 | ||
| 177 | var beforeUpdateHookType = reflect.TypeOf((*BeforeUpdateHook)(nil)).Elem() | 184 | var beforeUpdateHookType = reflect.TypeOf((*BeforeUpdateHook)(nil)).Elem() |
| 178 | 185 | ||
| 179 | -func callBeforeUpdateHook(c context.Context, v reflect.Value) (context.Context, error) { | ||
| 180 | - return v.Interface().(BeforeUpdateHook).BeforeUpdate(c) | 186 | +func callBeforeUpdateHook(ctx context.Context, v reflect.Value) (context.Context, error) { |
| 187 | + return v.Interface().(BeforeUpdateHook).BeforeUpdate(ctx) | ||
| 181 | } | 188 | } |
| 182 | 189 | ||
| 183 | func callBeforeUpdateHookSlice( | 190 | func callBeforeUpdateHookSlice( |
| 184 | - c context.Context, slice reflect.Value, ptr bool, | 191 | + ctx context.Context, slice reflect.Value, ptr bool, |
| 185 | ) (context.Context, error) { | 192 | ) (context.Context, error) { |
| 186 | - return callHookSlice(c, slice, ptr, callBeforeUpdateHook) | 193 | + return callHookSlice(ctx, slice, ptr, callBeforeUpdateHook) |
| 187 | } | 194 | } |
| 188 | 195 | ||
| 189 | //------------------------------------------------------------------------------ | 196 | //------------------------------------------------------------------------------ |
| @@ -194,14 +201,14 @@ type AfterUpdateHook interface { | @@ -194,14 +201,14 @@ type AfterUpdateHook interface { | ||
| 194 | 201 | ||
| 195 | var afterUpdateHookType = reflect.TypeOf((*AfterUpdateHook)(nil)).Elem() | 202 | var afterUpdateHookType = reflect.TypeOf((*AfterUpdateHook)(nil)).Elem() |
| 196 | 203 | ||
| 197 | -func callAfterUpdateHook(c context.Context, v reflect.Value) error { | ||
| 198 | - return v.Interface().(AfterUpdateHook).AfterUpdate(c) | 204 | +func callAfterUpdateHook(ctx context.Context, v reflect.Value) error { |
| 205 | + return v.Interface().(AfterUpdateHook).AfterUpdate(ctx) | ||
| 199 | } | 206 | } |
| 200 | 207 | ||
| 201 | func callAfterUpdateHookSlice( | 208 | func callAfterUpdateHookSlice( |
| 202 | - c context.Context, slice reflect.Value, ptr bool, | 209 | + ctx context.Context, slice reflect.Value, ptr bool, |
| 203 | ) error { | 210 | ) error { |
| 204 | - return callHookSlice2(c, slice, ptr, callAfterUpdateHook) | 211 | + return callHookSlice2(ctx, slice, ptr, callAfterUpdateHook) |
| 205 | } | 212 | } |
| 206 | 213 | ||
| 207 | //------------------------------------------------------------------------------ | 214 | //------------------------------------------------------------------------------ |
| @@ -212,14 +219,14 @@ type BeforeDeleteHook interface { | @@ -212,14 +219,14 @@ type BeforeDeleteHook interface { | ||
| 212 | 219 | ||
| 213 | var beforeDeleteHookType = reflect.TypeOf((*BeforeDeleteHook)(nil)).Elem() | 220 | var beforeDeleteHookType = reflect.TypeOf((*BeforeDeleteHook)(nil)).Elem() |
| 214 | 221 | ||
| 215 | -func callBeforeDeleteHook(c context.Context, v reflect.Value) (context.Context, error) { | ||
| 216 | - return v.Interface().(BeforeDeleteHook).BeforeDelete(c) | 222 | +func callBeforeDeleteHook(ctx context.Context, v reflect.Value) (context.Context, error) { |
| 223 | + return v.Interface().(BeforeDeleteHook).BeforeDelete(ctx) | ||
| 217 | } | 224 | } |
| 218 | 225 | ||
| 219 | func callBeforeDeleteHookSlice( | 226 | func callBeforeDeleteHookSlice( |
| 220 | - c context.Context, slice reflect.Value, ptr bool, | 227 | + ctx context.Context, slice reflect.Value, ptr bool, |
| 221 | ) (context.Context, error) { | 228 | ) (context.Context, error) { |
| 222 | - return callHookSlice(c, slice, ptr, callBeforeDeleteHook) | 229 | + return callHookSlice(ctx, slice, ptr, callBeforeDeleteHook) |
| 223 | } | 230 | } |
| 224 | 231 | ||
| 225 | //------------------------------------------------------------------------------ | 232 | //------------------------------------------------------------------------------ |
| @@ -230,12 +237,12 @@ type AfterDeleteHook interface { | @@ -230,12 +237,12 @@ type AfterDeleteHook interface { | ||
| 230 | 237 | ||
| 231 | var afterDeleteHookType = reflect.TypeOf((*AfterDeleteHook)(nil)).Elem() | 238 | var afterDeleteHookType = reflect.TypeOf((*AfterDeleteHook)(nil)).Elem() |
| 232 | 239 | ||
| 233 | -func callAfterDeleteHook(c context.Context, v reflect.Value) error { | ||
| 234 | - return v.Interface().(AfterDeleteHook).AfterDelete(c) | 240 | +func callAfterDeleteHook(ctx context.Context, v reflect.Value) error { |
| 241 | + return v.Interface().(AfterDeleteHook).AfterDelete(ctx) | ||
| 235 | } | 242 | } |
| 236 | 243 | ||
| 237 | func callAfterDeleteHookSlice( | 244 | func callAfterDeleteHookSlice( |
| 238 | - c context.Context, slice reflect.Value, ptr bool, | 245 | + ctx context.Context, slice reflect.Value, ptr bool, |
| 239 | ) error { | 246 | ) error { |
| 240 | - return callHookSlice2(c, slice, ptr, callAfterDeleteHook) | 247 | + return callHookSlice2(ctx, slice, ptr, callAfterDeleteHook) |
| 241 | } | 248 | } |
| @@ -3,55 +3,59 @@ package orm | @@ -3,55 +3,59 @@ package orm | ||
| 3 | import ( | 3 | import ( |
| 4 | "fmt" | 4 | "fmt" |
| 5 | "reflect" | 5 | "reflect" |
| 6 | + "sort" | ||
| 6 | 7 | ||
| 7 | "github.com/go-pg/pg/v10/types" | 8 | "github.com/go-pg/pg/v10/types" |
| 8 | ) | 9 | ) |
| 9 | 10 | ||
| 10 | -func Insert(db DB, model ...interface{}) error { | ||
| 11 | - _, err := NewQuery(db, model...).Insert() | ||
| 12 | - return err | ||
| 13 | -} | ||
| 14 | - | ||
| 15 | -type insertQuery struct { | 11 | +type InsertQuery struct { |
| 16 | q *Query | 12 | q *Query |
| 17 | returningFields []*Field | 13 | returningFields []*Field |
| 18 | placeholder bool | 14 | placeholder bool |
| 19 | } | 15 | } |
| 20 | 16 | ||
| 21 | -var _ queryCommand = (*insertQuery)(nil) | 17 | +var _ QueryCommand = (*InsertQuery)(nil) |
| 22 | 18 | ||
| 23 | -func newInsertQuery(q *Query) *insertQuery { | ||
| 24 | - return &insertQuery{ | 19 | +func NewInsertQuery(q *Query) *InsertQuery { |
| 20 | + return &InsertQuery{ | ||
| 25 | q: q, | 21 | q: q, |
| 26 | } | 22 | } |
| 27 | } | 23 | } |
| 28 | 24 | ||
| 29 | -func (q *insertQuery) Operation() string { | 25 | +func (q *InsertQuery) String() string { |
| 26 | + b, err := q.AppendQuery(defaultFmter, nil) | ||
| 27 | + if err != nil { | ||
| 28 | + panic(err) | ||
| 29 | + } | ||
| 30 | + return string(b) | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +func (q *InsertQuery) Operation() QueryOp { | ||
| 30 | return InsertOp | 34 | return InsertOp |
| 31 | } | 35 | } |
| 32 | 36 | ||
| 33 | -func (q *insertQuery) Clone() queryCommand { | ||
| 34 | - return &insertQuery{ | 37 | +func (q *InsertQuery) Clone() QueryCommand { |
| 38 | + return &InsertQuery{ | ||
| 35 | q: q.q.Clone(), | 39 | q: q.q.Clone(), |
| 36 | placeholder: q.placeholder, | 40 | placeholder: q.placeholder, |
| 37 | } | 41 | } |
| 38 | } | 42 | } |
| 39 | 43 | ||
| 40 | -func (q *insertQuery) Query() *Query { | 44 | +func (q *InsertQuery) Query() *Query { |
| 41 | return q.q | 45 | return q.q |
| 42 | } | 46 | } |
| 43 | 47 | ||
| 44 | -var _ TemplateAppender = (*insertQuery)(nil) | 48 | +var _ TemplateAppender = (*InsertQuery)(nil) |
| 45 | 49 | ||
| 46 | -func (q *insertQuery) AppendTemplate(b []byte) ([]byte, error) { | ||
| 47 | - cp := q.Clone().(*insertQuery) | 50 | +func (q *InsertQuery) AppendTemplate(b []byte) ([]byte, error) { |
| 51 | + cp := q.Clone().(*InsertQuery) | ||
| 48 | cp.placeholder = true | 52 | cp.placeholder = true |
| 49 | return cp.AppendQuery(dummyFormatter{}, b) | 53 | return cp.AppendQuery(dummyFormatter{}, b) |
| 50 | } | 54 | } |
| 51 | 55 | ||
| 52 | -var _ QueryAppender = (*insertQuery)(nil) | 56 | +var _ QueryAppender = (*InsertQuery)(nil) |
| 53 | 57 | ||
| 54 | -func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { | 58 | +func (q *InsertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err error) { |
| 55 | if q.q.stickyErr != nil { | 59 | if q.q.stickyErr != nil { |
| 56 | return nil, q.q.stickyErr | 60 | return nil, q.q.stickyErr |
| 57 | } | 61 | } |
| @@ -73,54 +77,9 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | @@ -73,54 +77,9 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | ||
| 73 | return nil, err | 77 | return nil, err |
| 74 | } | 78 | } |
| 75 | 79 | ||
| 76 | - if q.q.hasMultiTables() { | ||
| 77 | - if q.q.columns != nil { | ||
| 78 | - b = append(b, " ("...) | ||
| 79 | - b, err = q.q.appendColumns(fmter, b) | ||
| 80 | - if err != nil { | ||
| 81 | - return nil, err | ||
| 82 | - } | ||
| 83 | - b = append(b, ")"...) | ||
| 84 | - } | ||
| 85 | - b = append(b, " SELECT * FROM "...) | ||
| 86 | - b, err = q.q.appendOtherTables(fmter, b) | ||
| 87 | - if err != nil { | ||
| 88 | - return nil, err | ||
| 89 | - } | ||
| 90 | - } else { | ||
| 91 | - if !q.q.hasModel() { | ||
| 92 | - return nil, errModelNil | ||
| 93 | - } | ||
| 94 | - | ||
| 95 | - fields, err := q.q.getFields() | ||
| 96 | - if err != nil { | ||
| 97 | - return nil, err | ||
| 98 | - } | ||
| 99 | - | ||
| 100 | - if len(fields) == 0 { | ||
| 101 | - fields = q.q.model.Table().Fields | ||
| 102 | - } | ||
| 103 | - value := q.q.model.Value() | ||
| 104 | - | ||
| 105 | - b = append(b, " ("...) | ||
| 106 | - b = q.appendColumns(b, fields) | ||
| 107 | - b = append(b, ") VALUES ("...) | ||
| 108 | - if m, ok := q.q.model.(*sliceTableModel); ok { | ||
| 109 | - if m.sliceLen == 0 { | ||
| 110 | - err = fmt.Errorf("pg: can't bulk-insert empty slice %s", value.Type()) | ||
| 111 | - return nil, err | ||
| 112 | - } | ||
| 113 | - b, err = q.appendSliceValues(fmter, b, fields, value) | ||
| 114 | - if err != nil { | ||
| 115 | - return nil, err | ||
| 116 | - } | ||
| 117 | - } else { | ||
| 118 | - b, err = q.appendValues(fmter, b, fields, value) | ||
| 119 | - if err != nil { | ||
| 120 | - return nil, err | ||
| 121 | - } | ||
| 122 | - } | ||
| 123 | - b = append(b, ")"...) | 80 | + b, err = q.appendColumnsValues(fmter, b) |
| 81 | + if err != nil { | ||
| 82 | + return nil, err | ||
| 124 | } | 83 | } |
| 125 | 84 | ||
| 126 | if q.q.onConflict != nil { | 85 | if q.q.onConflict != nil { |
| @@ -143,7 +102,7 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | @@ -143,7 +102,7 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | ||
| 143 | } | 102 | } |
| 144 | 103 | ||
| 145 | if len(fields) == 0 { | 104 | if len(fields) == 0 { |
| 146 | - fields = q.q.model.Table().DataFields | 105 | + fields = q.q.tableModel.Table().DataFields |
| 147 | } | 106 | } |
| 148 | 107 | ||
| 149 | b = q.appendSetExcluded(b, fields) | 108 | b = q.appendSetExcluded(b, fields) |
| @@ -171,7 +130,103 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | @@ -171,7 +130,103 @@ func (q *insertQuery) AppendQuery(fmter QueryFormatter, b []byte) (_ []byte, err | ||
| 171 | return b, q.q.stickyErr | 130 | return b, q.q.stickyErr |
| 172 | } | 131 | } |
| 173 | 132 | ||
| 174 | -func (q *insertQuery) appendValues( | 133 | +func (q *InsertQuery) appendColumnsValues(fmter QueryFormatter, b []byte) (_ []byte, err error) { |
| 134 | + if q.q.hasMultiTables() { | ||
| 135 | + if q.q.columns != nil { | ||
| 136 | + b = append(b, " ("...) | ||
| 137 | + b, err = q.q.appendColumns(fmter, b) | ||
| 138 | + if err != nil { | ||
| 139 | + return nil, err | ||
| 140 | + } | ||
| 141 | + b = append(b, ")"...) | ||
| 142 | + } | ||
| 143 | + | ||
| 144 | + b = append(b, " SELECT * FROM "...) | ||
| 145 | + b, err = q.q.appendOtherTables(fmter, b) | ||
| 146 | + if err != nil { | ||
| 147 | + return nil, err | ||
| 148 | + } | ||
| 149 | + | ||
| 150 | + return b, nil | ||
| 151 | + } | ||
| 152 | + | ||
| 153 | + if m, ok := q.q.model.(*mapModel); ok { | ||
| 154 | + return q.appendMapColumnsValues(b, m.m), nil | ||
| 155 | + } | ||
| 156 | + | ||
| 157 | + if !q.q.hasTableModel() { | ||
| 158 | + return nil, errModelNil | ||
| 159 | + } | ||
| 160 | + | ||
| 161 | + fields, err := q.q.getFields() | ||
| 162 | + if err != nil { | ||
| 163 | + return nil, err | ||
| 164 | + } | ||
| 165 | + | ||
| 166 | + if len(fields) == 0 { | ||
| 167 | + fields = q.q.tableModel.Table().Fields | ||
| 168 | + } | ||
| 169 | + value := q.q.tableModel.Value() | ||
| 170 | + | ||
| 171 | + b = append(b, " ("...) | ||
| 172 | + b = q.appendColumns(b, fields) | ||
| 173 | + b = append(b, ") VALUES ("...) | ||
| 174 | + if m, ok := q.q.tableModel.(*sliceTableModel); ok { | ||
| 175 | + if m.sliceLen == 0 { | ||
| 176 | + err = fmt.Errorf("pg: can't bulk-insert empty slice %s", value.Type()) | ||
| 177 | + return nil, err | ||
| 178 | + } | ||
| 179 | + b, err = q.appendSliceValues(fmter, b, fields, value) | ||
| 180 | + if err != nil { | ||
| 181 | + return nil, err | ||
| 182 | + } | ||
| 183 | + } else { | ||
| 184 | + b, err = q.appendValues(fmter, b, fields, value) | ||
| 185 | + if err != nil { | ||
| 186 | + return nil, err | ||
| 187 | + } | ||
| 188 | + } | ||
| 189 | + b = append(b, ")"...) | ||
| 190 | + | ||
| 191 | + return b, nil | ||
| 192 | +} | ||
| 193 | + | ||
| 194 | +func (q *InsertQuery) appendMapColumnsValues(b []byte, m map[string]interface{}) []byte { | ||
| 195 | + keys := make([]string, 0, len(m)) | ||
| 196 | + | ||
| 197 | + for k := range m { | ||
| 198 | + keys = append(keys, k) | ||
| 199 | + } | ||
| 200 | + sort.Strings(keys) | ||
| 201 | + | ||
| 202 | + b = append(b, " ("...) | ||
| 203 | + | ||
| 204 | + for i, k := range keys { | ||
| 205 | + if i > 0 { | ||
| 206 | + b = append(b, ", "...) | ||
| 207 | + } | ||
| 208 | + b = types.AppendIdent(b, k, 1) | ||
| 209 | + } | ||
| 210 | + | ||
| 211 | + b = append(b, ") VALUES ("...) | ||
| 212 | + | ||
| 213 | + for i, k := range keys { | ||
| 214 | + if i > 0 { | ||
| 215 | + b = append(b, ", "...) | ||
| 216 | + } | ||
| 217 | + if q.placeholder { | ||
| 218 | + b = append(b, '?') | ||
| 219 | + } else { | ||
| 220 | + b = types.Append(b, m[k], 1) | ||
| 221 | + } | ||
| 222 | + } | ||
| 223 | + | ||
| 224 | + b = append(b, ")"...) | ||
| 225 | + | ||
| 226 | + return b | ||
| 227 | +} | ||
| 228 | + | ||
| 229 | +func (q *InsertQuery) appendValues( | ||
| 175 | fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value, | 230 | fmter QueryFormatter, b []byte, fields []*Field, strct reflect.Value, |
| 176 | ) (_ []byte, err error) { | 231 | ) (_ []byte, err error) { |
| 177 | for i, f := range fields { | 232 | for i, f := range fields { |
| @@ -214,7 +269,7 @@ func (q *insertQuery) appendValues( | @@ -214,7 +269,7 @@ func (q *insertQuery) appendValues( | ||
| 214 | return b, nil | 269 | return b, nil |
| 215 | } | 270 | } |
| 216 | 271 | ||
| 217 | -func (q *insertQuery) appendSliceValues( | 272 | +func (q *InsertQuery) appendSliceValues( |
| 218 | fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value, | 273 | fmter QueryFormatter, b []byte, fields []*Field, slice reflect.Value, |
| 219 | ) (_ []byte, err error) { | 274 | ) (_ []byte, err error) { |
| 220 | if q.placeholder { | 275 | if q.placeholder { |
| @@ -247,7 +302,7 @@ func (q *insertQuery) appendSliceValues( | @@ -247,7 +302,7 @@ func (q *insertQuery) appendSliceValues( | ||
| 247 | return b, nil | 302 | return b, nil |
| 248 | } | 303 | } |
| 249 | 304 | ||
| 250 | -func (q *insertQuery) addReturningField(field *Field) { | 305 | +func (q *InsertQuery) addReturningField(field *Field) { |
| 251 | if len(q.q.returning) > 0 { | 306 | if len(q.q.returning) > 0 { |
| 252 | return | 307 | return |
| 253 | } | 308 | } |
| @@ -259,7 +314,7 @@ func (q *insertQuery) addReturningField(field *Field) { | @@ -259,7 +314,7 @@ func (q *insertQuery) addReturningField(field *Field) { | ||
| 259 | q.returningFields = append(q.returningFields, field) | 314 | q.returningFields = append(q.returningFields, field) |
| 260 | } | 315 | } |
| 261 | 316 | ||
| 262 | -func (q *insertQuery) appendSetExcluded(b []byte, fields []*Field) []byte { | 317 | +func (q *InsertQuery) appendSetExcluded(b []byte, fields []*Field) []byte { |
| 263 | b = append(b, " SET "...) | 318 | b = append(b, " SET "...) |
| 264 | for i, f := range fields { | 319 | for i, f := range fields { |
| 265 | if i > 0 { | 320 | if i > 0 { |
| @@ -272,7 +327,7 @@ func (q *insertQuery) appendSetExcluded(b []byte, fields []*Field) []byte { | @@ -272,7 +327,7 @@ func (q *insertQuery) appendSetExcluded(b []byte, fields []*Field) []byte { | ||
| 272 | return b | 327 | return b |
| 273 | } | 328 | } |
| 274 | 329 | ||
| 275 | -func (q *insertQuery) appendColumns(b []byte, fields []*Field) []byte { | 330 | +func (q *InsertQuery) appendColumns(b []byte, fields []*Field) []byte { |
| 276 | b = appendColumns(b, "", fields) | 331 | b = appendColumns(b, "", fields) |
| 277 | for i, v := range q.q.extraValues { | 332 | for i, v := range q.q.extraValues { |
| 278 | if i > 0 || len(fields) > 0 { | 333 | if i > 0 || len(fields) > 0 { |
| @@ -64,16 +64,16 @@ func (j *join) manyQuery(q *Query) (*Query, error) { | @@ -64,16 +64,16 @@ func (j *join) manyQuery(q *Query) (*Query, error) { | ||
| 64 | 64 | ||
| 65 | baseTable := j.BaseModel.Table() | 65 | baseTable := j.BaseModel.Table() |
| 66 | var where []byte | 66 | var where []byte |
| 67 | - if len(j.Rel.FKs) > 1 { | 67 | + if len(j.Rel.JoinFKs) > 1 { |
| 68 | where = append(where, '(') | 68 | where = append(where, '(') |
| 69 | } | 69 | } |
| 70 | - where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.FKs) | ||
| 71 | - if len(j.Rel.FKs) > 1 { | 70 | + where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.JoinFKs) |
| 71 | + if len(j.Rel.JoinFKs) > 1 { | ||
| 72 | where = append(where, ')') | 72 | where = append(where, ')') |
| 73 | } | 73 | } |
| 74 | where = append(where, " IN ("...) | 74 | where = append(where, " IN ("...) |
| 75 | where = appendChildValues( | 75 | where = appendChildValues( |
| 76 | - where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.FKValues) | 76 | + where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.BaseFKs) |
| 77 | where = append(where, ")"...) | 77 | where = append(where, ")"...) |
| 78 | q = q.Where(internal.BytesToString(where)) | 78 | q = q.Where(internal.BytesToString(where)) |
| 79 | 79 | ||
| @@ -126,7 +126,7 @@ func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) { | @@ -126,7 +126,7 @@ func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) { | ||
| 126 | join = append(join, " AS "...) | 126 | join = append(join, " AS "...) |
| 127 | join = append(join, j.Rel.M2MTableAlias...) | 127 | join = append(join, j.Rel.M2MTableAlias...) |
| 128 | join = append(join, " ON ("...) | 128 | join = append(join, " ON ("...) |
| 129 | - for i, col := range j.Rel.BaseFKs { | 129 | + for i, col := range j.Rel.M2MBaseFKs { |
| 130 | if i > 0 { | 130 | if i > 0 { |
| 131 | join = append(join, ", "...) | 131 | join = append(join, ", "...) |
| 132 | } | 132 | } |
| @@ -140,10 +140,7 @@ func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) { | @@ -140,10 +140,7 @@ func (j *join) m2mQuery(fmter QueryFormatter, q *Query) (*Query, error) { | ||
| 140 | q = q.Join(internal.BytesToString(join)) | 140 | q = q.Join(internal.BytesToString(join)) |
| 141 | 141 | ||
| 142 | joinTable := j.JoinModel.Table() | 142 | joinTable := j.JoinModel.Table() |
| 143 | - for i, col := range j.Rel.JoinFKs { | ||
| 144 | - if i >= len(joinTable.PKs) { | ||
| 145 | - break | ||
| 146 | - } | 143 | + for i, col := range j.Rel.M2MJoinFKs { |
| 147 | pk := joinTable.PKs[i] | 144 | pk := joinTable.PKs[i] |
| 148 | q = q.Where("?.? = ?.?", | 145 | q = q.Where("?.? = ?.?", |
| 149 | joinTable.Alias, pk.Column, | 146 | joinTable.Alias, pk.Column, |
| @@ -242,7 +239,7 @@ func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []b | @@ -242,7 +239,7 @@ func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []b | ||
| 242 | isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) | 239 | isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.hasFlag(allWithDeletedFlag) |
| 243 | 240 | ||
| 244 | b = append(b, "LEFT JOIN "...) | 241 | b = append(b, "LEFT JOIN "...) |
| 245 | - b = fmter.FormatQuery(b, string(j.JoinModel.Table().FullNameForSelects)) | 242 | + b = fmter.FormatQuery(b, string(j.JoinModel.Table().SQLNameForSelects)) |
| 246 | b = append(b, " AS "...) | 243 | b = append(b, " AS "...) |
| 247 | b = j.appendAlias(b) | 244 | b = j.appendAlias(b) |
| 248 | 245 | ||
| @@ -252,38 +249,22 @@ func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []b | @@ -252,38 +249,22 @@ func (j *join) appendHasOneJoin(fmter QueryFormatter, b []byte, q *Query) (_ []b | ||
| 252 | b = append(b, '(') | 249 | b = append(b, '(') |
| 253 | } | 250 | } |
| 254 | 251 | ||
| 255 | - if len(j.Rel.FKs) > 1 { | 252 | + if len(j.Rel.BaseFKs) > 1 { |
| 256 | b = append(b, '(') | 253 | b = append(b, '(') |
| 257 | } | 254 | } |
| 258 | - if j.Rel.Type == HasOneRelation { | ||
| 259 | - for i, fk := range j.Rel.FKs { | ||
| 260 | - if i > 0 { | ||
| 261 | - b = append(b, " AND "...) | ||
| 262 | - } | ||
| 263 | - b = j.appendAlias(b) | ||
| 264 | - b = append(b, '.') | ||
| 265 | - b = append(b, j.Rel.JoinTable.PKs[i].Column...) | ||
| 266 | - b = append(b, " = "...) | ||
| 267 | - b = j.appendBaseAlias(b) | ||
| 268 | - b = append(b, '.') | ||
| 269 | - b = append(b, fk.Column...) | ||
| 270 | - } | ||
| 271 | - } else { | ||
| 272 | - baseTable := j.BaseModel.Table() | ||
| 273 | - for i, fk := range j.Rel.FKs { | ||
| 274 | - if i > 0 { | ||
| 275 | - b = append(b, " AND "...) | ||
| 276 | - } | ||
| 277 | - b = j.appendAlias(b) | ||
| 278 | - b = append(b, '.') | ||
| 279 | - b = append(b, fk.Column...) | ||
| 280 | - b = append(b, " = "...) | ||
| 281 | - b = j.appendBaseAlias(b) | ||
| 282 | - b = append(b, '.') | ||
| 283 | - b = append(b, baseTable.PKs[i].Column...) | 255 | + for i, baseFK := range j.Rel.BaseFKs { |
| 256 | + if i > 0 { | ||
| 257 | + b = append(b, " AND "...) | ||
| 284 | } | 258 | } |
| 259 | + b = j.appendAlias(b) | ||
| 260 | + b = append(b, '.') | ||
| 261 | + b = append(b, j.Rel.JoinFKs[i].Column...) | ||
| 262 | + b = append(b, " = "...) | ||
| 263 | + b = j.appendBaseAlias(b) | ||
| 264 | + b = append(b, '.') | ||
| 265 | + b = append(b, baseFK.Column...) | ||
| 285 | } | 266 | } |
| 286 | - if len(j.Rel.FKs) > 1 { | 267 | + if len(j.Rel.BaseFKs) > 1 { |
| 287 | b = append(b, ')') | 268 | b = append(b, ')') |
| 288 | } | 269 | } |
| 289 | 270 |
| @@ -31,6 +31,7 @@ type HooklessModel interface { | @@ -31,6 +31,7 @@ type HooklessModel interface { | ||
| 31 | type Model interface { | 31 | type Model interface { |
| 32 | HooklessModel | 32 | HooklessModel |
| 33 | 33 | ||
| 34 | + AfterScanHook | ||
| 34 | AfterSelectHook | 35 | AfterSelectHook |
| 35 | 36 | ||
| 36 | BeforeInsertHook | 37 | BeforeInsertHook |
| @@ -43,45 +44,90 @@ type Model interface { | @@ -43,45 +44,90 @@ type Model interface { | ||
| 43 | AfterDeleteHook | 44 | AfterDeleteHook |
| 44 | } | 45 | } |
| 45 | 46 | ||
| 46 | -func NewModel(values ...interface{}) (Model, error) { | 47 | +func NewModel(value interface{}) (Model, error) { |
| 48 | + return newModel(value, false) | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +func newScanModel(values []interface{}) (Model, error) { | ||
| 47 | if len(values) > 1 { | 52 | if len(values) > 1 { |
| 48 | return Scan(values...), nil | 53 | return Scan(values...), nil |
| 49 | } | 54 | } |
| 55 | + return newModel(values[0], true) | ||
| 56 | +} | ||
| 50 | 57 | ||
| 51 | - v0 := values[0] | ||
| 52 | - switch v0 := v0.(type) { | 58 | +func newModel(value interface{}, scan bool) (Model, error) { |
| 59 | + switch value := value.(type) { | ||
| 53 | case Model: | 60 | case Model: |
| 54 | - return v0, nil | 61 | + return value, nil |
| 55 | case HooklessModel: | 62 | case HooklessModel: |
| 56 | - return newModelWithHookStubs(v0), nil | 63 | + return newModelWithHookStubs(value), nil |
| 57 | case types.ValueScanner, sql.Scanner: | 64 | case types.ValueScanner, sql.Scanner: |
| 58 | - return Scan(v0), nil | 65 | + if !scan { |
| 66 | + return nil, fmt.Errorf("pg: Model(unsupported %T)", value) | ||
| 67 | + } | ||
| 68 | + return Scan(value), nil | ||
| 59 | } | 69 | } |
| 60 | 70 | ||
| 61 | - v := reflect.ValueOf(v0) | 71 | + v := reflect.ValueOf(value) |
| 62 | if !v.IsValid() { | 72 | if !v.IsValid() { |
| 63 | return nil, errModelNil | 73 | return nil, errModelNil |
| 64 | } | 74 | } |
| 65 | if v.Kind() != reflect.Ptr { | 75 | if v.Kind() != reflect.Ptr { |
| 66 | - return nil, fmt.Errorf("pg: Model(non-pointer %T)", v0) | 76 | + return nil, fmt.Errorf("pg: Model(non-pointer %T)", value) |
| 77 | + } | ||
| 78 | + | ||
| 79 | + if v.IsNil() { | ||
| 80 | + typ := v.Type().Elem() | ||
| 81 | + if typ.Kind() == reflect.Struct { | ||
| 82 | + return newStructTableModel(GetTable(typ)), nil | ||
| 83 | + } | ||
| 84 | + return nil, errModelNil | ||
| 67 | } | 85 | } |
| 86 | + | ||
| 68 | v = v.Elem() | 87 | v = v.Elem() |
| 69 | 88 | ||
| 89 | + if v.Kind() == reflect.Interface { | ||
| 90 | + if !v.IsNil() { | ||
| 91 | + v = v.Elem() | ||
| 92 | + if v.Kind() != reflect.Ptr { | ||
| 93 | + return nil, fmt.Errorf("pg: Model(non-pointer %s)", v.Type().String()) | ||
| 94 | + } | ||
| 95 | + } | ||
| 96 | + } | ||
| 97 | + | ||
| 70 | switch v.Kind() { | 98 | switch v.Kind() { |
| 71 | case reflect.Struct: | 99 | case reflect.Struct: |
| 72 | if v.Type() != timeType { | 100 | if v.Type() != timeType { |
| 73 | return newStructTableModelValue(v), nil | 101 | return newStructTableModelValue(v), nil |
| 74 | } | 102 | } |
| 75 | case reflect.Slice: | 103 | case reflect.Slice: |
| 76 | - typ := v.Type() | ||
| 77 | - elemType := indirectType(typ.Elem()) | ||
| 78 | - if elemType.Kind() == reflect.Struct && elemType != timeType { | ||
| 79 | - return newSliceTableModel(v, elemType), nil | 104 | + elemType := sliceElemType(v) |
| 105 | + switch elemType.Kind() { | ||
| 106 | + case reflect.Struct: | ||
| 107 | + if elemType != timeType { | ||
| 108 | + return newSliceTableModel(v, elemType), nil | ||
| 109 | + } | ||
| 110 | + case reflect.Map: | ||
| 111 | + if err := validMap(elemType); err != nil { | ||
| 112 | + return nil, err | ||
| 113 | + } | ||
| 114 | + slicePtr := v.Addr().Interface().(*[]map[string]interface{}) | ||
| 115 | + return newMapSliceModel(slicePtr), nil | ||
| 80 | } | 116 | } |
| 81 | return newSliceModel(v, elemType), nil | 117 | return newSliceModel(v, elemType), nil |
| 118 | + case reflect.Map: | ||
| 119 | + typ := v.Type() | ||
| 120 | + if err := validMap(typ); err != nil { | ||
| 121 | + return nil, err | ||
| 122 | + } | ||
| 123 | + mapPtr := v.Addr().Interface().(*map[string]interface{}) | ||
| 124 | + return newMapModel(mapPtr), nil | ||
| 82 | } | 125 | } |
| 83 | 126 | ||
| 84 | - return Scan(v0), nil | 127 | + if !scan { |
| 128 | + return nil, fmt.Errorf("pg: Model(unsupported %T)", value) | ||
| 129 | + } | ||
| 130 | + return Scan(value), nil | ||
| 85 | } | 131 | } |
| 86 | 132 | ||
| 87 | type modelWithHookStubs struct { | 133 | type modelWithHookStubs struct { |
| @@ -94,3 +140,11 @@ func newModelWithHookStubs(m HooklessModel) Model { | @@ -94,3 +140,11 @@ func newModelWithHookStubs(m HooklessModel) Model { | ||
| 94 | HooklessModel: m, | 140 | HooklessModel: m, |
| 95 | } | 141 | } |
| 96 | } | 142 | } |
| 143 | + | ||
| 144 | +func validMap(typ reflect.Type) error { | ||
| 145 | + if typ.Key().Kind() != reflect.String || typ.Elem().Kind() != reflect.Interface { | ||
| 146 | + return fmt.Errorf("pg: Model(unsupported %s, expected *map[string]interface{})", | ||
| 147 | + typ.String()) | ||
| 148 | + } | ||
| 149 | + return nil | ||
| 150 | +} |
| @@ -22,6 +22,6 @@ func (m Discard) AddColumnScanner(ColumnScanner) error { | @@ -22,6 +22,6 @@ func (m Discard) AddColumnScanner(ColumnScanner) error { | ||
| 22 | return nil | 22 | return nil |
| 23 | } | 23 | } |
| 24 | 24 | ||
| 25 | -func (m Discard) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error { | 25 | +func (m Discard) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { |
| 26 | return nil | 26 | return nil |
| 27 | } | 27 | } |
| 1 | +package orm | ||
| 2 | + | ||
| 3 | +import ( | ||
| 4 | + "github.com/go-pg/pg/v10/types" | ||
| 5 | +) | ||
| 6 | + | ||
| 7 | +type mapModel struct { | ||
| 8 | + hookStubs | ||
| 9 | + ptr *map[string]interface{} | ||
| 10 | + m map[string]interface{} | ||
| 11 | +} | ||
| 12 | + | ||
| 13 | +var _ Model = (*mapModel)(nil) | ||
| 14 | + | ||
| 15 | +func newMapModel(ptr *map[string]interface{}) *mapModel { | ||
| 16 | + model := &mapModel{ | ||
| 17 | + ptr: ptr, | ||
| 18 | + } | ||
| 19 | + if ptr != nil { | ||
| 20 | + model.m = *ptr | ||
| 21 | + } | ||
| 22 | + return model | ||
| 23 | +} | ||
| 24 | + | ||
| 25 | +func (m *mapModel) Init() error { | ||
| 26 | + return nil | ||
| 27 | +} | ||
| 28 | + | ||
| 29 | +func (m *mapModel) NextColumnScanner() ColumnScanner { | ||
| 30 | + if m.m == nil { | ||
| 31 | + m.m = make(map[string]interface{}) | ||
| 32 | + *m.ptr = m.m | ||
| 33 | + } | ||
| 34 | + return m | ||
| 35 | +} | ||
| 36 | + | ||
| 37 | +func (m mapModel) AddColumnScanner(ColumnScanner) error { | ||
| 38 | + return nil | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | +func (m *mapModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { | ||
| 42 | + val, err := types.ReadColumnValue(col, rd, n) | ||
| 43 | + if err != nil { | ||
| 44 | + return err | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + m.m[col.Name] = val | ||
| 48 | + return nil | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +func (mapModel) useQueryOne() bool { | ||
| 52 | + return true | ||
| 53 | +} |
| 1 | +package orm | ||
| 2 | + | ||
| 3 | +type mapSliceModel struct { | ||
| 4 | + mapModel | ||
| 5 | + slice *[]map[string]interface{} | ||
| 6 | +} | ||
| 7 | + | ||
| 8 | +var _ Model = (*mapSliceModel)(nil) | ||
| 9 | + | ||
| 10 | +func newMapSliceModel(ptr *[]map[string]interface{}) *mapSliceModel { | ||
| 11 | + return &mapSliceModel{ | ||
| 12 | + slice: ptr, | ||
| 13 | + } | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +func (m *mapSliceModel) Init() error { | ||
| 17 | + slice := *m.slice | ||
| 18 | + if len(slice) > 0 { | ||
| 19 | + *m.slice = slice[:0] | ||
| 20 | + } | ||
| 21 | + return nil | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +func (m *mapSliceModel) NextColumnScanner() ColumnScanner { | ||
| 25 | + slice := *m.slice | ||
| 26 | + if len(slice) == cap(slice) { | ||
| 27 | + m.mapModel.m = make(map[string]interface{}) | ||
| 28 | + *m.slice = append(slice, m.mapModel.m) //nolint:gocritic | ||
| 29 | + return m | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + slice = slice[:len(slice)+1] | ||
| 33 | + el := slice[len(slice)-1] | ||
| 34 | + if el != nil { | ||
| 35 | + m.mapModel.m = el | ||
| 36 | + } else { | ||
| 37 | + el = make(map[string]interface{}) | ||
| 38 | + slice[len(slice)-1] = el | ||
| 39 | + m.mapModel.m = el | ||
| 40 | + } | ||
| 41 | + *m.slice = slice | ||
| 42 | + return m | ||
| 43 | +} | ||
| 44 | + | ||
| 45 | +func (mapSliceModel) useQueryOne() {} //nolint:unused |
| @@ -29,12 +29,12 @@ func (m scanValuesModel) NextColumnScanner() ColumnScanner { | @@ -29,12 +29,12 @@ func (m scanValuesModel) NextColumnScanner() ColumnScanner { | ||
| 29 | return m | 29 | return m |
| 30 | } | 30 | } |
| 31 | 31 | ||
| 32 | -func (m scanValuesModel) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error { | ||
| 33 | - if colIdx >= len(m.values) { | 32 | +func (m scanValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { |
| 33 | + if int(col.Index) >= len(m.values) { | ||
| 34 | return fmt.Errorf("pg: no Scan var for column index=%d name=%q", | 34 | return fmt.Errorf("pg: no Scan var for column index=%d name=%q", |
| 35 | - colIdx, colName) | 35 | + col.Index, col.Name) |
| 36 | } | 36 | } |
| 37 | - return types.Scan(m.values[colIdx], rd, n) | 37 | + return types.Scan(m.values[col.Index], rd, n) |
| 38 | } | 38 | } |
| 39 | 39 | ||
| 40 | //------------------------------------------------------------------------------ | 40 | //------------------------------------------------------------------------------ |
| @@ -60,10 +60,10 @@ func (m scanReflectValuesModel) NextColumnScanner() ColumnScanner { | @@ -60,10 +60,10 @@ func (m scanReflectValuesModel) NextColumnScanner() ColumnScanner { | ||
| 60 | return m | 60 | return m |
| 61 | } | 61 | } |
| 62 | 62 | ||
| 63 | -func (m scanReflectValuesModel) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error { | ||
| 64 | - if colIdx >= len(m.values) { | 63 | +func (m scanReflectValuesModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { |
| 64 | + if int(col.Index) >= len(m.values) { | ||
| 65 | return fmt.Errorf("pg: no Scan var for column index=%d name=%q", | 65 | return fmt.Errorf("pg: no Scan var for column index=%d name=%q", |
| 66 | - colIdx, colName) | 66 | + col.Index, col.Name) |
| 67 | } | 67 | } |
| 68 | - return types.ScanValue(m.values[colIdx], rd, n) | 68 | + return types.ScanValue(m.values[col.Index], rd, n) |
| 69 | } | 69 | } |
| @@ -34,7 +34,7 @@ func (m *sliceModel) NextColumnScanner() ColumnScanner { | @@ -34,7 +34,7 @@ func (m *sliceModel) NextColumnScanner() ColumnScanner { | ||
| 34 | return m | 34 | return m |
| 35 | } | 35 | } |
| 36 | 36 | ||
| 37 | -func (m *sliceModel) ScanColumn(colIdx int, _ string, rd types.Reader, n int) error { | 37 | +func (m *sliceModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { |
| 38 | if m.nextElem == nil { | 38 | if m.nextElem == nil { |
| 39 | m.nextElem = internal.MakeSliceNextElemFunc(m.slice) | 39 | m.nextElem = internal.MakeSliceNextElemFunc(m.slice) |
| 40 | } | 40 | } |
| @@ -27,56 +27,8 @@ type TableModel interface { | @@ -27,56 +27,8 @@ type TableModel interface { | ||
| 27 | Kind() reflect.Kind | 27 | Kind() reflect.Kind |
| 28 | Value() reflect.Value | 28 | Value() reflect.Value |
| 29 | 29 | ||
| 30 | - setSoftDeleteField() | ||
| 31 | - scanColumn(int, string, types.Reader, int) (bool, error) | ||
| 32 | -} | ||
| 33 | - | ||
| 34 | -func newTableModel(value interface{}) (TableModel, error) { | ||
| 35 | - if value, ok := value.(TableModel); ok { | ||
| 36 | - return value, nil | ||
| 37 | - } | ||
| 38 | - | ||
| 39 | - v := reflect.ValueOf(value) | ||
| 40 | - if !v.IsValid() { | ||
| 41 | - return nil, errModelNil | ||
| 42 | - } | ||
| 43 | - if v.Kind() != reflect.Ptr { | ||
| 44 | - return nil, fmt.Errorf("pg: Model(non-pointer %T)", value) | ||
| 45 | - } | ||
| 46 | - | ||
| 47 | - if v.IsNil() { | ||
| 48 | - typ := v.Type().Elem() | ||
| 49 | - if typ.Kind() == reflect.Struct { | ||
| 50 | - return newStructTableModel(GetTable(typ)), nil | ||
| 51 | - } | ||
| 52 | - return nil, errModelNil | ||
| 53 | - } | ||
| 54 | - | ||
| 55 | - v = v.Elem() | ||
| 56 | - if v.Kind() == reflect.Interface { | ||
| 57 | - if !v.IsNil() { | ||
| 58 | - v = v.Elem() | ||
| 59 | - if v.Kind() != reflect.Ptr { | ||
| 60 | - return nil, fmt.Errorf("pg: Model(non-pointer %s)", v.Type().String()) | ||
| 61 | - } | ||
| 62 | - } | ||
| 63 | - } | ||
| 64 | - | ||
| 65 | - return newTableModelValue(v) | ||
| 66 | -} | ||
| 67 | - | ||
| 68 | -func newTableModelValue(v reflect.Value) (TableModel, error) { | ||
| 69 | - switch v.Kind() { | ||
| 70 | - case reflect.Struct: | ||
| 71 | - return newStructTableModelValue(v), nil | ||
| 72 | - case reflect.Slice: | ||
| 73 | - elemType := sliceElemType(v) | ||
| 74 | - if elemType.Kind() == reflect.Struct { | ||
| 75 | - return newSliceTableModel(v, elemType), nil | ||
| 76 | - } | ||
| 77 | - } | ||
| 78 | - | ||
| 79 | - return nil, fmt.Errorf("pg: Model(unsupported %s)", v.Type()) | 30 | + setSoftDeleteField() error |
| 31 | + scanColumn(types.ColumnInfo, types.Reader, int) (bool, error) | ||
| 80 | } | 32 | } |
| 81 | 33 | ||
| 82 | func newTableModelIndex(typ reflect.Type, root reflect.Value, index []int, rel *Relation) (TableModel, error) { | 34 | func newTableModelIndex(typ reflect.Type, root reflect.Value, index []int, rel *Relation) (TableModel, error) { |
| @@ -4,6 +4,7 @@ import ( | @@ -4,6 +4,7 @@ import ( | ||
| 4 | "fmt" | 4 | "fmt" |
| 5 | "reflect" | 5 | "reflect" |
| 6 | 6 | ||
| 7 | + "github.com/go-pg/pg/v10/internal/pool" | ||
| 7 | "github.com/go-pg/pg/v10/types" | 8 | "github.com/go-pg/pg/v10/types" |
| 8 | ) | 9 | ) |
| 9 | 10 | ||
| @@ -60,7 +61,7 @@ func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error { | @@ -60,7 +61,7 @@ func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error { | ||
| 60 | dstValues, ok := m.dstValues[string(buf)] | 61 | dstValues, ok := m.dstValues[string(buf)] |
| 61 | if !ok { | 62 | if !ok { |
| 62 | return fmt.Errorf( | 63 | return fmt.Errorf( |
| 63 | - "pg: relation=%q has no base %s with id=%q (check join conditions)", | 64 | + "pg: relation=%q does not have base %s with id=%q (check join conditions)", |
| 64 | m.rel.Field.GoName, m.baseTable, buf) | 65 | m.rel.Field.GoName, m.baseTable, buf) |
| 65 | } | 66 | } |
| 66 | 67 | ||
| @@ -76,31 +77,35 @@ func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error { | @@ -76,31 +77,35 @@ func (m *m2mModel) AddColumnScanner(_ ColumnScanner) error { | ||
| 76 | } | 77 | } |
| 77 | 78 | ||
| 78 | func (m *m2mModel) modelIDMap(b []byte) ([]byte, error) { | 79 | func (m *m2mModel) modelIDMap(b []byte) ([]byte, error) { |
| 79 | - for i, col := range m.rel.BaseFKs { | 80 | + for i, col := range m.rel.M2MBaseFKs { |
| 80 | if i > 0 { | 81 | if i > 0 { |
| 81 | b = append(b, ',') | 82 | b = append(b, ',') |
| 82 | } | 83 | } |
| 83 | if s, ok := m.columns[col]; ok { | 84 | if s, ok := m.columns[col]; ok { |
| 84 | b = append(b, s...) | 85 | b = append(b, s...) |
| 85 | } else { | 86 | } else { |
| 86 | - return nil, fmt.Errorf("pg: %s has no column=%q", | 87 | + return nil, fmt.Errorf("pg: %s does not have column=%q", |
| 87 | m.sliceTableModel, col) | 88 | m.sliceTableModel, col) |
| 88 | } | 89 | } |
| 89 | } | 90 | } |
| 90 | return b, nil | 91 | return b, nil |
| 91 | } | 92 | } |
| 92 | 93 | ||
| 93 | -func (m *m2mModel) ScanColumn(colIdx int, colName string, rd types.Reader, n int) error { | ||
| 94 | - ok, err := m.sliceTableModel.scanColumn(colIdx, colName, rd, n) | ||
| 95 | - if ok { | ||
| 96 | - return err | 94 | +func (m *m2mModel) ScanColumn(col types.ColumnInfo, rd types.Reader, n int) error { |
| 95 | + if n > 0 { | ||
| 96 | + b, err := rd.ReadFullTemp() | ||
| 97 | + if err != nil { | ||
| 98 | + return err | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + m.columns[col.Name] = string(b) | ||
| 102 | + rd = pool.NewBytesReader(b) | ||
| 103 | + } else { | ||
| 104 | + m.columns[col.Name] = "" | ||
| 97 | } | 105 | } |
| 98 | 106 | ||
| 99 | - tmp, err := rd.ReadFullTemp() | ||
| 100 | - if err != nil { | 107 | + if ok, err := m.sliceTableModel.scanColumn(col, rd, n); ok { |
| 101 | return err | 108 | return err |
| 102 | } | 109 | } |
| 103 | - | ||
| 104 | - m.columns[colName] = string(tmp) | ||
| 105 | return nil | 110 | return nil |
| 106 | } | 111 | } |
-
请 注册 或 登录 后发表评论