mdbc/aggregate_scope.go

413 lines
9.5 KiB
Go
Raw Normal View History

2022-02-23 08:59:45 +00:00
package mdbc
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"time"
jsoniter "github.com/json-iterator/go"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type AggregateScope struct {
scope *Scope
cw *ctxWrap
err error
pipeline interface{}
opts *options.AggregateOptions
cursor *mongo.Cursor
enableCache bool
cacheKey string
cacheFunc AggregateCacheFunc
}
// AggregateFunc 自定义获取结果集的注入函数
type AggregateFunc func(cursor *mongo.Cursor) (interface{}, error)
// AggregateCacheFunc 缓存Find结果集
// field: 若无作用 可置空 DefaultAggregateCacheFunc 将使用其反射出field字段对应的value并将其设为key
// obj: 这是一个slice 只能在GetList是使用GetMap使用将无效果
type AggregateCacheFunc func(field string, obj interface{}) (*CacheObject, error)
// DefaultAggregateCacheFunc 按照第一个结果的field字段作为key缓存list数据
// field字段需要满足所有数据均有该属性 否则有概率导致缓存失败
// 和 DefaultFindCacheFunc 有重复嫌疑 需要抽象
var DefaultAggregateCacheFunc = func() AggregateCacheFunc {
return func(field string, obj interface{}) (*CacheObject, error) {
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
// 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可
// 下面使用反射进行操作 保证兼容
v := reflect.ValueOf(obj)
if v.Type().Kind() != reflect.Ptr {
return nil, fmt.Errorf("invalid list type, not ptr")
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
return nil, fmt.Errorf("invalid list type, not ptr to slice")
}
// 空slice 无需缓存
if v.Len() == 0 {
return nil, nil
}
firstNode := v.Index(0)
// 判断过长度 无需判断nil
if firstNode.Kind() == reflect.Ptr {
firstNode = firstNode.Elem()
}
idVal := firstNode.FieldByName(field)
if idVal.Kind() == reflect.Invalid {
return nil, fmt.Errorf("first node's %s field not found", field)
}
b, _ := json.Marshal(v.Interface())
co := &CacheObject{
Key: idVal.String(),
Value: string(b),
}
return co, nil
}
}
// SetContext 设置上下文
func (as *AggregateScope) SetContext(ctx context.Context) *AggregateScope {
if as.cw == nil {
as.cw = &ctxWrap{}
}
as.cw.ctx = ctx
return as
}
// getContext 获取ctx
func (as *AggregateScope) getContext() context.Context {
return as.cw.ctx
}
// preCheck 预检查
func (as *AggregateScope) preCheck() {
var breakerTTL time.Duration
if as.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if as.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = as.scope.breaker.ttl
}
if as.cw == nil {
as.cw = &ctxWrap{}
}
if as.cw.ctx == nil {
as.cw.ctx, as.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}
// SetAggregateOption 重写配置项
func (as *AggregateScope) SetAggregateOption(opts options.AggregateOptions) *AggregateScope {
as.opts = &opts
return as
}
// SetPipeline 设置管道
// 传递 []bson.D 或者mongo.Pipeline
func (as *AggregateScope) SetPipeline(pipeline interface{}) *AggregateScope {
as.pipeline = pipeline
return as
}
// assertErr 断言错误
func (as *AggregateScope) assertErr() {
if as.err == nil {
return
}
if errors.Is(as.err, context.DeadlineExceeded) {
as.err = &ErrRequestBroken
return
}
err, ok := as.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
as.err = &ErrRequestBroken
}
}
// doString 实现debugger的 actionDo 接口
func (as *AggregateScope) doString() string {
var data []interface{}
builder := RegisterTimestampCodec(nil).Build()
vo := reflect.ValueOf(as.pipeline)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Slice {
panic("pipeline type not slice")
}
for i := 0; i < vo.Len(); i++ {
var body interface{}
b, _ := bson.MarshalExtJSONWithRegistry(builder, vo.Index(i).Interface(), true, true)
_ = jsoniter.Unmarshal(b, &body)
data = append(data, body)
}
b, _ := jsoniter.Marshal(data)
return fmt.Sprintf("aggregate(%s)", string(b))
}
// debug 统一debug入口
func (as *AggregateScope) debug() {
if !as.scope.debug {
return
}
debugger := &Debugger{
collection: as.scope.tableName,
execT: as.scope.execT,
action: as,
}
debugger.String()
}
// doSearch 查询的真正语句
func (as *AggregateScope) doSearch() {
var starTime time.Time
if as.scope.debug {
starTime = time.Now()
}
as.cursor, as.err = db.Collection(as.scope.tableName).Aggregate(as.getContext(), as.pipeline, as.opts)
if as.scope.debug {
as.scope.execT = time.Since(starTime)
}
as.debug()
as.assertErr()
}
func (as *AggregateScope) doClear() {
if as.cw != nil && as.cw.cancel != nil {
as.cw.cancel()
}
as.scope.debug = false
as.scope.execT = 0
}
// GetList 获取管道数据列表 传递 *[]*TypeValue
func (as *AggregateScope) GetList(list interface{}) error {
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return as.err
}
v := reflect.ValueOf(list)
if v.Type().Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid list type, not ptr")
return as.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
as.err = fmt.Errorf("invalid list type, not ptr to slice")
return as.err
}
as.err = as.cursor.All(as.getContext(), list)
if as.enableCache {
as.doCache(list)
}
return as.err
}
// GetOne 获取管道数据的第一个元素 传递 *struct 不建议使用而建议直接用 limit:1
// 然后数组取第一个数据的方式获取结果 为了确保效率 当明确只获取一条时 请用 limit 限制一下
// 当获取不到数据的时候 会报错 ErrRecordNotFound 目前不支持缓存
func (as *AggregateScope) GetOne(obj interface{}) error {
v := reflect.ValueOf(obj)
vt := v.Type()
if vt.Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid type, not ptr")
return as.err
}
vt = vt.Elem()
if vt.Kind() != reflect.Struct {
as.err = fmt.Errorf("invalid type, not ptr to struct")
return as.err
}
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return as.err
}
objType := reflect.TypeOf(obj)
objSliceType := reflect.SliceOf(objType)
objSlice := reflect.New(objSliceType)
objSliceInterface := objSlice.Interface()
as.err = as.cursor.All(as.getContext(), objSliceInterface)
if as.err != nil {
return as.err
}
realV := reflect.ValueOf(objSliceInterface)
if realV.Kind() == reflect.Ptr {
realV = realV.Elem()
}
if realV.Len() == 0 {
as.err = &ErrRecordNotFound
return as.err
}
firstNode := realV.Index(0)
if firstNode.Kind() != reflect.Ptr {
as.err = &ErrFindObjectTypeNotSupport
return as.err
}
v.Elem().Set(firstNode.Elem())
if as.enableCache {
as.doCache(obj)
}
return as.err
}
// GetMap 获取管道数据map 传递 *map[string]*TypeValue, field 基于哪个字段进行map构建
func (as *AggregateScope) GetMap(m interface{}, field string) error {
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return as.err
}
v := reflect.ValueOf(m)
if v.Type().Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid map type, not ptr")
return as.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Map {
as.err = fmt.Errorf("invalid map type, not map")
return as.err
}
mapType := v.Type()
valueType := mapType.Elem()
keyType := mapType.Key()
if valueType.Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid map value type, not prt")
return as.err
}
if !v.CanSet() {
as.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields")
return as.err
}
v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType)))
for as.cursor.Next(as.getContext()) {
t := reflect.New(valueType.Elem())
if as.err = as.cursor.Decode(t.Interface()); as.err != nil {
logrus.Errorf("err: %+v", as.err)
return as.err
}
keyv := t.Elem().FieldByName(field)
if keyv.Kind() == reflect.Invalid {
as.err = fmt.Errorf("invalid model key: %s", field)
return as.err
}
v.SetMapIndex(keyv, t)
}
return as.err
}
// Error 获取错误
func (as *AggregateScope) Error() error {
return as.err
}
// GetByFunc 自定义函数 获取数据集
func (as *AggregateScope) GetByFunc(cb AggregateFunc) (val interface{}, err error) {
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return nil, as.err
}
val, as.err = cb(as.cursor)
if as.err != nil {
return nil, as.err
}
// 自行断言val
return val, nil
}
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现
func (as *AggregateScope) SetCacheFunc(key string, cb AggregateCacheFunc) *AggregateScope {
as.enableCache = true
as.cacheFunc = cb
as.cacheKey = key
return as
}
// doCache 执行缓存
func (as *AggregateScope) doCache(obj interface{}) *AggregateScope {
// redis句柄不存在
if as.scope.cache == nil {
return nil
}
cacheObj, err := as.cacheFunc(as.cacheKey, obj)
if err != nil {
as.err = err
return as
}
if cacheObj == nil {
as.err = fmt.Errorf("cache object nil")
return as
}
ttl := as.scope.cache.ttl
if ttl == 0 {
ttl = time.Hour
} else if ttl == -1 {
ttl = 0
}
if as.getContext().Err() != nil {
as.err = as.getContext().Err()
as.assertErr()
return as
}
as.err = as.scope.cache.client.Set(as.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
if as.err != nil {
as.assertErr()
}
return as
}