coco/router_register.go

490 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package coco
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"reflect"
"runtime"
"runtime/debug"
"strings"
json "github.com/bytedance/sonic"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"github.com/gorilla/schema"
"github.com/sirupsen/logrus"
)
type ImplGetRouterMap interface {
GetRouterMap() *Routers
}
type Result struct {
ErrCode int `json:"err_code"`
ErrMsg string `json:"err_msg"`
Hint string `json:"hint,omitempty"`
Data interface{} `json:"data,omitempty"`
}
// Register 是一个组路由注射器
// 只能通过NewRegister()进行实例化 否则会panic
type Register struct {
// GET参数解析器
queryDecoder *schema.Decoder
engine *gin.Engine
addr string
defaultContentType string
paramsValidate bool
}
type RegisterOptions func(register *Register)
// WithGinMode 设置运行模式
func WithGinMode(mode string) RegisterOptions {
return func(register *Register) {
gin.SetMode(mode)
}
}
// WithRecovery 自动处理panic异常
func WithRecovery() RegisterOptions {
return func(register *Register) {
register.engine.Use(gin.Recovery())
}
}
// WithCors 设置跨域
func WithCors() RegisterOptions {
return func(register *Register) {
register.engine.Use(CorsFilter())
}
}
// WithListenAddress 设置监听地址
func WithListenAddress(addr string) RegisterOptions {
return func(register *Register) {
register.addr = addr
}
}
// WithDisableRequestParamsValidate 是否默认关闭请求参数校验
func WithDisableRequestParamsValidate() RegisterOptions {
return func(register *Register) {
register.paramsValidate = true
}
}
func WithDefaultContentType(ct string) RegisterOptions {
return func(register *Register) {
register.defaultContentType = ct
}
}
// NewRegister 实例化注册器
func NewRegister() *Register {
var register = &Register{queryDecoder: schema.NewDecoder()}
register.queryDecoder.IgnoreUnknownKeys(true)
register.queryDecoder.SetAliasTag("json")
return register
}
func (r *Register) DefaultRouter(opts ...RegisterOptions) *Register {
r.engine = gin.New()
for _, option := range opts {
option(r)
}
return r
}
// RegisterStruct 按照 struct 的方法进行路由注册
// routers: 绑定自动生成的路由配置文件 proto同级的 autogen_router_module.go 文件中的 AutoGenXXXMap
// ig: 需要注册的API组的struct ptr
// 对于错误或异常零容忍 直接panic
func (r *Register) RegisterStruct(drv interface{}, mws ...gin.HandlerFunc) {
if drv == nil {
logrus.Warnf("struct nil, skip register")
return
}
implGetRouters, ok := drv.(ImplGetRouterMap)
if !ok {
logrus.Panicf("struct not impl ImplGetRouterMap interface")
}
routers := implGetRouters.GetRouterMap()
if routers == nil {
logrus.Warnf("routers empty, skip register")
return
}
if len(routers.Apis) == 0 {
logrus.Warnf("%s api list empty, skip register", routers.StructName)
return
}
// new router group
if r.engine == nil {
r.DefaultRouter(WithGinMode(gin.ReleaseMode))
}
newRouter := r.engine.Group(routers.BaseURL, mws...)
refVal := reflect.ValueOf(drv)
refTyp := reflect.TypeOf(drv)
//routConfig := r.routers[bind.Bind()]
//if routConfig == nil {
// panic("no func to register")
//}
//if routConfig.Apis == nil {
// panic("no func to register")
//}
// 注册路由公共中间件
//if len(routConfig.Middlewares) != 0 {
// r.registerMiddleware(rout, routConfig.Middlewares)
//}
routMap := routers.Apis
for m := 0; m < refTyp.NumMethod(); m++ {
// 这里取出方法
method := refTyp.Method(m)
var node *RouterNode
var exist bool
if node, exist = routMap[method.Name]; !exist {
continue
}
// 注册路由
if err := r.registerHandle(newRouter, node, method.Func, refVal); err != nil {
logrus.Panicf("register handle failed: %+v", err)
}
}
}
// PreRun 运行前动作
func (r *Register) PreRun(f func() error) error {
if f == nil {
return nil
}
return f()
}
// RawEngine 获取gin路由引擎
func (r *Register) RawEngine() *gin.Engine {
return r.engine
}
// Run 服务运行
func (r *Register) Run() {
if r.addr == "" {
r.addr = ":8080"
}
if err := r.engine.Run(r.addr); err != nil {
logrus.Errorf("run server failed: %v", err)
}
}
// registerHandle 注册Handle
func (r *Register) registerHandle(router gin.IRouter, rc *RouterNode, rFunc, rGroup reflect.Value) error {
call, err := r.getCallFunc(rFunc, rGroup)
if err != nil {
logrus.Errorf("register handle failed: %+v", err)
return err
}
if call == nil {
logrus.Warnf("register handle failed: handle nil")
return nil
}
var hfs []gin.HandlerFunc
if len(rc.Middlewares) != 0 {
hfs = append(hfs, rc.Middlewares...)
hfs = append(hfs, call)
} else {
hfs = append(hfs, call)
}
switch rc.Method {
case http.MethodPost:
router.POST(rc.API, hfs...)
case http.MethodGet:
router.GET(rc.API, hfs...)
case http.MethodDelete:
router.DELETE(rc.API, hfs...)
case http.MethodPatch:
router.PATCH(rc.API, hfs...)
case http.MethodPut:
router.PUT(rc.API, hfs...)
case http.MethodOptions:
router.OPTIONS(rc.API, hfs...)
case http.MethodHead:
router.HEAD(rc.API, hfs...)
case "ANY":
router.Any(rc.API, hfs...)
default:
return fmt.Errorf("method: [%v] not support", rc.Method)
}
return nil
}
// getCallFunc 获取运行函数入口
func (r *Register) getCallFunc(rFunc, rGroup reflect.Value) (gin.HandlerFunc, error) {
typ := rFunc.Type() // 获取函数的类型
// 参数检查
if typ.NumIn() != 3 {
return nil, fmt.Errorf("func need two request param, (ctx, req)")
}
// 响应检查
if typ.NumOut() != 2 {
return nil, fmt.Errorf("func need two response param, (resp, error)")
}
// 第二返回参数是否是error
if returnType := typ.Out(1); returnType != reflect.TypeOf((*error)(nil)).Elem() {
return nil, fmt.Errorf("method : %v , returns[1] %v not error",
runtime.FuncForPC(rFunc.Pointer()).Name(), returnType.String())
}
ctxType, reqType := typ.In(1), typ.In(2)
if ctxType != reflect.TypeOf(&Context{}) {
return nil, fmt.Errorf("first param must *core.Context")
}
if reqType.Kind() != reflect.Ptr {
return nil, fmt.Errorf("req type not ptr")
}
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
logrus.Errorf("err: %+v\nstack: %+v", err, string(debug.Stack()))
return
}
}()
req := reflect.New(reqType.Elem())
// 参数校验
err := r.bindAndValidate(c, req.Interface())
if err != nil {
c.JSON(http.StatusOK, Result{
ErrCode: ErrorBadRequest,
ErrMsg: err.Error(),
})
return
}
var returnValues = rFunc.Call([]reflect.Value{rGroup, reflect.ValueOf(&Context{Context: c}), req})
// 重定向的情况
if r.isStatusFoundOrMoved(c) {
return
}
// 传输文件直接下载的情况
if !r.isApplicationJson(c) {
return
}
// 设置缺省 Content-Type
r.setDefaultContentType(c)
// 处理返回值
r.processReturnValue(c, returnValues)
}, nil
}
func (r *Register) mergeRequestParam(a, b map[string]interface{}) map[string]interface{} {
for k, v := range a {
b[k] = v
}
return b
}
// bindAndValidate 绑定并校验参数
func (r *Register) bindAndValidate(c *gin.Context, req interface{}) error {
if !r.paramsValidate {
return nil
}
// 非 application/json 直接透传
if c.GetHeader("Content-Type") != "application/json" {
// 如果只有query有数据则获取query中的参数
if len(c.Request.URL.Query()) != 0 {
err := r.queryDecoder.Decode(req, c.Request.URL.Query())
if err != nil {
logrus.Errorf("unmarshal query failed: %+v", err)
return err
}
}
return nil
}
bodyBytes, _ := io.ReadAll(c.Request.Body)
_ = c.Request.Body.Close()
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 如果只有query有数据则获取query中的参数
if len(c.Request.URL.Query()) != 0 && len(bodyBytes) == 0 {
err := r.queryDecoder.Decode(req, c.Request.URL.Query())
if err != nil {
logrus.Errorf("unmarshal query failed: %+v", err)
return err
}
}
// 如果只有body中有数据获取body中的参数
if len(c.Request.URL.Query()) == 0 && len(bodyBytes) != 0 {
err := json.Unmarshal(bodyBytes, req)
if err != nil {
logrus.Errorf("unmarshal body failed: %+v", err)
return err
}
}
// query和body中都有数据 合并他们 并且query有高优先级
if len(c.Request.URL.Query()) != 0 && len(bodyBytes) != 0 {
var queryMap = make(map[string]interface{})
err := r.queryDecoder.Decode(queryMap, c.Request.URL.Query())
if err != nil {
logrus.Errorf("unmarshal query failed: %+v", err)
return err
}
var bodyMap = make(map[string]interface{})
err = json.Unmarshal(bodyBytes, bodyMap)
if err != nil {
logrus.Errorf("unmarshal body failed: %+v", err)
return err
}
mergeMap := r.mergeRequestParam(queryMap, bodyMap)
body, err := json.Marshal(mergeMap)
if err != nil {
logrus.Errorf("merge query and body failed: %v", err)
return err
}
if err := json.Unmarshal(body, req); err != nil {
logrus.Errorf("unmarshal query and body failed: %v", err)
return err
}
}
var validate = validator.New()
validate.RegisterTagNameFunc(func(field reflect.StructField) string {
name := strings.SplitN(field.Tag.Get("json"), ",", 2)[0]
if name == "" {
return field.Name
}
if name == "-" {
return ""
}
return name
})
err := validate.Struct(req)
if err != nil {
var invalidValidationError *validator.InvalidValidationError
if errors.As(err, &invalidValidationError) {
logrus.Errorf("validate failed: %+v", err)
return err
}
for _, val := range err.(validator.ValidationErrors) {
return fmt.Errorf("invalid request, field %s, rule %s", val.Field(), val.Tag())
}
return err
}
return nil
}
// processReturnValue 处理返回值
func (r *Register) processReturnValue(c *gin.Context, returnValues []reflect.Value) {
if len(returnValues) != 2 {
return
}
resp := returnValues[0].Interface()
rerr := returnValues[1].Interface()
if rerr == nil {
c.Writer.WriteHeader(http.StatusOK)
respData, _ := json.Marshal(Result{
ErrCode: ErrorNil,
ErrMsg: GetErrMsgMsg(ErrorNil),
Data: resp,
})
_, _ = c.Writer.Write(respData)
return
}
var err error
var errCode int
var errMsg string
var isAutonomy bool
if reflect.TypeOf(rerr).String() == "*core.ErrMsg" {
e := rerr.(*ErrMsg)
if e.Autonomy {
err = e
errCode = int(e.ErrCode)
errMsg = e.ErrMsg
isAutonomy = true
}
}
if !isAutonomy {
err = rerr.(error)
errCode = GetErrMsgCode(err)
errMsg = GetErrMsgMsg(int32(errCode))
}
c.Writer.WriteHeader(http.StatusOK)
respData, _ := json.Marshal(Result{
ErrCode: errCode,
ErrMsg: errMsg,
Data: resp,
})
_, _ = c.Writer.Write(respData)
return
}
// CorsFilter 跨域过滤器
func CorsFilter() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Headers", "*")
c.Header("Access-Control-Allow-Methods", "*")
c.Header("Access-Control-Expose-Headers", "*")
if c.Request.Method == "OPTIONS" {
c.String(http.StatusNoContent, "OPTIONS")
return
}
}
}
// ResponseJsonHeader 统一设置返回json格式数据
func ResponseJsonHeader() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "application/json")
}
}
func ResponseHtmlHeader() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/html")
}
}
func (r *Register) setDefaultContentType(c *gin.Context) {
if r.defaultContentType == "" {
r.defaultContentType = "application/json"
}
if c.Writer.Header().Get("Content-Type") == "" {
c.Writer.Header().Set("Content-Type", r.defaultContentType)
}
}
func (r *Register) isApplicationJson(c *gin.Context) bool {
return c.Writer.Header().Get("Content-Type") == "application/json"
}
func (r *Register) isStatusFoundOrMoved(c *gin.Context) bool {
return c.Writer.Status() == http.StatusFound || c.Writer.Status() == http.StatusMovedPermanently
}