partition.go 8.8 KB
package database

import (
	"fmt"
	"github.com/jinzhu/now"
	"github.com/zeromicro/go-zero/core/stores/redis"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"
	"gorm.io/gorm/migrator"
	"gorm.io/gorm/schema"
	"strings"
	"time"
)

var (
	// PartitionByRangeTime 按unix时间戳分区
	PartitionByRangeTime = 1
	// PartitionByHash 按系统的hash值分区
	PartitionByHash = 2
	// PartitionByList 按List包含值分区
	PartitionByList = 3
)

type PartitionTable interface {
	TableName() string
}

type PartitionMigrator struct {
	ServiceName string
	DB          *gorm.DB
	Redis       *redis.Redis
}

func NewPartitionMigrator(serviceName string, db *gorm.DB, redis *redis.Redis) *PartitionMigrator {
	return &PartitionMigrator{
		DB:          db,
		ServiceName: serviceName,
		Redis:       redis,
	}
}

func (c *PartitionMigrator) AutoMigrate(t PartitionTable, option ...PartitionOptionFunc) error {
	options := NewPartitionOptions()
	for i := range option {
		option[i](options)
	}

	tableName := t.TableName()
	if !c.DB.Migrator().HasTable(tableName) {
		migrator := Migrator{migrator.Migrator{
			migrator.Config{
				CreateIndexAfterCreateTable: true,
				DB:                          c.DB,
				Dialector:                   c.DB.Dialector,
			},
		}}
		if err := migrator.CreatePartitionTable(options, t); err != nil {
			panic(err)
		}
	}

	rk := fmt.Sprintf("%s:auto-partition:%s", c.ServiceName, tableName)
	lock := redis.NewRedisLock(c.Redis, rk)
	ok, err := lock.Acquire()
	if !ok || err != nil {
		return nil
	}
	defer lock.Release()
	switch options.Type {
	case PartitionByRangeTime:
		begin := options.TimeBegin
		end := options.TimeEnd
		for {
			if begin.Unix() > end.Unix() {
				break
			}
			pTable := fmt.Sprintf("%s_%s", tableName, options.FormatTimeSubFunc(begin))
			sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s PARTITION OF %s FOR VALUES FROM (%d) TO (%d);",
				pTable, tableName, begin.Unix(), begin.AddDate(0, options.TimeSpanMonth, 0).Unix())
			tx := c.DB.Exec(sql)
			if tx.Error != nil {
				return tx.Error
			}
			c.log(t, pTable)
			begin = begin.AddDate(0, options.TimeSpanMonth, 0)
		}
		break
	case PartitionByHash:
		for i := 0; i < options.Modulus; i++ {
			pTable := fmt.Sprintf("%s_%d", tableName, i)
			sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s PARTITION OF %s FOR VALUES WITH (MODULUS %d, REMAINDER %d);",
				pTable, tableName, options.Modulus, i)
			tx := c.DB.Exec(sql)
			if tx.Error != nil {
				return tx.Error
			}
			c.log(t, pTable)
		}
		break
	case PartitionByList:
		for i := 0; i < len(options.ListRange); i++ {
			pTable := fmt.Sprintf("%s_%d", tableName, i)
			sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s PARTITION OF %s FOR VALUES IN %s;",
				pTable, tableName, InArgs(options.ListRange[i]))
			tx := c.DB.Exec(sql)
			if tx.Error != nil {
				return tx.Error
			}
			c.log(t, pTable)
		}
		break
	default:
		return nil
	}

	return nil
}

func (c *PartitionMigrator) log(t PartitionTable, pTable string) {
	fmt.Println("【自动分区】 create partition table", pTable, "on table", t.TableName())
}

type PartitionOptions struct {
	// 分区类型 1:Hash 2:RangeTime
	Type int
	// 分区列
	Column string

	// Hash分区
	Modulus int

	// List 范围
	ListRange []interface{}

	// Range时间分区
	TimeBegin         time.Time
	TimeEnd           time.Time
	TimeSpanMonth     int
	FormatTimeSubFunc func(time.Time) string

	// 禁用PrimaryKey生成
	// 分区字段有函数表达式的,需要禁用掉PrimaryKey,使用自定义的唯一ID生成规则
	DisablePrimaryKey bool
}

func NewPartitionOptions() *PartitionOptions {
	return &PartitionOptions{
		Type: PartitionByRangeTime,
		FormatTimeSubFunc: func(t time.Time) string {
			return t.Format("200601")
		},
	}
}

