413 lines
9.5 KiB
Go
413 lines
9.5 KiB
Go
|
package mdbc
|
|||
|
|
|||
|
import (
|
|||
|
"context"
|
|||
|
"encoding/json"
|
|||
|
"errors"
|
|||
|
"fmt"
|
|||
|
"reflect"
|
|||
|
"time"
|
|||
|
|
|||
|
jsoniter "github.com/json-iterator/go"
|
|||
|
|
|||
|
"github.com/sirupsen/logrus"
|
|||
|
"go.mongodb.org/mongo-driver/bson"
|
|||
|
"go.mongodb.org/mongo-driver/mongo"
|
|||
|
"go.mongodb.org/mongo-driver/mongo/options"
|
|||
|
)
|
|||
|
|
|||
|
type AggregateScope struct {
|
|||
|
scope *Scope
|
|||
|
cw *ctxWrap
|
|||
|
err error
|
|||
|
pipeline interface{}
|
|||
|
opts *options.AggregateOptions
|
|||
|
cursor *mongo.Cursor
|
|||
|
enableCache bool
|
|||
|
cacheKey string
|
|||
|
cacheFunc AggregateCacheFunc
|
|||
|
}
|
|||
|
|
|||
|
// AggregateFunc 自定义获取结果集的注入函数
|
|||
|
type AggregateFunc func(cursor *mongo.Cursor) (interface{}, error)
|
|||
|
|
|||
|
// AggregateCacheFunc 缓存Find结果集
|
|||
|
// field: 若无作用 可置空 DefaultAggregateCacheFunc 将使用其反射出field字段对应的value并将其设为key
|
|||
|
// obj: 这是一个slice 只能在GetList是使用,GetMap使用将无效果
|
|||
|
type AggregateCacheFunc func(field string, obj interface{}) (*CacheObject, error)
|
|||
|
|
|||
|
// DefaultAggregateCacheFunc 按照第一个结果的field字段作为key缓存list数据
|
|||
|
// field字段需要满足所有数据均有该属性 否则有概率导致缓存失败
|
|||
|
// 和 DefaultFindCacheFunc 有重复嫌疑 需要抽象
|
|||
|
var DefaultAggregateCacheFunc = func() AggregateCacheFunc {
|
|||
|
return func(field string, obj interface{}) (*CacheObject, error) {
|
|||
|
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
|
|||
|
// 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可
|
|||
|
// 下面使用反射进行操作 保证兼容
|
|||
|
v := reflect.ValueOf(obj)
|
|||
|
if v.Type().Kind() != reflect.Ptr {
|
|||
|
return nil, fmt.Errorf("invalid list type, not ptr")
|
|||
|
}
|
|||
|
|
|||
|
v = v.Elem()
|
|||
|
if v.Type().Kind() != reflect.Slice {
|
|||
|
return nil, fmt.Errorf("invalid list type, not ptr to slice")
|
|||
|
}
|
|||
|
|
|||
|
// 空slice 无需缓存
|
|||
|
if v.Len() == 0 {
|
|||
|
return nil, nil
|
|||
|
}
|
|||
|
|
|||
|
firstNode := v.Index(0)
|
|||
|
// 判断过长度 无需判断nil
|
|||
|
if firstNode.Kind() == reflect.Ptr {
|
|||
|
firstNode = firstNode.Elem()
|
|||
|
}
|
|||
|
idVal := firstNode.FieldByName(field)
|
|||
|
if idVal.Kind() == reflect.Invalid {
|
|||
|
return nil, fmt.Errorf("first node's %s field not found", field)
|
|||
|
}
|
|||
|
|
|||
|
b, _ := json.Marshal(v.Interface())
|
|||
|
co := &CacheObject{
|
|||
|
Key: idVal.String(),
|
|||
|
Value: string(b),
|
|||
|
}
|
|||
|
|
|||
|
return co, nil
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// SetContext 设置上下文
|
|||
|
func (as *AggregateScope) SetContext(ctx context.Context) *AggregateScope {
|
|||
|
if as.cw == nil {
|
|||
|
as.cw = &ctxWrap{}
|
|||
|
}
|
|||
|
as.cw.ctx = ctx
|
|||
|
return as
|
|||
|
}
|
|||
|
|
|||
|
// getContext 获取ctx
|
|||
|
func (as *AggregateScope) getContext() context.Context {
|
|||
|
return as.cw.ctx
|
|||
|
}
|
|||
|
|
|||
|
// preCheck 预检查
|
|||
|
func (as *AggregateScope) preCheck() {
|
|||
|
var breakerTTL time.Duration
|
|||
|
if as.scope.breaker == nil {
|
|||
|
breakerTTL = defaultBreakerTime
|
|||
|
} else if as.scope.breaker.ttl == 0 {
|
|||
|
breakerTTL = defaultBreakerTime
|
|||
|
} else {
|
|||
|
breakerTTL = as.scope.breaker.ttl
|
|||
|
}
|
|||
|
if as.cw == nil {
|
|||
|
as.cw = &ctxWrap{}
|
|||
|
}
|
|||
|
if as.cw.ctx == nil {
|
|||
|
as.cw.ctx, as.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// SetAggregateOption 重写配置项
|
|||
|
func (as *AggregateScope) SetAggregateOption(opts options.AggregateOptions) *AggregateScope {
|
|||
|
as.opts = &opts
|
|||
|
return as
|
|||
|
}
|
|||
|
|
|||
|
// SetPipeline 设置管道
|
|||
|
// 传递 []bson.D 或者mongo.Pipeline
|
|||
|
func (as *AggregateScope) SetPipeline(pipeline interface{}) *AggregateScope {
|
|||
|
as.pipeline = pipeline
|
|||
|
return as
|
|||
|
}
|
|||
|
|
|||
|
// assertErr 断言错误
|
|||
|
func (as *AggregateScope) assertErr() {
|
|||
|
if as.err == nil {
|
|||
|
return
|
|||
|
}
|
|||
|
if errors.Is(as.err, context.DeadlineExceeded) {
|
|||
|
as.err = &ErrRequestBroken
|
|||
|
return
|
|||
|
}
|
|||
|
err, ok := as.err.(mongo.CommandError)
|
|||
|
if !ok {
|
|||
|
return
|
|||
|
}
|
|||
|
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
|
|||
|
as.err = &ErrRequestBroken
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// doString 实现debugger的 actionDo 接口
|
|||
|
func (as *AggregateScope) doString() string {
|
|||
|
var data []interface{}
|
|||
|
builder := RegisterTimestampCodec(nil).Build()
|
|||
|
vo := reflect.ValueOf(as.pipeline)
|
|||
|
if vo.Kind() == reflect.Ptr {
|
|||
|
vo = vo.Elem()
|
|||
|
}
|
|||
|
if vo.Kind() != reflect.Slice {
|
|||
|
panic("pipeline type not slice")
|
|||
|
}
|
|||
|
for i := 0; i < vo.Len(); i++ {
|
|||
|
var body interface{}
|
|||
|
b, _ := bson.MarshalExtJSONWithRegistry(builder, vo.Index(i).Interface(), true, true)
|
|||
|
_ = jsoniter.Unmarshal(b, &body)
|
|||
|
data = append(data, body)
|
|||
|
}
|
|||
|
b, _ := jsoniter.Marshal(data)
|
|||
|
return fmt.Sprintf("aggregate(%s)", string(b))
|
|||
|
}
|
|||
|
|
|||
|
// debug 统一debug入口
|
|||
|
func (as *AggregateScope) debug() {
|
|||
|
if !as.scope.debug {
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
debugger := &Debugger{
|
|||
|
collection: as.scope.tableName,
|
|||
|
execT: as.scope.execT,
|
|||
|
action: as,
|
|||
|
}
|
|||
|
debugger.String()
|
|||
|
}
|
|||
|
|
|||
|
// doSearch 查询的真正语句
|
|||
|
func (as *AggregateScope) doSearch() {
|
|||
|
var starTime time.Time
|
|||
|
if as.scope.debug {
|
|||
|
starTime = time.Now()
|
|||
|
}
|
|||
|
as.cursor, as.err = db.Collection(as.scope.tableName).Aggregate(as.getContext(), as.pipeline, as.opts)
|
|||
|
if as.scope.debug {
|
|||
|
as.scope.execT = time.Since(starTime)
|
|||
|
}
|
|||
|
as.debug()
|
|||
|
as.assertErr()
|
|||
|
}
|
|||
|
|
|||
|
func (as *AggregateScope) doClear() {
|
|||
|
if as.cw != nil && as.cw.cancel != nil {
|
|||
|
as.cw.cancel()
|
|||
|
}
|
|||
|
as.scope.debug = false
|
|||
|
as.scope.execT = 0
|
|||
|
}
|
|||
|
|
|||
|
// GetList 获取管道数据列表 传递 *[]*TypeValue
|
|||
|
func (as *AggregateScope) GetList(list interface{}) error {
|
|||
|
defer as.doClear()
|
|||
|
as.preCheck()
|
|||
|
as.doSearch()
|
|||
|
|
|||
|
if as.err != nil {
|
|||
|
return as.err
|
|||
|
}
|
|||
|
v := reflect.ValueOf(list)
|
|||
|
if v.Type().Kind() != reflect.Ptr {
|
|||
|
as.err = fmt.Errorf("invalid list type, not ptr")
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
v = v.Elem()
|
|||
|
if v.Type().Kind() != reflect.Slice {
|
|||
|
as.err = fmt.Errorf("invalid list type, not ptr to slice")
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
as.err = as.cursor.All(as.getContext(), list)
|
|||
|
|
|||
|
if as.enableCache {
|
|||
|
as.doCache(list)
|
|||
|
}
|
|||
|
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
// GetOne 获取管道数据的第一个元素 传递 *struct 不建议使用而建议直接用 limit:1
|
|||
|
// 然后数组取第一个数据的方式获取结果 为了确保效率 当明确只获取一条时 请用 limit 限制一下
|
|||
|
// 当获取不到数据的时候 会报错 ErrRecordNotFound 目前不支持缓存
|
|||
|
func (as *AggregateScope) GetOne(obj interface{}) error {
|
|||
|
v := reflect.ValueOf(obj)
|
|||
|
vt := v.Type()
|
|||
|
if vt.Kind() != reflect.Ptr {
|
|||
|
as.err = fmt.Errorf("invalid type, not ptr")
|
|||
|
return as.err
|
|||
|
}
|
|||
|
vt = vt.Elem()
|
|||
|
if vt.Kind() != reflect.Struct {
|
|||
|
as.err = fmt.Errorf("invalid type, not ptr to struct")
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
defer as.doClear()
|
|||
|
as.preCheck()
|
|||
|
as.doSearch()
|
|||
|
|
|||
|
if as.err != nil {
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
objType := reflect.TypeOf(obj)
|
|||
|
objSliceType := reflect.SliceOf(objType)
|
|||
|
objSlice := reflect.New(objSliceType)
|
|||
|
objSliceInterface := objSlice.Interface()
|
|||
|
|
|||
|
as.err = as.cursor.All(as.getContext(), objSliceInterface)
|
|||
|
if as.err != nil {
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
realV := reflect.ValueOf(objSliceInterface)
|
|||
|
if realV.Kind() == reflect.Ptr {
|
|||
|
realV = realV.Elem()
|
|||
|
}
|
|||
|
|
|||
|
if realV.Len() == 0 {
|
|||
|
as.err = &ErrRecordNotFound
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
firstNode := realV.Index(0)
|
|||
|
if firstNode.Kind() != reflect.Ptr {
|
|||
|
as.err = &ErrFindObjectTypeNotSupport
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
v.Elem().Set(firstNode.Elem())
|
|||
|
|
|||
|
if as.enableCache {
|
|||
|
as.doCache(obj)
|
|||
|
}
|
|||
|
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
// GetMap 获取管道数据map 传递 *map[string]*TypeValue, field 基于哪个字段进行map构建
|
|||
|
func (as *AggregateScope) GetMap(m interface{}, field string) error {
|
|||
|
defer as.doClear()
|
|||
|
as.preCheck()
|
|||
|
as.doSearch()
|
|||
|
|
|||
|
if as.err != nil {
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
v := reflect.ValueOf(m)
|
|||
|
if v.Type().Kind() != reflect.Ptr {
|
|||
|
as.err = fmt.Errorf("invalid map type, not ptr")
|
|||
|
return as.err
|
|||
|
}
|
|||
|
v = v.Elem()
|
|||
|
if v.Type().Kind() != reflect.Map {
|
|||
|
as.err = fmt.Errorf("invalid map type, not map")
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
mapType := v.Type()
|
|||
|
valueType := mapType.Elem()
|
|||
|
keyType := mapType.Key()
|
|||
|
|
|||
|
if valueType.Kind() != reflect.Ptr {
|
|||
|
as.err = fmt.Errorf("invalid map value type, not prt")
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
if !v.CanSet() {
|
|||
|
as.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields")
|
|||
|
return as.err
|
|||
|
}
|
|||
|
v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType)))
|
|||
|
|
|||
|
for as.cursor.Next(as.getContext()) {
|
|||
|
t := reflect.New(valueType.Elem())
|
|||
|
if as.err = as.cursor.Decode(t.Interface()); as.err != nil {
|
|||
|
logrus.Errorf("err: %+v", as.err)
|
|||
|
return as.err
|
|||
|
}
|
|||
|
keyv := t.Elem().FieldByName(field)
|
|||
|
if keyv.Kind() == reflect.Invalid {
|
|||
|
as.err = fmt.Errorf("invalid model key: %s", field)
|
|||
|
return as.err
|
|||
|
}
|
|||
|
v.SetMapIndex(keyv, t)
|
|||
|
}
|
|||
|
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
// Error 获取错误
|
|||
|
func (as *AggregateScope) Error() error {
|
|||
|
return as.err
|
|||
|
}
|
|||
|
|
|||
|
// GetByFunc 自定义函数 获取数据集
|
|||
|
func (as *AggregateScope) GetByFunc(cb AggregateFunc) (val interface{}, err error) {
|
|||
|
defer as.doClear()
|
|||
|
as.preCheck()
|
|||
|
as.doSearch()
|
|||
|
|
|||
|
if as.err != nil {
|
|||
|
return nil, as.err
|
|||
|
}
|
|||
|
|
|||
|
val, as.err = cb(as.cursor)
|
|||
|
if as.err != nil {
|
|||
|
return nil, as.err
|
|||
|
}
|
|||
|
|
|||
|
// 自行断言val
|
|||
|
return val, nil
|
|||
|
}
|
|||
|
|
|||
|
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现
|
|||
|
func (as *AggregateScope) SetCacheFunc(key string, cb AggregateCacheFunc) *AggregateScope {
|
|||
|
as.enableCache = true
|
|||
|
as.cacheFunc = cb
|
|||
|
as.cacheKey = key
|
|||
|
return as
|
|||
|
}
|
|||
|
|
|||
|
// doCache 执行缓存
|
|||
|
func (as *AggregateScope) doCache(obj interface{}) *AggregateScope {
|
|||
|
// redis句柄不存在
|
|||
|
if as.scope.cache == nil {
|
|||
|
return nil
|
|||
|
}
|
|||
|
cacheObj, err := as.cacheFunc(as.cacheKey, obj)
|
|||
|
if err != nil {
|
|||
|
as.err = err
|
|||
|
return as
|
|||
|
}
|
|||
|
|
|||
|
if cacheObj == nil {
|
|||
|
as.err = fmt.Errorf("cache object nil")
|
|||
|
return as
|
|||
|
}
|
|||
|
|
|||
|
ttl := as.scope.cache.ttl
|
|||
|
if ttl == 0 {
|
|||
|
ttl = time.Hour
|
|||
|
} else if ttl == -1 {
|
|||
|
ttl = 0
|
|||
|
}
|
|||
|
|
|||
|
if as.getContext().Err() != nil {
|
|||
|
as.err = as.getContext().Err()
|
|||
|
as.assertErr()
|
|||
|
return as
|
|||
|
}
|
|||
|
|
|||
|
as.err = as.scope.cache.client.Set(as.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
|
|||
|
if as.err != nil {
|
|||
|
as.assertErr()
|
|||
|
}
|
|||
|
|
|||
|
return as
|
|||
|
}
|