pool_single.go 2.9 KB
package pool

import (
	"fmt"
	"sync/atomic"
)

const (
	stateDefault = 0
	stateInited  = 1
	stateClosed  = 2
)

type SingleConnPool struct {
	pool Pooler

	state uint32 // atomic
	ch    chan *Conn

	level      int32  // atomic
	hasBadConn uint32 // atomic
}

var _ Pooler = (*SingleConnPool)(nil)

func NewSingleConnPool(pool Pooler) *SingleConnPool {
	p, ok := pool.(*SingleConnPool)
	if !ok {
		p = &SingleConnPool{
			pool: pool,
			ch:   make(chan *Conn, 1),
		}
	}
	atomic.AddInt32(&p.level, 1)
	return p
}

func (p *SingleConnPool) Clone() *SingleConnPool {
	return NewSingleConnPool(p.pool)
}

func (p *SingleConnPool) SetConn(cn *Conn) {
	if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
		p.ch <- cn
		return
	}
	panic("not reached")
}

func (p *SingleConnPool) NewConn() (*Conn, error) {
	return p.pool.NewConn()
}

func (p *SingleConnPool) CloseConn(cn *Conn) error {
	return p.pool.CloseConn(cn)
}

func (p *SingleConnPool) Get() (*Conn, error) {
	for i := 0; i < 1e6; i++ {
		switch atomic.LoadUint32(&p.state) {
		case stateDefault:
			if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
				return p.pool.Get()
			}
		case stateInited:
			cn, ok := <-p.ch
			if !ok {
				return nil, ErrClosed
			}
			return cn, nil
		case stateClosed:
			return nil, ErrClosed
		default:
			panic("not reached")
		}
	}
	return nil, fmt.Errorf("pg: SingleConnPool.Get: infinite loop")
}

func (p *SingleConnPool) Put(cn *Conn) {
	defer func() {
		if recover() != nil {
			p.pool.Put(cn)
		}
	}()
	p.ch <- cn
}

func (p *SingleConnPool) Remove(cn *Conn) {
	defer func() {
		if recover() != nil {
			p.pool.Remove(cn)
		}
	}()
	atomic.StoreUint32(&p.hasBadConn, 1)
	p.ch <- cn
}

func (p *SingleConnPool) Len() int {
	switch atomic.LoadUint32(&p.state) {
	case stateDefault:
		return 0
	case stateInited:
		return 1
	case stateClosed:
		return 0
	default:
		panic("not reached")
	}
}

func (p *SingleConnPool) IdleLen() int {
	return len(p.ch)
}

func (p *SingleConnPool) Stats() *Stats {
	return &Stats{}
}

func (p *SingleConnPool) Close() error {
	level := atomic.AddInt32(&p.level, -1)
	if level > 0 {
		return nil
	}

	for i := 0; i < 1e6; i++ {
		state := atomic.LoadUint32(&p.state)
		if state == stateClosed {
			return nil
		}
		if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
			close(p.ch)
			cn, ok := <-p.ch
			if ok {
				if atomic.LoadUint32(&p.hasBadConn) == 1 {
					p.pool.Remove(cn)
				} else {
					p.pool.Put(cn)
				}
			}
			return nil
		}
	}

	return fmt.Errorf("pg: SingleConnPool.Close: infinite loop")
}

func (p *SingleConnPool) Reset() error {
	if atomic.LoadUint32(&p.hasBadConn) == 0 {
		return nil
	}

	cn, ok := <-p.ch
	if !ok {
		return fmt.Errorf("pg: SingleConnPool does not have a Conn")
	}
	p.pool.Remove(cn)

	if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
		state := atomic.LoadUint32(&p.state)
		return fmt.Errorf("pg: invalid SingleConnPool state: %d", state)
	}

	return nil
}