sqlTool.go 6.5 KB
package utils

import (
	"errors"
	"fmt"
	"oppmg/common/log"
	"oppmg/protocol"
	"reflect"
	"strings"

	"github.com/astaxie/beego/orm"
)

//PrintLogSql 打印sql语句
func PrintLogSql(sql string, param ...interface{}) {
	format := `SQL EXCEUTE:[%s]-%s`
	parmformat := `[%v]`
	var p strings.Builder
	for i := range param {
		p.WriteString(fmt.Sprintf(parmformat, param[i]))
	}
	log.Debug(format, sql, p.String())
}

//ExecuteQueryOne 执行原生sql查询单条记录;结果用结构体接收
func ExecuteQueryOne(result interface{}, sqlstr string, param ...interface{}) error {

	var err error
	o := orm.NewOrm()
	err = ExecuteQueryOneWithOrmer(o, result, sqlstr, param)
	return err
}

//ExecuteQueryOneWithOrmer 执行原生sql查询单条
func ExecuteQueryOneWithOrmer(o orm.Ormer, result interface{}, sqlstr string, param ...interface{}) error {
	PrintLogSql(sqlstr, param...)
	var err error
	err = o.Raw(sqlstr, param).QueryRow(result)
	if err != nil {
		return err
	}
	return nil
}

//ExecuteQuerySql 执行原生sql查询多条记录
func ExecuteQueryAll(result interface{}, sqlstr string, param ...interface{}) error {

	var err error
	o := orm.NewOrm()
	err = ExecuteQueryAllWithOrmer(o, result, sqlstr, param)
	return err
}

//ExecuteQueryOneWithOrmer 执行原生sql查询多条记录
func ExecuteQueryAllWithOrmer(o orm.Ormer, result interface{}, sqlstr string, param ...interface{}) error {
	PrintLogSql(sqlstr, param...)
	var (
		err error
	)
	_, err = o.Raw(sqlstr, param).QueryRows(result)
	if err != nil {
		return err
	}
	return nil
}

func ExecuteSQLWithOrmer(o orm.Ormer, sqlstr string, param ...interface{}) error {
	PrintLogSql(sqlstr, param...)
	var (
		err error
	)
	r, err := o.Raw(sqlstr, param...).Exec()
	if err != nil {
		return err
	}
	num, _ := r.RowsAffected()
	log.Debug("RowsAffected:%d", num)
	return nil
}

//ExecuteQuerySql 执行原生sql查询多条记录
func ExecuteQueryValue(result *[]orm.Params, sqlstr string, param ...interface{}) error {
	PrintLogSql(sqlstr, param...)
	var (
		err error
	)
	o := orm.NewOrm()
	_, err = o.Raw(sqlstr, param).Values(result)
	if err != nil {
		return err
	}
	return nil
}

type QueryDataByPage struct {
	CountSql string
	DataSql  string
	Param    []interface{}
	offset   int
	num      int
}

func NewQueryDataByPage(countsql, datasql string) *QueryDataByPage {
	return &QueryDataByPage{
		CountSql: countsql,
		DataSql:  datasql,
	}
}

//AddParam 添加条件参数
func (q *QueryDataByPage) AddParam(param ...interface{}) {
	q.Param = param
}

func (q *QueryDataByPage) LimitPage(offset, num int) {
	q.offset = offset
	q.num = num
}

//Query 执行分页查询
func (q *QueryDataByPage) Query(result interface{}) (pageinfo protocol.ResponsePageInfo, err error) {
	pagebegin := (q.offset - 1) * q.num
	if pagebegin < 0 {
		pagebegin = 0
	}
	var (
		total int
	)
	o := orm.NewOrm()
	err = ExecuteQueryOneWithOrmer(o, &total, q.CountSql, q.Param...)
	if err != nil {
		return
	}
	if total == 0 {
		return protocol.ResponsePageInfo{CurrentPage: q.offset, TotalPage: total}, nil
	}
	q.DataSql = fmt.Sprintf("%s limit %d,%d", q.DataSql, pagebegin, q.num)
	err = ExecuteQueryAllWithOrmer(o, result, q.DataSql, q.Param...)
	if err != nil {
		return
	}
	return protocol.ResponsePageInfo{CurrentPage: q.offset, TotalPage: total}, nil
}

