coco/router_register.go

490 lines
12 KiB
Go
Raw Normal View History

2023-10-13 22:30:29 +08:00
package coco
2023-03-13 00:16:52 +08:00
import (
"bytes"
2024-04-20 02:00:06 +08:00
"errors"
2023-03-13 00:16:52 +08:00
"fmt"
2023-03-19 00:49:46 +08:00
"io"
2023-03-13 00:16:52 +08:00
"net/http"
"reflect"
"runtime"
"runtime/debug"
"strings"
2024-04-20 02:00:06 +08:00
json "github.com/bytedance/sonic"
2023-03-13 00:16:52 +08:00
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"github.com/gorilla/schema"
"github.com/sirupsen/logrus"
)
2023-10-13 22:30:29 +08:00
type ImplGetRouterMap interface {
GetRouterMap() *Routers
}
2023-03-13 00:16:52 +08:00
type Result struct {
ErrCode int `json:"err_code"`
ErrMsg string `json:"err_msg"`
Hint string `json:"hint,omitempty"`
Data interface{} `json:"data,omitempty"`
}
// Register 是一个组路由注射器
// 只能通过NewRegister()进行实例化 否则会panic
type Register struct {
// GET参数解析器
2024-04-25 23:18:43 +08:00
queryDecoder *schema.Decoder
engine *gin.Engine
addr string
defaultContentType string
paramsValidate bool
2023-03-19 00:49:46 +08:00
}
type RegisterOptions func(register *Register)
2024-04-25 23:18:43 +08:00
// WithGinMode 设置运行模式
2023-03-19 00:49:46 +08:00
func WithGinMode(mode string) RegisterOptions {
return func(register *Register) {
gin.SetMode(mode)
}
}
2024-04-25 23:18:43 +08:00
// WithRecovery 自动处理panic异常
2023-03-19 00:49:46 +08:00
func WithRecovery() RegisterOptions {
return func(register *Register) {
register.engine.Use(gin.Recovery())
}
}
2024-04-25 23:18:43 +08:00
// WithCors 设置跨域
2023-03-19 00:49:46 +08:00
func WithCors() RegisterOptions {
return func(register *Register) {
register.engine.Use(CorsFilter())
}
}
2024-04-25 23:18:43 +08:00
// WithListenAddress 设置监听地址
2023-03-19 00:49:46 +08:00
func WithListenAddress(addr string) RegisterOptions {
return func(register *Register) {
register.addr = addr
}
2023-03-13 00:16:52 +08:00
}
2024-04-25 23:18:43 +08:00
// WithDisableRequestParamsValidate 是否默认关闭请求参数校验
func WithDisableRequestParamsValidate() RegisterOptions {
return func(register *Register) {
register.paramsValidate = true
}
}
func WithDefaultContentType(ct string) RegisterOptions {
return func(register *Register) {
register.defaultContentType = ct
}
}
2023-03-13 00:16:52 +08:00
// NewRegister 实例化注册器
func NewRegister() *Register {
var register = &Register{queryDecoder: schema.NewDecoder()}
register.queryDecoder.IgnoreUnknownKeys(true)
register.queryDecoder.SetAliasTag("json")
return register
}
2023-03-19 00:49:46 +08:00
func (r *Register) DefaultRouter(opts ...RegisterOptions) *Register {
r.engine = gin.New()
for _, option := range opts {
option(r)
}
return r
}
2023-03-13 00:16:52 +08:00
// RegisterStruct 按照 struct 的方法进行路由注册
// routers: 绑定自动生成的路由配置文件 proto同级的 autogen_router_module.go 文件中的 AutoGenXXXMap
2023-03-19 00:49:46 +08:00
// ig: 需要注册的API组的struct ptr
2023-03-13 00:16:52 +08:00
// 对于错误或异常零容忍 直接panic
2023-10-13 22:30:29 +08:00
func (r *Register) RegisterStruct(drv interface{}, mws ...gin.HandlerFunc) {
if drv == nil {
logrus.Warnf("struct nil, skip register")
2023-03-13 00:16:52 +08:00
return
}
2023-10-13 22:30:29 +08:00
implGetRouters, ok := drv.(ImplGetRouterMap)
if !ok {
logrus.Panicf("struct not impl ImplGetRouterMap interface")
}
routers := implGetRouters.GetRouterMap()
if routers == nil {
logrus.Warnf("routers empty, skip register")
return
}
if len(routers.Apis) == 0 {
logrus.Warnf("%s api list empty, skip register", routers.StructName)
2023-03-19 00:49:46 +08:00
return
2023-03-13 00:16:52 +08:00
}
2023-03-19 00:49:46 +08:00
// new router group
if r.engine == nil {
r.DefaultRouter(WithGinMode(gin.ReleaseMode))
}
newRouter := r.engine.Group(routers.BaseURL, mws...)
refVal := reflect.ValueOf(drv)
refTyp := reflect.TypeOf(drv)
//routConfig := r.routers[bind.Bind()]
//if routConfig == nil {
// panic("no func to register")
//}
//if routConfig.Apis == nil {
// panic("no func to register")
//}
// 注册路由公共中间件
//if len(routConfig.Middlewares) != 0 {
// r.registerMiddleware(rout, routConfig.Middlewares)
//}
routMap := routers.Apis
for m := 0; m < refTyp.NumMethod(); m++ {
// 这里取出方法
method := refTyp.Method(m)
var node *RouterNode
var exist bool
if node, exist = routMap[method.Name]; !exist {
continue
}
// 注册路由
if err := r.registerHandle(newRouter, node, method.Func, refVal); err != nil {
2023-07-23 14:05:38 +08:00
logrus.Panicf("register handle failed: %+v", err)
2023-03-13 00:16:52 +08:00
}
}
}
2023-03-19 22:24:12 +08:00
// PreRun 运行前动作
func (r *Register) PreRun(f func() error) error {
if f == nil {
return nil
}
return f()
}
2023-03-22 21:41:03 +08:00
// RawEngine 获取gin路由引擎
func (r *Register) RawEngine() *gin.Engine {
return r.engine
}
2023-03-19 00:49:46 +08:00
// Run 服务运行
func (r *Register) Run() {
if r.addr == "" {
r.addr = ":8080"
}
if err := r.engine.Run(r.addr); err != nil {
logrus.Errorf("run server failed: %v", err)
}
2023-03-13 00:16:52 +08:00
}
// registerHandle 注册Handle
func (r *Register) registerHandle(router gin.IRouter, rc *RouterNode, rFunc, rGroup reflect.Value) error {
call, err := r.getCallFunc(rFunc, rGroup)
if err != nil {
2023-07-23 14:05:38 +08:00
logrus.Errorf("register handle failed: %+v", err)
2023-03-13 00:16:52 +08:00
return err
}
if call == nil {
2023-07-23 14:05:38 +08:00
logrus.Warnf("register handle failed: handle nil")
2023-03-13 00:16:52 +08:00
return nil
}
var hfs []gin.HandlerFunc
if len(rc.Middlewares) != 0 {
hfs = append(hfs, rc.Middlewares...)
hfs = append(hfs, call)
} else {
hfs = append(hfs, call)
}
switch rc.Method {
case http.MethodPost:
router.POST(rc.API, hfs...)
case http.MethodGet:
router.GET(rc.API, hfs...)
case http.MethodDelete:
router.DELETE(rc.API, hfs...)
case http.MethodPatch:
router.PATCH(rc.API, hfs...)
case http.MethodPut:
router.PUT(rc.API, hfs...)
case http.MethodOptions:
router.OPTIONS(rc.API, hfs...)
case http.MethodHead:
router.HEAD(rc.API, hfs...)
case "ANY":
router.Any(rc.API, hfs...)
default:
2023-07-23 14:05:38 +08:00
return fmt.Errorf("method: [%v] not support", rc.Method)
2023-03-13 00:16:52 +08:00
}
return nil
}
// getCallFunc 获取运行函数入口
func (r *Register) getCallFunc(rFunc, rGroup reflect.Value) (gin.HandlerFunc, error) {
typ := rFunc.Type() // 获取函数的类型
// 参数检查
if typ.NumIn() != 3 {
return nil, fmt.Errorf("func need two request param, (ctx, req)")
}
// 响应检查
if typ.NumOut() != 2 {
return nil, fmt.Errorf("func need two response param, (resp, error)")
}
// 第二返回参数是否是error
if returnType := typ.Out(1); returnType != reflect.TypeOf((*error)(nil)).Elem() {
2024-04-20 02:00:06 +08:00
return nil, fmt.Errorf("method : %v , returns[1] %v not error",
2023-03-13 00:16:52 +08:00
runtime.FuncForPC(rFunc.Pointer()).Name(), returnType.String())
}
ctxType, reqType := typ.In(1), typ.In(2)
if ctxType != reflect.TypeOf(&Context{}) {
return nil, fmt.Errorf("first param must *core.Context")
}
if reqType.Kind() != reflect.Ptr {
return nil, fmt.Errorf("req type not ptr")
}
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
logrus.Errorf("err: %+v\nstack: %+v", err, string(debug.Stack()))
return
}
}()
req := reflect.New(reqType.Elem())
// 参数校验
err := r.bindAndValidate(c, req.Interface())
if err != nil {
2023-07-23 14:05:38 +08:00
c.JSON(http.StatusOK, Result{
2024-04-06 01:50:08 +08:00
ErrCode: ErrorBadRequest,
2023-03-13 00:16:52 +08:00
ErrMsg: err.Error(),
})
return
}
var returnValues = rFunc.Call([]reflect.Value{rGroup, reflect.ValueOf(&Context{Context: c}), req})
// 重定向的情况
2024-04-25 23:18:43 +08:00
if r.isStatusFoundOrMoved(c) {
2023-03-13 00:16:52 +08:00
return
}
// 传输文件直接下载的情况
2024-04-25 23:18:43 +08:00
if !r.isApplicationJson(c) {
2023-03-13 00:16:52 +08:00
return
}
2024-04-25 23:18:43 +08:00
// 设置缺省 Content-Type
r.setDefaultContentType(c)
2023-03-13 00:16:52 +08:00
2024-04-25 23:18:43 +08:00
// 处理返回值
r.processReturnValue(c, returnValues)
2023-03-13 00:16:52 +08:00
}, nil
}
2023-07-23 14:05:38 +08:00
func (r *Register) mergeRequestParam(a, b map[string]interface{}) map[string]interface{} {
for k, v := range a {
b[k] = v
}
return b
}
2023-03-13 00:16:52 +08:00
// bindAndValidate 绑定并校验参数
func (r *Register) bindAndValidate(c *gin.Context, req interface{}) error {
2024-04-25 23:18:43 +08:00
if !r.paramsValidate {
return nil
}
2023-09-03 21:45:15 +08:00
// 非 application/json 直接透传
if c.GetHeader("Content-Type") != "application/json" {
2023-09-03 22:14:14 +08:00
// 如果只有query有数据则获取query中的参数
if len(c.Request.URL.Query()) != 0 {
err := r.queryDecoder.Decode(req, c.Request.URL.Query())
if err != nil {
logrus.Errorf("unmarshal query failed: %+v", err)
return err
}
}
2023-09-03 21:45:15 +08:00
return nil
}
2023-03-19 00:49:46 +08:00
bodyBytes, _ := io.ReadAll(c.Request.Body)
2024-04-06 01:50:08 +08:00
_ = c.Request.Body.Close()
2023-03-19 00:49:46 +08:00
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
2023-03-13 00:16:52 +08:00
2023-07-23 14:05:38 +08:00
// 如果只有query有数据则获取query中的参数
if len(c.Request.URL.Query()) != 0 && len(bodyBytes) == 0 {
2023-03-13 00:16:52 +08:00
err := r.queryDecoder.Decode(req, c.Request.URL.Query())
if err != nil {
2023-07-23 14:05:38 +08:00
logrus.Errorf("unmarshal query failed: %+v", err)
2023-03-13 00:16:52 +08:00
return err
}
2023-07-23 14:05:38 +08:00
}
// 如果只有body中有数据获取body中的参数
if len(c.Request.URL.Query()) == 0 && len(bodyBytes) != 0 {
2024-04-20 02:00:06 +08:00
err := json.Unmarshal(bodyBytes, req)
2023-07-23 14:05:38 +08:00
if err != nil {
logrus.Errorf("unmarshal body failed: %+v", err)
return err
2023-03-13 00:16:52 +08:00
}
2023-07-23 14:05:38 +08:00
}
// query和body中都有数据 合并他们 并且query有高优先级
if len(c.Request.URL.Query()) != 0 && len(bodyBytes) != 0 {
var queryMap = make(map[string]interface{})
err := r.queryDecoder.Decode(queryMap, c.Request.URL.Query())
if err != nil {
logrus.Errorf("unmarshal query failed: %+v", err)
return err
2023-03-13 00:16:52 +08:00
}
2023-07-23 14:05:38 +08:00
var bodyMap = make(map[string]interface{})
2024-04-20 02:00:06 +08:00
err = json.Unmarshal(bodyBytes, bodyMap)
2023-03-13 00:16:52 +08:00
if err != nil {
2023-07-23 14:05:38 +08:00
logrus.Errorf("unmarshal body failed: %+v", err)
return err
}
mergeMap := r.mergeRequestParam(queryMap, bodyMap)
2024-04-20 02:00:06 +08:00
body, err := json.Marshal(mergeMap)
2023-07-23 14:05:38 +08:00
if err != nil {
logrus.Errorf("merge query and body failed: %v", err)
return err
}
2024-04-20 02:00:06 +08:00
if err := json.Unmarshal(body, req); err != nil {
2023-07-23 14:05:38 +08:00
logrus.Errorf("unmarshal query and body failed: %v", err)
2023-03-13 00:16:52 +08:00
return err
}
}
var validate = validator.New()
validate.RegisterTagNameFunc(func(field reflect.StructField) string {
name := strings.SplitN(field.Tag.Get("json"), ",", 2)[0]
if name == "" {
return field.Name
}
if name == "-" {
return ""
}
return name
})
err := validate.Struct(req)
if err != nil {
2024-04-06 01:50:08 +08:00
var invalidValidationError *validator.InvalidValidationError
if errors.As(err, &invalidValidationError) {
2023-07-23 14:05:38 +08:00
logrus.Errorf("validate failed: %+v", err)
2023-03-13 00:16:52 +08:00
return err
}
for _, val := range err.(validator.ValidationErrors) {
return fmt.Errorf("invalid request, field %s, rule %s", val.Field(), val.Tag())
}
return err
}
return nil
}
2023-03-19 00:49:46 +08:00
2024-04-25 23:18:43 +08:00
// processReturnValue 处理返回值
func (r *Register) processReturnValue(c *gin.Context, returnValues []reflect.Value) {
if len(returnValues) != 2 {
return
}
resp := returnValues[0].Interface()
rerr := returnValues[1].Interface()
if rerr == nil {
c.Writer.WriteHeader(http.StatusOK)
respData, _ := json.Marshal(Result{
ErrCode: ErrorNil,
ErrMsg: GetErrMsgMsg(ErrorNil),
Data: resp,
})
_, _ = c.Writer.Write(respData)
return
}
var err error
var errCode int
var errMsg string
var isAutonomy bool
if reflect.TypeOf(rerr).String() == "*core.ErrMsg" {
e := rerr.(*ErrMsg)
if e.Autonomy {
err = e
errCode = int(e.ErrCode)
errMsg = e.ErrMsg
isAutonomy = true
}
}
if !isAutonomy {
err = rerr.(error)
errCode = GetErrMsgCode(err)
errMsg = GetErrMsgMsg(int32(errCode))
}
c.Writer.WriteHeader(http.StatusOK)
respData, _ := json.Marshal(Result{
ErrCode: errCode,
ErrMsg: errMsg,
Data: resp,
})
_, _ = c.Writer.Write(respData)
return
}
2023-03-19 00:49:46 +08:00
// CorsFilter 跨域过滤器
func CorsFilter() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Headers", "*")
c.Header("Access-Control-Allow-Methods", "*")
c.Header("Access-Control-Expose-Headers", "*")
if c.Request.Method == "OPTIONS" {
c.String(http.StatusNoContent, "OPTIONS")
return
}
}
}
// ResponseJsonHeader 统一设置返回json格式数据
func ResponseJsonHeader() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "application/json")
}
}
func ResponseHtmlHeader() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/html")
}
}
2024-04-25 23:18:43 +08:00
func (r *Register) setDefaultContentType(c *gin.Context) {
if r.defaultContentType == "" {
r.defaultContentType = "application/json"
}
if c.Writer.Header().Get("Content-Type") == "" {
c.Writer.Header().Set("Content-Type", r.defaultContentType)
}
}
func (r *Register) isApplicationJson(c *gin.Context) bool {
return c.Writer.Header().Get("Content-Type") == "application/json"
}
func (r *Register) isStatusFoundOrMoved(c *gin.Context) bool {
return c.Writer.Status() == http.StatusFound || c.Writer.Status() == http.StatusMovedPermanently
}