func (c *PartitionOptions) Sql() string {
	if c.Type == PartitionByHash {
		return fmt.Sprintf("PARTITION BY HASH(%s)", c.Column)
	}
	if c.Type == PartitionByRangeTime {
		return fmt.Sprintf("PARTITION BY RANGE(%s)", c.Column)
	}
	if c.Type == PartitionByList {
		return fmt.Sprintf("PARTITION BY LIST(%s)", c.Column)
	}
	return ""
}

type PartitionOptionFunc func(*PartitionOptions)

func WithPartitionType(t int) PartitionOptionFunc {
	return func(options *PartitionOptions) {
		options.Type = t
	}
}

func WithPartitionColumn(c string) PartitionOptionFunc {
	return func(options *PartitionOptions) {
		options.Column = c
	}
}

func WithPartitionHash(modulus int) PartitionOptionFunc {
	return func(options *PartitionOptions) {
		options.Modulus = modulus
	}
}

func WithPartitionRangeTime(begin, end time.Time, spanMonth int) PartitionOptionFunc {
	return func(options *PartitionOptions) {
		options.TimeBegin = begin
		options.TimeEnd = end
		options.TimeSpanMonth = spanMonth
	}
}

func WithPartitionList(list ...interface{}) PartitionOptionFunc {
	return func(options *PartitionOptions) {
		options.ListRange = list
	}
}

func WithDisablePrimaryKey(disablePrimaryKey bool) PartitionOptionFunc {
	return func(options *PartitionOptions) {
		options.DisablePrimaryKey = disablePrimaryKey
	}
}

func Date(date string) time.Time {
	return now.MustParse(date)
}

type Migrator struct {
	migrator.Migrator
}

// CreatePartitionTable create table in database for values
func (m Migrator) CreatePartitionTable(options *PartitionOptions, values ...interface{}) error {
	for _, value := range m.ReorderModels(values, false) {
		tx := m.DB.Session(&gorm.Session{})
		if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
			var (
				createTableSQL          = "CREATE TABLE ? ("
				values                  = []interface{}{m.CurrentTable(stmt)}
				hasPrimaryKeyInDataType bool
			)

			for _, dbName := range stmt.Schema.DBNames {
				field := stmt.Schema.FieldsByDBName[dbName]
				if !field.IgnoreMigration {
					createTableSQL += "? ?"
					hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
					values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
					createTableSQL += ","
				}
			}

			if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 && !options.DisablePrimaryKey {
				createTableSQL += "PRIMARY KEY ?,"
				primaryKeys := []interface{}{}
				for _, field := range stmt.Schema.PrimaryFields {
					primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
				}

				values = append(values, primaryKeys)
			}

			for _, idx := range stmt.Schema.ParseIndexes() {
				if m.CreateIndexAfterCreateTable {
					defer func(value interface{}, name string) {
						if errr == nil {
							errr = tx.Migrator().CreateIndex(value, name)
						}
					}(value, idx.Name)
				} else {
					if idx.Class != "" {
						createTableSQL += idx.Class + " "
					}
					createTableSQL += "INDEX ? ?"

					if idx.Comment != "" {
						createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
					}

					if idx.Option != "" {
						createTableSQL += " " + idx.Option
					}

					createTableSQL += ","
					values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(migrator.BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
				}
			}

			if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
				for _, rel := range stmt.Schema.Relationships.Relations {
					if rel.Field.IgnoreMigration {
						continue
					}
					if constraint := rel.ParseConstraint(); constraint != nil {
						if constraint.Schema == stmt.Schema {
							sql, vars := buildConstraint(constraint)
							createTableSQL += sql + ","
							values = append(values, vars...)
						}
					}
				}
			}

			for _, chk := range stmt.Schema.ParseCheckConstraints() {
				createTableSQL += "CONSTRAINT ? CHECK (?),"
				values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
			}

			createTableSQL = strings.TrimSuffix(createTableSQL, ",")

			createTableSQL += ")"

			if options != nil {
				createTableSQL += options.Sql()
			}

			if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
				createTableSQL += fmt.Sprint(tableOption)
			}

			errr = tx.Exec(createTableSQL, values...).Error
			return errr
		}); err != nil {
			return err
		}
	}
	return nil
}

func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
	sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
	if constraint.OnDelete != "" {
		sql += " ON DELETE " + constraint.OnDelete
	}

	if constraint.OnUpdate != "" {
		sql += " ON UPDATE " + constraint.OnUpdate
	}

	var foreignKeys, references []interface{}
	for _, field := range constraint.ForeignKeys {
		foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
	}

	for _, field := range constraint.References {
		references = append(references, clause.Column{Name: field.DBName})
	}
	results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
	return
}