413 lines
9.5 KiB
Go
413 lines
9.5 KiB
Go
package mdbc
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"reflect"
|
||
"time"
|
||
|
||
jsoniter "github.com/json-iterator/go"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
"go.mongodb.org/mongo-driver/bson"
|
||
"go.mongodb.org/mongo-driver/mongo"
|
||
"go.mongodb.org/mongo-driver/mongo/options"
|
||
)
|
||
|
||
type AggregateScope struct {
|
||
scope *Scope
|
||
cw *ctxWrap
|
||
err error
|
||
pipeline interface{}
|
||
opts *options.AggregateOptions
|
||
cursor *mongo.Cursor
|
||
enableCache bool
|
||
cacheKey string
|
||
cacheFunc AggregateCacheFunc
|
||
}
|
||
|
||
// AggregateFunc 自定义获取结果集的注入函数
|
||
type AggregateFunc func(cursor *mongo.Cursor) (interface{}, error)
|
||
|
||
// AggregateCacheFunc 缓存Find结果集
|
||
// field: 若无作用 可置空 DefaultAggregateCacheFunc 将使用其反射出field字段对应的value并将其设为key
|
||
// obj: 这是一个slice 只能在GetList是使用,GetMap使用将无效果
|
||
type AggregateCacheFunc func(field string, obj interface{}) (*CacheObject, error)
|
||
|
||
// DefaultAggregateCacheFunc 按照第一个结果的field字段作为key缓存list数据
|
||
// field字段需要满足所有数据均有该属性 否则有概率导致缓存失败
|
||
// 和 DefaultFindCacheFunc 有重复嫌疑 需要抽象
|
||
var DefaultAggregateCacheFunc = func() AggregateCacheFunc {
|
||
return func(field string, obj interface{}) (*CacheObject, error) {
|
||
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
|
||
// 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可
|
||
// 下面使用反射进行操作 保证兼容
|
||
v := reflect.ValueOf(obj)
|
||
if v.Type().Kind() != reflect.Ptr {
|
||
return nil, fmt.Errorf("invalid list type, not ptr")
|
||
}
|
||
|
||
v = v.Elem()
|
||
if v.Type().Kind() != reflect.Slice {
|
||
return nil, fmt.Errorf("invalid list type, not ptr to slice")
|
||
}
|
||
|
||
// 空slice 无需缓存
|
||
if v.Len() == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
firstNode := v.Index(0)
|
||
// 判断过长度 无需判断nil
|
||
if firstNode.Kind() == reflect.Ptr {
|
||
firstNode = firstNode.Elem()
|
||
}
|
||
idVal := firstNode.FieldByName(field)
|
||
if idVal.Kind() == reflect.Invalid {
|
||
return nil, fmt.Errorf("first node's %s field not found", field)
|
||
}
|
||
|
||
b, _ := json.Marshal(v.Interface())
|
||
co := &CacheObject{
|
||
Key: idVal.String(),
|
||
Value: string(b),
|
||
}
|
||
|
||
return co, nil
|
||
}
|
||
}
|
||
|
||
// SetContext 设置上下文
|
||
func (as *AggregateScope) SetContext(ctx context.Context) *AggregateScope {
|
||
if as.cw == nil {
|
||
as.cw = &ctxWrap{}
|
||
}
|
||
as.cw.ctx = ctx
|
||
return as
|
||
}
|
||
|
||
// getContext 获取ctx
|
||
func (as *AggregateScope) getContext() context.Context {
|
||
return as.cw.ctx
|
||
}
|
||
|
||
// preCheck 预检查
|
||
func (as *AggregateScope) preCheck() {
|
||
var breakerTTL time.Duration
|
||
if as.scope.breaker == nil {
|
||
breakerTTL = defaultBreakerTime
|
||
} else if as.scope.breaker.ttl == 0 {
|
||
breakerTTL = defaultBreakerTime
|
||
} else {
|
||
breakerTTL = as.scope.breaker.ttl
|
||
}
|
||
if as.cw == nil {
|
||
as.cw = &ctxWrap{}
|
||
}
|
||
if as.cw.ctx == nil {
|
||
as.cw.ctx, as.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
|
||
}
|
||
}
|
||
|
||
// SetAggregateOption 重写配置项
|
||
func (as *AggregateScope) SetAggregateOption(opts options.AggregateOptions) *AggregateScope {
|
||
as.opts = &opts
|
||
return as
|
||
}
|
||
|
||
// SetPipeline 设置管道
|
||
// 传递 []bson.D 或者mongo.Pipeline
|
||
func (as *AggregateScope) SetPipeline(pipeline interface{}) *AggregateScope {
|
||
as.pipeline = pipeline
|
||
return as
|
||
}
|
||
|
||
// assertErr 断言错误
|
||
func (as *AggregateScope) assertErr() {
|
||
if as.err == nil {
|
||
return
|
||
}
|
||
if errors.Is(as.err, context.DeadlineExceeded) {
|
||
as.err = &ErrRequestBroken
|
||
return
|
||
}
|
||
err, ok := as.err.(mongo.CommandError)
|
||
if !ok {
|
||
return
|
||
}
|
||
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
|
||
as.err = &ErrRequestBroken
|
||
}
|
||
}
|
||
|
||
// doString 实现debugger的 actionDo 接口
|
||
func (as *AggregateScope) doString() string {
|
||
var data []interface{}
|
||
builder := RegisterTimestampCodec(nil).Build()
|
||
vo := reflect.ValueOf(as.pipeline)
|
||
if vo.Kind() == reflect.Ptr {
|
||
vo = vo.Elem()
|
||
}
|
||
if vo.Kind() != reflect.Slice {
|
||
panic("pipeline type not slice")
|
||
}
|
||
for i := 0; i < vo.Len(); i++ {
|
||
var body interface{}
|
||
b, _ := bson.MarshalExtJSONWithRegistry(builder, vo.Index(i).Interface(), true, true)
|
||
_ = jsoniter.Unmarshal(b, &body)
|
||
data = append(data, body)
|
||
}
|
||
b, _ := jsoniter.Marshal(data)
|
||
return fmt.Sprintf("aggregate(%s)", string(b))
|
||
}
|
||
|
||
// debug 统一debug入口
|
||
func (as *AggregateScope) debug() {
|
||
if !as.scope.debug {
|
||
return
|
||
}
|
||
|
||
debugger := &Debugger{
|
||
collection: as.scope.tableName,
|
||
execT: as.scope.execT,
|
||
action: as,
|
||
}
|
||
debugger.String()
|
||
}
|
||
|
||
// doSearch 查询的真正语句
|
||
func (as *AggregateScope) doSearch() {
|
||
var starTime time.Time
|
||
if as.scope.debug {
|
||
starTime = time.Now()
|
||
}
|
||
as.cursor, as.err = db.Collection(as.scope.tableName).Aggregate(as.getContext(), as.pipeline, as.opts)
|
||
if as.scope.debug {
|
||
as.scope.execT = time.Since(starTime)
|
||
}
|
||
as.debug()
|
||
as.assertErr()
|
||
}
|
||
|
||
func (as *AggregateScope) doClear() {
|
||
if as.cw != nil && as.cw.cancel != nil {
|
||
as.cw.cancel()
|
||
}
|
||
as.scope.debug = false
|
||
as.scope.execT = 0
|
||
}
|
||
|
||
// GetList 获取管道数据列表 传递 *[]*TypeValue
|
||
func (as *AggregateScope) GetList(list interface{}) error {
|
||
defer as.doClear()
|
||
as.preCheck()
|
||
as.doSearch()
|
||
|
||
if as.err != nil {
|
||
return as.err
|
||
}
|
||
v := reflect.ValueOf(list)
|
||
if v.Type().Kind() != reflect.Ptr {
|
||
as.err = fmt.Errorf("invalid list type, not ptr")
|
||
return as.err
|
||
}
|
||
|
||
v = v.Elem()
|
||
if v.Type().Kind() != reflect.Slice {
|
||
as.err = fmt.Errorf("invalid list type, not ptr to slice")
|
||
return as.err
|
||
}
|
||
|
||
as.err = as.cursor.All(as.getContext(), list)
|
||
|
||
if as.enableCache {
|
||
as.doCache(list)
|
||
}
|
||
|
||
return as.err
|
||
}
|
||
|
||
// GetOne 获取管道数据的第一个元素 传递 *struct 不建议使用而建议直接用 limit:1
|
||
// 然后数组取第一个数据的方式获取结果 为了确保效率 当明确只获取一条时 请用 limit 限制一下
|
||
// 当获取不到数据的时候 会报错 ErrRecordNotFound 目前不支持缓存
|
||
func (as *AggregateScope) GetOne(obj interface{}) error {
|
||
v := reflect.ValueOf(obj)
|
||
vt := v.Type()
|
||
if vt.Kind() != reflect.Ptr {
|
||
as.err = fmt.Errorf("invalid type, not ptr")
|
||
return as.err
|
||
}
|
||
vt = vt.Elem()
|
||
if vt.Kind() != reflect.Struct {
|
||
as.err = fmt.Errorf("invalid type, not ptr to struct")
|
||
return as.err
|
||
}
|
||
|
||
defer as.doClear()
|
||
as.preCheck()
|
||
as.doSearch()
|
||
|
||
if as.err != nil {
|
||
return as.err
|
||
}
|
||
|
||
objType := reflect.TypeOf(obj)
|
||
objSliceType := reflect.SliceOf(objType)
|
||
objSlice := reflect.New(objSliceType)
|
||
objSliceInterface := objSlice.Interface()
|
||
|
||
as.err = as.cursor.All(as.getContext(), objSliceInterface)
|
||
if as.err != nil {
|
||
return as.err
|
||
}
|
||
|
||
realV := reflect.ValueOf(objSliceInterface)
|
||
if realV.Kind() == reflect.Ptr {
|
||
realV = realV.Elem()
|
||
}
|
||
|
||
if realV.Len() == 0 {
|
||
as.err = &ErrRecordNotFound
|
||
return as.err
|
||
}
|
||
|
||
firstNode := realV.Index(0)
|
||
if firstNode.Kind() != reflect.Ptr {
|
||
as.err = &ErrFindObjectTypeNotSupport
|
||
return as.err
|
||
}
|
||
|
||
v.Elem().Set(firstNode.Elem())
|
||
|
||
if as.enableCache {
|
||
as.doCache(obj)
|
||
}
|
||
|
||
return as.err
|
||
}
|
||
|
||
// GetMap 获取管道数据map 传递 *map[string]*TypeValue, field 基于哪个字段进行map构建
|
||
func (as *AggregateScope) GetMap(m interface{}, field string) error {
|
||
defer as.doClear()
|
||
as.preCheck()
|
||
as.doSearch()
|
||
|
||
if as.err != nil {
|
||
return as.err
|
||
}
|
||
|
||
v := reflect.ValueOf(m)
|
||
if v.Type().Kind() != reflect.Ptr {
|
||
as.err = fmt.Errorf("invalid map type, not ptr")
|
||
return as.err
|
||
}
|
||
v = v.Elem()
|
||
if v.Type().Kind() != reflect.Map {
|
||
as.err = fmt.Errorf("invalid map type, not map")
|
||
return as.err
|
||
}
|
||
|
||
mapType := v.Type()
|
||
valueType := mapType.Elem()
|
||
keyType := mapType.Key()
|
||
|
||
if valueType.Kind() != reflect.Ptr {
|
||
as.err = fmt.Errorf("invalid map value type, not prt")
|
||
return as.err
|
||
}
|
||
|
||
if !v.CanSet() {
|
||
as.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields")
|
||
return as.err
|
||
}
|
||
v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType)))
|
||
|
||
for as.cursor.Next(as.getContext()) {
|
||
t := reflect.New(valueType.Elem())
|
||
if as.err = as.cursor.Decode(t.Interface()); as.err != nil {
|
||
logrus.Errorf("err: %+v", as.err)
|
||
return as.err
|
||
}
|
||
keyv := t.Elem().FieldByName(field)
|
||
if keyv.Kind() == reflect.Invalid {
|
||
as.err = fmt.Errorf("invalid model key: %s", field)
|
||
return as.err
|
||
}
|
||
v.SetMapIndex(keyv, t)
|
||
}
|
||
|
||
return as.err
|
||
}
|
||
|
||
// Error 获取错误
|
||
func (as *AggregateScope) Error() error {
|
||
return as.err
|
||
}
|
||
|
||
// GetByFunc 自定义函数 获取数据集
|
||
func (as *AggregateScope) GetByFunc(cb AggregateFunc) (val interface{}, err error) {
|
||
defer as.doClear()
|
||
as.preCheck()
|
||
as.doSearch()
|
||
|
||
if as.err != nil {
|
||
return nil, as.err
|
||
}
|
||
|
||
val, as.err = cb(as.cursor)
|
||
if as.err != nil {
|
||
return nil, as.err
|
||
}
|
||
|
||
// 自行断言val
|
||
return val, nil
|
||
}
|
||
|
||
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现
|
||
func (as *AggregateScope) SetCacheFunc(key string, cb AggregateCacheFunc) *AggregateScope {
|
||
as.enableCache = true
|
||
as.cacheFunc = cb
|
||
as.cacheKey = key
|
||
return as
|
||
}
|
||
|
||
// doCache 执行缓存
|
||
func (as *AggregateScope) doCache(obj interface{}) *AggregateScope {
|
||
// redis句柄不存在
|
||
if as.scope.cache == nil {
|
||
return nil
|
||
}
|
||
cacheObj, err := as.cacheFunc(as.cacheKey, obj)
|
||
if err != nil {
|
||
as.err = err
|
||
return as
|
||
}
|
||
|
||
if cacheObj == nil {
|
||
as.err = fmt.Errorf("cache object nil")
|
||
return as
|
||
}
|
||
|
||
ttl := as.scope.cache.ttl
|
||
if ttl == 0 {
|
||
ttl = time.Hour
|
||
} else if ttl == -1 {
|
||
ttl = 0
|
||
}
|
||
|
||
if as.getContext().Err() != nil {
|
||
as.err = as.getContext().Err()
|
||
as.assertErr()
|
||
return as
|
||
}
|
||
|
||
as.err = as.scope.cache.client.Set(as.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
|
||
if as.err != nil {
|
||
as.assertErr()
|
||
}
|
||
|
||
return as
|
||
}
|