diff --git a/router_register.go b/router_register.go index be4c3bf..1332d0b 100644 --- a/router_register.go +++ b/router_register.go @@ -33,39 +33,56 @@ type Result struct { // 只能通过NewRegister()进行实例化 否则会panic type Register struct { // GET参数解析器 - queryDecoder *schema.Decoder - engine *gin.Engine - addr string + queryDecoder *schema.Decoder + engine *gin.Engine + addr string + defaultContentType string + paramsValidate bool } type RegisterOptions func(register *Register) -// WithGinMode set gin router mode +// WithGinMode 设置运行模式 func WithGinMode(mode string) RegisterOptions { return func(register *Register) { gin.SetMode(mode) } } -// WithRecovery returns a middleware that recovers from any panics and writes a 500 if there was one. +// WithRecovery 自动处理panic异常 func WithRecovery() RegisterOptions { return func(register *Register) { register.engine.Use(gin.Recovery()) } } +// WithCors 设置跨域 func WithCors() RegisterOptions { return func(register *Register) { register.engine.Use(CorsFilter()) } } +// WithListenAddress 设置监听地址 func WithListenAddress(addr string) RegisterOptions { return func(register *Register) { register.addr = addr } } +// WithDisableRequestParamsValidate 是否默认关闭请求参数校验 +func WithDisableRequestParamsValidate() RegisterOptions { + return func(register *Register) { + register.paramsValidate = true + } +} + +func WithDefaultContentType(ct string) RegisterOptions { + return func(register *Register) { + register.defaultContentType = ct + } +} + // NewRegister 实例化注册器 func NewRegister() *Register { var register = &Register{queryDecoder: schema.NewDecoder()} @@ -260,66 +277,21 @@ func (r *Register) getCallFunc(rFunc, rGroup reflect.Value) (gin.HandlerFunc, er var returnValues = rFunc.Call([]reflect.Value{rGroup, reflect.ValueOf(&Context{Context: c}), req}) - // 设置缺省 Content-Type - if c.Writer.Header().Get("Content-Type") == "" { - c.Writer.Header().Set("Content-Type", "application/json") - } - // 重定向的情况 - if c.Writer.Status() == http.StatusFound || c.Writer.Status() == http.StatusMovedPermanently { + if r.isStatusFoundOrMoved(c) { return } // 传输文件直接下载的情况 - if !strings.Contains(c.Writer.Header().Get("Content-Type"), "application/json") { + if !r.isApplicationJson(c) { return } - if returnValues != nil { - resp := returnValues[0].Interface() - rerr := returnValues[1].Interface() + // 设置缺省 Content-Type + r.setDefaultContentType(c) - if rerr == nil { - c.Writer.WriteHeader(http.StatusOK) - respData, _ := json.Marshal(Result{ - ErrCode: ErrorNil, - ErrMsg: "ok", - Data: resp, - }) - _, _ = c.Writer.Write(respData) - return - } - - var err error - var errCode int - var errMsg string - - var isAutonomy bool - if reflect.TypeOf(rerr).String() == "*core.ErrMsg" { - e := rerr.(*ErrMsg) - if e.Autonomy { - err = e - errCode = int(e.ErrCode) - errMsg = e.ErrMsg - isAutonomy = true - } - } - - if !isAutonomy { - err = rerr.(error) - errCode = GetErrMsgCode(err) - errMsg = GetErrMsgMsg(int32(errCode)) - } - - c.Writer.WriteHeader(http.StatusOK) - respData, _ := json.Marshal(Result{ - ErrCode: errCode, - ErrMsg: errMsg, - Data: resp, - }) - _, _ = c.Writer.Write(respData) - return - } + // 处理返回值 + r.processReturnValue(c, returnValues) }, nil } @@ -332,6 +304,9 @@ func (r *Register) mergeRequestParam(a, b map[string]interface{}) map[string]int // bindAndValidate 绑定并校验参数 func (r *Register) bindAndValidate(c *gin.Context, req interface{}) error { + if !r.paramsValidate { + return nil + } // 非 application/json 直接透传 if c.GetHeader("Content-Type") != "application/json" { // 如果只有query有数据,则获取query中的参数 @@ -419,6 +394,56 @@ func (r *Register) bindAndValidate(c *gin.Context, req interface{}) error { return nil } +// processReturnValue 处理返回值 +func (r *Register) processReturnValue(c *gin.Context, returnValues []reflect.Value) { + if len(returnValues) != 2 { + return + } + resp := returnValues[0].Interface() + rerr := returnValues[1].Interface() + + if rerr == nil { + c.Writer.WriteHeader(http.StatusOK) + respData, _ := json.Marshal(Result{ + ErrCode: ErrorNil, + ErrMsg: GetErrMsgMsg(ErrorNil), + Data: resp, + }) + _, _ = c.Writer.Write(respData) + return + } + + var err error + var errCode int + var errMsg string + + var isAutonomy bool + if reflect.TypeOf(rerr).String() == "*core.ErrMsg" { + e := rerr.(*ErrMsg) + if e.Autonomy { + err = e + errCode = int(e.ErrCode) + errMsg = e.ErrMsg + isAutonomy = true + } + } + + if !isAutonomy { + err = rerr.(error) + errCode = GetErrMsgCode(err) + errMsg = GetErrMsgMsg(int32(errCode)) + } + + c.Writer.WriteHeader(http.StatusOK) + respData, _ := json.Marshal(Result{ + ErrCode: errCode, + ErrMsg: errMsg, + Data: resp, + }) + _, _ = c.Writer.Write(respData) + return +} + // CorsFilter 跨域过滤器 func CorsFilter() gin.HandlerFunc { return func(c *gin.Context) { @@ -445,3 +470,20 @@ func ResponseHtmlHeader() gin.HandlerFunc { c.Writer.Header().Set("Content-Type", "text/html") } } + +func (r *Register) setDefaultContentType(c *gin.Context) { + if r.defaultContentType == "" { + r.defaultContentType = "application/json" + } + if c.Writer.Header().Get("Content-Type") == "" { + c.Writer.Header().Set("Content-Type", r.defaultContentType) + } +} + +func (r *Register) isApplicationJson(c *gin.Context) bool { + return c.Writer.Header().Get("Content-Type") == "application/json" +} + +func (r *Register) isStatusFoundOrMoved(c *gin.Context) bool { + return c.Writer.Status() == http.StatusFound || c.Writer.Status() == http.StatusMovedPermanently +} diff --git a/router_register_test.go b/router_register_test.go index fe516c0..4a33fab 100644 --- a/router_register_test.go +++ b/router_register_test.go @@ -27,8 +27,9 @@ func TestRegister_getCallFunc(t *testing.T) { func TestNewRegister(t *testing.T) { var register = NewRegister() - register.DefaultRouter(WithListenAddress(""), WithGinMode(gin.ReleaseMode), WithCors(), WithRecovery()) - register.PreRun(nil) + register.DefaultRouter(WithListenAddress("localhost:8080"), WithGinMode(gin.ReleaseMode), WithCors(), WithRecovery()) + + // register.PreRun(nil) // protoc core/file_module.proto --coco_out=core --go_out=core //reg.RegisterStruct(AutoGenXXXRouterMap, &XXX{}) register.Run() diff --git a/utils/reflect.go b/utils/reflect.go deleted file mode 100644 index bde8328..0000000 --- a/utils/reflect.go +++ /dev/null @@ -1,174 +0,0 @@ -package utils - -import ( - "fmt" - log "github.com/sirupsen/logrus" - "reflect" - "strings" -) - -func ReflectGet(obj interface{}, path string, failHint *string) (res interface{}, success bool) { - setFailHint := func(hint string) { - if failHint != nil { - *failHint = hint - } - } - for obj != nil && path != "" { - v := reflect.ValueOf(obj) - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() != reflect.Struct { - setFailHint("not struct type") - break - } - - var fieldName string - pos := strings.IndexByte(path, '.') - if pos < 0 { - fieldName = path - path = "" - } else { - fieldName = path[:pos] - path = path[pos+1:] - } - - f := v.FieldByName(fieldName) - if !f.IsValid() { - setFailHint(fmt.Sprintf("%s not found", fieldName)) - break - } - - if path == "" { - res = f.Interface() - success = true - break - } - - obj = f.Interface() - } - - return -} - -func ReflectGetInt(obj interface{}, path string, failHint *string) (res int, success bool) { - setFailHint := func(hint string) { - if failHint != nil { - *failHint = hint - } - } - - var i interface{} - i, success = ReflectGet(obj, path, failHint) - if !success { - return - } - switch v := i.(type) { - case int: - res = v - success = true - case int32: - res = int(v) - success = true - default: - setFailHint(fmt.Sprintf("type not match: %s", reflect.TypeOf(i).String())) - } - return -} - -func ReflectGetStr(obj interface{}, path string, failHint *string) (res string, success bool) { - setFailHint := func(hint string) { - if failHint != nil { - *failHint = hint - } - } - - var i interface{} - i, success = ReflectGet(obj, path, failHint) - if !success { - return - } - switch v := i.(type) { - case string: - res = v - success = true - case []byte: - res = string(v) - success = true - default: - setFailHint(fmt.Sprintf("type not match: %s", reflect.TypeOf(i).String())) - } - return -} - -func Interface2Int(i interface{}) int { - vo := reflect.ValueOf(i) - vk := vo.Kind() - switch vk { - case reflect.Uint, reflect.Uint32, reflect.Uint64, reflect.Uint8, reflect.Uint16: - return int(vo.Uint()) - } - return int(vo.Int()) -} - -func Interface2String(i interface{}) string { - vo := reflect.ValueOf(i) - if vo.Kind() != reflect.String { - log.Infof("expected string type, but got %v", vo.Type()) - panic("expected string type") - } - return vo.String() -} - -func EnsureIsSliceOrArray(obj interface{}) (res reflect.Value) { - vo := reflect.ValueOf(obj) - for vo.Kind() == reflect.Ptr || vo.Kind() == reflect.Interface { - vo = vo.Elem() - } - k := vo.Kind() - if k != reflect.Slice && k != reflect.Array { - panic(fmt.Sprintf("obj required slice or array type, but got %v", vo.Type())) - } - res = vo - return -} - -func EnsureIsMapType(m reflect.Value, keyType, valType reflect.Type) { - if m.Kind() != reflect.Map { - panic(fmt.Sprintf("required map type, but got %v", m.Type())) - } - - t := m.Type() - if t.Key() != keyType { - panic(fmt.Sprintf("map key type not equal, %v != %v", t.Key(), keyType)) - } - - if t.Elem() != valType { - panic(fmt.Sprintf("map val type not equal, %v != %v", t.Elem(), valType)) - } -} - -func ClearSlice(ptr interface{}) { - vo := reflect.ValueOf(ptr) - if vo.Kind() != reflect.Ptr { - panic("required ptr to slice type") - } - for vo.Kind() == reflect.Ptr { - vo = vo.Elem() - } - if vo.Kind() != reflect.Slice { - panic("required ptr to slice type") - } - vo.Set(reflect.MakeSlice(vo.Type(), 0, 0)) -} - -func GetSliceLen(i interface{}) int { - vo := reflect.ValueOf(i) - for vo.Kind() == reflect.Ptr { - vo = vo.Elem() - } - if vo.Kind() != reflect.Slice && vo.Kind() != reflect.Array { - panic("required slice or array type") - } - return vo.Len() -}