package mdbc import ( "context" "encoding/json" "errors" "fmt" "reflect" "time" "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) type FindScope struct { scope *Scope // 向上提权 cw *ctxWrap // 查询上下文环境 cursor *mongo.Cursor // 查询指针 err error // 操作是否有错误 limit *int64 // 限制查询数量 skip *int64 // 偏移量 count *int64 // 计数统计结果 withCount bool // 是否开启计数统计 这将只基于filter进行计数而忽略 limit skip selects interface{} // sort bson.D // 排序 filter interface{} // 过滤条件 opts *options.FindOptions // 查询条件 enableCache bool // 是否开启缓存 cacheKey string // 缓存key cacheFunc FindCacheFunc // 基于什么缓存函数进行缓存 } // FindCacheFunc 缓存Find结果集 // field: 若无作用 可置空 DefaultFindCacheFunc将使用其反射出field字段对应的value并将其设为key // obj: 这是一个slice 只能在GetList是使用,GetMap使用将无效果 type FindCacheFunc func(field string, obj interface{}) (*CacheObject, error) // DefaultFindCacheFunc 按照第一个结果的field字段作为key缓存list数据 // field字段需要满足所有数据均有该属性 否则有概率导致缓存失败 var DefaultFindCacheFunc = func() FindCacheFunc { return func(field string, obj interface{}) (*CacheObject, error) { // 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行 // 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可 // 下面使用反射进行操作 保证兼容 v := reflect.ValueOf(obj) if v.Type().Kind() != reflect.Ptr { return nil, fmt.Errorf("invalid list type, not ptr") } v = v.Elem() if v.Type().Kind() != reflect.Slice { return nil, fmt.Errorf("invalid list type, not ptr to slice") } // 空slice 无需缓存 if v.Len() == 0 { return nil, nil } firstNode := v.Index(0) // 判断过长度 无需判断nil if firstNode.Kind() == reflect.Ptr { firstNode = firstNode.Elem() } idVal := firstNode.FieldByName(field) if idVal.Kind() == reflect.Invalid { return nil, fmt.Errorf("first node %s field not found", idVal) } b, _ := json.Marshal(v.Interface()) co := &CacheObject{ Key: idVal.String(), Value: string(b), } return co, nil } } // SetLimit 设置获取的条数 func (fs *FindScope) SetLimit(limit int64) *FindScope { fs.limit = &limit return fs } // SetSkip 设置跳过的条数 func (fs *FindScope) SetSkip(skip int64) *FindScope { fs.skip = &skip return fs } // SetSort 设置排序 func (fs *FindScope) SetSort(sort bson.D) *FindScope { fs.sort = sort return fs } // SetContext 设置上下文 func (fs *FindScope) SetContext(ctx context.Context) *FindScope { if fs.cw == nil { fs.cw = &ctxWrap{} } fs.cw.ctx = ctx return fs } // SetSelect 选择查询字段 格式 {key: 1, key: 0} // 你可以传入 bson.M, map[string]integer_interface 其他类型将导致该功能失效 // 1:显示 0:不显示 func (fs *FindScope) SetSelect(selects interface{}) *FindScope { vo := reflect.ValueOf(selects) if vo.Kind() == reflect.Ptr { vo = vo.Elem() } // 不是map 提前返回 if vo.Kind() != reflect.Map { return fs } fs.selects = selects return fs } func (fs *FindScope) getContext() context.Context { if fs.cw == nil { fs.cw = &ctxWrap{ ctx: context.Background(), } } return fs.cw.ctx } // SetFilter 设置过滤条件 func (fs *FindScope) SetFilter(filter interface{}) *FindScope { if filter == nil { fs.filter = bson.M{} return fs } v := reflect.ValueOf(filter) if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice { if v.IsNil() { fs.filter = bson.M{} } } fs.filter = filter return fs } // SetFindOption 设置FindOption 优先级最低 sort/skip/limit 会被 set 函数重写 func (fs *FindScope) SetFindOption(opts options.FindOptions) *FindScope { fs.opts = &opts return fs } // WithCount 查询的同时获取总数量 // 将按照 SetFilter 的条件进行查询 未设置是获取所有文档数量 // 如果在该步骤出错 将不会进行余下操作 func (fs *FindScope) WithCount(count *int64) *FindScope { fs.withCount = true fs.count = count return fs } func (fs *FindScope) optionAssembled() { // 配置项被直接调用重写过 if fs.opts == nil { fs.opts = new(options.FindOptions) } if fs.sort != nil { fs.opts.Sort = fs.sort } if fs.skip != nil { fs.opts.Skip = fs.skip } if fs.limit != nil { fs.opts.Limit = fs.limit } if fs.selects != nil { fs.opts.Projection = fs.selects } } func (fs *FindScope) preCheck() { if fs.filter == nil { fs.filter = bson.M{} } var breakerTTL time.Duration if fs.scope.breaker == nil { breakerTTL = defaultBreakerTime } else if fs.scope.breaker.ttl == 0 { breakerTTL = defaultBreakerTime } else { breakerTTL = fs.scope.breaker.ttl } if fs.cw == nil { fs.cw = &ctxWrap{} } if fs.cw.ctx == nil { fs.cw.ctx, fs.cw.cancel = context.WithTimeout(context.Background(), breakerTTL) } else { fs.cw.ctx, fs.cw.cancel = context.WithTimeout(fs.cw.ctx, breakerTTL) } } func (fs *FindScope) assertErr() { if fs.err == nil { return } if errors.Is(fs.err, mongo.ErrNoDocuments) || errors.Is(fs.err, mongo.ErrNilDocument) { fs.err = &ErrRecordNotFound return } if errors.Is(fs.err, context.DeadlineExceeded) { fs.err = &ErrRequestBroken return } err, ok := fs.err.(mongo.CommandError) if ok && err.HasErrorMessage(context.DeadlineExceeded.Error()) { fs.err = &ErrRequestBroken return } fs.err = err } func (fs *FindScope) doString() string { builder := RegisterTimestampCodec(nil).Build() filter, _ := bson.MarshalExtJSONWithRegistry(builder, fs.filter, true, true) query := fmt.Sprintf("find(%s)", string(filter)) if fs.skip != nil { query = fmt.Sprintf("%s.skip(%d)", query, *fs.skip) } if fs.limit != nil { query = fmt.Sprintf("%s.limit(%d)", query, *fs.limit) } if fs.sort != nil { sort, _ := bson.MarshalExtJSON(fs.sort, true, true) query = fmt.Sprintf("%s.sort(%s)", query, string(sort)) } return query } func (fs *FindScope) debug() { if !fs.scope.debug && !fs.scope.debugWhenError { return } debugger := &Debugger{ collection: fs.scope.tableName, execT: fs.scope.execT, action: fs, } // 当错误时优先输出 if fs.scope.debugWhenError { if fs.err != nil { debugger.errMsg = fs.err.Error() debugger.ErrorString() } return } // 所有bug输出 if fs.scope.debug { debugger.String() } } func (fs *FindScope) doClear() { if fs.cw != nil && fs.cw.cancel != nil { fs.cw.cancel() } fs.scope.debug = false fs.scope.execT = 0 } func (fs *FindScope) doSearch() { var starTime time.Time if fs.scope.debug { starTime = time.Now() } fs.cursor, fs.err = db.Collection(fs.scope.tableName).Find(fs.getContext(), fs.filter, fs.opts) fs.assertErr() // 断言错误 if fs.scope.debug { fs.scope.execT = time.Since(starTime) fs.debug() } // 有检测数量的 if fs.withCount { var res int64 res, fs.err = fs.scope.Count().SetContext(fs.getContext()).SetFilter(fs.filter).Count() *fs.count = res } fs.assertErr() } // GetList 获取列表 // list: 需要一个 *[]*struct func (fs *FindScope) GetList(list interface{}) error { defer fs.doClear() fs.optionAssembled() fs.preCheck() if fs.err != nil { return fs.err } fs.doSearch() if fs.err != nil { return fs.err } v := reflect.ValueOf(list) if v.Type().Kind() != reflect.Ptr { fs.err = fmt.Errorf("invalid list type, not ptr") return fs.err } v = v.Elem() if v.Type().Kind() != reflect.Slice { fs.err = fmt.Errorf("invalid list type, not ptr to slice") return fs.err } fs.err = fs.cursor.All(fs.getContext(), list) if fs.enableCache { fs.doCache(list) } return fs.err } // GetMap 基于结果的某个字段为Key 获取Map // m: 传递一个 *map[string]*Struct // field: struct的一个字段名称 需要是公开可访问的大写 func (fs *FindScope) GetMap(m interface{}, field string) error { defer fs.doClear() fs.optionAssembled() fs.preCheck() if fs.err != nil { return fs.err } fs.doSearch() if fs.err != nil { return fs.err } v := reflect.ValueOf(m) if v.Type().Kind() != reflect.Ptr { fs.err = fmt.Errorf("invalid map type, not ptr") return fs.err } v = v.Elem() if v.Type().Kind() != reflect.Map { fs.err = fmt.Errorf("invalid map type, not map") return fs.err } mapType := v.Type() valueType := mapType.Elem() keyType := mapType.Key() if valueType.Kind() != reflect.Ptr { fs.err = fmt.Errorf("invalid map value type, not prt") return fs.err } if !v.CanSet() { fs.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields") return fs.err } v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType))) for fs.cursor.Next(fs.getContext()) { t := reflect.New(valueType.Elem()) if fs.err = fs.cursor.Decode(t.Interface()); fs.err != nil { logrus.Errorf("err: %+v", fs.err) return fs.err } fieldNode := t.Elem().FieldByName(field) if fieldNode.Kind() == reflect.Invalid { fs.err = fmt.Errorf("invalid model key: %s", field) return fs.err } v.SetMapIndex(fieldNode, t) } return fs.err } func (fs *FindScope) Error() error { return fs.err } // SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现 func (fs *FindScope) SetCacheFunc(key string, cb FindCacheFunc) *FindScope { fs.enableCache = true fs.cacheFunc = cb fs.cacheKey = key return fs } // doCache 执行缓存 func (fs *FindScope) doCache(obj interface{}) *FindScope { // redis句柄不存在 if fs.scope.cache == nil { return nil } cacheObj, err := fs.cacheFunc(fs.cacheKey, obj) if err != nil { fs.err = err return fs } if cacheObj == nil { fs.err = fmt.Errorf("cache object nil") return fs } ttl := fs.scope.cache.ttl if ttl == 0 { ttl = time.Hour } else if ttl == -1 { ttl = 0 } if fs.getContext().Err() != nil { fs.err = fs.getContext().Err() fs.assertErr() return fs } fs.err = fs.scope.cache.client.Set(fs.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err() // 这里也要断言错误 看是否是熔断错误 return fs }