490 lines
12 KiB
Go
490 lines
12 KiB
Go
package coco
|
||
|
||
import (
|
||
"bytes"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"reflect"
|
||
"runtime"
|
||
"runtime/debug"
|
||
"strings"
|
||
|
||
json "github.com/bytedance/sonic"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/go-playground/validator/v10"
|
||
"github.com/gorilla/schema"
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
type ImplGetRouterMap interface {
|
||
GetRouterMap() *Routers
|
||
}
|
||
|
||
type Result struct {
|
||
ErrCode int `json:"err_code"`
|
||
ErrMsg string `json:"err_msg"`
|
||
Hint string `json:"hint,omitempty"`
|
||
Data interface{} `json:"data,omitempty"`
|
||
}
|
||
|
||
// Register 是一个组路由注射器
|
||
// 只能通过NewRegister()进行实例化 否则会panic
|
||
type Register struct {
|
||
// GET参数解析器
|
||
queryDecoder *schema.Decoder
|
||
engine *gin.Engine
|
||
addr string
|
||
defaultContentType string
|
||
paramsValidate bool
|
||
}
|
||
|
||
type RegisterOptions func(register *Register)
|
||
|
||
// WithGinMode 设置运行模式
|
||
func WithGinMode(mode string) RegisterOptions {
|
||
return func(register *Register) {
|
||
gin.SetMode(mode)
|
||
}
|
||
}
|
||
|
||
// WithRecovery 自动处理panic异常
|
||
func WithRecovery() RegisterOptions {
|
||
return func(register *Register) {
|
||
register.engine.Use(gin.Recovery())
|
||
}
|
||
}
|
||
|
||
// WithCors 设置跨域
|
||
func WithCors() RegisterOptions {
|
||
return func(register *Register) {
|
||
register.engine.Use(CorsFilter())
|
||
}
|
||
}
|
||
|
||
// WithListenAddress 设置监听地址
|
||
func WithListenAddress(addr string) RegisterOptions {
|
||
return func(register *Register) {
|
||
register.addr = addr
|
||
}
|
||
}
|
||
|
||
// WithDisableRequestParamsValidate 是否默认关闭请求参数校验
|
||
func WithDisableRequestParamsValidate() RegisterOptions {
|
||
return func(register *Register) {
|
||
register.paramsValidate = true
|
||
}
|
||
}
|
||
|
||
func WithDefaultContentType(ct string) RegisterOptions {
|
||
return func(register *Register) {
|
||
register.defaultContentType = ct
|
||
}
|
||
}
|
||
|
||
// NewRegister 实例化注册器
|
||
func NewRegister() *Register {
|
||
var register = &Register{queryDecoder: schema.NewDecoder()}
|
||
register.queryDecoder.IgnoreUnknownKeys(true)
|
||
register.queryDecoder.SetAliasTag("json")
|
||
return register
|
||
}
|
||
|
||
func (r *Register) DefaultRouter(opts ...RegisterOptions) *Register {
|
||
r.engine = gin.New()
|
||
for _, option := range opts {
|
||
option(r)
|
||
}
|
||
return r
|
||
}
|
||
|
||
// RegisterStruct 按照 struct 的方法进行路由注册
|
||
// routers: 绑定自动生成的路由配置文件 proto同级的 autogen_router_module.go 文件中的 AutoGenXXXMap
|
||
// ig: 需要注册的API组的struct ptr
|
||
// 对于错误或异常零容忍 直接panic
|
||
func (r *Register) RegisterStruct(drv interface{}, mws ...gin.HandlerFunc) {
|
||
if drv == nil {
|
||
logrus.Warnf("struct nil, skip register")
|
||
return
|
||
}
|
||
|
||
implGetRouters, ok := drv.(ImplGetRouterMap)
|
||
if !ok {
|
||
logrus.Panicf("struct not impl ImplGetRouterMap interface")
|
||
}
|
||
routers := implGetRouters.GetRouterMap()
|
||
if routers == nil {
|
||
logrus.Warnf("routers empty, skip register")
|
||
return
|
||
}
|
||
if len(routers.Apis) == 0 {
|
||
logrus.Warnf("%s api list empty, skip register", routers.StructName)
|
||
return
|
||
}
|
||
|
||
// new router group
|
||
if r.engine == nil {
|
||
r.DefaultRouter(WithGinMode(gin.ReleaseMode))
|
||
}
|
||
newRouter := r.engine.Group(routers.BaseURL, mws...)
|
||
|
||
refVal := reflect.ValueOf(drv)
|
||
refTyp := reflect.TypeOf(drv)
|
||
|
||
//routConfig := r.routers[bind.Bind()]
|
||
//if routConfig == nil {
|
||
// panic("no func to register")
|
||
//}
|
||
//if routConfig.Apis == nil {
|
||
// panic("no func to register")
|
||
//}
|
||
// 注册路由公共中间件
|
||
//if len(routConfig.Middlewares) != 0 {
|
||
// r.registerMiddleware(rout, routConfig.Middlewares)
|
||
//}
|
||
routMap := routers.Apis
|
||
for m := 0; m < refTyp.NumMethod(); m++ {
|
||
// 这里取出方法
|
||
method := refTyp.Method(m)
|
||
|
||
var node *RouterNode
|
||
var exist bool
|
||
if node, exist = routMap[method.Name]; !exist {
|
||
continue
|
||
}
|
||
|
||
// 注册路由
|
||
if err := r.registerHandle(newRouter, node, method.Func, refVal); err != nil {
|
||
logrus.Panicf("register handle failed: %+v", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
// PreRun 运行前动作
|
||
func (r *Register) PreRun(f func() error) error {
|
||
if f == nil {
|
||
return nil
|
||
}
|
||
return f()
|
||
}
|
||
|
||
// RawEngine 获取gin路由引擎
|
||
func (r *Register) RawEngine() *gin.Engine {
|
||
return r.engine
|
||
}
|
||
|
||
// Run 服务运行
|
||
func (r *Register) Run() {
|
||
if r.addr == "" {
|
||
r.addr = ":8080"
|
||
}
|
||
if err := r.engine.Run(r.addr); err != nil {
|
||
logrus.Errorf("run server failed: %v", err)
|
||
}
|
||
}
|
||
|
||
// registerHandle 注册Handle
|
||
func (r *Register) registerHandle(router gin.IRouter, rc *RouterNode, rFunc, rGroup reflect.Value) error {
|
||
call, err := r.getCallFunc(rFunc, rGroup)
|
||
if err != nil {
|
||
logrus.Errorf("register handle failed: %+v", err)
|
||
return err
|
||
}
|
||
if call == nil {
|
||
logrus.Warnf("register handle failed: handle nil")
|
||
return nil
|
||
}
|
||
|
||
var hfs []gin.HandlerFunc
|
||
if len(rc.Middlewares) != 0 {
|
||
hfs = append(hfs, rc.Middlewares...)
|
||
hfs = append(hfs, call)
|
||
} else {
|
||
hfs = append(hfs, call)
|
||
}
|
||
|
||
switch rc.Method {
|
||
case http.MethodPost:
|
||
router.POST(rc.API, hfs...)
|
||
case http.MethodGet:
|
||
router.GET(rc.API, hfs...)
|
||
case http.MethodDelete:
|
||
router.DELETE(rc.API, hfs...)
|
||
case http.MethodPatch:
|
||
router.PATCH(rc.API, hfs...)
|
||
case http.MethodPut:
|
||
router.PUT(rc.API, hfs...)
|
||
case http.MethodOptions:
|
||
router.OPTIONS(rc.API, hfs...)
|
||
case http.MethodHead:
|
||
router.HEAD(rc.API, hfs...)
|
||
case "ANY":
|
||
router.Any(rc.API, hfs...)
|
||
default:
|
||
return fmt.Errorf("method: [%v] not support", rc.Method)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// getCallFunc 获取运行函数入口
|
||
func (r *Register) getCallFunc(rFunc, rGroup reflect.Value) (gin.HandlerFunc, error) {
|
||
typ := rFunc.Type() // 获取函数的类型
|
||
|
||
// 参数检查
|
||
if typ.NumIn() != 3 {
|
||
return nil, fmt.Errorf("func need two request param, (ctx, req)")
|
||
}
|
||
|
||
// 响应检查
|
||
if typ.NumOut() != 2 {
|
||
return nil, fmt.Errorf("func need two response param, (resp, error)")
|
||
}
|
||
|
||
// 第二返回参数是否是error
|
||
if returnType := typ.Out(1); returnType != reflect.TypeOf((*error)(nil)).Elem() {
|
||
return nil, fmt.Errorf("method : %v , returns[1] %v not error",
|
||
runtime.FuncForPC(rFunc.Pointer()).Name(), returnType.String())
|
||
}
|
||
|
||
ctxType, reqType := typ.In(1), typ.In(2)
|
||
if ctxType != reflect.TypeOf(&Context{}) {
|
||
return nil, fmt.Errorf("first param must *core.Context")
|
||
}
|
||
|
||
if reqType.Kind() != reflect.Ptr {
|
||
return nil, fmt.Errorf("req type not ptr")
|
||
}
|
||
|
||
return func(c *gin.Context) {
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
logrus.Errorf("err: %+v\nstack: %+v", err, string(debug.Stack()))
|
||
return
|
||
}
|
||
}()
|
||
|
||
req := reflect.New(reqType.Elem())
|
||
// 参数校验
|
||
err := r.bindAndValidate(c, req.Interface())
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, Result{
|
||
ErrCode: ErrorBadRequest,
|
||
ErrMsg: err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
var returnValues = rFunc.Call([]reflect.Value{rGroup, reflect.ValueOf(&Context{Context: c}), req})
|
||
|
||
// 重定向的情况
|
||
if r.isStatusFoundOrMoved(c) {
|
||
return
|
||
}
|
||
|
||
// 传输文件直接下载的情况
|
||
if !r.isApplicationJson(c) {
|
||
return
|
||
}
|
||
|
||
// 设置缺省 Content-Type
|
||
r.setDefaultContentType(c)
|
||
|
||
// 处理返回值
|
||
r.processReturnValue(c, returnValues)
|
||
}, nil
|
||
}
|
||
|
||
func (r *Register) mergeRequestParam(a, b map[string]interface{}) map[string]interface{} {
|
||
for k, v := range a {
|
||
b[k] = v
|
||
}
|
||
return b
|
||
}
|
||
|
||
// bindAndValidate 绑定并校验参数
|
||
func (r *Register) bindAndValidate(c *gin.Context, req interface{}) error {
|
||
if !r.paramsValidate {
|
||
return nil
|
||
}
|
||
// 非 application/json 直接透传
|
||
if c.GetHeader("Content-Type") != "application/json" {
|
||
// 如果只有query有数据,则获取query中的参数
|
||
if len(c.Request.URL.Query()) != 0 {
|
||
err := r.queryDecoder.Decode(req, c.Request.URL.Query())
|
||
if err != nil {
|
||
logrus.Errorf("unmarshal query failed: %+v", err)
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||
_ = c.Request.Body.Close()
|
||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||
|
||
// 如果只有query有数据,则获取query中的参数
|
||
if len(c.Request.URL.Query()) != 0 && len(bodyBytes) == 0 {
|
||
err := r.queryDecoder.Decode(req, c.Request.URL.Query())
|
||
if err != nil {
|
||
logrus.Errorf("unmarshal query failed: %+v", err)
|
||
return err
|
||
}
|
||
}
|
||
// 如果只有body中有数据,获取body中的参数
|
||
if len(c.Request.URL.Query()) == 0 && len(bodyBytes) != 0 {
|
||
err := json.Unmarshal(bodyBytes, req)
|
||
if err != nil {
|
||
logrus.Errorf("unmarshal body failed: %+v", err)
|
||
return err
|
||
}
|
||
}
|
||
// query和body中都有数据 合并他们 并且query有高优先级
|
||
if len(c.Request.URL.Query()) != 0 && len(bodyBytes) != 0 {
|
||
var queryMap = make(map[string]interface{})
|
||
err := r.queryDecoder.Decode(queryMap, c.Request.URL.Query())
|
||
if err != nil {
|
||
logrus.Errorf("unmarshal query failed: %+v", err)
|
||
return err
|
||
}
|
||
var bodyMap = make(map[string]interface{})
|
||
err = json.Unmarshal(bodyBytes, bodyMap)
|
||
if err != nil {
|
||
logrus.Errorf("unmarshal body failed: %+v", err)
|
||
return err
|
||
}
|
||
mergeMap := r.mergeRequestParam(queryMap, bodyMap)
|
||
body, err := json.Marshal(mergeMap)
|
||
if err != nil {
|
||
logrus.Errorf("merge query and body failed: %v", err)
|
||
return err
|
||
}
|
||
if err := json.Unmarshal(body, req); err != nil {
|
||
logrus.Errorf("unmarshal query and body failed: %v", err)
|
||
return err
|
||
}
|
||
}
|
||
|
||
var validate = validator.New()
|
||
validate.RegisterTagNameFunc(func(field reflect.StructField) string {
|
||
name := strings.SplitN(field.Tag.Get("json"), ",", 2)[0]
|
||
if name == "" {
|
||
return field.Name
|
||
}
|
||
if name == "-" {
|
||
return ""
|
||
}
|
||
return name
|
||
})
|
||
|
||
err := validate.Struct(req)
|
||
if err != nil {
|
||
var invalidValidationError *validator.InvalidValidationError
|
||
if errors.As(err, &invalidValidationError) {
|
||
logrus.Errorf("validate failed: %+v", err)
|
||
return err
|
||
}
|
||
|
||
for _, val := range err.(validator.ValidationErrors) {
|
||
return fmt.Errorf("invalid request, field %s, rule %s", val.Field(), val.Tag())
|
||
}
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// processReturnValue 处理返回值
|
||
func (r *Register) processReturnValue(c *gin.Context, returnValues []reflect.Value) {
|
||
if len(returnValues) != 2 {
|
||
return
|
||
}
|
||
resp := returnValues[0].Interface()
|
||
rerr := returnValues[1].Interface()
|
||
|
||
if rerr == nil {
|
||
c.Writer.WriteHeader(http.StatusOK)
|
||
respData, _ := json.Marshal(Result{
|
||
ErrCode: ErrorNil,
|
||
ErrMsg: GetErrMsgMsg(ErrorNil),
|
||
Data: resp,
|
||
})
|
||
_, _ = c.Writer.Write(respData)
|
||
return
|
||
}
|
||
|
||
var err error
|
||
var errCode int
|
||
var errMsg string
|
||
|
||
var isAutonomy bool
|
||
if reflect.TypeOf(rerr).String() == "*core.ErrMsg" {
|
||
e := rerr.(*ErrMsg)
|
||
if e.Autonomy {
|
||
err = e
|
||
errCode = int(e.ErrCode)
|
||
errMsg = e.ErrMsg
|
||
isAutonomy = true
|
||
}
|
||
}
|
||
|
||
if !isAutonomy {
|
||
err = rerr.(error)
|
||
errCode = GetErrMsgCode(err)
|
||
errMsg = GetErrMsgMsg(int32(errCode))
|
||
}
|
||
|
||
c.Writer.WriteHeader(http.StatusOK)
|
||
respData, _ := json.Marshal(Result{
|
||
ErrCode: errCode,
|
||
ErrMsg: errMsg,
|
||
Data: resp,
|
||
})
|
||
_, _ = c.Writer.Write(respData)
|
||
return
|
||
}
|
||
|
||
// CorsFilter 跨域过滤器
|
||
func CorsFilter() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
c.Header("Access-Control-Allow-Origin", "*")
|
||
c.Header("Access-Control-Allow-Headers", "*")
|
||
c.Header("Access-Control-Allow-Methods", "*")
|
||
c.Header("Access-Control-Expose-Headers", "*")
|
||
if c.Request.Method == "OPTIONS" {
|
||
c.String(http.StatusNoContent, "OPTIONS")
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// ResponseJsonHeader 统一设置返回json格式数据
|
||
func ResponseJsonHeader() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
c.Writer.Header().Set("Content-Type", "application/json")
|
||
}
|
||
}
|
||
|
||
func ResponseHtmlHeader() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
c.Writer.Header().Set("Content-Type", "text/html")
|
||
}
|
||
}
|
||
|
||
func (r *Register) setDefaultContentType(c *gin.Context) {
|
||
if r.defaultContentType == "" {
|
||
r.defaultContentType = "application/json"
|
||
}
|
||
if c.Writer.Header().Get("Content-Type") == "" {
|
||
c.Writer.Header().Set("Content-Type", r.defaultContentType)
|
||
}
|
||
}
|
||
|
||
func (r *Register) isApplicationJson(c *gin.Context) bool {
|
||
return c.Writer.Header().Get("Content-Type") == "application/json"
|
||
}
|
||
|
||
func (r *Register) isStatusFoundOrMoved(c *gin.Context) bool {
|
||
return c.Writer.Status() == http.StatusFound || c.Writer.Status() == http.StatusMovedPermanently
|
||
}
|