mdbc/aggregate_scope.go
2022-02-23 16:59:45 +08:00

413 lines
9.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mdbc
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"time"
jsoniter "github.com/json-iterator/go"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type AggregateScope struct {
scope *Scope
cw *ctxWrap
err error
pipeline interface{}
opts *options.AggregateOptions
cursor *mongo.Cursor
enableCache bool
cacheKey string
cacheFunc AggregateCacheFunc
}
// AggregateFunc 自定义获取结果集的注入函数
type AggregateFunc func(cursor *mongo.Cursor) (interface{}, error)
// AggregateCacheFunc 缓存Find结果集
// field: 若无作用 可置空 DefaultAggregateCacheFunc 将使用其反射出field字段对应的value并将其设为key
// obj: 这是一个slice 只能在GetList是使用GetMap使用将无效果
type AggregateCacheFunc func(field string, obj interface{}) (*CacheObject, error)
// DefaultAggregateCacheFunc 按照第一个结果的field字段作为key缓存list数据
// field字段需要满足所有数据均有该属性 否则有概率导致缓存失败
// 和 DefaultFindCacheFunc 有重复嫌疑 需要抽象
var DefaultAggregateCacheFunc = func() AggregateCacheFunc {
return func(field string, obj interface{}) (*CacheObject, error) {
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
// 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可
// 下面使用反射进行操作 保证兼容
v := reflect.ValueOf(obj)
if v.Type().Kind() != reflect.Ptr {
return nil, fmt.Errorf("invalid list type, not ptr")
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
return nil, fmt.Errorf("invalid list type, not ptr to slice")
}
// 空slice 无需缓存
if v.Len() == 0 {
return nil, nil
}
firstNode := v.Index(0)
// 判断过长度 无需判断nil
if firstNode.Kind() == reflect.Ptr {
firstNode = firstNode.Elem()
}
idVal := firstNode.FieldByName(field)
if idVal.Kind() == reflect.Invalid {
return nil, fmt.Errorf("first node's %s field not found", field)
}
b, _ := json.Marshal(v.Interface())
co := &CacheObject{
Key: idVal.String(),
Value: string(b),
}
return co, nil
}
}
// SetContext 设置上下文
func (as *AggregateScope) SetContext(ctx context.Context) *AggregateScope {
if as.cw == nil {
as.cw = &ctxWrap{}
}
as.cw.ctx = ctx
return as
}
// getContext 获取ctx
func (as *AggregateScope) getContext() context.Context {
return as.cw.ctx
}
// preCheck 预检查
func (as *AggregateScope) preCheck() {
var breakerTTL time.Duration
if as.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if as.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = as.scope.breaker.ttl
}
if as.cw == nil {
as.cw = &ctxWrap{}
}
if as.cw.ctx == nil {
as.cw.ctx, as.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}
// SetAggregateOption 重写配置项
func (as *AggregateScope) SetAggregateOption(opts options.AggregateOptions) *AggregateScope {
as.opts = &opts
return as
}
// SetPipeline 设置管道
// 传递 []bson.D 或者mongo.Pipeline
func (as *AggregateScope) SetPipeline(pipeline interface{}) *AggregateScope {
as.pipeline = pipeline
return as
}
// assertErr 断言错误
func (as *AggregateScope) assertErr() {
if as.err == nil {
return
}
if errors.Is(as.err, context.DeadlineExceeded) {
as.err = &ErrRequestBroken
return
}
err, ok := as.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
as.err = &ErrRequestBroken
}
}
// doString 实现debugger的 actionDo 接口
func (as *AggregateScope) doString() string {
var data []interface{}
builder := RegisterTimestampCodec(nil).Build()
vo := reflect.ValueOf(as.pipeline)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Slice {
panic("pipeline type not slice")
}
for i := 0; i < vo.Len(); i++ {
var body interface{}
b, _ := bson.MarshalExtJSONWithRegistry(builder, vo.Index(i).Interface(), true, true)
_ = jsoniter.Unmarshal(b, &body)
data = append(data, body)
}
b, _ := jsoniter.Marshal(data)
return fmt.Sprintf("aggregate(%s)", string(b))
}
// debug 统一debug入口
func (as *AggregateScope) debug() {
if !as.scope.debug {
return
}
debugger := &Debugger{
collection: as.scope.tableName,
execT: as.scope.execT,
action: as,
}
debugger.String()
}
// doSearch 查询的真正语句
func (as *AggregateScope) doSearch() {
var starTime time.Time
if as.scope.debug {
starTime = time.Now()
}
as.cursor, as.err = db.Collection(as.scope.tableName).Aggregate(as.getContext(), as.pipeline, as.opts)
if as.scope.debug {
as.scope.execT = time.Since(starTime)
}
as.debug()
as.assertErr()
}
func (as *AggregateScope) doClear() {
if as.cw != nil && as.cw.cancel != nil {
as.cw.cancel()
}
as.scope.debug = false
as.scope.execT = 0
}
// GetList 获取管道数据列表 传递 *[]*TypeValue
func (as *AggregateScope) GetList(list interface{}) error {
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return as.err
}
v := reflect.ValueOf(list)
if v.Type().Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid list type, not ptr")
return as.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
as.err = fmt.Errorf("invalid list type, not ptr to slice")
return as.err
}
as.err = as.cursor.All(as.getContext(), list)
if as.enableCache {
as.doCache(list)
}
return as.err
}
// GetOne 获取管道数据的第一个元素 传递 *struct 不建议使用而建议直接用 limit:1
// 然后数组取第一个数据的方式获取结果 为了确保效率 当明确只获取一条时 请用 limit 限制一下
// 当获取不到数据的时候 会报错 ErrRecordNotFound 目前不支持缓存
func (as *AggregateScope) GetOne(obj interface{}) error {
v := reflect.ValueOf(obj)
vt := v.Type()
if vt.Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid type, not ptr")
return as.err
}
vt = vt.Elem()
if vt.Kind() != reflect.Struct {
as.err = fmt.Errorf("invalid type, not ptr to struct")
return as.err
}
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return as.err
}
objType := reflect.TypeOf(obj)
objSliceType := reflect.SliceOf(objType)
objSlice := reflect.New(objSliceType)
objSliceInterface := objSlice.Interface()
as.err = as.cursor.All(as.getContext(), objSliceInterface)
if as.err != nil {
return as.err
}
realV := reflect.ValueOf(objSliceInterface)
if realV.Kind() == reflect.Ptr {
realV = realV.Elem()
}
if realV.Len() == 0 {
as.err = &ErrRecordNotFound
return as.err
}
firstNode := realV.Index(0)
if firstNode.Kind() != reflect.Ptr {
as.err = &ErrFindObjectTypeNotSupport
return as.err
}
v.Elem().Set(firstNode.Elem())
if as.enableCache {
as.doCache(obj)
}
return as.err
}
// GetMap 获取管道数据map 传递 *map[string]*TypeValue, field 基于哪个字段进行map构建
func (as *AggregateScope) GetMap(m interface{}, field string) error {
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return as.err
}
v := reflect.ValueOf(m)
if v.Type().Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid map type, not ptr")
return as.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Map {
as.err = fmt.Errorf("invalid map type, not map")
return as.err
}
mapType := v.Type()
valueType := mapType.Elem()
keyType := mapType.Key()
if valueType.Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid map value type, not prt")
return as.err
}
if !v.CanSet() {
as.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields")
return as.err
}
v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType)))
for as.cursor.Next(as.getContext()) {
t := reflect.New(valueType.Elem())
if as.err = as.cursor.Decode(t.Interface()); as.err != nil {
logrus.Errorf("err: %+v", as.err)
return as.err
}
keyv := t.Elem().FieldByName(field)
if keyv.Kind() == reflect.Invalid {
as.err = fmt.Errorf("invalid model key: %s", field)
return as.err
}
v.SetMapIndex(keyv, t)
}
return as.err
}
// Error 获取错误
func (as *AggregateScope) Error() error {
return as.err
}
// GetByFunc 自定义函数 获取数据集
func (as *AggregateScope) GetByFunc(cb AggregateFunc) (val interface{}, err error) {
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return nil, as.err
}
val, as.err = cb(as.cursor)
if as.err != nil {
return nil, as.err
}
// 自行断言val
return val, nil
}
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现
func (as *AggregateScope) SetCacheFunc(key string, cb AggregateCacheFunc) *AggregateScope {
as.enableCache = true
as.cacheFunc = cb
as.cacheKey = key
return as
}
// doCache 执行缓存
func (as *AggregateScope) doCache(obj interface{}) *AggregateScope {
// redis句柄不存在
if as.scope.cache == nil {
return nil
}
cacheObj, err := as.cacheFunc(as.cacheKey, obj)
if err != nil {
as.err = err
return as
}
if cacheObj == nil {
as.err = fmt.Errorf("cache object nil")
return as
}
ttl := as.scope.cache.ttl
if ttl == 0 {
ttl = time.Hour
} else if ttl == -1 {
ttl = 0
}
if as.getContext().Err() != nil {
as.err = as.getContext().Err()
as.assertErr()
return as
}
as.err = as.scope.cache.client.Set(as.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
if as.err != nil {
as.assertErr()
}
return as
}