446 lines
10 KiB
Go
446 lines
10 KiB
Go
package mdbc
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"reflect"
|
||
"time"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
"go.mongodb.org/mongo-driver/bson"
|
||
"go.mongodb.org/mongo-driver/mongo"
|
||
"go.mongodb.org/mongo-driver/mongo/options"
|
||
)
|
||
|
||
type FindScope struct {
|
||
scope *Scope // 向上提权
|
||
cw *ctxWrap // 查询上下文环境
|
||
cursor *mongo.Cursor // 查询指针
|
||
err error // 操作是否有错误
|
||
limit *int64 // 限制查询数量
|
||
skip *int64 // 偏移量
|
||
count *int64 // 计数统计结果
|
||
withCount bool // 是否开启计数统计 这将只基于filter进行计数而忽略 limit skip
|
||
selects interface{} //
|
||
sort bson.D // 排序
|
||
filter interface{} // 过滤条件
|
||
opts *options.FindOptions // 查询条件
|
||
enableCache bool // 是否开启缓存
|
||
cacheKey string // 缓存key
|
||
cacheFunc FindCacheFunc // 基于什么缓存函数进行缓存
|
||
}
|
||
|
||
// FindCacheFunc 缓存Find结果集
|
||
// field: 若无作用 可置空 DefaultFindCacheFunc将使用其反射出field字段对应的value并将其设为key
|
||
// obj: 这是一个slice 只能在GetList是使用,GetMap使用将无效果
|
||
type FindCacheFunc func(field string, obj interface{}) (*CacheObject, error)
|
||
|
||
// DefaultFindCacheFunc 按照第一个结果的field字段作为key缓存list数据
|
||
// field字段需要满足所有数据均有该属性 否则有概率导致缓存失败
|
||
var DefaultFindCacheFunc = func() FindCacheFunc {
|
||
return func(field string, obj interface{}) (*CacheObject, error) {
|
||
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
|
||
// 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可
|
||
// 下面使用反射进行操作 保证兼容
|
||
v := reflect.ValueOf(obj)
|
||
if v.Type().Kind() != reflect.Ptr {
|
||
return nil, fmt.Errorf("invalid list type, not ptr")
|
||
}
|
||
|
||
v = v.Elem()
|
||
if v.Type().Kind() != reflect.Slice {
|
||
return nil, fmt.Errorf("invalid list type, not ptr to slice")
|
||
}
|
||
|
||
// 空slice 无需缓存
|
||
if v.Len() == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
firstNode := v.Index(0)
|
||
// 判断过长度 无需判断nil
|
||
if firstNode.Kind() == reflect.Ptr {
|
||
firstNode = firstNode.Elem()
|
||
}
|
||
idVal := firstNode.FieldByName(field)
|
||
if idVal.Kind() == reflect.Invalid {
|
||
return nil, fmt.Errorf("first node %s field not found", idVal)
|
||
}
|
||
|
||
b, _ := json.Marshal(v.Interface())
|
||
co := &CacheObject{
|
||
Key: idVal.String(),
|
||
Value: string(b),
|
||
}
|
||
|
||
return co, nil
|
||
}
|
||
}
|
||
|
||
// SetLimit 设置获取的条数
|
||
func (fs *FindScope) SetLimit(limit int64) *FindScope {
|
||
fs.limit = &limit
|
||
return fs
|
||
}
|
||
|
||
// SetSkip 设置跳过的条数
|
||
func (fs *FindScope) SetSkip(skip int64) *FindScope {
|
||
fs.skip = &skip
|
||
return fs
|
||
}
|
||
|
||
// SetSort 设置排序
|
||
func (fs *FindScope) SetSort(sort bson.D) *FindScope {
|
||
fs.sort = sort
|
||
return fs
|
||
}
|
||
|
||
// SetContext 设置上下文
|
||
func (fs *FindScope) SetContext(ctx context.Context) *FindScope {
|
||
if fs.cw == nil {
|
||
fs.cw = &ctxWrap{}
|
||
}
|
||
fs.cw.ctx = ctx
|
||
return fs
|
||
}
|
||
|
||
// SetSelect 选择查询字段 格式 {key: 1, key: 0}
|
||
// 你可以传入 bson.M, map[string]integer_interface 其他类型将导致该功能失效
|
||
// 1:显示 0:不显示
|
||
func (fs *FindScope) SetSelect(selects interface{}) *FindScope {
|
||
vo := reflect.ValueOf(selects)
|
||
if vo.Kind() == reflect.Ptr {
|
||
vo = vo.Elem()
|
||
}
|
||
// 不是map 提前返回
|
||
if vo.Kind() != reflect.Map {
|
||
return fs
|
||
}
|
||
|
||
fs.selects = selects
|
||
return fs
|
||
}
|
||
|
||
func (fs *FindScope) getContext() context.Context {
|
||
if fs.cw == nil {
|
||
fs.cw = &ctxWrap{
|
||
ctx: context.Background(),
|
||
}
|
||
}
|
||
return fs.cw.ctx
|
||
}
|
||
|
||
// SetFilter 设置过滤条件
|
||
func (fs *FindScope) SetFilter(filter interface{}) *FindScope {
|
||
if filter == nil {
|
||
fs.filter = bson.M{}
|
||
return fs
|
||
}
|
||
|
||
v := reflect.ValueOf(filter)
|
||
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
|
||
if v.IsNil() {
|
||
fs.filter = bson.M{}
|
||
}
|
||
}
|
||
fs.filter = filter
|
||
return fs
|
||
}
|
||
|
||
// SetFindOption 设置FindOption 优先级最低 sort/skip/limit 会被 set 函数重写
|
||
func (fs *FindScope) SetFindOption(opts options.FindOptions) *FindScope {
|
||
fs.opts = &opts
|
||
return fs
|
||
}
|
||
|
||
// WithCount 查询的同时获取总数量
|
||
// 将按照 SetFilter 的条件进行查询 未设置是获取所有文档数量
|
||
// 如果在该步骤出错 将不会进行余下操作
|
||
func (fs *FindScope) WithCount(count *int64) *FindScope {
|
||
fs.withCount = true
|
||
fs.count = count
|
||
return fs
|
||
}
|
||
|
||
func (fs *FindScope) optionAssembled() {
|
||
// 配置项被直接调用重写过
|
||
if fs.opts == nil {
|
||
fs.opts = new(options.FindOptions)
|
||
}
|
||
|
||
if fs.sort != nil {
|
||
fs.opts.Sort = fs.sort
|
||
}
|
||
|
||
if fs.skip != nil {
|
||
fs.opts.Skip = fs.skip
|
||
}
|
||
|
||
if fs.limit != nil {
|
||
fs.opts.Limit = fs.limit
|
||
}
|
||
|
||
if fs.selects != nil {
|
||
fs.opts.Projection = fs.selects
|
||
}
|
||
}
|
||
|
||
func (fs *FindScope) preCheck() {
|
||
if fs.filter == nil {
|
||
fs.filter = bson.M{}
|
||
}
|
||
var breakerTTL time.Duration
|
||
if fs.scope.breaker == nil {
|
||
breakerTTL = defaultBreakerTime
|
||
} else if fs.scope.breaker.ttl == 0 {
|
||
breakerTTL = defaultBreakerTime
|
||
} else {
|
||
breakerTTL = fs.scope.breaker.ttl
|
||
}
|
||
if fs.cw == nil {
|
||
fs.cw = &ctxWrap{}
|
||
}
|
||
if fs.cw.ctx == nil {
|
||
fs.cw.ctx, fs.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
|
||
} else {
|
||
fs.cw.ctx, fs.cw.cancel = context.WithTimeout(fs.cw.ctx, breakerTTL)
|
||
}
|
||
}
|
||
|
||
func (fs *FindScope) assertErr() {
|
||
if fs.err == nil {
|
||
return
|
||
}
|
||
if errors.Is(fs.err, mongo.ErrNoDocuments) || errors.Is(fs.err, mongo.ErrNilDocument) {
|
||
fs.err = &ErrRecordNotFound
|
||
return
|
||
}
|
||
if errors.Is(fs.err, context.DeadlineExceeded) {
|
||
fs.err = &ErrRequestBroken
|
||
return
|
||
}
|
||
err, ok := fs.err.(mongo.CommandError)
|
||
if ok && err.HasErrorMessage(context.DeadlineExceeded.Error()) {
|
||
fs.err = &ErrRequestBroken
|
||
return
|
||
}
|
||
fs.err = err
|
||
}
|
||
|
||
func (fs *FindScope) doString() string {
|
||
builder := RegisterTimestampCodec(nil).Build()
|
||
filter, _ := bson.MarshalExtJSONWithRegistry(builder, fs.filter, true, true)
|
||
query := fmt.Sprintf("find(%s)", string(filter))
|
||
if fs.skip != nil {
|
||
query = fmt.Sprintf("%s.skip(%d)", query, *fs.skip)
|
||
}
|
||
if fs.limit != nil {
|
||
query = fmt.Sprintf("%s.limit(%d)", query, *fs.limit)
|
||
}
|
||
if fs.sort != nil {
|
||
sort, _ := bson.MarshalExtJSON(fs.sort, true, true)
|
||
query = fmt.Sprintf("%s.sort(%s)", query, string(sort))
|
||
}
|
||
return query
|
||
}
|
||
|
||
func (fs *FindScope) debug() {
|
||
if !fs.scope.debug && !fs.scope.debugWhenError {
|
||
return
|
||
}
|
||
|
||
debugger := &Debugger{
|
||
collection: fs.scope.tableName,
|
||
execT: fs.scope.execT,
|
||
action: fs,
|
||
}
|
||
|
||
// 当错误时优先输出
|
||
if fs.scope.debugWhenError {
|
||
if fs.err != nil {
|
||
debugger.errMsg = fs.err.Error()
|
||
debugger.ErrorString()
|
||
}
|
||
return
|
||
}
|
||
|
||
// 所有bug输出
|
||
if fs.scope.debug {
|
||
debugger.String()
|
||
}
|
||
}
|
||
|
||
func (fs *FindScope) doClear() {
|
||
if fs.cw != nil && fs.cw.cancel != nil {
|
||
fs.cw.cancel()
|
||
}
|
||
fs.scope.debug = false
|
||
fs.scope.execT = 0
|
||
}
|
||
|
||
func (fs *FindScope) doSearch() {
|
||
var starTime time.Time
|
||
if fs.scope.debug {
|
||
starTime = time.Now()
|
||
}
|
||
fs.cursor, fs.err = db.Collection(fs.scope.tableName).Find(fs.getContext(), fs.filter, fs.opts)
|
||
fs.assertErr() // 断言错误
|
||
if fs.scope.debug {
|
||
fs.scope.execT = time.Since(starTime)
|
||
fs.debug()
|
||
}
|
||
|
||
// 有检测数量的
|
||
if fs.withCount {
|
||
var res int64
|
||
res, fs.err = fs.scope.Count().SetContext(fs.getContext()).SetFilter(fs.filter).Count()
|
||
*fs.count = res
|
||
}
|
||
fs.assertErr()
|
||
}
|
||
|
||
// GetList 获取列表
|
||
// list: 需要一个 *[]*struct
|
||
func (fs *FindScope) GetList(list interface{}) error {
|
||
defer fs.doClear()
|
||
fs.optionAssembled()
|
||
fs.preCheck()
|
||
if fs.err != nil {
|
||
return fs.err
|
||
}
|
||
fs.doSearch()
|
||
|
||
if fs.err != nil {
|
||
return fs.err
|
||
}
|
||
v := reflect.ValueOf(list)
|
||
if v.Type().Kind() != reflect.Ptr {
|
||
fs.err = fmt.Errorf("invalid list type, not ptr")
|
||
return fs.err
|
||
}
|
||
|
||
v = v.Elem()
|
||
if v.Type().Kind() != reflect.Slice {
|
||
fs.err = fmt.Errorf("invalid list type, not ptr to slice")
|
||
return fs.err
|
||
}
|
||
|
||
fs.err = fs.cursor.All(fs.getContext(), list)
|
||
|
||
if fs.enableCache {
|
||
fs.doCache(list)
|
||
}
|
||
|
||
return fs.err
|
||
}
|
||
|
||
// GetMap 基于结果的某个字段为Key 获取Map
|
||
// m: 传递一个 *map[string]*Struct
|
||
// field: struct的一个字段名称 需要是公开可访问的大写
|
||
func (fs *FindScope) GetMap(m interface{}, field string) error {
|
||
defer fs.doClear()
|
||
fs.optionAssembled()
|
||
fs.preCheck()
|
||
if fs.err != nil {
|
||
return fs.err
|
||
}
|
||
fs.doSearch()
|
||
|
||
if fs.err != nil {
|
||
return fs.err
|
||
}
|
||
|
||
v := reflect.ValueOf(m)
|
||
if v.Type().Kind() != reflect.Ptr {
|
||
fs.err = fmt.Errorf("invalid map type, not ptr")
|
||
return fs.err
|
||
}
|
||
v = v.Elem()
|
||
if v.Type().Kind() != reflect.Map {
|
||
fs.err = fmt.Errorf("invalid map type, not map")
|
||
return fs.err
|
||
}
|
||
|
||
mapType := v.Type()
|
||
valueType := mapType.Elem()
|
||
keyType := mapType.Key()
|
||
|
||
if valueType.Kind() != reflect.Ptr {
|
||
fs.err = fmt.Errorf("invalid map value type, not prt")
|
||
return fs.err
|
||
}
|
||
|
||
if !v.CanSet() {
|
||
fs.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields")
|
||
return fs.err
|
||
}
|
||
v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType)))
|
||
|
||
for fs.cursor.Next(fs.getContext()) {
|
||
t := reflect.New(valueType.Elem())
|
||
if fs.err = fs.cursor.Decode(t.Interface()); fs.err != nil {
|
||
logrus.Errorf("err: %+v", fs.err)
|
||
return fs.err
|
||
}
|
||
fieldNode := t.Elem().FieldByName(field)
|
||
if fieldNode.Kind() == reflect.Invalid {
|
||
fs.err = fmt.Errorf("invalid model key: %s", field)
|
||
return fs.err
|
||
}
|
||
v.SetMapIndex(fieldNode, t)
|
||
}
|
||
|
||
return fs.err
|
||
}
|
||
|
||
func (fs *FindScope) Error() error {
|
||
return fs.err
|
||
}
|
||
|
||
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现
|
||
func (fs *FindScope) SetCacheFunc(key string, cb FindCacheFunc) *FindScope {
|
||
fs.enableCache = true
|
||
fs.cacheFunc = cb
|
||
fs.cacheKey = key
|
||
return fs
|
||
}
|
||
|
||
// doCache 执行缓存
|
||
func (fs *FindScope) doCache(obj interface{}) *FindScope {
|
||
// redis句柄不存在
|
||
if fs.scope.cache == nil {
|
||
return nil
|
||
}
|
||
cacheObj, err := fs.cacheFunc(fs.cacheKey, obj)
|
||
if err != nil {
|
||
fs.err = err
|
||
return fs
|
||
}
|
||
|
||
if cacheObj == nil {
|
||
fs.err = fmt.Errorf("cache object nil")
|
||
return fs
|
||
}
|
||
|
||
ttl := fs.scope.cache.ttl
|
||
if ttl == 0 {
|
||
ttl = time.Hour
|
||
} else if ttl == -1 {
|
||
ttl = 0
|
||
}
|
||
|
||
if fs.getContext().Err() != nil {
|
||
fs.err = fs.getContext().Err()
|
||
fs.assertErr()
|
||
return fs
|
||
}
|
||
|
||
fs.err = fs.scope.cache.client.Set(fs.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
|
||
|
||
// 这里也要断言错误 看是否是熔断错误
|
||
|
||
return fs
|
||
}
|