mdbc/find_scope.go
2022-02-23 16:59:45 +08:00

446 lines
10 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mdbc
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"time"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type FindScope struct {
scope *Scope // 向上提权
cw *ctxWrap // 查询上下文环境
cursor *mongo.Cursor // 查询指针
err error // 操作是否有错误
limit *int64 // 限制查询数量
skip *int64 // 偏移量
count *int64 // 计数统计结果
withCount bool // 是否开启计数统计 这将只基于filter进行计数而忽略 limit skip
selects interface{} //
sort bson.D // 排序
filter interface{} // 过滤条件
opts *options.FindOptions // 查询条件
enableCache bool // 是否开启缓存
cacheKey string // 缓存key
cacheFunc FindCacheFunc // 基于什么缓存函数进行缓存
}
// FindCacheFunc 缓存Find结果集
// field: 若无作用 可置空 DefaultFindCacheFunc将使用其反射出field字段对应的value并将其设为key
// obj: 这是一个slice 只能在GetList是使用GetMap使用将无效果
type FindCacheFunc func(field string, obj interface{}) (*CacheObject, error)
// DefaultFindCacheFunc 按照第一个结果的field字段作为key缓存list数据
// field字段需要满足所有数据均有该属性 否则有概率导致缓存失败
var DefaultFindCacheFunc = func() FindCacheFunc {
return func(field string, obj interface{}) (*CacheObject, error) {
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
// 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可
// 下面使用反射进行操作 保证兼容
v := reflect.ValueOf(obj)
if v.Type().Kind() != reflect.Ptr {
return nil, fmt.Errorf("invalid list type, not ptr")
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
return nil, fmt.Errorf("invalid list type, not ptr to slice")
}
// 空slice 无需缓存
if v.Len() == 0 {
return nil, nil
}
firstNode := v.Index(0)
// 判断过长度 无需判断nil
if firstNode.Kind() == reflect.Ptr {
firstNode = firstNode.Elem()
}
idVal := firstNode.FieldByName(field)
if idVal.Kind() == reflect.Invalid {
return nil, fmt.Errorf("first node %s field not found", idVal)
}
b, _ := json.Marshal(v.Interface())
co := &CacheObject{
Key: idVal.String(),
Value: string(b),
}
return co, nil
}
}
// SetLimit 设置获取的条数
func (fs *FindScope) SetLimit(limit int64) *FindScope {
fs.limit = &limit
return fs
}
// SetSkip 设置跳过的条数
func (fs *FindScope) SetSkip(skip int64) *FindScope {
fs.skip = &skip
return fs
}
// SetSort 设置排序
func (fs *FindScope) SetSort(sort bson.D) *FindScope {
fs.sort = sort
return fs
}
// SetContext 设置上下文
func (fs *FindScope) SetContext(ctx context.Context) *FindScope {
if fs.cw == nil {
fs.cw = &ctxWrap{}
}
fs.cw.ctx = ctx
return fs
}
// SetSelect 选择查询字段 格式 {key: 1, key: 0}
// 你可以传入 bson.M, map[string]integer_interface 其他类型将导致该功能失效
// 1显示 0不显示
func (fs *FindScope) SetSelect(selects interface{}) *FindScope {
vo := reflect.ValueOf(selects)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
// 不是map 提前返回
if vo.Kind() != reflect.Map {
return fs
}
fs.selects = selects
return fs
}
func (fs *FindScope) getContext() context.Context {
if fs.cw == nil {
fs.cw = &ctxWrap{
ctx: context.Background(),
}
}
return fs.cw.ctx
}
// SetFilter 设置过滤条件
func (fs *FindScope) SetFilter(filter interface{}) *FindScope {
if filter == nil {
fs.filter = bson.M{}
return fs
}
v := reflect.ValueOf(filter)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
fs.filter = bson.M{}
}
}
fs.filter = filter
return fs
}
// SetFindOption 设置FindOption 优先级最低 sort/skip/limit 会被 set 函数重写
func (fs *FindScope) SetFindOption(opts options.FindOptions) *FindScope {
fs.opts = &opts
return fs
}
// WithCount 查询的同时获取总数量
// 将按照 SetFilter 的条件进行查询 未设置是获取所有文档数量
// 如果在该步骤出错 将不会进行余下操作
func (fs *FindScope) WithCount(count *int64) *FindScope {
fs.withCount = true
fs.count = count
return fs
}
func (fs *FindScope) optionAssembled() {
// 配置项被直接调用重写过
if fs.opts == nil {
fs.opts = new(options.FindOptions)
}
if fs.sort != nil {
fs.opts.Sort = fs.sort
}
if fs.skip != nil {
fs.opts.Skip = fs.skip
}
if fs.limit != nil {
fs.opts.Limit = fs.limit
}
if fs.selects != nil {
fs.opts.Projection = fs.selects
}
}
func (fs *FindScope) preCheck() {
if fs.filter == nil {
fs.filter = bson.M{}
}
var breakerTTL time.Duration
if fs.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if fs.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = fs.scope.breaker.ttl
}
if fs.cw == nil {
fs.cw = &ctxWrap{}
}
if fs.cw.ctx == nil {
fs.cw.ctx, fs.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
} else {
fs.cw.ctx, fs.cw.cancel = context.WithTimeout(fs.cw.ctx, breakerTTL)
}
}
func (fs *FindScope) assertErr() {
if fs.err == nil {
return
}
if errors.Is(fs.err, mongo.ErrNoDocuments) || errors.Is(fs.err, mongo.ErrNilDocument) {
fs.err = &ErrRecordNotFound
return
}
if errors.Is(fs.err, context.DeadlineExceeded) {
fs.err = &ErrRequestBroken
return
}
err, ok := fs.err.(mongo.CommandError)
if ok && err.HasErrorMessage(context.DeadlineExceeded.Error()) {
fs.err = &ErrRequestBroken
return
}
fs.err = err
}
func (fs *FindScope) doString() string {
builder := RegisterTimestampCodec(nil).Build()
filter, _ := bson.MarshalExtJSONWithRegistry(builder, fs.filter, true, true)
query := fmt.Sprintf("find(%s)", string(filter))
if fs.skip != nil {
query = fmt.Sprintf("%s.skip(%d)", query, *fs.skip)
}
if fs.limit != nil {
query = fmt.Sprintf("%s.limit(%d)", query, *fs.limit)
}
if fs.sort != nil {
sort, _ := bson.MarshalExtJSON(fs.sort, true, true)
query = fmt.Sprintf("%s.sort(%s)", query, string(sort))
}
return query
}
func (fs *FindScope) debug() {
if !fs.scope.debug && !fs.scope.debugWhenError {
return
}
debugger := &Debugger{
collection: fs.scope.tableName,
execT: fs.scope.execT,
action: fs,
}
// 当错误时优先输出
if fs.scope.debugWhenError {
if fs.err != nil {
debugger.errMsg = fs.err.Error()
debugger.ErrorString()
}
return
}
// 所有bug输出
if fs.scope.debug {
debugger.String()
}
}
func (fs *FindScope) doClear() {
if fs.cw != nil && fs.cw.cancel != nil {
fs.cw.cancel()
}
fs.scope.debug = false
fs.scope.execT = 0
}
func (fs *FindScope) doSearch() {
var starTime time.Time
if fs.scope.debug {
starTime = time.Now()
}
fs.cursor, fs.err = db.Collection(fs.scope.tableName).Find(fs.getContext(), fs.filter, fs.opts)
fs.assertErr() // 断言错误
if fs.scope.debug {
fs.scope.execT = time.Since(starTime)
fs.debug()
}
// 有检测数量的
if fs.withCount {
var res int64
res, fs.err = fs.scope.Count().SetContext(fs.getContext()).SetFilter(fs.filter).Count()
*fs.count = res
}
fs.assertErr()
}
// GetList 获取列表
// list: 需要一个 *[]*struct
func (fs *FindScope) GetList(list interface{}) error {
defer fs.doClear()
fs.optionAssembled()
fs.preCheck()
if fs.err != nil {
return fs.err
}
fs.doSearch()
if fs.err != nil {
return fs.err
}
v := reflect.ValueOf(list)
if v.Type().Kind() != reflect.Ptr {
fs.err = fmt.Errorf("invalid list type, not ptr")
return fs.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
fs.err = fmt.Errorf("invalid list type, not ptr to slice")
return fs.err
}
fs.err = fs.cursor.All(fs.getContext(), list)
if fs.enableCache {
fs.doCache(list)
}
return fs.err
}
// GetMap 基于结果的某个字段为Key 获取Map
// m: 传递一个 *map[string]*Struct
// field: struct的一个字段名称 需要是公开可访问的大写
func (fs *FindScope) GetMap(m interface{}, field string) error {
defer fs.doClear()
fs.optionAssembled()
fs.preCheck()
if fs.err != nil {
return fs.err
}
fs.doSearch()
if fs.err != nil {
return fs.err
}
v := reflect.ValueOf(m)
if v.Type().Kind() != reflect.Ptr {
fs.err = fmt.Errorf("invalid map type, not ptr")
return fs.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Map {
fs.err = fmt.Errorf("invalid map type, not map")
return fs.err
}
mapType := v.Type()
valueType := mapType.Elem()
keyType := mapType.Key()
if valueType.Kind() != reflect.Ptr {
fs.err = fmt.Errorf("invalid map value type, not prt")
return fs.err
}
if !v.CanSet() {
fs.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields")
return fs.err
}
v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType)))
for fs.cursor.Next(fs.getContext()) {
t := reflect.New(valueType.Elem())
if fs.err = fs.cursor.Decode(t.Interface()); fs.err != nil {
logrus.Errorf("err: %+v", fs.err)
return fs.err
}
fieldNode := t.Elem().FieldByName(field)
if fieldNode.Kind() == reflect.Invalid {
fs.err = fmt.Errorf("invalid model key: %s", field)
return fs.err
}
v.SetMapIndex(fieldNode, t)
}
return fs.err
}
func (fs *FindScope) Error() error {
return fs.err
}
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现
func (fs *FindScope) SetCacheFunc(key string, cb FindCacheFunc) *FindScope {
fs.enableCache = true
fs.cacheFunc = cb
fs.cacheKey = key
return fs
}
// doCache 执行缓存
func (fs *FindScope) doCache(obj interface{}) *FindScope {
// redis句柄不存在
if fs.scope.cache == nil {
return nil
}
cacheObj, err := fs.cacheFunc(fs.cacheKey, obj)
if err != nil {
fs.err = err
return fs
}
if cacheObj == nil {
fs.err = fmt.Errorf("cache object nil")
return fs
}
ttl := fs.scope.cache.ttl
if ttl == 0 {
ttl = time.Hour
} else if ttl == -1 {
ttl = 0
}
if fs.getContext().Err() != nil {
fs.err = fs.getContext().Err()
fs.assertErr()
return fs
}
fs.err = fs.scope.cache.client.Set(fs.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
// 这里也要断言错误 看是否是熔断错误
return fs
}