refactor: new json lib

This commit is contained in:
xuthus5 2024-04-25 23:18:43 +08:00
parent 0469c4d45a
commit e2f6f1ad3a
Signed by: xuthus5
GPG Key ID: A23CF9620CBB55F9
3 changed files with 101 additions and 232 deletions

View File

@ -33,39 +33,56 @@ type Result struct {
// 只能通过NewRegister()进行实例化 否则会panic
type Register struct {
// GET参数解析器
queryDecoder *schema.Decoder
engine *gin.Engine
addr string
queryDecoder *schema.Decoder
engine *gin.Engine
addr string
defaultContentType string
paramsValidate bool
}
type RegisterOptions func(register *Register)
// WithGinMode set gin router mode
// WithGinMode 设置运行模式
func WithGinMode(mode string) RegisterOptions {
return func(register *Register) {
gin.SetMode(mode)
}
}
// WithRecovery returns a middleware that recovers from any panics and writes a 500 if there was one.
// WithRecovery 自动处理panic异常
func WithRecovery() RegisterOptions {
return func(register *Register) {
register.engine.Use(gin.Recovery())
}
}
// WithCors 设置跨域
func WithCors() RegisterOptions {
return func(register *Register) {
register.engine.Use(CorsFilter())
}
}
// WithListenAddress 设置监听地址
func WithListenAddress(addr string) RegisterOptions {
return func(register *Register) {
register.addr = addr
}
}
// WithDisableRequestParamsValidate 是否默认关闭请求参数校验
func WithDisableRequestParamsValidate() RegisterOptions {
return func(register *Register) {
register.paramsValidate = true
}
}
func WithDefaultContentType(ct string) RegisterOptions {
return func(register *Register) {
register.defaultContentType = ct
}
}
// NewRegister 实例化注册器
func NewRegister() *Register {
var register = &Register{queryDecoder: schema.NewDecoder()}
@ -260,66 +277,21 @@ func (r *Register) getCallFunc(rFunc, rGroup reflect.Value) (gin.HandlerFunc, er
var returnValues = rFunc.Call([]reflect.Value{rGroup, reflect.ValueOf(&Context{Context: c}), req})
// 设置缺省 Content-Type
if c.Writer.Header().Get("Content-Type") == "" {
c.Writer.Header().Set("Content-Type", "application/json")
}
// 重定向的情况
if c.Writer.Status() == http.StatusFound || c.Writer.Status() == http.StatusMovedPermanently {
if r.isStatusFoundOrMoved(c) {
return
}
// 传输文件直接下载的情况
if !strings.Contains(c.Writer.Header().Get("Content-Type"), "application/json") {
if !r.isApplicationJson(c) {
return
}
if returnValues != nil {
resp := returnValues[0].Interface()
rerr := returnValues[1].Interface()
// 设置缺省 Content-Type
r.setDefaultContentType(c)
if rerr == nil {
c.Writer.WriteHeader(http.StatusOK)
respData, _ := json.Marshal(Result{
ErrCode: ErrorNil,
ErrMsg: "ok",
Data: resp,
})
_, _ = c.Writer.Write(respData)
return
}
var err error
var errCode int
var errMsg string
var isAutonomy bool
if reflect.TypeOf(rerr).String() == "*core.ErrMsg" {
e := rerr.(*ErrMsg)
if e.Autonomy {
err = e
errCode = int(e.ErrCode)
errMsg = e.ErrMsg
isAutonomy = true
}
}
if !isAutonomy {
err = rerr.(error)
errCode = GetErrMsgCode(err)
errMsg = GetErrMsgMsg(int32(errCode))
}
c.Writer.WriteHeader(http.StatusOK)
respData, _ := json.Marshal(Result{
ErrCode: errCode,
ErrMsg: errMsg,
Data: resp,
})
_, _ = c.Writer.Write(respData)
return
}
// 处理返回值
r.processReturnValue(c, returnValues)
}, nil
}
@ -332,6 +304,9 @@ func (r *Register) mergeRequestParam(a, b map[string]interface{}) map[string]int
// bindAndValidate 绑定并校验参数
func (r *Register) bindAndValidate(c *gin.Context, req interface{}) error {
if !r.paramsValidate {
return nil
}
// 非 application/json 直接透传
if c.GetHeader("Content-Type") != "application/json" {
// 如果只有query有数据则获取query中的参数
@ -419,6 +394,56 @@ func (r *Register) bindAndValidate(c *gin.Context, req interface{}) error {
return nil
}
// processReturnValue 处理返回值
func (r *Register) processReturnValue(c *gin.Context, returnValues []reflect.Value) {
if len(returnValues) != 2 {
return
}
resp := returnValues[0].Interface()
rerr := returnValues[1].Interface()
if rerr == nil {
c.Writer.WriteHeader(http.StatusOK)
respData, _ := json.Marshal(Result{
ErrCode: ErrorNil,
ErrMsg: GetErrMsgMsg(ErrorNil),
Data: resp,
})
_, _ = c.Writer.Write(respData)
return
}
var err error
var errCode int
var errMsg string
var isAutonomy bool
if reflect.TypeOf(rerr).String() == "*core.ErrMsg" {
e := rerr.(*ErrMsg)
if e.Autonomy {
err = e
errCode = int(e.ErrCode)
errMsg = e.ErrMsg
isAutonomy = true
}
}
if !isAutonomy {
err = rerr.(error)
errCode = GetErrMsgCode(err)
errMsg = GetErrMsgMsg(int32(errCode))
}
c.Writer.WriteHeader(http.StatusOK)
respData, _ := json.Marshal(Result{
ErrCode: errCode,
ErrMsg: errMsg,
Data: resp,
})
_, _ = c.Writer.Write(respData)
return
}
// CorsFilter 跨域过滤器
func CorsFilter() gin.HandlerFunc {
return func(c *gin.Context) {
@ -445,3 +470,20 @@ func ResponseHtmlHeader() gin.HandlerFunc {
c.Writer.Header().Set("Content-Type", "text/html")
}
}
func (r *Register) setDefaultContentType(c *gin.Context) {
if r.defaultContentType == "" {
r.defaultContentType = "application/json"
}
if c.Writer.Header().Get("Content-Type") == "" {
c.Writer.Header().Set("Content-Type", r.defaultContentType)
}
}
func (r *Register) isApplicationJson(c *gin.Context) bool {
return c.Writer.Header().Get("Content-Type") == "application/json"
}
func (r *Register) isStatusFoundOrMoved(c *gin.Context) bool {
return c.Writer.Status() == http.StatusFound || c.Writer.Status() == http.StatusMovedPermanently
}

View File

@ -27,8 +27,9 @@ func TestRegister_getCallFunc(t *testing.T) {
func TestNewRegister(t *testing.T) {
var register = NewRegister()
register.DefaultRouter(WithListenAddress(""), WithGinMode(gin.ReleaseMode), WithCors(), WithRecovery())
register.PreRun(nil)
register.DefaultRouter(WithListenAddress("localhost:8080"), WithGinMode(gin.ReleaseMode), WithCors(), WithRecovery())
// register.PreRun(nil)
// protoc core/file_module.proto --coco_out=core --go_out=core
//reg.RegisterStruct(AutoGenXXXRouterMap, &XXX{})
register.Run()

View File

@ -1,174 +0,0 @@
package utils
import (
"fmt"
log "github.com/sirupsen/logrus"
"reflect"
"strings"
)
func ReflectGet(obj interface{}, path string, failHint *string) (res interface{}, success bool) {
setFailHint := func(hint string) {
if failHint != nil {
*failHint = hint
}
}
for obj != nil && path != "" {
v := reflect.ValueOf(obj)
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
setFailHint("not struct type")
break
}
var fieldName string
pos := strings.IndexByte(path, '.')
if pos < 0 {
fieldName = path
path = ""
} else {
fieldName = path[:pos]
path = path[pos+1:]
}
f := v.FieldByName(fieldName)
if !f.IsValid() {
setFailHint(fmt.Sprintf("%s not found", fieldName))
break
}
if path == "" {
res = f.Interface()
success = true
break
}
obj = f.Interface()
}
return
}
func ReflectGetInt(obj interface{}, path string, failHint *string) (res int, success bool) {
setFailHint := func(hint string) {
if failHint != nil {
*failHint = hint
}
}
var i interface{}
i, success = ReflectGet(obj, path, failHint)
if !success {
return
}
switch v := i.(type) {
case int:
res = v
success = true
case int32:
res = int(v)
success = true
default:
setFailHint(fmt.Sprintf("type not match: %s", reflect.TypeOf(i).String()))
}
return
}
func ReflectGetStr(obj interface{}, path string, failHint *string) (res string, success bool) {
setFailHint := func(hint string) {
if failHint != nil {
*failHint = hint
}
}
var i interface{}
i, success = ReflectGet(obj, path, failHint)
if !success {
return
}
switch v := i.(type) {
case string:
res = v
success = true
case []byte:
res = string(v)
success = true
default:
setFailHint(fmt.Sprintf("type not match: %s", reflect.TypeOf(i).String()))
}
return
}
func Interface2Int(i interface{}) int {
vo := reflect.ValueOf(i)
vk := vo.Kind()
switch vk {
case reflect.Uint, reflect.Uint32, reflect.Uint64, reflect.Uint8, reflect.Uint16:
return int(vo.Uint())
}
return int(vo.Int())
}
func Interface2String(i interface{}) string {
vo := reflect.ValueOf(i)
if vo.Kind() != reflect.String {
log.Infof("expected string type, but got %v", vo.Type())
panic("expected string type")
}
return vo.String()
}
func EnsureIsSliceOrArray(obj interface{}) (res reflect.Value) {
vo := reflect.ValueOf(obj)
for vo.Kind() == reflect.Ptr || vo.Kind() == reflect.Interface {
vo = vo.Elem()
}
k := vo.Kind()
if k != reflect.Slice && k != reflect.Array {
panic(fmt.Sprintf("obj required slice or array type, but got %v", vo.Type()))
}
res = vo
return
}
func EnsureIsMapType(m reflect.Value, keyType, valType reflect.Type) {
if m.Kind() != reflect.Map {
panic(fmt.Sprintf("required map type, but got %v", m.Type()))
}
t := m.Type()
if t.Key() != keyType {
panic(fmt.Sprintf("map key type not equal, %v != %v", t.Key(), keyType))
}
if t.Elem() != valType {
panic(fmt.Sprintf("map val type not equal, %v != %v", t.Elem(), valType))
}
}
func ClearSlice(ptr interface{}) {
vo := reflect.ValueOf(ptr)
if vo.Kind() != reflect.Ptr {
panic("required ptr to slice type")
}
for vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Slice {
panic("required ptr to slice type")
}
vo.Set(reflect.MakeSlice(vo.Type(), 0, 0))
}
func GetSliceLen(i interface{}) int {
vo := reflect.ValueOf(i)
for vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Slice && vo.Kind() != reflect.Array {
panic("required slice or array type")
}
return vo.Len()
}