middleware.go 4.5 KB
package middleware

import (
	"encoding/json"
	"errors"
	"fmt"
	"github.com/astaxie/beego/context"
	//"github.com/opentracing/opentracing-go"
	"github.com/tiptok/gocomm/common"
	"github.com/tiptok/gocomm/pkg/log"
	"gitlab.fjmaimaimai.com/mmm-go/godevp/pkg/application/cachex"
	"gitlab.fjmaimaimai.com/mmm-go/godevp/pkg/protocol"
	"net/http"
	"strconv"
	"strings"
	"time"
)

var (
	errAuthorization       = errors.New("无访问权限")
	errAuthorizationExpire = errors.New("权限已过期,请重新登录")
)

func CheckAuthorization(ctx *context.Context) {
	var (
		msg *protocol.ResponseMessage
	)
	defer func() {
		if msg != nil {
			ctx.Output.JSON(msg, false, false)
		}
	}()
	authorization := ctx.Input.Header("Authorization")
	if len(authorization) == 0 {
		msg = protocol.NewResponseMessage(2, errAuthorization.Error())
		return
	}
	var token string
	splitToken := strings.Split(authorization, " ")
	if len(splitToken) == 1 {
		token = splitToken[0]
	} else {
		token = splitToken[1]
	}
	claim, err := common.ParseJWTToken(token)
	if err != nil {
		msg = protocol.NewResponseMessage(2, errAuthorizationExpire.Error())
		return
	}
	userId, _ := strconv.Atoi(claim.Username)
	ctx.Input.SetData("x-mmm-id", userId)
	ctx.Input.SetData("x-mmm-uname", claim.AddData["UserName"])
	//ctx.Input.SetData("x-mmm-phone", claim.AddData["Phone"])
	return
}

func CheckRoleAccess(ctx *context.Context, object, method string) {
	var (
		msg *protocol.ResponseMessage
	)
	defer func() {
		if msg != nil {
			ctx.Output.JSON(msg, false, false)
		}
	}()
	userId := ctx.Input.GetData("x-mmm-id")
	if userId == nil {
		msg = protocol.NewResponseMessage(-1, errAuthorization.Error())
		return
	}
	validUserRole := cachex.CacheService{}
	if ok, _ := validUserRole.ValidUserAccess(int64(userId.(int)), object, method); !ok {
		msg = protocol.NewResponseMessage(-1, errAuthorization.Error())
		return
	}
	return
}

func InspectRoleAccess(parentObject string, skipUrl ...string) func(*context.Context) {
	return func(c *context.Context) {
		var validParentPermision bool
		if len(skipUrl) > 0 {
			requestUrl := c.Input.URL()
			for _, url := range skipUrl {
				if cachex.KeyMatch3(requestUrl, url) {
					validParentPermision = true
					break
				}
			}
		}
		// 跳过这个路由底下所有接口,使用父模块权限验证
		if len(parentObject) > 0 && len(skipUrl) == 0 {
			validParentPermision = true
		}
		CheckAuthorization(c)
		if validParentPermision {
			CheckRoleAccess(c, parentObject, c.Input.Method())
			return
		}
		CheckRoleAccess(c, c.Input.URL(), c.Input.Method())
	}
}

func CreateRequestLogFilter() func(ctx *context.Context) {
	return func(ctx *context.Context) {
		requestId := fmt.Sprintf("%v.%v.%v ", ctx.Input.Method(), ctx.Input.URI(), time.Now().UnixNano())
		ctx.Request.Header.Add("requestId", requestId)
		var body string = "{}"
		if len(ctx.Input.RequestBody) > 0 {
			body = string(ctx.Input.RequestBody)
		}
		log.Debug(fmt.Sprintf("====>Recv  RequestId:%s \nBodyData:%s", requestId, body))
	}
}

func CreateResponseLogFilter() func(ctx *context.Context) {
	return func(ctx *context.Context) {
		requestId := ctx.Request.Header.Get("requestId")
		body, _ := json.Marshal(ctx.Input.GetData("outputData"))
		if len(body) > 1000 {
			body = body[:1000]
		}
		log.Debug(fmt.Sprintf("<====Send User:%v RequestId:%v \nResponseData:%s", ctx.Input.GetData("x-mmm-id"), requestId, body))
	}
}

func AllowCors() func(ctx *context.Context) {
	return func(ctx *context.Context) {
		ctx.Output.Header("Access-Control-Allow-Methods", "OPTIONS,DELETE,POST,GET,PUT,PATCH")
		//ctx.Output.Header("Access-Control-Max-Age", "3600")
		ctx.Output.Header("Access-Control-Allow-Headers", "*")
		ctx.Output.Header("Access-Control-Allow-Credentials", "true")
		ctx.Output.Header("Access-Control-Allow-Origin", "*") //origin
		if ctx.Input.Method() == http.MethodOptions {
			// options请求,返回200
			ctx.Output.SetStatus(http.StatusOK)
			_ = ctx.Output.Body([]byte("options support"))
		}
	}
}

//func OpenTracingAdapter(ctx *context.Context) {
//	var sp opentracing.Span
//	opName := ctx.Input.URL()
//	// Attempt to join a trace by getting trace context from the headers.
//	wireContext, err := opentracing.GlobalTracer().Extract(
//		opentracing.HTTPHeaders,
//		opentracing.HTTPHeadersCarrier(ctx.Request.Header))
//	if err != nil {
//		// If for whatever reason we can't join, go ahead an start a new root span.
//		sp = opentracing.StartSpan(opName)
//	} else {
//		sp = opentracing.StartSpan(opName, opentracing.ChildOf(wireContext))
//	}
//	defer sp.Finish()
//}