pool_sticky.go 3.8 KB
package pool

import (
	"context"
	"errors"
	"fmt"
	"sync/atomic"
)

const (
	stateDefault = 0
	stateInited  = 1
	stateClosed  = 2
)

type BadConnError struct {
	wrapped error
}

var _ error = (*BadConnError)(nil)

func (e BadConnError) Error() string {
	s := "pg: Conn is in a bad state"
	if e.wrapped != nil {
		s += ": " + e.wrapped.Error()
	}
	return s
}

func (e BadConnError) Unwrap() error {
	return e.wrapped
}

//------------------------------------------------------------------------------

type StickyConnPool struct {
	pool   Pooler
	shared int32 // atomic

	state uint32 // atomic
	ch    chan *Conn

	_badConnError atomic.Value
}

var _ Pooler = (*StickyConnPool)(nil)

func NewStickyConnPool(pool Pooler) *StickyConnPool {
	p, ok := pool.(*StickyConnPool)
	if !ok {
		p = &StickyConnPool{
			pool: pool,
			ch:   make(chan *Conn, 1),
		}
	}
	atomic.AddInt32(&p.shared, 1)
	return p
}

func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) {
	return p.pool.NewConn(ctx)
}

func (p *StickyConnPool) CloseConn(cn *Conn) error {
	return p.pool.CloseConn(cn)
}

func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
	// In worst case this races with Close which is not a very common operation.
	for i := 0; i < 1000; i++ {
		switch atomic.LoadUint32(&p.state) {
		case stateDefault:
			cn, err := p.pool.Get(ctx)
			if err != nil {
				return nil, err
			}
			if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
				return cn, nil
			}
			p.pool.Remove(ctx, cn, ErrClosed)
		case stateInited:
			if err := p.badConnError(); err != nil {
				return nil, err
			}
			cn, ok := <-p.ch
			if !ok {
				return nil, ErrClosed
			}
			return cn, nil
		case stateClosed:
			return nil, ErrClosed
		default:
			panic("not reached")
		}
	}
	return nil, fmt.Errorf("pg: StickyConnPool.Get: infinite loop")
}

func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) {
	defer func() {
		if recover() != nil {
			p.freeConn(ctx, cn)
		}
	}()
	p.ch <- cn
}

func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) {
	if err := p.badConnError(); err != nil {
		p.pool.Remove(ctx, cn, err)
	} else {
		p.pool.Put(ctx, cn)
	}
}

func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
	defer func() {
		if recover() != nil {
			p.pool.Remove(ctx, cn, ErrClosed)
		}
	}()
	p._badConnError.Store(BadConnError{wrapped: reason})
	p.ch <- cn
}

func (p *StickyConnPool) Close() error {
	if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
		return nil
	}

	for i := 0; i < 1000; i++ {
		state := atomic.LoadUint32(&p.state)
		if state == stateClosed {
			return ErrClosed
		}
		if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
			close(p.ch)
			cn, ok := <-p.ch
			if ok {
				p.freeConn(context.TODO(), cn)
			}
			return nil
		}
	}

	return errors.New("pg: StickyConnPool.Close: infinite loop")
}

func (p *StickyConnPool) Reset(ctx context.Context) error {
	if p.badConnError() == nil {
		return nil
	}

	select {
	case cn, ok := <-p.ch:
		if !ok {
			return ErrClosed
		}
		p.pool.Remove(ctx, cn, ErrClosed)
		p._badConnError.Store(BadConnError{wrapped: nil})
	default:
		return errors.New("pg: StickyConnPool does not have a Conn")
	}

	if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
		state := atomic.LoadUint32(&p.state)
		return fmt.Errorf("pg: invalid StickyConnPool state: %d", state)
	}

	return nil
}

func (p *StickyConnPool) badConnError() error {
	if v := p._badConnError.Load(); v != nil {
		err := v.(BadConnError)
		if err.wrapped != nil {
			return err
		}
	}
	return nil
}

func (p *StickyConnPool) Len() int {
	switch atomic.LoadUint32(&p.state) {
	case stateDefault:
		return 0
	case stateInited:
		return 1
	case stateClosed:
		return 0
	default:
		panic("not reached")
	}
}

func (p *StickyConnPool) IdleLen() int {
	return len(p.ch)
}

func (p *StickyConnPool) Stats() *Stats {
	return &Stats{}
}