package urlstruct import ( "database/sql" "encoding" "fmt" "reflect" "strconv" "time" ) var ( textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() timeType = reflect.TypeOf((*time.Time)(nil)).Elem() durationType = reflect.TypeOf((*time.Duration)(nil)).Elem() nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() nullInt64Type = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() nullFloat64Type = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() mapStringStringType = reflect.TypeOf((*map[string]string)(nil)).Elem() ) type scannerFunc func(v reflect.Value, values []string) error func scanner(typ reflect.Type) scannerFunc { if typ == timeType { return scanTime } if typ.Implements(textUnmarshalerType) { return scanTextUnmarshaler } if reflect.PtrTo(typ).Implements(textUnmarshalerType) { return scanTextUnmarshalerAddr } switch typ { case durationType: return scanDuration case nullBoolType: return scanNullBool case nullInt64Type: return scanNullInt64 case nullFloat64Type: return scanNullFloat64 case nullStringType: return scanNullString case mapStringStringType: return scanMapStringString } switch typ.Kind() { case reflect.Bool: return scanBool case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return scanInt64 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return scanUint64 case reflect.Float32: return scanFloat32 case reflect.Float64: return scanFloat64 case reflect.String: return scanString } return nil } func sliceScanner(typ reflect.Type) scannerFunc { switch typ.Elem().Kind() { case reflect.Int: return scanIntSlice case reflect.Int32: return scanInt32Slice case reflect.Int64: return scanInt64Slice case reflect.String: return scanStringSlice } return nil } func scanTextUnmarshaler(v reflect.Value, values []string) error { if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } u := v.Interface().(encoding.TextUnmarshaler) return u.UnmarshalText([]byte(values[0])) } func scanTextUnmarshalerAddr(v reflect.Value, values []string) error { if !v.CanAddr() { return fmt.Errorf("pg: Scan(nonsettable %s)", v.Type()) } u := v.Addr().Interface().(encoding.TextUnmarshaler) return u.UnmarshalText([]byte(values[0])) } func scanBool(v reflect.Value, values []string) error { f, err := strconv.ParseBool(values[0]) if err != nil { return err } v.SetBool(f) return nil } func scanInt64(v reflect.Value, values []string) error { n, err := strconv.ParseInt(values[0], 10, 64) if err != nil { return err } v.SetInt(n) return nil } func scanUint64(v reflect.Value, values []string) error { n, err := strconv.ParseUint(values[0], 10, 64) if err != nil { return err } v.SetUint(n) return nil } func scanFloat32(v reflect.Value, values []string) error { return scanFloat(v, values, 32) } func scanFloat64(v reflect.Value, values []string) error { return scanFloat(v, values, 64) } func scanFloat(v reflect.Value, values []string, bits int) error { n, err := strconv.ParseFloat(values[0], bits) if err != nil { return err } v.SetFloat(n) return nil } func scanString(v reflect.Value, values []string) error { v.SetString(values[0]) return nil } func scanTime(v reflect.Value, values []string) error { tm, err := parseTime(values[0]) if err != nil { return err } v.Set(reflect.ValueOf(tm)) return nil } func parseTime(s string) (time.Time, error) { n, err := strconv.ParseInt(s, 10, 64) if err == nil { return time.Unix(n, 0), nil } return time.Parse(time.RFC3339Nano, s) } func scanDuration(v reflect.Value, values []string) error { dur, err := time.ParseDuration(values[0]) if err != nil { return err } v.SetInt(int64(dur)) return nil } func scanNullBool(v reflect.Value, values []string) error { value := sql.NullBool{ Valid: true, } s := values[0] if s == "" { v.Set(reflect.ValueOf(value)) return nil } f, err := strconv.ParseBool(s) if err != nil { return err } value.Bool = f v.Set(reflect.ValueOf(value)) return nil } func scanNullInt64(v reflect.Value, values []string) error { value := sql.NullInt64{ Valid: true, } s := values[0] if s == "" { v.Set(reflect.ValueOf(value)) return nil } n, err := strconv.ParseInt(s, 10, 64) if err != nil { return err } value.Int64 = n v.Set(reflect.ValueOf(value)) return nil } func scanNullFloat64(v reflect.Value, values []string) error { value := sql.NullFloat64{ Valid: true, } s := values[0] if s == "" { v.Set(reflect.ValueOf(value)) return nil } n, err := strconv.ParseFloat(s, 64) if err != nil { return err } value.Float64 = n v.Set(reflect.ValueOf(value)) return nil } func scanNullString(v reflect.Value, values []string) error { value := sql.NullString{ Valid: true, } s := values[0] if s == "" { v.Set(reflect.ValueOf(value)) return nil } value.String = s v.Set(reflect.ValueOf(value)) return nil } func scanMapStringString(v reflect.Value, values []string) error { if len(values)%2 != 0 { return nil } m := make(map[string]string) for i := 0; i < len(values); i += 2 { m[values[i]] = values[i+1] } v.Set(reflect.ValueOf(m)) return nil } func scanIntSlice(v reflect.Value, values []string) error { nn := make([]int, 0, len(values)) for _, s := range values { n, err := strconv.Atoi(s) if err != nil { return err } nn = append(nn, n) } v.Set(reflect.ValueOf(nn)) return nil } func scanInt32Slice(v reflect.Value, values []string) error { nn := make([]int32, 0, len(values)) for _, s := range values { n, err := strconv.ParseInt(s, 10, 32) if err != nil { return err } nn = append(nn, int32(n)) } v.Set(reflect.ValueOf(nn)) return nil } func scanInt64Slice(v reflect.Value, values []string) error { nn := make([]int64, 0, len(values)) for _, s := range values { n, err := strconv.ParseInt(s, 10, 64) if err != nil { return err } nn = append(nn, n) } v.Set(reflect.ValueOf(nn)) return nil } func scanStringSlice(v reflect.Value, values []string) error { v.Set(reflect.ValueOf(values)) return nil }