redirect.go 3.6 KB
package service_gateway

import (
	"fmt"
	"github.com/beego/beego/v2/client/httplib"
	"github.com/beego/beego/v2/server/web"
	"github.com/beego/beego/v2/server/web/context"
	"github.com/linmadan/egglib-go/log"
	"github.com/linmadan/egglib-go/utils/json"
	"gitlab.fjmaimaimai.com/allied-creation/allied-creation-manufacture/pkg/constant"
	"io/ioutil"
	"net/http"
	"strings"
)

type internalService interface {
	GetResponseData(result Response, data interface{}) error
	CreateRequest(url string, method string) *httplib.BeegoHTTPRequest
	Host() string
}

func RedirectInternalService(prefix string, svr internalService, log log.Logger) web.FilterFunc {
	return func(ctx *context.Context) {
		if !strings.HasPrefix(ctx.Request.RequestURI, prefix) {
			return
		}
		var err error
		var byteResult []byte
		var data = make(map[string]interface{})
		defer func() {
			if err != nil {
				ctx.Output.SetStatus(http.StatusOK)
				ctx.Output.JSON(map[string]interface{}{
					"msg":  err.Error(),
					"code": 1,
					"data": struct{}{},
				}, false, false)
			}
		}()
		method := strings.ToLower(ctx.Request.Method)
		url := strings.Replace(ctx.Request.RequestURI, prefix, "", 1)
		req := svr.CreateRequest(svr.Host()+url, method)
		log.Debug(method + "  请求url:" + svr.Host() + url)
		// 传递当前登录信息(可配置)
		req.Header("companyId", fmt.Sprintf("%v", constant.MANUFACTURE_DEFAULT_COMPANYID))
		req.Header("orgId", fmt.Sprintf("%v", constant.MANUFACTURE_DEFAULT_ORGID))
		req.Body(ctx.Input.RequestBody)
		response, err := req.Response()
		if err != nil {
			return
		}
		if response.StatusCode != http.StatusOK {
			err = fmt.Errorf("%v", response.Status)
			return
		}

		byteResult, err = ioutil.ReadAll(response.Body)
		if err != nil {
			return
		}
		defer response.Body.Close()

		// 透传非json数据
		contentType := response.Header.Get("Content-Type")
		if contentType != "application/json" {
			copyResponse(ctx, response, byteResult)
			return
		}

		var result Response
		err = json.Unmarshal(byteResult, &result)
		if err != nil {
			return
		}

		err = svr.GetResponseData(result, &data)
		if err != nil {
			return
		}
		ctx.Output.SetStatus(http.StatusOK)
		ctx.Output.JSON(map[string]interface{}{
			"msg":  "成功",
			"code": 0,
			"data": data,
		}, false, false)
	}
}

func InvokeInternalService(ctx *context.Context, svr internalService, log log.Logger, wrapper func(ctx *context.Context, req *httplib.BeegoHTTPRequest)) (interface{}, error) {
	var err error
	var byteResult []byte
	method := strings.ToLower(ctx.Request.Method)
	url := strings.Replace(ctx.Request.RequestURI, "", "", 1)
	req := svr.CreateRequest(svr.Host()+url, method)
	log.Debug(method + "  请求url:" + svr.Host() + url)
	// 传递当前登录信息(可配置)
	//req.Header("companyId", fmt.Sprintf("%v", constant.MANUFACTURE_DEFAULT_COMPANYID))
	//req.Header("orgId", fmt.Sprintf("%v", constant.MANUFACTURE_DEFAULT_ORGID))
	//req.Body(ctx.Input.RequestBody)
	if wrapper != nil {
		wrapper(ctx, req)
	}
	response, err := req.Response()
	if err != nil {
		return nil, err
	}
	if response.StatusCode != http.StatusOK {
		err = fmt.Errorf("%v", response.Status)
		return nil, err
	}

	byteResult, err = ioutil.ReadAll(response.Body)
	if err != nil {
		return nil, err
	}
	defer response.Body.Close()
	var result Response
	err = json.Unmarshal(byteResult, &result)
	if err != nil {
		return nil, err
	}
	return ([]byte)(result.Data), nil
}

func copyResponse(ctx *context.Context, response *http.Response, data []byte) {
	for k, v := range response.Header {
		if len(v) == 0 {
			continue
		}
		ctx.Output.Header(k, strings.Join(v, ","))
	}
	ctx.Output.SetStatus(response.StatusCode)
	ctx.Output.Body(data)
}