mdbc/find_scope.go

446 lines
10 KiB
Go
Raw Normal View History

2022-02-23 08:59:45 +00:00
package mdbc
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"time"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type FindScope struct {
scope *Scope // 向上提权
cw *ctxWrap // 查询上下文环境
cursor *mongo.Cursor // 查询指针
err error // 操作是否有错误
limit *int64 // 限制查询数量
skip *int64 // 偏移量
count *int64 // 计数统计结果
withCount bool // 是否开启计数统计 这将只基于filter进行计数而忽略 limit skip
selects interface{} //
sort bson.D // 排序
filter interface{} // 过滤条件
opts *options.FindOptions // 查询条件
enableCache bool // 是否开启缓存
cacheKey string // 缓存key
cacheFunc FindCacheFunc // 基于什么缓存函数进行缓存
}
// FindCacheFunc 缓存Find结果集
// field: 若无作用 可置空 DefaultFindCacheFunc将使用其反射出field字段对应的value并将其设为key
// obj: 这是一个slice 只能在GetList是使用GetMap使用将无效果
type FindCacheFunc func(field string, obj interface{}) (*CacheObject, error)
// DefaultFindCacheFunc 按照第一个结果的field字段作为key缓存list数据
// field字段需要满足所有数据均有该属性 否则有概率导致缓存失败
var DefaultFindCacheFunc = func() FindCacheFunc {
return func(field string, obj interface{}) (*CacheObject, error) {
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
// 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可
// 下面使用反射进行操作 保证兼容
v := reflect.ValueOf(obj)
if v.Type().Kind() != reflect.Ptr {
return nil, fmt.Errorf("invalid list type, not ptr")
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
return nil, fmt.Errorf("invalid list type, not ptr to slice")
}
// 空slice 无需缓存
if v.Len() == 0 {
return nil, nil
}
firstNode := v.Index(0)
// 判断过长度 无需判断nil
if firstNode.Kind() == reflect.Ptr {
firstNode = firstNode.Elem()
}
idVal := firstNode.FieldByName(field)
if idVal.Kind() == reflect.Invalid {
return nil, fmt.Errorf("first node %s field not found", idVal)
}
b, _ := json.Marshal(v.Interface())
co := &CacheObject{
Key: idVal.String(),
Value: string(b),
}
return co, nil
}
}
// SetLimit 设置获取的条数
func (fs *FindScope) SetLimit(limit int64) *FindScope {
fs.limit = &limit
return fs
}
// SetSkip 设置跳过的条数
func (fs *FindScope) SetSkip(skip int64) *FindScope {
fs.skip = &skip
return fs
}
// SetSort 设置排序
func (fs *FindScope) SetSort(sort bson.D) *FindScope {
fs.sort = sort
return fs
}
// SetContext 设置上下文
func (fs *FindScope) SetContext(ctx context.Context) *FindScope {
if fs.cw == nil {
fs.cw = &ctxWrap{}
}
fs.cw.ctx = ctx
return fs
}
// SetSelect 选择查询字段 格式 {key: 1, key: 0}
// 你可以传入 bson.M, map[string]integer_interface 其他类型将导致该功能失效
// 1显示 0不显示
func (fs *FindScope) SetSelect(selects interface{}) *FindScope {
vo := reflect.ValueOf(selects)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
// 不是map 提前返回
if vo.Kind() != reflect.Map {
return fs
}
fs.selects = selects
return fs
}
func (fs *FindScope) getContext() context.Context {
if fs.cw == nil {
fs.cw = &ctxWrap{
ctx: context.Background(),
}
}
return fs.cw.ctx
}
// SetFilter 设置过滤条件
func (fs *FindScope) SetFilter(filter interface{}) *FindScope {
if filter == nil {
fs.filter = bson.M{}
return fs
}
v := reflect.ValueOf(filter)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
fs.filter = bson.M{}
}
}
fs.filter = filter
return fs
}
// SetFindOption 设置FindOption 优先级最低 sort/skip/limit 会被 set 函数重写
func (fs *FindScope) SetFindOption(opts options.FindOptions) *FindScope {
fs.opts = &opts
return fs
}
// WithCount 查询的同时获取总数量
// 将按照 SetFilter 的条件进行查询 未设置是获取所有文档数量
// 如果在该步骤出错 将不会进行余下操作
func (fs *FindScope) WithCount(count *int64) *FindScope {
fs.withCount = true
fs.count = count
return fs
}
func (fs *FindScope) optionAssembled() {
// 配置项被直接调用重写过
if fs.opts == nil {
fs.opts = new(options.FindOptions)
}
if fs.sort != nil {
fs.opts.Sort = fs.sort
}
if fs.skip != nil {
fs.opts.Skip = fs.skip
}
if fs.limit != nil {
fs.opts.Limit = fs.limit
}
if fs.selects != nil {
fs.opts.Projection = fs.selects
}
}
func (fs *FindScope) preCheck() {
if fs.filter == nil {
fs.filter = bson.M{}
}
var breakerTTL time.Duration
if fs.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if fs.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = fs.scope.breaker.ttl
}
if fs.cw == nil {
fs.cw = &ctxWrap{}
}
if fs.cw.ctx == nil {
fs.cw.ctx, fs.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
} else {
fs.cw.ctx, fs.cw.cancel = context.WithTimeout(fs.cw.ctx, breakerTTL)
}
}
func (fs *FindScope) assertErr() {
if fs.err == nil {
return
}
if errors.Is(fs.err, mongo.ErrNoDocuments) || errors.Is(fs.err, mongo.ErrNilDocument) {
fs.err = &ErrRecordNotFound
return
}
if errors.Is(fs.err, context.DeadlineExceeded) {
fs.err = &ErrRequestBroken
return
}
err, ok := fs.err.(mongo.CommandError)
if ok && err.HasErrorMessage(context.DeadlineExceeded.Error()) {
fs.err = &ErrRequestBroken
return
}
fs.err = err
}
func (fs *FindScope) doString() string {
builder := RegisterTimestampCodec(nil).Build()
filter, _ := bson.MarshalExtJSONWithRegistry(builder, fs.filter, true, true)
query := fmt.Sprintf("find(%s)", string(filter))
if fs.skip != nil {
query = fmt.Sprintf("%s.skip(%d)", query, *fs.skip)
}
if fs.limit != nil {
query = fmt.Sprintf("%s.limit(%d)", query, *fs.limit)
}
if fs.sort != nil {
sort, _ := bson.MarshalExtJSON(fs.sort, true, true)
query = fmt.Sprintf("%s.sort(%s)", query, string(sort))
}
return query
}
func (fs *FindScope) debug() {
if !fs.scope.debug && !fs.scope.debugWhenError {
return
}
debugger := &Debugger{
collection: fs.scope.tableName,
execT: fs.scope.execT,
action: fs,
}
// 当错误时优先输出
if fs.scope.debugWhenError {
if fs.err != nil {
debugger.errMsg = fs.err.Error()
debugger.ErrorString()
}
return
}
// 所有bug输出
if fs.scope.debug {
debugger.String()
}
}
func (fs *FindScope) doClear() {
if fs.cw != nil && fs.cw.cancel != nil {
fs.cw.cancel()
}
fs.scope.debug = false
fs.scope.execT = 0
}
func (fs *FindScope) doSearch() {
var starTime time.Time
if fs.scope.debug {
starTime = time.Now()
}
fs.cursor, fs.err = db.Collection(fs.scope.tableName).Find(fs.getContext(), fs.filter, fs.opts)
fs.assertErr() // 断言错误
if fs.scope.debug {
fs.scope.execT = time.Since(starTime)
fs.debug()
}
// 有检测数量的
if fs.withCount {
var res int64
res, fs.err = fs.scope.Count().SetContext(fs.getContext()).SetFilter(fs.filter).Count()
*fs.count = res
}
fs.assertErr()
}
// GetList 获取列表
// list: 需要一个 *[]*struct
func (fs *FindScope) GetList(list interface{}) error {
defer fs.doClear()
fs.optionAssembled()
fs.preCheck()
if fs.err != nil {
return fs.err
}
fs.doSearch()
if fs.err != nil {
return fs.err
}
v := reflect.ValueOf(list)
if v.Type().Kind() != reflect.Ptr {
fs.err = fmt.Errorf("invalid list type, not ptr")
return fs.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
fs.err = fmt.Errorf("invalid list type, not ptr to slice")
return fs.err
}
fs.err = fs.cursor.All(fs.getContext(), list)
if fs.enableCache {
fs.doCache(list)
}
return fs.err
}
// GetMap 基于结果的某个字段为Key 获取Map
// m: 传递一个 *map[string]*Struct
// field: struct的一个字段名称 需要是公开可访问的大写
func (fs *FindScope) GetMap(m interface{}, field string) error {
defer fs.doClear()
fs.optionAssembled()
fs.preCheck()
if fs.err != nil {
return fs.err
}
fs.doSearch()
if fs.err != nil {
return fs.err
}
v := reflect.ValueOf(m)
if v.Type().Kind() != reflect.Ptr {
fs.err = fmt.Errorf("invalid map type, not ptr")
return fs.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Map {
fs.err = fmt.Errorf("invalid map type, not map")
return fs.err
}
mapType := v.Type()
valueType := mapType.Elem()
keyType := mapType.Key()
if valueType.Kind() != reflect.Ptr {
fs.err = fmt.Errorf("invalid map value type, not prt")
return fs.err
}
if !v.CanSet() {
fs.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields")
return fs.err
}
v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType)))
for fs.cursor.Next(fs.getContext()) {
t := reflect.New(valueType.Elem())
if fs.err = fs.cursor.Decode(t.Interface()); fs.err != nil {
logrus.Errorf("err: %+v", fs.err)
return fs.err
}
fieldNode := t.Elem().FieldByName(field)
if fieldNode.Kind() == reflect.Invalid {
fs.err = fmt.Errorf("invalid model key: %s", field)
return fs.err
}
v.SetMapIndex(fieldNode, t)
}
return fs.err
}
func (fs *FindScope) Error() error {
return fs.err
}
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现
func (fs *FindScope) SetCacheFunc(key string, cb FindCacheFunc) *FindScope {
fs.enableCache = true
fs.cacheFunc = cb
fs.cacheKey = key
return fs
}
// doCache 执行缓存
func (fs *FindScope) doCache(obj interface{}) *FindScope {
// redis句柄不存在
if fs.scope.cache == nil {
return nil
}
cacheObj, err := fs.cacheFunc(fs.cacheKey, obj)
if err != nil {
fs.err = err
return fs
}
if cacheObj == nil {
fs.err = fmt.Errorf("cache object nil")
return fs
}
ttl := fs.scope.cache.ttl
if ttl == 0 {
ttl = time.Hour
} else if ttl == -1 {
ttl = 0
}
if fs.getContext().Err() != nil {
fs.err = fs.getContext().Err()
fs.assertErr()
return fs
}
fs.err = fs.scope.cache.client.Set(fs.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
// 这里也要断言错误 看是否是熔断错误
return fs
}