package mdbc import ( "context" "encoding/json" "errors" "fmt" "reflect" "time" jsoniter "github.com/json-iterator/go" "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) type AggregateScope struct { scope *Scope cw *ctxWrap err error pipeline interface{} opts *options.AggregateOptions cursor *mongo.Cursor enableCache bool cacheKey string cacheFunc AggregateCacheFunc } // AggregateFunc 自定义获取结果集的注入函数 type AggregateFunc func(cursor *mongo.Cursor) (interface{}, error) // AggregateCacheFunc 缓存Find结果集 // field: 若无作用 可置空 DefaultAggregateCacheFunc 将使用其反射出field字段对应的value并将其设为key // obj: 这是一个slice 只能在GetList是使用,GetMap使用将无效果 type AggregateCacheFunc func(field string, obj interface{}) (*CacheObject, error) // DefaultAggregateCacheFunc 按照第一个结果的field字段作为key缓存list数据 // field字段需要满足所有数据均有该属性 否则有概率导致缓存失败 // 和 DefaultFindCacheFunc 有重复嫌疑 需要抽象 var DefaultAggregateCacheFunc = func() AggregateCacheFunc { return func(field string, obj interface{}) (*CacheObject, error) { // 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行 // 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可 // 下面使用反射进行操作 保证兼容 v := reflect.ValueOf(obj) if v.Type().Kind() != reflect.Ptr { return nil, fmt.Errorf("invalid list type, not ptr") } v = v.Elem() if v.Type().Kind() != reflect.Slice { return nil, fmt.Errorf("invalid list type, not ptr to slice") } // 空slice 无需缓存 if v.Len() == 0 { return nil, nil } firstNode := v.Index(0) // 判断过长度 无需判断nil if firstNode.Kind() == reflect.Ptr { firstNode = firstNode.Elem() } idVal := firstNode.FieldByName(field) if idVal.Kind() == reflect.Invalid { return nil, fmt.Errorf("first node's %s field not found", field) } b, _ := json.Marshal(v.Interface()) co := &CacheObject{ Key: idVal.String(), Value: string(b), } return co, nil } } // SetContext 设置上下文 func (as *AggregateScope) SetContext(ctx context.Context) *AggregateScope { if as.cw == nil { as.cw = &ctxWrap{} } as.cw.ctx = ctx return as } // getContext 获取ctx func (as *AggregateScope) getContext() context.Context { return as.cw.ctx } // preCheck 预检查 func (as *AggregateScope) preCheck() { var breakerTTL time.Duration if as.scope.breaker == nil { breakerTTL = defaultBreakerTime } else if as.scope.breaker.ttl == 0 { breakerTTL = defaultBreakerTime } else { breakerTTL = as.scope.breaker.ttl } if as.cw == nil { as.cw = &ctxWrap{} } if as.cw.ctx == nil { as.cw.ctx, as.cw.cancel = context.WithTimeout(context.Background(), breakerTTL) } } // SetAggregateOption 重写配置项 func (as *AggregateScope) SetAggregateOption(opts options.AggregateOptions) *AggregateScope { as.opts = &opts return as } // SetPipeline 设置管道 // 传递 []bson.D 或者mongo.Pipeline func (as *AggregateScope) SetPipeline(pipeline interface{}) *AggregateScope { as.pipeline = pipeline return as } // assertErr 断言错误 func (as *AggregateScope) assertErr() { if as.err == nil { return } if errors.Is(as.err, context.DeadlineExceeded) { as.err = &ErrRequestBroken return } err, ok := as.err.(mongo.CommandError) if !ok { return } if err.HasErrorMessage(context.DeadlineExceeded.Error()) { as.err = &ErrRequestBroken } } // doString 实现debugger的 actionDo 接口 func (as *AggregateScope) doString() string { var data []interface{} builder := RegisterTimestampCodec(nil).Build() vo := reflect.ValueOf(as.pipeline) if vo.Kind() == reflect.Ptr { vo = vo.Elem() } if vo.Kind() != reflect.Slice { panic("pipeline type not slice") } for i := 0; i < vo.Len(); i++ { var body interface{} b, _ := bson.MarshalExtJSONWithRegistry(builder, vo.Index(i).Interface(), true, true) _ = jsoniter.Unmarshal(b, &body) data = append(data, body) } b, _ := jsoniter.Marshal(data) return fmt.Sprintf("aggregate(%s)", string(b)) } // debug 统一debug入口 func (as *AggregateScope) debug() { if !as.scope.debug { return } debugger := &Debugger{ collection: as.scope.tableName, execT: as.scope.execT, action: as, } debugger.String() } // doSearch 查询的真正语句 func (as *AggregateScope) doSearch() { var starTime time.Time if as.scope.debug { starTime = time.Now() } as.cursor, as.err = db.Collection(as.scope.tableName).Aggregate(as.getContext(), as.pipeline, as.opts) if as.scope.debug { as.scope.execT = time.Since(starTime) } as.debug() as.assertErr() } func (as *AggregateScope) doClear() { if as.cw != nil && as.cw.cancel != nil { as.cw.cancel() } as.scope.debug = false as.scope.execT = 0 } // GetList 获取管道数据列表 传递 *[]*TypeValue func (as *AggregateScope) GetList(list interface{}) error { defer as.doClear() as.preCheck() as.doSearch() if as.err != nil { return as.err } v := reflect.ValueOf(list) if v.Type().Kind() != reflect.Ptr { as.err = fmt.Errorf("invalid list type, not ptr") return as.err } v = v.Elem() if v.Type().Kind() != reflect.Slice { as.err = fmt.Errorf("invalid list type, not ptr to slice") return as.err } as.err = as.cursor.All(as.getContext(), list) if as.enableCache { as.doCache(list) } return as.err } // GetOne 获取管道数据的第一个元素 传递 *struct 不建议使用而建议直接用 limit:1 // 然后数组取第一个数据的方式获取结果 为了确保效率 当明确只获取一条时 请用 limit 限制一下 // 当获取不到数据的时候 会报错 ErrRecordNotFound 目前不支持缓存 func (as *AggregateScope) GetOne(obj interface{}) error { v := reflect.ValueOf(obj) vt := v.Type() if vt.Kind() != reflect.Ptr { as.err = fmt.Errorf("invalid type, not ptr") return as.err } vt = vt.Elem() if vt.Kind() != reflect.Struct { as.err = fmt.Errorf("invalid type, not ptr to struct") return as.err } defer as.doClear() as.preCheck() as.doSearch() if as.err != nil { return as.err } objType := reflect.TypeOf(obj) objSliceType := reflect.SliceOf(objType) objSlice := reflect.New(objSliceType) objSliceInterface := objSlice.Interface() as.err = as.cursor.All(as.getContext(), objSliceInterface) if as.err != nil { return as.err } realV := reflect.ValueOf(objSliceInterface) if realV.Kind() == reflect.Ptr { realV = realV.Elem() } if realV.Len() == 0 { as.err = &ErrRecordNotFound return as.err } firstNode := realV.Index(0) if firstNode.Kind() != reflect.Ptr { as.err = &ErrFindObjectTypeNotSupport return as.err } v.Elem().Set(firstNode.Elem()) if as.enableCache { as.doCache(obj) } return as.err } // GetMap 获取管道数据map 传递 *map[string]*TypeValue, field 基于哪个字段进行map构建 func (as *AggregateScope) GetMap(m interface{}, field string) error { defer as.doClear() as.preCheck() as.doSearch() if as.err != nil { return as.err } v := reflect.ValueOf(m) if v.Type().Kind() != reflect.Ptr { as.err = fmt.Errorf("invalid map type, not ptr") return as.err } v = v.Elem() if v.Type().Kind() != reflect.Map { as.err = fmt.Errorf("invalid map type, not map") return as.err } mapType := v.Type() valueType := mapType.Elem() keyType := mapType.Key() if valueType.Kind() != reflect.Ptr { as.err = fmt.Errorf("invalid map value type, not prt") return as.err } if !v.CanSet() { as.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields") return as.err } v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType))) for as.cursor.Next(as.getContext()) { t := reflect.New(valueType.Elem()) if as.err = as.cursor.Decode(t.Interface()); as.err != nil { logrus.Errorf("err: %+v", as.err) return as.err } keyv := t.Elem().FieldByName(field) if keyv.Kind() == reflect.Invalid { as.err = fmt.Errorf("invalid model key: %s", field) return as.err } v.SetMapIndex(keyv, t) } return as.err } // Error 获取错误 func (as *AggregateScope) Error() error { return as.err } // GetByFunc 自定义函数 获取数据集 func (as *AggregateScope) GetByFunc(cb AggregateFunc) (val interface{}, err error) { defer as.doClear() as.preCheck() as.doSearch() if as.err != nil { return nil, as.err } val, as.err = cb(as.cursor) if as.err != nil { return nil, as.err } // 自行断言val return val, nil } // SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现 func (as *AggregateScope) SetCacheFunc(key string, cb AggregateCacheFunc) *AggregateScope { as.enableCache = true as.cacheFunc = cb as.cacheKey = key return as } // doCache 执行缓存 func (as *AggregateScope) doCache(obj interface{}) *AggregateScope { // redis句柄不存在 if as.scope.cache == nil { return nil } cacheObj, err := as.cacheFunc(as.cacheKey, obj) if err != nil { as.err = err return as } if cacheObj == nil { as.err = fmt.Errorf("cache object nil") return as } ttl := as.scope.cache.ttl if ttl == 0 { ttl = time.Hour } else if ttl == -1 { ttl = 0 } if as.getContext().Err() != nil { as.err = as.getContext().Err() as.assertErr() return as } as.err = as.scope.cache.client.Set(as.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err() if as.err != nil { as.assertErr() } return as }