// 更新指定表的几个列
func UpdateTableByMap(tabeleStruct interface{}, changeMap map[string]interface{}) error {
	if reflect.TypeOf(tabeleStruct).Kind() != reflect.Ptr {
		err := errors.New("UpdateTableByMap: tableStruct must ptr")
		log.Error(err.Error())
		return err
	}
	if len(changeMap) < 1 {
		log.Info("changeMap is nil")
		return nil
	}
	o := orm.NewOrm()
	changeColumn := make([]string, 0, len(changeMap))
	for i, v := range changeMap {
		changeColumn = append(changeColumn, i)
		if err := SetStructValueByType(tabeleStruct, i, v); err != nil {
			log.Error("err:%v key:%v value:%v", err.Error(), i, v)
			return err
		}
	}
	num, err := o.Update(tabeleStruct, changeColumn...)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	log.Info(fmt.Sprintf("UpdateTableByMap: table:%s effect records:%d column:%v", GetTableName(tabeleStruct), num, changeColumn))
	return nil
}

func UpdateTableByMapWithOrmer(o orm.Ormer, tabeleStruct interface{}, changeMap map[string]interface{}) error {
	if reflect.TypeOf(tabeleStruct).Kind() != reflect.Ptr {
		err := errors.New("UpdateTableByMap: tableStruct must ptr")
		log.Error(err.Error())
		return err
	}
	if len(changeMap) < 1 {
		log.Info("changeMap is nil")
		return nil
	}
	changeColumn := make([]string, 0, len(changeMap))
	for i, v := range changeMap {
		changeColumn = append(changeColumn, i)
		if err := SetStructValueByType(tabeleStruct, i, v); err != nil {
			log.Error("err:%v key:%v value:%v", err.Error(), i, v)
			return err
		}
	}
	num, err := o.Update(tabeleStruct, changeColumn...)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	log.Info(fmt.Sprintf("UpdateTableByMap: table:%s effect records:%d column:%v", GetTableName(tabeleStruct), num, changeColumn))
	return nil
}

// 通过反射调用结构对应的TableName函数,达到返回表名的目的
func GetTableName(tableStruct interface{}) string {
	m := reflect.ValueOf(tableStruct).MethodByName("TableName")
	if m.IsValid() && m.Kind() == reflect.Func {
		re := m.Call(nil)
		for _, v := range re {
			if v.IsValid() {
				return v.String()
			}
		}
	}
	return "unknown"
}

// 此函数将指定的结构体成员值更新到结构体中
func SetStructValueByType(s interface{}, columnType string, columnValue interface{}) error {
	columnValueV := reflect.ValueOf(columnValue)
	var setValue reflect.Value
	var flag = false
	v := reflect.ValueOf(s)
	for i, n := 0, v.Elem().NumField(); i < n; i++ {
		if v.Elem().Type().Field(i).Name == columnType {
			setValue = v.Elem().Field(i)
			flag = true
			break
		}
	}
	if !flag {
		return errors.New("struct is not type:")
	} else if !setValue.CanSet() {
		return errors.New("setValue.CanSet is false")
	} else if setValue.Kind() != columnValueV.Kind() {
		return errors.New("struct field and value of type is error")
	}
	switch columnValueV.Kind() {
	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
		setValue.SetInt(int64(columnValueV.Int()))
	case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		setValue.SetUint(uint64(columnValueV.Uint()))
	case reflect.Float32, reflect.Float64:
		setValue.SetFloat(float64(columnValueV.Float()))
	case reflect.String:
		setValue.SetString(columnValueV.String())
	case reflect.Struct:
		setValue.Set(columnValueV)
	default:
		return errors.New("columnValue err for:" + columnType)
	}
	return nil
}