tables.go 2.3 KB
package orm

import (
	"fmt"
	"reflect"
	"sync"

	"github.com/go-pg/pg/v10/types"
)

var _tables = newTables()

type tableInProgress struct {
	table *Table

	init1Once sync.Once
	init2Once sync.Once
}

func newTableInProgress(table *Table) *tableInProgress {
	return &tableInProgress{
		table: table,
	}
}

func (inp *tableInProgress) init1() bool {
	var inited bool
	inp.init1Once.Do(func() {
		inp.table.init1()
		inited = true
	})
	return inited
}

func (inp *tableInProgress) init2() bool {
	var inited bool
	inp.init2Once.Do(func() {
		inp.table.init2()
		inited = true
	})
	return inited
}

// GetTable returns a Table for a struct type.
func GetTable(typ reflect.Type) *Table {
	return _tables.Get(typ)
}

// RegisterTable registers a struct as SQL table.
// It is usually used to register intermediate table
// in many to many relationship.
func RegisterTable(strct interface{}) {
	_tables.Register(strct)
}

type tables struct {
	tables sync.Map

	mu         sync.RWMutex
	inProgress map[reflect.Type]*tableInProgress
}

func newTables() *tables {
	return &tables{
		inProgress: make(map[reflect.Type]*tableInProgress),
	}
}

func (t *tables) Register(strct interface{}) {
	typ := reflect.TypeOf(strct)
	if typ.Kind() == reflect.Ptr {
		typ = typ.Elem()
	}
	_ = t.Get(typ)
}

func (t *tables) get(typ reflect.Type, allowInProgress bool) *Table {
	if typ.Kind() != reflect.Struct {
		panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
	}

	if v, ok := t.tables.Load(typ); ok {
		return v.(*Table)
	}

	t.mu.Lock()

	if v, ok := t.tables.Load(typ); ok {
		t.mu.Unlock()
		return v.(*Table)
	}

	var table *Table

	inProgress := t.inProgress[typ]
	if inProgress == nil {
		table = newTable(typ)
		inProgress = newTableInProgress(table)
		t.inProgress[typ] = inProgress
	} else {
		table = inProgress.table
	}

	t.mu.Unlock()

	inProgress.init1()
	if allowInProgress {
		return table
	}

	if inProgress.init2() {
		t.mu.Lock()
		delete(t.inProgress, typ)
		t.tables.Store(typ, table)
		t.mu.Unlock()
	}

	return table
}

func (t *tables) Get(typ reflect.Type) *Table {
	return t.get(typ, false)
}

func (t *tables) getByName(name types.Safe) *Table {
	var found *Table
	t.tables.Range(func(key, value interface{}) bool {
		t := value.(*Table)
		if t.SQLName == name {
			found = t
			return false
		}
		return true
	})
	return found
}