first commit

This commit is contained in:
xuthus5 2022-02-23 16:59:45 +08:00
commit 32b8997441
Signed by: xuthus5
GPG Key ID: A23CF9620CBB55F9
52 changed files with 15078 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
.idea
go.sum

62
README.md Normal file
View File

@ -0,0 +1,62 @@
<img src="icon.png" alt="MDBC" width="300">
### 快速开始
初始化mongo数据库连接
```go
client, err := mongodb.ConnInit("mongodb://admin:admin@10.0.0.135:27017/admin")
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
mdbc.InitDB(client.Database("mdbc"))
```
声明 model
```go
var m = mdbc.NewModel(&ModelSchedTask{})
```
然后就可以使用 m 进行链式操作
### 注册全局对象
可以将model注册成一个全局变量
```go
type WsConnectRecordScope struct {
*mdbc.Scope
}
var WsConnectRecord *WsConnectRecordScope
func NewWsConnectRecord() {
WsConnectRecord = new(WsConnectRecordScope)
WsConnectRecord.Scope = mdbc.NewModel(&model.ModelWsConnectRecord{})
}
```
使用:
```go
func beforeRemoveWs(ctx context.Context, recordID, key string) {
if WsConnectRecord == nil {
NewWsConnectRecord()
}
tm := time.Now().UnixNano() / 1e6
if message_common.GetEtcdWatcher().RemoveWatch(key) {
// 已经移除 变更最近的一条消息
err := WsConnectRecord.SetContext(ctx).FindOne().SetFilter(bson.M{
model.ModelWsConnectRecordField_Id.DbFieldName: recordID,
}).Update(bson.M{
"$set": bson.M{
model.ModelWsConnectRecordField_LogoutAt.DbFieldName: tm,
},
})
if err != nil {
log.Errorf("update ws conn record err: %+v", err)
common.Logger.Error(ctx, "WsConn", log2.String("error", err.Error()))
}
}
}
```

412
aggregate_scope.go Normal file
View File

@ -0,0 +1,412 @@
package mdbc
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"time"
jsoniter "github.com/json-iterator/go"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type AggregateScope struct {
scope *Scope
cw *ctxWrap
err error
pipeline interface{}
opts *options.AggregateOptions
cursor *mongo.Cursor
enableCache bool
cacheKey string
cacheFunc AggregateCacheFunc
}
// AggregateFunc 自定义获取结果集的注入函数
type AggregateFunc func(cursor *mongo.Cursor) (interface{}, error)
// AggregateCacheFunc 缓存Find结果集
// field: 若无作用 可置空 DefaultAggregateCacheFunc 将使用其反射出field字段对应的value并将其设为key
// obj: 这是一个slice 只能在GetList是使用GetMap使用将无效果
type AggregateCacheFunc func(field string, obj interface{}) (*CacheObject, error)
// DefaultAggregateCacheFunc 按照第一个结果的field字段作为key缓存list数据
// field字段需要满足所有数据均有该属性 否则有概率导致缓存失败
// 和 DefaultFindCacheFunc 有重复嫌疑 需要抽象
var DefaultAggregateCacheFunc = func() AggregateCacheFunc {
return func(field string, obj interface{}) (*CacheObject, error) {
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
// 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可
// 下面使用反射进行操作 保证兼容
v := reflect.ValueOf(obj)
if v.Type().Kind() != reflect.Ptr {
return nil, fmt.Errorf("invalid list type, not ptr")
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
return nil, fmt.Errorf("invalid list type, not ptr to slice")
}
// 空slice 无需缓存
if v.Len() == 0 {
return nil, nil
}
firstNode := v.Index(0)
// 判断过长度 无需判断nil
if firstNode.Kind() == reflect.Ptr {
firstNode = firstNode.Elem()
}
idVal := firstNode.FieldByName(field)
if idVal.Kind() == reflect.Invalid {
return nil, fmt.Errorf("first node's %s field not found", field)
}
b, _ := json.Marshal(v.Interface())
co := &CacheObject{
Key: idVal.String(),
Value: string(b),
}
return co, nil
}
}
// SetContext 设置上下文
func (as *AggregateScope) SetContext(ctx context.Context) *AggregateScope {
if as.cw == nil {
as.cw = &ctxWrap{}
}
as.cw.ctx = ctx
return as
}
// getContext 获取ctx
func (as *AggregateScope) getContext() context.Context {
return as.cw.ctx
}
// preCheck 预检查
func (as *AggregateScope) preCheck() {
var breakerTTL time.Duration
if as.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if as.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = as.scope.breaker.ttl
}
if as.cw == nil {
as.cw = &ctxWrap{}
}
if as.cw.ctx == nil {
as.cw.ctx, as.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}
// SetAggregateOption 重写配置项
func (as *AggregateScope) SetAggregateOption(opts options.AggregateOptions) *AggregateScope {
as.opts = &opts
return as
}
// SetPipeline 设置管道
// 传递 []bson.D 或者mongo.Pipeline
func (as *AggregateScope) SetPipeline(pipeline interface{}) *AggregateScope {
as.pipeline = pipeline
return as
}
// assertErr 断言错误
func (as *AggregateScope) assertErr() {
if as.err == nil {
return
}
if errors.Is(as.err, context.DeadlineExceeded) {
as.err = &ErrRequestBroken
return
}
err, ok := as.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
as.err = &ErrRequestBroken
}
}
// doString 实现debugger的 actionDo 接口
func (as *AggregateScope) doString() string {
var data []interface{}
builder := RegisterTimestampCodec(nil).Build()
vo := reflect.ValueOf(as.pipeline)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Slice {
panic("pipeline type not slice")
}
for i := 0; i < vo.Len(); i++ {
var body interface{}
b, _ := bson.MarshalExtJSONWithRegistry(builder, vo.Index(i).Interface(), true, true)
_ = jsoniter.Unmarshal(b, &body)
data = append(data, body)
}
b, _ := jsoniter.Marshal(data)
return fmt.Sprintf("aggregate(%s)", string(b))
}
// debug 统一debug入口
func (as *AggregateScope) debug() {
if !as.scope.debug {
return
}
debugger := &Debugger{
collection: as.scope.tableName,
execT: as.scope.execT,
action: as,
}
debugger.String()
}
// doSearch 查询的真正语句
func (as *AggregateScope) doSearch() {
var starTime time.Time
if as.scope.debug {
starTime = time.Now()
}
as.cursor, as.err = db.Collection(as.scope.tableName).Aggregate(as.getContext(), as.pipeline, as.opts)
if as.scope.debug {
as.scope.execT = time.Since(starTime)
}
as.debug()
as.assertErr()
}
func (as *AggregateScope) doClear() {
if as.cw != nil && as.cw.cancel != nil {
as.cw.cancel()
}
as.scope.debug = false
as.scope.execT = 0
}
// GetList 获取管道数据列表 传递 *[]*TypeValue
func (as *AggregateScope) GetList(list interface{}) error {
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return as.err
}
v := reflect.ValueOf(list)
if v.Type().Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid list type, not ptr")
return as.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
as.err = fmt.Errorf("invalid list type, not ptr to slice")
return as.err
}
as.err = as.cursor.All(as.getContext(), list)
if as.enableCache {
as.doCache(list)
}
return as.err
}
// GetOne 获取管道数据的第一个元素 传递 *struct 不建议使用而建议直接用 limit:1
// 然后数组取第一个数据的方式获取结果 为了确保效率 当明确只获取一条时 请用 limit 限制一下
// 当获取不到数据的时候 会报错 ErrRecordNotFound 目前不支持缓存
func (as *AggregateScope) GetOne(obj interface{}) error {
v := reflect.ValueOf(obj)
vt := v.Type()
if vt.Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid type, not ptr")
return as.err
}
vt = vt.Elem()
if vt.Kind() != reflect.Struct {
as.err = fmt.Errorf("invalid type, not ptr to struct")
return as.err
}
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return as.err
}
objType := reflect.TypeOf(obj)
objSliceType := reflect.SliceOf(objType)
objSlice := reflect.New(objSliceType)
objSliceInterface := objSlice.Interface()
as.err = as.cursor.All(as.getContext(), objSliceInterface)
if as.err != nil {
return as.err
}
realV := reflect.ValueOf(objSliceInterface)
if realV.Kind() == reflect.Ptr {
realV = realV.Elem()
}
if realV.Len() == 0 {
as.err = &ErrRecordNotFound
return as.err
}
firstNode := realV.Index(0)
if firstNode.Kind() != reflect.Ptr {
as.err = &ErrFindObjectTypeNotSupport
return as.err
}
v.Elem().Set(firstNode.Elem())
if as.enableCache {
as.doCache(obj)
}
return as.err
}
// GetMap 获取管道数据map 传递 *map[string]*TypeValue, field 基于哪个字段进行map构建
func (as *AggregateScope) GetMap(m interface{}, field string) error {
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return as.err
}
v := reflect.ValueOf(m)
if v.Type().Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid map type, not ptr")
return as.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Map {
as.err = fmt.Errorf("invalid map type, not map")
return as.err
}
mapType := v.Type()
valueType := mapType.Elem()
keyType := mapType.Key()
if valueType.Kind() != reflect.Ptr {
as.err = fmt.Errorf("invalid map value type, not prt")
return as.err
}
if !v.CanSet() {
as.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields")
return as.err
}
v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType)))
for as.cursor.Next(as.getContext()) {
t := reflect.New(valueType.Elem())
if as.err = as.cursor.Decode(t.Interface()); as.err != nil {
logrus.Errorf("err: %+v", as.err)
return as.err
}
keyv := t.Elem().FieldByName(field)
if keyv.Kind() == reflect.Invalid {
as.err = fmt.Errorf("invalid model key: %s", field)
return as.err
}
v.SetMapIndex(keyv, t)
}
return as.err
}
// Error 获取错误
func (as *AggregateScope) Error() error {
return as.err
}
// GetByFunc 自定义函数 获取数据集
func (as *AggregateScope) GetByFunc(cb AggregateFunc) (val interface{}, err error) {
defer as.doClear()
as.preCheck()
as.doSearch()
if as.err != nil {
return nil, as.err
}
val, as.err = cb(as.cursor)
if as.err != nil {
return nil, as.err
}
// 自行断言val
return val, nil
}
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现
func (as *AggregateScope) SetCacheFunc(key string, cb AggregateCacheFunc) *AggregateScope {
as.enableCache = true
as.cacheFunc = cb
as.cacheKey = key
return as
}
// doCache 执行缓存
func (as *AggregateScope) doCache(obj interface{}) *AggregateScope {
// redis句柄不存在
if as.scope.cache == nil {
return nil
}
cacheObj, err := as.cacheFunc(as.cacheKey, obj)
if err != nil {
as.err = err
return as
}
if cacheObj == nil {
as.err = fmt.Errorf("cache object nil")
return as
}
ttl := as.scope.cache.ttl
if ttl == 0 {
ttl = time.Hour
} else if ttl == -1 {
ttl = 0
}
if as.getContext().Err() != nil {
as.err = as.getContext().Err()
as.assertErr()
return as
}
as.err = as.scope.cache.client.Set(as.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
if as.err != nil {
as.assertErr()
}
return as
}

80
aggregate_scope_test.go Normal file
View File

@ -0,0 +1,80 @@
package mdbc
import (
"fmt"
"testing"
"github.com/sirupsen/logrus"
"gitlab.com/gotk/mdbc/builder"
"go.mongodb.org/mongo-driver/bson"
)
func TestAggregateScope(t *testing.T) {
cfg := &Config{
URI: "mongodb://mdbc:mdbc@10.0.0.135:27117/admin",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
Init(client).InitDB(client.Database("mdbc"))
var m = NewModel(&ModelSchedTask{})
bu := builder.NewBuilder().Pipeline()
bu.Match(bson.M{"created_at": 0}).
Group(bson.M{"_id": "$task_type", "count": bson.M{"$sum": 1}})
type Res struct {
Count int `bson:"count"`
Type string `bson:"_id"`
}
//ds := []bson.D{{{"$match", bson.M{"created_at": bson.M{"$gt": 0}}}}}
var records Res
err = m.SetDebug(true).
Aggregate().SetPipeline(bu.Build()).GetOne(&records)
if err != nil {
panic(err)
}
fmt.Printf("%+v\n", records)
}
func TestAggregateScope_GetMap(t *testing.T) {
cfg := &Config{
URI: "mongodb://admin:admin@10.0.0.135:27017/admin",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
Init(client).InitDB(client.Database("mdbc"))
var m = NewModel(&ModelSchedTask{})
bu := builder.NewBuilder()
//bu.Match(bson.M{"created_at": bson.M{"$gt": 0}}).
// Group(bson.M{"_id": "$task_type", "count": bson.M{"$sum": 1}}).
// Project(bson.M{"type": "$_id", "_id": 0, "count": 1}).
// Sort(bson.M{"count": 1})
type Res struct {
Count int `bson:"count"`
Type string `bson:"type"`
}
record := make(map[string]*Res)
err = m.
SetDebug(true).
Aggregate().SetPipeline(bu.Pipeline()).GetMap(&record, "Type")
if err != nil {
panic(err)
}
fmt.Println(record)
}

2411
autogen_model_field_mdbc.go Normal file

File diff suppressed because it is too large Load Diff

2835
autogen_model_mdbc.go Normal file

File diff suppressed because it is too large Load Diff

21
breaker.go Normal file
View File

@ -0,0 +1,21 @@
package mdbc
// breakerDo db熔断相关的告警处理
// 熔断应该记录执行时间 和单位时间间隔内执行次数
// 这里执行时间已经记录 单位时间间隔内执行次数暂未统计
type breakerDo interface {
doReporter() string // 告警 返回告警内容
}
// BreakerReporter 告警器
type BreakerReporter struct {
reportTitle string
reportMsg string
reportErrorWrap error
breakerDo
}
// Report 告警向Slack发出告警信息
func (b *BreakerReporter) Report() {
}

12
builder/index.go Normal file
View File

@ -0,0 +1,12 @@
package builder
// Add 添加一个索引项
//func (ib *IndexBuilder) Add(key string, sort mdbc.KeySort) *IndexBuilder {
// ib.d = append(ib.d, bson.E{Key: key, Value: sort})
// return ib
//}
// Build 构建完整索引
//func (ib *IndexBuilder) Build() bson.D {
// return ib.d
//}

63
builder/pipeline.go Normal file
View File

@ -0,0 +1,63 @@
package builder
import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)
// AddFields do something
func (b *PipelineBuilder) AddFields(filter bson.M) *PipelineBuilder {
b.pipeline = append(b.pipeline, bson.D{{"$addFields", filter}})
return b
}
// Match 过滤匹配
func (b *PipelineBuilder) Match(filter bson.M) *PipelineBuilder {
b.pipeline = append(b.pipeline, bson.D{{"$match", filter}})
return b
}
// Sort 排序
func (b *PipelineBuilder) Sort(filter bson.M) *PipelineBuilder {
b.pipeline = append(b.pipeline, bson.D{{"$sort", filter}})
return b
}
// Group 分组
func (b *PipelineBuilder) Group(filter bson.M) *PipelineBuilder {
b.pipeline = append(b.pipeline, bson.D{{"$group", filter}})
return b
}
// Project do something
func (b *PipelineBuilder) Project(filter bson.M) *PipelineBuilder {
b.pipeline = append(b.pipeline, bson.D{{"$project", filter}})
return b
}
// Limit do something
func (b *PipelineBuilder) Limit(limit int) *PipelineBuilder {
b.pipeline = append(b.pipeline, bson.D{{"$limit", limit}})
return b
}
// Skip do something
func (b *PipelineBuilder) Skip(limit int) *PipelineBuilder {
b.pipeline = append(b.pipeline, bson.D{{"$skip", limit}})
return b
}
// Count 统计
func (b *PipelineBuilder) Count(fieldName string) *PipelineBuilder {
b.pipeline = append(b.pipeline, bson.D{{"$count", fieldName}})
return b
}
func (b *PipelineBuilder) Other(d ...bson.D) *PipelineBuilder {
b.pipeline = append(b.pipeline, d...)
return b
}
func (b *PipelineBuilder) Build() mongo.Pipeline {
return b.pipeline
}

19
builder/query.go Normal file
View File

@ -0,0 +1,19 @@
package builder
import "go.mongodb.org/mongo-driver/bson"
// Add 添加一个默认的隐式and操作
func (qb *QueryBuilder) Add(opName string, value interface{}) *QueryBuilder {
qb.q[opName] = value
return qb
}
// Or 添加一个 or 操作
func (qb *QueryBuilder) Or(filters ...bson.M) *QueryBuilder {
if qb.q["$or"] == nil {
return qb
}
return qb
}

40
builder/type.go Normal file
View File

@ -0,0 +1,40 @@
// Package builder 用来快速构建 aggregate 查询
package builder
import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)
type Builder struct{}
type PipelineBuilder struct {
pipeline mongo.Pipeline
}
type IndexBuilder struct {
d bson.D
}
type QueryBuilder struct {
q bson.M
}
func NewBuilder() *Builder {
return &Builder{}
}
// Pipeline 构建器
func (b *Builder) Pipeline() *PipelineBuilder {
return &PipelineBuilder{pipeline: mongo.Pipeline{}}
}
// Index key构建器
func (b *Builder) Index() *IndexBuilder {
return &IndexBuilder{}
}
// Query 查询构建器
func (b *Builder) Query() *QueryBuilder {
return &QueryBuilder{}
}

218
bulk_write_scope.go Normal file
View File

@ -0,0 +1,218 @@
package mdbc
import (
"context"
"errors"
"fmt"
"time"
jsoniter "github.com/json-iterator/go"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type BulkWriteScope struct {
scope *Scope
cw *ctxWrap
err error
chunkSize uint32
opts *options.BulkWriteOptions
models []mongo.WriteModel
result *mongo.BulkWriteResult
chunkFunc func([]mongo.WriteModel) error
}
// doString debug string
func (bws *BulkWriteScope) doString() string {
var data []interface{}
builder := RegisterTimestampCodec(nil).Build()
for _, v := range bws.models {
var body interface{}
rawv := Struct2MapOmitEmpty(v)
b, _ := bson.MarshalExtJSONWithRegistry(builder, rawv, true, true)
_ = jsoniter.Unmarshal(b, &body)
data = append(data, body)
}
b, _ := jsoniter.Marshal(data)
return fmt.Sprintf("bulkWrite(%s)", string(b))
}
// debug debug
func (bws *BulkWriteScope) debug() {
if !bws.scope.debug {
return
}
debugger := &Debugger{
collection: bws.scope.tableName,
execT: bws.scope.execT,
action: bws,
}
debugger.String()
}
// SetContext 设置上下文
func (bws *BulkWriteScope) SetContext(ctx context.Context) *BulkWriteScope {
if bws.cw == nil {
bws.cw = &ctxWrap{}
}
bws.cw.ctx = ctx
return bws
}
func (bws BulkWriteScope) getContext() context.Context {
return bws.cw.ctx
}
// SetBulkWriteOption 设置BulkWriteOption
func (bws *BulkWriteScope) SetBulkWriteOption(opts options.BulkWriteOptions) *BulkWriteScope {
bws.opts = &opts
return bws
}
// SetOrdered 设置BulkWriteOptions中的Ordered
func (bws *BulkWriteScope) SetOrdered(ordered bool) *BulkWriteScope {
if bws.opts == nil {
bws.opts = new(options.BulkWriteOptions)
}
bws.opts.Ordered = &ordered
return bws
}
// SetChunkSize 指定分块操作大小 默认不分块 当数据足够大时 可能导致deadlock问题不确定这个问题
func (bws *BulkWriteScope) SetChunkSize(size uint32) *BulkWriteScope {
bws.chunkSize = size
return bws
}
// SetChunkFunc 可以在进行批量插入之前做一些事情 若错误将终止这一批数据的写入而执行下一批
func (bws *BulkWriteScope) SetChunkFunc(f func(models []mongo.WriteModel) error) *BulkWriteScope {
bws.chunkFunc = f
return bws
}
// SetWriteModel 设置需要操作的数据
func (bws *BulkWriteScope) SetWriteModel(models []mongo.WriteModel) *BulkWriteScope {
bws.models = models
return bws
}
// SetWriteModelFunc 可以定义函数来返回需要操作的数据
func (bws *BulkWriteScope) SetWriteModelFunc(f func() []mongo.WriteModel) *BulkWriteScope {
bws.models = f()
return bws
}
// preCheck 预检查
func (bws *BulkWriteScope) preCheck() {
var breakerTTL time.Duration
if bws.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if bws.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = bws.scope.breaker.ttl
}
if bws.cw == nil {
bws.cw = &ctxWrap{}
}
if bws.cw.ctx == nil {
bws.cw.ctx, bws.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}
func (bws *BulkWriteScope) doClear() {
if bws.cw != nil && bws.cw.cancel != nil {
bws.cw.cancel()
}
bws.scope.debug = false
bws.scope.execT = 0
}
func (bws *BulkWriteScope) assertErr() {
if bws.err == nil {
return
}
if errors.Is(bws.err, context.DeadlineExceeded) {
bws.err = &ErrRequestBroken
return
}
err, ok := bws.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
bws.err = &ErrRequestBroken
}
}
func (bws *BulkWriteScope) doBulkWrite() {
defer bws.assertErr()
var starTime time.Time
if bws.scope.debug {
starTime = time.Now()
}
bws.result, bws.err = db.Collection(bws.scope.tableName).BulkWrite(bws.getContext(), bws.models, bws.opts)
if bws.scope.debug {
bws.scope.execT = time.Since(starTime)
bws.debug()
}
}
func (bws *BulkWriteScope) splitBulkWrite(arr []mongo.WriteModel, chunkSize int) [][]mongo.WriteModel {
var newArr [][]mongo.WriteModel
for i := 0; i < len(arr); i += chunkSize {
end := i + chunkSize
if end > len(arr) {
end = len(arr)
}
newArr = append(newArr, arr[i:end])
}
return newArr
}
func (bws *BulkWriteScope) checkModel() {
if len(bws.models) == 0 {
bws.err = fmt.Errorf("models empty")
return
}
// 命令检测
}
// Do 执行批量操作 请确保 SetWriteModel 已被设置 否则报错
func (bws *BulkWriteScope) Do() (*mongo.BulkWriteResult, error) {
defer bws.doClear()
bws.checkModel()
bws.preCheck()
if bws.err != nil {
return nil, bws.err
}
// 如果设置了chunkSize就分片插入
if bws.chunkSize > 0 {
models := bws.splitBulkWrite(bws.models, int(bws.chunkSize))
for _, model := range models {
bws.models = model
if bws.chunkFunc != nil {
if bws.err = bws.chunkFunc(bws.models); bws.err != nil {
continue
}
}
bws.doBulkWrite()
}
} else {
bws.doBulkWrite()
}
if bws.err != nil {
return nil, bws.err
}
return bws.result, nil
}

114
bulk_write_scope_test.go Normal file
View File

@ -0,0 +1,114 @@
package mdbc
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)
func TestBulkWriteScope(t *testing.T) {
cfg := &Config{
URI: "mongodb://admin:admin@10.0.0.135:27017/admin",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
InitDB(client.Database("heywoods_golang_jingliao_crm_dev"))
var m = NewModel(&ModelRobotFriend{})
var record ModelRobotFriend
if err := m.FindOne().SetFilter(bson.M{
ModelRobotFriendField_Id.DbFieldName: "498b1c85be41266efb29b6a79560ec7f",
}).Get(&record); err != nil {
panic(err)
}
record.AddAt = 12345
record.UpdateTime = time.Now().Unix()
var updateData = bson.M{
"$set": &record,
"$setOnInsert": bson.M{
"create_time": 1000000,
},
}
b, _ := bson.MarshalExtJSON(updateData, true, true)
logrus.Infof("updateData: %+v", string(b))
SetMapOmitInsertField(&updateData)
b, _ = json.Marshal(updateData)
logrus.Infof("updateData: %+v", string(b))
var opData []mongo.WriteModel
opData = append(opData, mongo.NewUpdateOneModel().SetFilter(bson.M{
ModelRobotFriendField_Id.DbFieldName: "498b1c85be41266efb29b6a79560ec7f",
}).SetUpsert(true).SetUpdate(updateData))
res, err := m.SetDebug(true).BulkWrite().
SetWriteModel(opData).Do()
if err != nil {
panic(err)
}
fmt.Printf("res %+v\n", res)
logrus.Infof("res %+v\n", res)
}
func TestSliceChunk(t *testing.T) {
a := []int{1, 2, 3}
chunkSize := 2
var b [][]int
for i := 0; i < len(a); i += chunkSize {
end := i + chunkSize
if end > len(a) {
end = len(a)
}
b = append(b, a[i:end])
}
fmt.Println(b)
}
func TestBulkWriteScope_SetChunkSize(t *testing.T) {
cfg := &Config{
URI: "mongodb://admin:admin@10.0.0.135:27017/admin",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
InitDB(client.Database("heywoods_golang_jingliao_crm_dev"))
var m = NewModel(&ModelRobotFriend{})
var opData []mongo.WriteModel
for i := 0; i < 10; i++ {
opData = append(opData, mongo.NewInsertOneModel().SetDocument(&ModelRobotFriend{Id: uuid.NewString()}))
}
res, err := m.SetDebug(true).BulkWrite().SetChunkSize(3).SetWriteModel(opData).Do()
if err != nil {
panic(err)
}
fmt.Println("res", res)
}

67
codec.go Normal file
View File

@ -0,0 +1,67 @@
package mdbc
import (
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"google.golang.org/protobuf/types/known/timestamppb"
)
// 该文件帮助自定义序列化bson文档
// RegisterTimestampCodec 注册一个针对 timestamppb.Timestamp 结构的 bson文档解析器
var (
timeTimeType = reflect.TypeOf(time.Time{})
timestampType = reflect.TypeOf(&timestamppb.Timestamp{})
)
// TimestampCodec 对 timestamppb.Timestamp <-> time.Time 进行互向转换
// time.Time 在bson中被转换为 Date 对象
type TimestampCodec struct{}
func (t *TimestampCodec) EncodeValue(encodeContext bsoncodec.EncodeContext, writer bsonrw.ValueWriter, value reflect.Value) error {
var rawv time.Time
switch t := value.Interface().(type) {
case *timestamppb.Timestamp:
rawv = t.AsTime()
case time.Time:
rawv = t
default:
panic("TimestampCodec get type: " + reflect.TypeOf(value.Interface()).String() + ", not support")
}
enc, err := encodeContext.LookupEncoder(timeTimeType)
if err != nil {
return err
}
return enc.EncodeValue(encodeContext, writer, reflect.ValueOf(rawv.In(time.UTC)))
}
func (t *TimestampCodec) DecodeValue(decodeContext bsoncodec.DecodeContext, reader bsonrw.ValueReader, value reflect.Value) error {
enc, err := decodeContext.LookupDecoder(timeTimeType)
if err != nil {
return err
}
var tt time.Time
if err := enc.DecodeValue(decodeContext, reader, reflect.ValueOf(&tt).Elem()); err != nil {
return err
}
ts := timestamppb.New(tt.In(time.UTC))
value.Set(reflect.ValueOf(ts))
return nil
}
// RegisterTimestampCodec 注册一个针对 timestamppb.Timestamp 结构的 bson文档解析器
// 将 mongodb 中 bson 字段的 Date(Go中的 time.Time ) 对象解析成 timestamppb.Timestamp
func RegisterTimestampCodec(rb *bsoncodec.RegistryBuilder) *bsoncodec.RegistryBuilder {
if rb == nil {
rb = bson.NewRegistryBuilder()
}
return rb.RegisterCodec(timestampType, &TimestampCodec{})
}

43
config.go Normal file
View File

@ -0,0 +1,43 @@
package mdbc
import (
"context"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/mongo/readpref"
)
var std *ClientInit
// Config MongoDB连接配置
type Config struct {
// URI 连接DSN 格式: protocol://username:password@host:port/auth_db
URI string `yaml:"uri"`
// DBName
DBName string `yaml:"db-name"`
// MinPoolSize 连接池最小 默认1个
MinPoolSize uint64 `yaml:"min-pool-size"`
// MaxPoolSize 连接池最大 默认32
MaxPoolSize uint64 `yaml:"max-pool-size"`
// ConnTimeout 连接超时时间 单位秒 默认10秒
ConnTimeout uint64 `yaml:"conn-timeout"`
// RegistryBuilder 注册bson文档的自定义解析器 详见当前目录 codec.go 其中定义了一系列的bson文档解析器
RegistryBuilder *bsoncodec.RegistryBuilder
// ReadPreference 读配置
ReadPreference *readpref.ReadPref
}
func (c *Config) Init(ctx context.Context) error {
var err error
c.RegistryBuilder = RegisterTimestampCodec(nil)
std, err = ConnInit(c)
if err != nil {
return err
}
return nil
}
func GetClient() *ClientInit {
return std
}

241
convert.go Normal file
View File

@ -0,0 +1,241 @@
package mdbc
import (
"fmt"
"reflect"
"strings"
)
// 这里定义了一个特殊功能 格式转换
// Struct2MapWithBsonTag 结构体转map 用 bson字段名 做Key
// 注意 obj非struct结构将PANIC
func Struct2MapWithBsonTag(obj interface{}) map[string]interface{} {
vo := reflect.ValueOf(obj)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Struct {
panic("object type not struct")
}
var data = make(map[string]interface{})
for i := 0; i < vo.NumField(); i++ {
vf := vo.Field(i)
key := vo.Type().Field(i).Tag.Get("bson")
if key == "" {
key = vo.Type().Field(i).Name
}
if vf.CanSet() {
data[key] = vf.Interface()
}
}
return data
}
// Struct2MapWithJsonTag 结构体转map 用 json字段名 做Key
// 注意 obj非struct结构将PANIC
func Struct2MapWithJsonTag(obj interface{}) map[string]interface{} {
vo := reflect.ValueOf(obj)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Struct {
panic("object type not struct")
}
var data = make(map[string]interface{})
for i := 0; i < vo.NumField(); i++ {
vf := vo.Field(i)
key := vo.Type().Field(i).Tag.Get("json")
if key == "" {
key = vo.Type().Field(i).Name
}
// 过滤掉 omitempty 选项
if strings.Contains(key, ",omitempty") {
key = strings.Replace(key, ",omitempty", "", 1)
}
if vf.CanSet() {
data[key] = vf.Interface()
}
}
return data
}
// Struct2MapOmitEmpty 结构体转map并忽略空字段 用 字段名 做Key
// 注意 只忽略顶层字段 obj非struct结构将PANIC
func Struct2MapOmitEmpty(obj interface{}) map[string]interface{} {
vo := reflect.ValueOf(obj)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Struct {
panic("object type not struct")
}
var data = make(map[string]interface{})
for i := 0; i < vo.NumField(); i++ {
vf := vo.Field(i)
if !vf.IsZero() && vf.CanSet() {
data[vo.Type().Field(i).Name] = vf.Interface()
}
}
return data
}
// Struct2MapOmitEmptyWithBsonTag 结构体转map并忽略空字段 按照 bson 标签做key
// obj 需要是一个指针
// 注意 只忽略顶层字段 obj非struct结构将PANIC
func Struct2MapOmitEmptyWithBsonTag(obj interface{}) map[string]interface{} {
vo := reflect.ValueOf(obj)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Struct {
panic("object type not struct")
}
var data = make(map[string]interface{})
for i := 0; i < vo.NumField(); i++ {
vf := vo.Field(i)
key := vo.Type().Field(i).Tag.Get("bson")
if key == "" {
key = vo.Type().Field(i).Name
}
if !vf.IsZero() && vf.CanSet() {
data[key] = vf.Interface()
}
}
return data
}
// Struct2MapOmitEmptyWithJsonTag 结构体转map并忽略空字段 按照 json 标签做key
// 注意 只忽略顶层字段 obj非struct结构将PANIC
func Struct2MapOmitEmptyWithJsonTag(obj interface{}) map[string]interface{} {
vo := reflect.ValueOf(obj)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Struct {
panic("object type not struct")
}
var data = make(map[string]interface{})
for i := 0; i < vo.NumField(); i++ {
vf := vo.Field(i)
key := vo.Type().Field(i).Tag.Get("json")
if key == "" {
key = vo.Type().Field(i).Name
}
// 过滤掉 omitempty 选项
if strings.Contains(key, ",omitempty") {
key = strings.Replace(key, ",omitempty", "", 1)
}
if !vf.IsZero() && vf.CanSet() {
data[key] = vf.Interface()
}
}
return data
}
// SliceStruct2MapOmitEmpty 结构体数组转map数组并忽略空字段
// 注意 子元素非struct将被忽略
func SliceStruct2MapOmitEmpty(obj interface{}) interface{} {
vo := reflect.ValueOf(obj)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Slice && vo.Kind() != reflect.Array {
panic("object type not slice")
}
var data []map[string]interface{}
for i := 0; i < vo.Len(); i++ {
node := vo.Index(i)
if node.Kind() == reflect.Ptr {
node = node.Elem()
}
if node.Kind() != reflect.Struct && node.Kind() != reflect.Interface {
continue
}
fn := Struct2MapOmitEmpty(node.Interface())
data = append(data, fn)
}
return data
}
// SetMapOmitInsertField 对于一个 update-bson-map 忽略$set中的$setOnInsert字段
// 需要传递一个map的指针 确保数据可写 否则PANIC
// $set支持map和struct 其他结构体将PANIC
// $setOnInsert只支持map 其他结构体将PANIC
// 只对包含$set和$setOnInsert的map生效 若$set或者$setOnInsert缺失 将PANIC
func SetMapOmitInsertField(m interface{}) {
vo := reflect.ValueOf(m)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Map {
panic("object type not map")
}
setVal := vo.MapIndex(reflect.ValueOf("$set"))
if !setVal.IsValid() {
panic("$set not found")
}
soiVal := vo.MapIndex(reflect.ValueOf("$setOnInsert"))
if !soiVal.IsValid() {
panic("$setOnInsert not found")
}
if !vo.CanSet() {
panic("map can't set")
}
soiRealVal := reflect.ValueOf(soiVal.Interface())
if soiRealVal.Kind() == reflect.Ptr {
soiRealVal = soiRealVal.Elem()
}
if soiRealVal.Kind() != reflect.Map {
err := fmt.Errorf("$setOnInsert type not map: type(%v)", soiRealVal.Kind())
panic(err)
}
setRealTyp := reflect.TypeOf(setVal.Interface())
if setRealTyp.Kind() == reflect.Ptr {
setRealTyp = setRealTyp.Elem()
}
setRealVal := reflect.ValueOf(setVal.Interface())
if setRealVal.Kind() == reflect.Ptr {
setRealVal = setRealVal.Elem()
}
if setRealVal.Kind() != reflect.Struct && setRealVal.Kind() != reflect.Map {
err := fmt.Errorf("$set type not map or struct: type(%v)", setRealVal.Kind())
panic(err)
}
var setMap = make(map[string]interface{})
//builder := RegisterTimestampCodec(nil).Build()
if setRealVal.Kind() == reflect.Struct {
var data = make(map[string]interface{})
for i := 0; i < setRealVal.NumField(); i++ {
field := setRealTyp.Field(i)
bsonTag := field.Tag.Get("bson")
if bsonTag == "" {
continue
}
fieldVal := setRealVal.Field(i)
data[bsonTag] = fieldVal.Interface()
}
vo.SetMapIndex(reflect.ValueOf("$set"), reflect.ValueOf(data))
}
setRealVal = reflect.ValueOf(vo.MapIndex(reflect.ValueOf("$set")).Interface())
if setRealVal.Kind() == reflect.Ptr {
setRealVal = setRealVal.Elem()
}
if setRealVal.Kind() == reflect.Map {
iter := setRealVal.MapRange()
for iter.Next() {
key := iter.Key()
val := iter.Value()
soiFType := soiRealVal.MapIndex(reflect.ValueOf(key.String()))
if !soiFType.IsValid() {
setMap[key.String()] = val.Interface()
}
}
}
vo.SetMapIndex(reflect.ValueOf("$set"), reflect.ValueOf(setMap))
}

217
count_scope.go Normal file
View File

@ -0,0 +1,217 @@
package mdbc
import (
"context"
"errors"
"fmt"
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type CountScope struct {
scope *Scope
cw *ctxWrap
err error
limit *int64
skip *int64
hint interface{}
filter interface{}
maxTimeMS *time.Duration
collation *options.Collation
opts *options.CountOptions
result int64
}
// SetLimit 设置获取数量
func (cs *CountScope) SetLimit(limit int64) *CountScope {
cs.limit = &limit
return cs
}
// SetSkip 设置跳过的数量
func (cs *CountScope) SetSkip(skip int64) *CountScope {
cs.skip = &skip
return cs
}
// SetHint 设置hit
func (cs *CountScope) SetHint(hint interface{}) *CountScope {
if hint == nil {
cs.hint = bson.M{}
return cs
}
v := reflect.ValueOf(hint)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
cs.hint = bson.M{}
}
}
cs.hint = hint
return cs
}
// SetFilter 设置过滤条件
func (cs *CountScope) SetFilter(filter interface{}) *CountScope {
if filter == nil {
cs.filter = bson.M{}
return cs
}
v := reflect.ValueOf(filter)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
cs.filter = bson.M{}
}
}
cs.filter = filter
return cs
}
// SetMaxTime 设置MaxTime
func (cs *CountScope) SetMaxTime(maxTime time.Duration) *CountScope {
cs.maxTimeMS = &maxTime
return cs
}
// SetCollation 设置文档
func (cs *CountScope) SetCollation(collation options.Collation) *CountScope {
cs.collation = &collation
return cs
}
// SetContext 设置上下文
func (cs *CountScope) SetContext(ctx context.Context) *CountScope {
if cs.cw == nil {
cs.cw = &ctxWrap{}
}
cs.cw.ctx = ctx
return cs
}
func (cs *CountScope) doClear() {
if cs.cw != nil && cs.cw.cancel != nil {
cs.cw.cancel()
}
cs.scope.execT = 0
cs.scope.debug = false
}
func (cs *CountScope) getContext() context.Context {
return cs.cw.ctx
}
// Count 执行计数
func (cs *CountScope) Count() (int64, error) {
defer cs.doClear()
cs.doCount()
if cs.err != nil {
return 0, cs.err
}
return cs.result, nil
}
func (cs *CountScope) optionAssembled() {
// 配置项被直接调用重写过
if cs.opts != nil {
return
}
cs.opts = new(options.CountOptions)
if cs.skip != nil {
cs.opts.Skip = cs.skip
}
if cs.limit != nil {
cs.opts.Limit = cs.limit
}
if cs.collation != nil {
cs.opts.Collation = cs.collation
}
if cs.hint != nil {
cs.opts.Hint = cs.hint
}
if cs.maxTimeMS != nil {
cs.opts.MaxTime = cs.maxTimeMS
}
}
func (cs *CountScope) assertErr() {
if cs.err == nil {
return
}
if errors.Is(cs.err, context.DeadlineExceeded) {
cs.err = &ErrRequestBroken
return
}
err, ok := cs.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
cs.err = &ErrRequestBroken
return
}
}
func (cs *CountScope) debug() {
if !cs.scope.debug {
return
}
debugger := &Debugger{
collection: cs.scope.tableName,
execT: cs.scope.execT,
action: cs,
}
debugger.String()
}
func (cs *CountScope) doString() string {
builder := RegisterTimestampCodec(nil).Build()
filter, _ := bson.MarshalExtJSONWithRegistry(builder, cs.filter, true, true)
return fmt.Sprintf("count(%s)", string(filter))
}
func (cs *CountScope) doCount() {
defer cs.assertErr()
cs.optionAssembled()
cs.preCheck()
var starTime time.Time
if cs.scope.debug {
starTime = time.Now()
}
cs.result, cs.err = db.Collection(cs.scope.tableName).CountDocuments(cs.getContext(), cs.filter, cs.opts)
if cs.scope.debug {
cs.scope.execT = time.Since(starTime)
cs.debug()
}
}
func (cs *CountScope) preCheck() {
if cs.filter == nil {
cs.filter = bson.M{}
}
var breakerTTL time.Duration
if cs.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if cs.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = cs.scope.breaker.ttl
}
if cs.cw == nil {
cs.cw = &ctxWrap{}
}
if cs.cw.ctx == nil {
cs.cw.ctx, cs.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}

41
count_scope_test.go Normal file
View File

@ -0,0 +1,41 @@
package mdbc
import (
"context"
"fmt"
"testing"
"time"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
)
func TestCountScope_Count(t *testing.T) {
cfg := &Config{
URI: "mongodb://admin:admin@10.0.0.135:27017/admin",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
Init(client).InitDB(client.Database("mdbc"))
var m = NewModel(&ModelSchedTask{})
count, err := m.Count().SetContext(context.Background()).
SetFilter(bson.M{"created_at": bson.M{"$gt": 0}}).
SetSkip(1).
SetLimit(10).
SetMaxTime(1 * time.Second).
Count()
if err != nil {
panic(err)
return
}
fmt.Println(count)
}

62
debugger.go Normal file
View File

@ -0,0 +1,62 @@
package mdbc
import (
"fmt"
"os"
"time"
)
// 语句执行解析器
// actionDo 实现对 操作部份 的字符化 方便debug输出
type actionDo interface {
doString() string
}
type Debugger struct {
collection string // 操作集合
errMsg string // 错误信息
execT time.Duration // 执行时间
action actionDo // 执行语句
}
// String 输出执行的基于query的语句
// 思路:语句分三部分 集合部份;操作部份;附加条件部份
// 集合操作部份 这个可以直接从collection中获取集合名称
// 拼接操作部份 这是大头部份 需要每个操作实现 actionDo 接口 拼接出入参
// 附加条件部份 对于一些sort skip limit等参数进行拼接
func (d *Debugger) String() {
_, _ = fmt.Fprintf(os.Stdout, "db.getCollection(\"%s\").%s; execTime: %s\n", d.collection, d.action.doString(), d.execT.String())
}
// GetString 获取执行的SQL信息
func (d *Debugger) GetString() string {
return fmt.Sprintf("db.getCollection(\"%s\").%s;", d.collection, d.action.doString())
}
// ErrorString 输出执行的基于query的语句
// 暂时不考虑执行时长
func (d *Debugger) ErrorString() {
queryAction := "╔\x1b[32mquery:\x1b[0m"
queryMsg := fmt.Sprintf("\x1b[36m%s\x1b[0m", fmt.Sprintf("db.getCollection(\"%s\").%s;", d.collection, d.action.doString()))
errorAction := "╠\x1b[33merror:\x1b[0m"
errorMsg := fmt.Sprintf("\x1b[31m%s\x1b[0m", d.errMsg)
execAction := "╚\x1b[34mexect:\x1b[0m"
execMsg := fmt.Sprintf("\x1b[31m%s\x1b[0m", d.execT.String())
_, _ = fmt.Fprintf(os.Stdout, "%s %s\n%s %s\n%s %s\n", queryAction, queryMsg, errorAction, errorMsg, execAction, execMsg)
}
// Echo 返回执行记录的string
func (d *Debugger) Echo() string {
return fmt.Sprintf("db.getCollection(\"%s\").%s; execTime: %s\n", d.collection, d.action.doString(), d.execT.String())
}
//前景 背景 颜色
//30 40 黑色
//31 41 红色
//32 42 绿色
//33 43 黄色
//34 44 蓝色
//35 45 紫色
//36 46 深绿
//37 47 白色

107
define.go Normal file
View File

@ -0,0 +1,107 @@
package mdbc
import (
"time"
"go.mongodb.org/mongo-driver/bson"
"gitlab.com/gotk/gotk/core"
)
const defaultBreakerTime = time.Second * 5
const (
// ErrFilterParamEmpty Query参数为空
ErrFilterParamEmpty = 20001
// ErrUpdateObjectEmpty 更新参数为空
ErrUpdateObjectEmpty = 20002
// ErrUpdateObjectNotSupport 更新参数类型不支持
ErrUpdateObjectNotSupport = 20003
// ErrOpTransaction 事务执行错误
ErrOpTransaction = 20004
// ErrObjectTypeNoMatch 获取对象类型不正确
ErrObjectTypeNoMatch = 20005
)
var (
coreCodeMap = map[int32]string{
ErrFilterParamEmpty: "query param is empty",
ErrUpdateObjectEmpty: "update object is empty",
ErrUpdateObjectNotSupport: "update object type not support",
ErrOpTransaction: "abort transaction failed",
ErrObjectTypeNoMatch: "find object type not support",
}
)
var (
ErrRecordNotFound = core.ErrMsg{
ErrCode: core.ErrRecordNotFound,
ErrMsg: core.GetErrMsg(core.ErrRecordNotFound),
}
ErrRequestBroken = core.ErrMsg{
ErrCode: core.ErrRequestBroken,
ErrMsg: core.GetErrMsg(core.ErrRequestBroken),
}
ErrFilterEmpty = core.ErrMsg{
ErrCode: ErrFilterParamEmpty,
ErrMsg: core.GetErrMsg(ErrFilterParamEmpty),
}
ErrObjectEmpty = core.ErrMsg{
ErrCode: ErrUpdateObjectEmpty,
ErrMsg: core.GetErrMsg(ErrUpdateObjectEmpty),
}
ErrUpdateObjectTypeNotSupport = core.ErrMsg{
ErrCode: ErrUpdateObjectNotSupport,
ErrMsg: core.GetErrMsg(ErrUpdateObjectNotSupport),
}
ErrRollbackTransaction = core.ErrMsg{
ErrCode: ErrOpTransaction,
ErrMsg: core.GetErrMsg(ErrOpTransaction),
}
ErrFindObjectTypeNotSupport = core.ErrMsg{
ErrCode: ErrObjectTypeNoMatch,
ErrMsg: core.GetErrMsg(ErrObjectTypeNoMatch),
}
)
func register() {
core.RegisterError(coreCodeMap)
}
type M bson.M
type A bson.A
type D bson.D
type E bson.E
type action string
const (
Aggregate action = "Aggregate"
BulkWrite action = "BulkWrite"
CountDocuments action = "Count"
DeleteOne action = "DeleteOne"
DeleteMany action = "DeleteMany"
Distinct action = "Distinct"
Drop action = "Drop"
Find action = "Find"
FindOne action = "FindOne"
FindOneAndDelete action = "FindOneAndDelete"
FindOneAndReplace action = "FindOneAndReplace"
FindOneAndUpdate action = "FindOneAndUpdate"
InsertOne action = "InsertOne"
InsertMany action = "InsertMany"
Indexes action = "Indexes"
ReplaceOne action = "ReplaceOne"
UpdateOne action = "UpdateOne"
UpdateMany action = "UpdateMany"
Watch action = "Watch"
)

339
delete_scope.go Normal file
View File

@ -0,0 +1,339 @@
package mdbc
import (
"context"
"errors"
"fmt"
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type deleteAction string
const (
deleteOne deleteAction = "deleteOne"
deleteMany deleteAction = "deleteMany"
)
type DeleteScope struct {
scope *Scope
cw *ctxWrap
err error
upsert *bool
id interface{}
filter interface{}
action deleteAction
opts *options.DeleteOptions
result *mongo.DeleteResult
}
// SetContext 设置上下文
func (ds *DeleteScope) SetContext(ctx context.Context) *DeleteScope {
if ds.cw == nil {
ds.cw = &ctxWrap{}
}
ds.cw.ctx = ctx
return ds
}
// SetUpsert 设置upsert属性
func (ds *DeleteScope) SetUpsert(upsert bool) *DeleteScope {
ds.upsert = &upsert
return ds
}
// SetID 有设置ID优先基于ID删除 其次基于filter删除
func (ds *DeleteScope) SetID(id interface{}) *DeleteScope {
ds.id = id
ds.filter = bson.M{"_id": ds.id}
return ds
}
// SetIDs 指定_id列表的方式进行删除 是 bson.M{"_id": {$in: []array}} 的快捷键
// 如果参数 ids 为空 将报错
func (ds *DeleteScope) SetIDs(ids []string) *DeleteScope {
if len(ids) == 0 {
ds.err = &ErrFilterEmpty
return ds
}
ds.filter = bson.M{"_id": bson.M{"$in": ids}}
return ds
}
// SetDeleteOption 设置删除选项
func (ds *DeleteScope) SetDeleteOption(opts options.DeleteOptions) *DeleteScope {
ds.opts = &opts
return ds
}
// SetFilter 设置过滤条件
// filter建议为map,因为它不需要进行额外的类型转换
// 若设置 filter 为struct 则会将第一级不为零值的属性平铺转换为map
// 若设置 filter 为一个切片 但基于以下规则 获取不到任何值时 将触发PANIC
// 若 filter 传递一个数组切片 如果结构为 []string/[]number/[]ObjectID 将被解析成 "_id": {"$in": []interface{}}
// 若 filter 只传递一个参数(即filterFields 为空) []struct/[]*struct 将从struct元素的tag中抽取字段_id的bson标签组成id的数组切片
// 并组装成 "$in": []interface 然后将其解析为 "_id": {"$in": []interface{}}
// 当然 在传递 []struct/[]*struct 后 可以指定第二参数 filterFields 用于获取 struct的指定字段(确保可访问该字段)的集合 作为查询条件
// 生成规则是: 基于可访问到的字段名 获取字段值 并按照其 bson_tag 将其组合成 $in 操作[无bson_tag则按传入的字段]
// 示例结构: User struct: {Id int bson:_id} {Name string bson:name} {Age int bson: age}
// 示例: ds.Many([]A{{Id: 1, Name: "张三", Age: 20},{Id: 2, Name: "李四", Age: 22}}, "Id", "Age")
// 生成: filter := bson.M{"_id": {"$in": [1,2]}, "age": {"$in": [20, 22]}}
func (ds *DeleteScope) SetFilter(filter interface{}, filterFields ...string) *DeleteScope {
if filter == nil {
ds.filter = bson.M{}
return ds
}
// 空指针
vo := reflect.ValueOf(filter)
if vo.Kind() == reflect.Ptr || vo.Kind() == reflect.Map ||
vo.Kind() == reflect.Slice || vo.Kind() == reflect.Interface {
if vo.IsNil() {
ds.filter = bson.M{}
return ds
}
}
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
// 是一个结构体 剥离零值字段
if vo.Kind() == reflect.Struct {
ds.filter = Struct2MapOmitEmptyWithBsonTag(filter)
return ds
}
// 是一个map
if vo.Kind() == reflect.Map {
ds.filter = filter
return ds
}
// 对于数组切片的struct将获取字段的和ID相关的数组转$in操作
if vo.Kind() == reflect.Array || vo.Kind() == reflect.Slice {
var filterMap = make(map[string][]interface{})
// 对于数组切片的number,string,objectID
for i := 0; i < vo.Len(); i++ {
cvo := vo.Index(i)
if cvo.Kind() == reflect.Ptr {
cvo = cvo.Elem()
}
// 是数值类型
if isReflectNumber(cvo) {
filterMap["_id"] = append(filterMap["_id"], cvo.Interface())
continue
}
// 是字符串
if cvo.Kind() == reflect.String {
filterMap["_id"] = append(filterMap["_id"], cvo.Interface())
continue
}
// 来自mongo的ObjectID
if cvo.Type().String() == "primitive.ObjectID" {
filterMap["_id"] = append(filterMap["_id"], cvo.Interface())
continue
}
// 是struct 剥离字段
if cvo.Kind() == reflect.Struct {
if len(filterFields) == 0 {
for i := 0; i < cvo.Type().NumField(); i++ {
field := cvo.Type().Field(i)
if field.Tag.Get("bson") != "_id" {
continue
}
kk := cvo.FieldByName(field.Name)
if kk.IsValid() && !kk.IsZero() {
filterMap["_id"] = append(filterMap["_id"], kk.Interface())
}
}
} else {
for _, field := range filterFields {
bsonName := getBsonTagByReflectTypeField(cvo.Type(), field)
kk := cvo.FieldByName(field)
if kk.IsValid() && !kk.IsZero() {
filterMap[bsonName] = append(filterMap[bsonName], kk.Interface())
}
}
}
}
}
if len(filterMap) != 0 {
var f = bson.M{}
for key, val := range filterMap {
f[key] = bson.M{
"$in": val,
}
}
ds.filter = f
return ds
}
}
// 不符合参数接受条件 直接PANIC
panic("args invalid(required: map,struct,number slice,string slice,struct slice)")
}
func (ds *DeleteScope) optionAssembled() {
// 配置项被直接调用重写过
if ds.opts != nil {
return
}
ds.opts = new(options.DeleteOptions)
}
func (ds *DeleteScope) preCheck() {
var breakerTTL time.Duration
if ds.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if ds.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = ds.scope.breaker.ttl
}
if ds.cw == nil {
ds.cw = &ctxWrap{}
}
if ds.cw.ctx == nil {
ds.cw.ctx, ds.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
if ds.filter == nil {
ds.err = &ErrFilterEmpty
return
}
}
func (ds *DeleteScope) assertErr() {
if ds.err == nil {
return
}
if errors.Is(ds.err, context.DeadlineExceeded) {
ds.err = &ErrRequestBroken
return
}
err, ok := ds.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
ds.err = &ErrRequestBroken
}
}
func (ds *DeleteScope) getContext() context.Context {
return ds.cw.ctx
}
func (ds *DeleteScope) doClear() {
if ds.cw != nil && ds.cw.cancel != nil {
ds.cw.cancel()
}
ds.scope.execT = 0
ds.scope.debug = false
}
func (ds *DeleteScope) debug() {
if !ds.scope.debug && !ds.scope.debugWhenError {
return
}
debugger := &Debugger{
collection: ds.scope.tableName,
execT: ds.scope.execT,
action: ds,
}
// 当错误时输出
if ds.scope.debugWhenError {
if ds.err != nil {
debugger.errMsg = ds.err.Error()
debugger.ErrorString()
}
return
}
// 所有bug输出
if ds.scope.debug {
debugger.String()
}
}
func (ds *DeleteScope) doString() string {
builder := RegisterTimestampCodec(nil).Build()
filter, _ := bson.MarshalExtJSONWithRegistry(builder, ds.filter, true, true)
switch ds.action {
case deleteOne:
return fmt.Sprintf("deleteOne(%s)", string(filter))
case deleteMany:
return fmt.Sprintf("deleteMany(%s)", string(filter))
default:
panic("not support delete type")
}
}
func (ds *DeleteScope) doDelete(isMany bool) {
if isMany {
var starTime time.Time
if ds.scope.debug {
starTime = time.Now()
}
ds.result, ds.err = db.Collection(ds.scope.tableName).DeleteMany(ds.getContext(), ds.filter, ds.opts)
ds.assertErr()
if ds.scope.debug {
ds.action = deleteMany
ds.scope.execT = time.Since(starTime)
ds.debug()
}
return
}
var starTime time.Time
if ds.scope.debug {
starTime = time.Now()
}
ds.result, ds.err = db.Collection(ds.scope.tableName).DeleteOne(ds.getContext(), ds.filter, ds.opts)
ds.assertErr()
if ds.scope.debug {
ds.action = deleteOne
ds.scope.execT = time.Since(starTime)
ds.debug()
}
}
// One 删除单个文档 返回影响行数
func (ds *DeleteScope) One() (int64, error) {
defer ds.doClear()
ds.optionAssembled()
ds.preCheck()
if ds.err != nil {
return 0, ds.err
}
ds.doDelete(false)
if ds.err != nil {
return 0, ds.err
}
return ds.result.DeletedCount, nil
}
// Many 删除多个文档 返回影响行数
func (ds *DeleteScope) Many() (int64, error) {
defer ds.doClear()
ds.optionAssembled()
ds.preCheck()
if ds.err != nil {
return 0, ds.err
}
ds.doDelete(true)
if ds.err != nil {
return 0, ds.err
}
return ds.result.DeletedCount, nil
}

142
delete_scope_test.go Normal file
View File

@ -0,0 +1,142 @@
package mdbc
import (
"context"
"fmt"
"testing"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
)
func TestDeleteScope_OneID(t *testing.T) {
cfg := &Config{
URI: "mongodb://10.0.0.135:27117/mdbc",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
InitDB(client.Database("heywoods_golang_jingliao_crm_dev"))
var m = NewModel(&ModelWsConnectRecord{})
one, err := m.Delete().SetContext(context.Background()).SetID("697022b263e2c1528f32a26e704e63d6").One()
if err != nil {
logrus.Errorf("get err: %+v", err)
return
}
fmt.Println(one)
}
func TestDeleteScope_OneFilter(t *testing.T) {
cfg := &Config{
URI: "mongodb://10.0.0.135:27117/mdbc",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
InitDB(client.Database("heywoods_golang_jingliao_crm_dev"))
var m = NewModel(&ModelWsConnectRecord{})
one, err := m.Delete().SetContext(context.Background()).SetFilter(bson.M{}).One()
if err != nil {
logrus.Errorf("get err: %+v", err)
return
}
fmt.Println(one)
}
func TestDeleteScope_Many(t *testing.T) {
cfg := &Config{
URI: "mongodb://10.0.0.135:27117/mdbc",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
InitDB(client.Database("heywoods_golang_jingliao_crm_dev"))
var m = NewModel(&ModelSchedTask{})
//objs := []primitive.ObjectID{primitive.NewObjectID(), primitive.NewObjectID()}
ms := []*ModelSchedTask{
{
Id: "f36e55a5e4e64cc2947ae8c8a6333f6e",
TaskState: 3,
},
{
Id: "bc227d25df1552ab4eca2610295398e4",
TaskState: 2,
},
{
Id: "2ad97d5d330b20af0ad2ab7fd01cf32a",
TaskState: 0,
},
}
one, err := m.SetDebug(true).Delete().SetContext(context.Background()).
SetFilter(ms, "Id", "TaskState").Many()
//one, err := m.SetDebug(true).Delete().SetContext(context.Background()).Many()
if err != nil {
panic(err)
}
fmt.Println(one)
}
func TestDeleteScope_Other(t *testing.T) {
cfg := &Config{
URI: "mongodb://10.0.0.135:27117/mdbc",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
InitDB(client.Database("heywoods_golang_jingliao_crm_dev"))
var m = NewModel(&ModelSchedTask{})
//通过id删除
one, err := m.Delete().SetContext(context.Background()).SetIDs([]string{
"2cab037c1ea1a96e4010b397afe703b9",
"bd13a92ff734b2912920c5baa677432b",
"435f56b94ba4f6387dca9b081c58f93b",
}).Many()
if err != nil {
panic(err)
}
fmt.Println(one)
//通过条件删除
//one, err = m.Delete().SetContext(context.Background()).SetFilter(bson.M{"_id": "13123"}).One()
//if err != nil {
// panic(err)
//}
//fmt.Println(one)
//
////设置DeleteOption
//one, err = m.Delete().SetContext(context.Background()).SetDeleteOption(options.DeleteOptions{
// Collation: nil,
// Hint: nil,
//}).SetID("aac0a95ddfc5c3344777bef62bb8baae").One()
//if err != nil {
// panic(err)
//}
//fmt.Println(one)
}

305
distinct_scope.go Normal file
View File

@ -0,0 +1,305 @@
package mdbc
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type DistinctScope struct {
scope *Scope
cw *ctxWrap
err error
fieldName string
filter interface{}
opts *options.DistinctOptions
result []interface{}
enableCache bool
cacheFunc DistinctCacheFunc
cacheKey string
}
type DistinctCacheFunc func(field string, obj interface{}) (*CacheObject, error)
// DefaultDistinctCacheFunc 默认的缓存方法
// key: 缓存对象的key
// obj: 缓存对象
var DefaultDistinctCacheFunc = func() DistinctCacheFunc {
return func(key string, obj interface{}) (co *CacheObject, err error) {
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
// 下面使用反射进行操作 保证兼容
v := reflect.ValueOf(obj)
if v.Type().Kind() != reflect.Ptr {
return nil, fmt.Errorf("invalid list type, not ptr")
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
return nil, fmt.Errorf("invalid list type, not ptr to slice")
}
// 空slice 无需缓存
if v.Len() == 0 {
return nil, nil
}
b, _ := json.Marshal(obj)
co = &CacheObject{
Key: key,
Value: string(b),
}
return co, nil
}
}
// SetContext 设置上下文
func (ds *DistinctScope) SetContext(ctx context.Context) *DistinctScope {
if ds.cw == nil {
ds.cw = &ctxWrap{}
}
ds.cw.ctx = ctx
return ds
}
func (ds *DistinctScope) getContext() context.Context {
return ds.cw.ctx
}
func (ds *DistinctScope) doClear() {
if ds.cw != nil && ds.cw.cancel != nil {
ds.cw.cancel()
}
ds.scope.execT = 0
ds.scope.debug = false
}
// SetUpdateOption 设置更新选项
func (ds *DistinctScope) SetUpdateOption(opts options.DistinctOptions) *DistinctScope {
ds.opts = &opts
return ds
}
// SetFilter 设置过滤条件
func (ds *DistinctScope) SetFilter(filter interface{}) *DistinctScope {
if filter == nil {
ds.filter = bson.M{}
return ds
}
v := reflect.ValueOf(filter)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
ds.filter = bson.M{}
}
}
ds.filter = filter
return ds
}
// SetFieldName 设置字段名
func (ds *DistinctScope) SetFieldName(name string) *DistinctScope {
ds.fieldName = name
return ds
}
func (ds *DistinctScope) optionAssembled() {
// 配置项被直接调用重写过
if ds.opts != nil {
return
}
ds.opts = new(options.DistinctOptions)
}
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存
func (ds *DistinctScope) SetCacheFunc(key string, cb DistinctCacheFunc) *DistinctScope {
ds.enableCache = true
ds.cacheFunc = cb
ds.cacheKey = key
return ds
}
func (ds *DistinctScope) preCheck() {
if ds.filter == nil {
ds.filter = bson.M{}
}
var breakerTTL time.Duration
if ds.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if ds.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = ds.scope.breaker.ttl
}
if ds.cw == nil {
ds.cw = &ctxWrap{}
}
if ds.cw.ctx == nil {
ds.cw.ctx, ds.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}
func (ds *DistinctScope) assertErr() {
if ds.err == nil {
return
}
if errors.Is(ds.err, context.DeadlineExceeded) {
ds.err = &ErrRequestBroken
return
}
err, ok := ds.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
ds.err = &ErrRequestBroken
return
}
}
// debug 判断是否开启debug开启的话就打印
func (ds *DistinctScope) debug() {
if !ds.scope.debug {
return
}
debugger := &Debugger{
collection: ds.scope.tableName,
execT: ds.scope.execT,
action: ds,
}
// 当错误时优先输出
if ds.scope.debugWhenError {
if ds.err != nil {
debugger.errMsg = ds.err.Error()
debugger.ErrorString()
}
return
}
// 所有bug输出
if ds.scope.debug {
debugger.String()
}
}
func (ds *DistinctScope) doString() string {
builder := RegisterTimestampCodec(nil).Build()
filter, _ := bson.MarshalExtJSONWithRegistry(builder, ds.filter, true, true)
return fmt.Sprintf(`distinct("%s",%s)`, ds.fieldName, string(filter))
}
func (ds *DistinctScope) doGet() {
defer ds.assertErr()
var starTime time.Time
if ds.scope.debug {
starTime = time.Now()
}
ds.result, ds.err = db.Collection(ds.scope.tableName).Distinct(ds.getContext(), ds.fieldName, ds.filter, ds.opts)
if ds.scope.debug {
ds.scope.execT = time.Since(starTime)
ds.debug()
}
}
// doCache 执行缓存
// 检测句柄存不存在
// 从cacheFunc中获取cacheObj
// 判断下数据没问题以及没有错误就进行缓存
func (ds *DistinctScope) doCache(obj interface{}) *DistinctScope {
// redis句柄不存在
if ds.scope.cache == nil {
return nil
}
cacheObj, err := ds.cacheFunc(ds.cacheKey, obj)
if err != nil {
ds.err = err
return ds
}
if cacheObj == nil {
ds.err = fmt.Errorf("cache object nil")
return ds
}
ttl := ds.scope.cache.ttl
if ttl == 0 {
ttl = time.Hour
} else if ttl == -1 {
ttl = 0
}
if ds.getContext().Err() != nil {
ds.err = ds.getContext().Err()
ds.assertErr()
return ds
}
ds.err = ds.scope.cache.client.Set(ds.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
return ds
}
// Get 获取结果
// list: 必须 *[]string 或者 *[]struct
func (ds *DistinctScope) Get(list interface{}) error {
defer ds.doClear()
ds.optionAssembled()
ds.preCheck()
if ds.fieldName == "" {
return fmt.Errorf("field name empty")
}
ds.doGet()
if ds.err != nil {
return ds.err
}
vo := reflect.ValueOf(list)
if vo.Kind() != reflect.Ptr {
return fmt.Errorf("arg not ptr")
}
vo = vo.Elem()
if vo.Kind() != reflect.Slice {
return fmt.Errorf("arg not ptr to slice")
}
vot := vo.Type()
if vot.Kind() != reflect.Slice {
return fmt.Errorf("arg not ptr to slice")
}
vot = vot.Elem()
if vot.Kind() != reflect.String && vot.Kind() != reflect.Struct {
return fmt.Errorf("slice subtype must string or struct but get %+v", vot.Kind())
}
for _, obj := range ds.result {
vo = reflect.Append(vo, reflect.ValueOf(obj))
}
rtv := reflect.ValueOf(list)
rtv.Elem().Set(vo)
if ds.enableCache {
ds.doCache(list)
}
return nil
}

41
distinct_scope_test.go Normal file
View File

@ -0,0 +1,41 @@
package mdbc
import (
"testing"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
)
func TestDistinctScope(*testing.T) {
cfg := &Config{
URI: "mongodb://admin:admin@10.0.0.135:27017/admin",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
InitDB(client.Database("mdbc"))
var m = NewModel(&ModelSchedTask{})
var data []string
err = m.
SetDebug(true).
Distinct().
SetCacheFunc("task_type", DefaultDistinctCacheFunc()).
SetFieldName("task_type").
SetFilter(bson.M{"task_type": bson.M{"$ne": "test"}}).
Get(&data)
if err != nil {
panic(err)
}
logrus.Infof("get ttl: %+v\n", m.cache.ttl)
logrus.Infof("res %+v\n", data)
}

110
drop_scope.go Normal file
View File

@ -0,0 +1,110 @@
package mdbc
import (
"context"
"errors"
"time"
"go.mongodb.org/mongo-driver/mongo"
)
type DropScope struct {
scope *Scope
cw *ctxWrap
err error
}
// SetContext 设置上下文
func (ds *DropScope) SetContext(ctx context.Context) *DropScope {
if ds.cw == nil {
ds.cw = &ctxWrap{}
}
ds.cw.ctx = ctx
return ds
}
func (ds *DropScope) preCheck() {
var breakerTTL time.Duration
if ds.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if ds.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = ds.scope.breaker.ttl
}
if ds.cw == nil {
ds.cw = &ctxWrap{}
}
if ds.cw.ctx == nil {
ds.cw.ctx, ds.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}
func (ds *DropScope) assertErr() {
if ds.err == nil {
return
}
if errors.Is(ds.err, context.DeadlineExceeded) {
ds.err = &ErrRequestBroken
return
}
err, ok := ds.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
ds.err = &ErrRequestBroken
return
}
}
func (ds *DropScope) doString() string {
return "drop()"
}
func (ds *DropScope) debug() {
if !ds.scope.debug {
return
}
debugger := &Debugger{
collection: ds.scope.tableName,
execT: ds.scope.execT,
action: ds,
}
debugger.String()
}
func (ds *DropScope) getContext() context.Context {
return ds.cw.ctx
}
func (ds *DropScope) doClear() {
if ds.cw != nil && ds.cw.cancel != nil {
ds.cw.cancel()
}
ds.scope.execT = 0
ds.scope.debug = false
}
// Do 删除Model绑定的集合
func (ds *DropScope) Do() error {
defer ds.doClear()
ds.preCheck()
defer ds.assertErr()
var starTime time.Time
if ds.scope.debug {
starTime = time.Now()
}
ds.err = db.Collection(ds.scope.tableName).Drop(ds.getContext())
if ds.scope.debug {
ds.scope.execT = time.Since(starTime)
ds.debug()
}
if ds.err != nil {
return ds.err
}
return nil
}

591
find_one_scope.go Normal file
View File

@ -0,0 +1,591 @@
package mdbc
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type findOneAction string
const (
findOne findOneAction = "findOne"
findOneAndDelete findOneAction = "findOneAndDelete"
findOneAndUpdate findOneAction = "findOneAndUpdate"
findOneAndReplace findOneAction = "findOneAndReplace"
)
type FindOneScope struct {
scope *Scope
cw *ctxWrap
err error
limit *int64
skip *int64
sort bson.D
upsert *bool // 是否允许 upsert 操作
filter interface{} // 过滤条件
bindObject interface{} // 绑定结果
bindSQL *string // 绑定执行SQL
enableCache bool // 允许缓存
cacheKey string // 缓存KEY
cacheFunc FindOneCacheFunc // 缓存调用函数
action findOneAction // 标注什么动作 findOne findOneAndUpdate...
fOpt *options.FindOneOptions
dOpt *options.FindOneAndDeleteOptions
uOpt *options.FindOneAndUpdateOptions
rOpt *options.FindOneAndReplaceOptions
result *mongo.SingleResult
}
// FindOneCacheFunc 创建缓存的key
// field 是需要缓存的字段 当你需要基于结果obj进行操作
// obj 是结果 你可以对他进行操作
// defaultFindOneCacheFunc 是一个样例 基于obj某一个字段 取出其值 将其当作key进行结果缓存
// 返回 key, value error
// key: 需要存储的key
// value: 需要存储的value 请自行将其实例化
// error: callback错误信息
type FindOneCacheFunc func(field string, obj interface{}) (*CacheObject, error)
// DefaultFindOneCacheFunc 基于obj的字段key 返回该key对应value 以及该obj的marshal string
var DefaultFindOneCacheFunc = func() FindOneCacheFunc {
return func(key string, obj interface{}) (co *CacheObject, err error) {
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
// 下面使用反射进行操作 保证兼容
v := reflect.ValueOf(obj)
if v.Type().Kind() != reflect.Ptr {
return nil, fmt.Errorf("obj not ptr")
}
v = v.Elem()
idVal := v.FieldByName(key)
if idVal.Kind() == reflect.Invalid {
return nil, fmt.Errorf("%s field not found", key)
}
b, _ := json.Marshal(obj)
co = &CacheObject{
Key: idVal.String(),
Value: string(b),
}
return co, nil
}
}
// SetLimit 设置获取的条数
func (fos *FindOneScope) SetLimit(limit int64) *FindOneScope {
fos.limit = &limit
return fos
}
// SetSkip 设置跳过的条数
func (fos *FindOneScope) SetSkip(skip int64) *FindOneScope {
fos.skip = &skip
return fos
}
// SetSort 设置排序
func (fos *FindOneScope) SetSort(sort bson.D) *FindOneScope {
fos.sort = sort
return fos
}
// SetContext 设置上下文
func (fos *FindOneScope) SetContext(ctx context.Context) *FindOneScope {
if fos.cw == nil {
fos.cw = &ctxWrap{}
}
fos.cw.ctx = ctx
return fos
}
func (fos *FindOneScope) getContext() context.Context {
return fos.cw.ctx
}
// SetUpsert 设置upsert属性
func (fos *FindOneScope) SetUpsert(upsert bool) *FindOneScope {
fos.upsert = &upsert
return fos
}
// SetFilter 设置过滤条件
func (fos *FindOneScope) SetFilter(filter interface{}) *FindOneScope {
if filter == nil {
fos.filter = bson.M{}
return fos
}
v := reflect.ValueOf(filter)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
fos.filter = bson.M{}
}
}
fos.filter = filter
return fos
}
// SetFindOneOption 设置FindOneOption 优先级较低 有 set 时参数将被重写
func (fos *FindOneScope) SetFindOneOption(opts options.FindOneOptions) *FindOneScope {
fos.fOpt = &opts
return fos
}
// SetFindOneAndDeleteOption 设置FindOneAndDeleteOption 优先级较低 有 set 时参数将被重写
func (fos *FindOneScope) SetFindOneAndDeleteOption(opts options.FindOneAndDeleteOptions) *FindOneScope {
fos.dOpt = &opts
return fos
}
// SetFindOneAndUpdateOption 设置FindOneAndUpdateOption 优先级较低 有 set 时参数将被重写
func (fos *FindOneScope) SetFindOneAndUpdateOption(opts options.FindOneAndUpdateOptions) *FindOneScope {
fos.uOpt = &opts
return fos
}
// SetFindOneAndReplaceOption 设置FindOneAndReplaceOption 优先级较低 有 set 时参数将被重写
func (fos *FindOneScope) SetFindOneAndReplaceOption(opts options.FindOneAndReplaceOptions) *FindOneScope {
fos.rOpt = &opts
return fos
}
// BindResult 绑定结果到一个变量中
// 一般在链式末尾没调用 Get(obj) 方法但想获取结果时使用
// 入参 r 必须是一个 *struct
func (fos *FindOneScope) BindResult(r interface{}) *FindOneScope {
vo := reflect.ValueOf(r)
if vo.Kind() != reflect.Ptr {
panic("BindResult arg r must be pointer")
}
vo = vo.Elem()
if vo.Kind() != reflect.Struct {
panic("BindResult arg r must be struct pointer")
}
fos.bindObject = r
return fos
}
// BindSQL 获取本次执行的SQL信息
func (fos *FindOneScope) BindSQL(sql *string) *FindOneScope {
fos.bindSQL = sql
return fos
}
func (fos *FindOneScope) preCheck() {
if fos.filter == nil {
fos.filter = bson.M{}
}
var breakerTTL time.Duration
if fos.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if fos.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = fos.scope.breaker.ttl
}
if fos.cw == nil {
fos.cw = &ctxWrap{}
}
if fos.cw.ctx == nil {
fos.cw.ctx, fos.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}
func (fos *FindOneScope) findOneOptionAssembled() {
// 配置项被直接调用重写过
if fos.fOpt == nil {
fos.fOpt = new(options.FindOneOptions)
}
if fos.skip != nil {
fos.fOpt.Skip = fos.skip
}
if len(fos.sort) != 0 {
fos.fOpt.Sort = fos.sort
}
}
func (fos *FindOneScope) findOneDeleteOptionAssembled() {
// 配置项被直接调用重写过
if fos.dOpt == nil {
fos.dOpt = new(options.FindOneAndDeleteOptions)
}
if fos.sort != nil {
fos.dOpt.Sort = fos.sort
}
}
func (fos *FindOneScope) findOneUpdateOptionAssembled() {
// 配置项被直接调用重写过
if fos.uOpt == nil {
fos.uOpt = new(options.FindOneAndUpdateOptions)
}
if fos.sort != nil {
fos.uOpt.Sort = fos.sort
}
if fos.upsert != nil {
fos.uOpt.Upsert = fos.upsert
}
}
func (fos *FindOneScope) findOneReplaceOptionAssembled() {
// 配置项被直接调用重写过
if fos.rOpt == nil {
fos.rOpt = new(options.FindOneAndReplaceOptions)
}
if fos.sort != nil {
fos.rOpt.Sort = fos.sort
}
if fos.upsert != nil {
fos.rOpt.Upsert = fos.upsert
}
}
func (fos *FindOneScope) doString() string {
filter, _ := bson.MarshalExtJSON(fos.filter, true, true)
var query string
switch fos.action {
case findOne:
query = fmt.Sprintf("findOne(%s)", string(filter))
case findOneAndDelete:
query = fmt.Sprintf("findOneAndDelete(%s)", string(filter))
case findOneAndUpdate:
query = fmt.Sprintf("findOneAndUpdate(%s)", string(filter))
case findOneAndReplace:
query = fmt.Sprintf("findOneAndReplace(%s)", string(filter))
}
if fos.skip != nil {
query = fmt.Sprintf("%s.skip(%d)", query, *fos.skip)
}
if fos.limit != nil {
query = fmt.Sprintf("%s.limit(%d)", query, *fos.limit)
}
if fos.sort != nil {
sort, _ := bson.MarshalExtJSON(fos.sort, true, true)
query = fmt.Sprintf("%s.sort(%s)", query, string(sort))
}
return query
}
func (fos *FindOneScope) assertErr() {
if fos.err == nil {
return
}
if errors.Is(fos.err, mongo.ErrNoDocuments) || errors.Is(fos.err, mongo.ErrNilDocument) {
fos.err = &ErrRecordNotFound
return
}
if errors.Is(fos.err, context.DeadlineExceeded) {
fos.err = &ErrRequestBroken
return
}
err, ok := fos.err.(mongo.CommandError)
if ok && err.HasErrorMessage(context.DeadlineExceeded.Error()) {
fos.err = &ErrRequestBroken
return
}
fos.err = err
}
func (fos *FindOneScope) debug() {
if !fos.scope.debug && !fos.scope.debugWhenError && fos.bindSQL == nil {
return
}
debugger := &Debugger{
collection: fos.scope.tableName,
execT: fos.scope.execT,
action: fos,
}
// 当错误时优先输出
if fos.scope.debugWhenError {
if fos.err != nil {
debugger.errMsg = fos.err.Error()
debugger.ErrorString()
}
return
}
// 所有bug输出
if fos.scope.debug {
debugger.String()
}
// 绑定 SQL
if fos.bindSQL != nil {
*fos.bindSQL = debugger.GetString()
}
}
func (fos *FindOneScope) doSearch() {
var starTime time.Time
if fos.scope.debug {
starTime = time.Now()
}
fos.result = db.Collection(fos.scope.tableName).FindOne(fos.getContext(), fos.filter, fos.fOpt)
if fos.scope.debug {
fos.scope.execT = time.Since(starTime)
}
fos.action = findOne
fos.err = fos.result.Err()
fos.assertErr()
// debug
fos.debug()
}
func (fos *FindOneScope) doUpdate(obj interface{}, isReplace bool) {
if isReplace {
var starTime time.Time
if fos.scope.debug {
starTime = time.Now()
}
fos.result = db.Collection(fos.scope.tableName).FindOneAndReplace(fos.getContext(), fos.filter, obj, fos.rOpt)
if fos.scope.debug {
fos.scope.execT = time.Since(starTime)
}
fos.action = findOneAndReplace
fos.debug()
fos.err = fos.result.Err()
fos.assertErr()
return
}
vo := reflect.ValueOf(obj)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() == reflect.Struct {
obj = bson.M{
"$set": obj,
}
}
var starTime time.Time
if fos.scope.debug {
starTime = time.Now()
}
fos.result = db.Collection(fos.scope.tableName).FindOneAndUpdate(fos.getContext(), fos.filter, obj, fos.uOpt)
if fos.scope.debug {
fos.scope.execT = time.Since(starTime)
}
fos.action = findOneAndUpdate
fos.err = fos.result.Err()
fos.assertErr()
fos.debug()
}
func (fos *FindOneScope) doDelete() {
var starTime time.Time
if fos.scope.debug {
starTime = time.Now()
}
fos.result = db.Collection(fos.scope.tableName).FindOneAndDelete(fos.getContext(), fos.filter, fos.dOpt)
if fos.scope.debug {
fos.scope.execT = time.Since(starTime)
}
fos.action = findOneAndDelete
fos.err = fos.result.Err()
fos.assertErr()
fos.debug()
}
func (fos *FindOneScope) doClear() {
if fos.cw != nil && fos.cw.cancel != nil {
fos.cw.cancel()
}
fos.scope.debug = false
fos.scope.execT = 0
}
// Get 传递一个指针类型来接收结果
// m: 必须是 *struct
func (fos *FindOneScope) Get(m interface{}) error {
defer fos.doClear()
if fos.err != nil {
return fos.err
}
fos.preCheck()
fos.findOneOptionAssembled()
fos.doSearch()
if fos.err != nil {
return fos.err
}
fos.err = fos.result.Decode(m)
if fos.err != nil {
return fos.err
}
// 进行缓存
if fos.enableCache {
fos.doCache(m)
return fos.err
}
return nil
}
// Delete 匹配一个并删除
func (fos *FindOneScope) Delete() error {
defer fos.doClear()
if fos.err != nil {
return fos.err
}
fos.preCheck()
fos.findOneDeleteOptionAssembled()
fos.doDelete()
if fos.err != nil {
return fos.err
}
return nil
}
// Update 进行数据字段懒更新
// obj 传入 *struct 会 $set 自动包裹
// obj 传入 bson.M 不会 $set 自动包裹
func (fos *FindOneScope) Update(obj interface{}) error {
defer fos.doClear()
if fos.err != nil {
return fos.err
}
fos.preCheck()
fos.findOneUpdateOptionAssembled()
fos.doUpdate(obj, false)
if fos.err != nil {
return fos.err
}
return nil
}
// Replace 进行数据字段全强制更新
func (fos *FindOneScope) Replace(obj interface{}) error {
defer fos.doClear()
if fos.err != nil {
return fos.err
}
fos.preCheck()
fos.findOneReplaceOptionAssembled()
fos.doUpdate(obj, true)
if fos.err != nil {
return fos.err
}
return nil
}
// IsExist 是否存在 不存在返回 false, nil
// 非 NotFound 错误将返回 false, error
// 若找到了并想绑定结果 请在该方法前调用 BindResult
func (fos *FindOneScope) IsExist() (bool, error) {
defer fos.doClear()
if fos.err != nil {
return false, fos.err
}
fos.preCheck()
fos.findOneOptionAssembled()
fos.doSearch()
if fos.err != nil {
if IsRecordNotFound(fos.err) {
return false, nil
}
return false, fos.err
}
if fos.bindObject != nil {
fos.err = fos.result.Decode(fos.bindObject)
if fos.err != nil {
return false, fos.err
}
}
return true, nil
}
func (fos *FindOneScope) Error() error {
return fos.err
}
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存
func (fos *FindOneScope) SetCacheFunc(key string, cb FindOneCacheFunc) *FindOneScope {
fos.enableCache = true
fos.cacheFunc = cb
fos.cacheKey = key
return fos
}
// doCache 执行缓存
func (fos *FindOneScope) doCache(obj interface{}) *FindOneScope {
// redis句柄不存在
if fos.scope.cache == nil {
return nil
}
cacheObj, err := fos.cacheFunc(fos.cacheKey, obj)
if err != nil {
fos.err = err
return fos
}
if cacheObj == nil {
fos.err = fmt.Errorf("cache object nil")
return fos
}
ttl := fos.scope.cache.ttl
if ttl == 0 {
ttl = time.Hour
} else if ttl == -1 {
ttl = 0
}
if fos.getContext().Err() != nil {
fos.err = fos.getContext().Err()
fos.assertErr()
return fos
}
fos.err = fos.scope.cache.client.Set(fos.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
if fos.err != nil {
fos.assertErr()
}
return fos
}

62
find_one_scope_test.go Normal file
View File

@ -0,0 +1,62 @@
package mdbc
import (
"testing"
"time"
"go.mongodb.org/mongo-driver/mongo/readpref"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
)
func TestFindOneScope(t *testing.T) {
client, err := ConnInit(&Config{
URI: "mongodb://mdbc:mdbc@10.0.0.135:27117/admin",
MinPoolSize: 32,
ConnTimeout: 10,
ReadPreference: readpref.Nearest(),
RegistryBuilder: RegisterTimestampCodec(nil),
})
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
InitDB(client.Database("mdbc"))
time.Sleep(time.Second * 5)
var m = NewModel(&ModelSchedTask{})
var record ModelSchedTask
err = m.SetDebugError(true).FindOne().SetFilter(bson.M{"_id": "insertffdddfknkodsanfkasdf"}).Get(&record)
if err != nil {
return
}
logrus.Infof("get: %+v", &record)
time.Sleep(time.Second * 5)
}
func TestFindOneScope_Delete(t *testing.T) {
var m = NewModel(&ModelSchedTask{})
err := m.SetCacheExpiredAt(time.Second*300).FindOne().
SetFilter(bson.M{"_id": "13123"}).SetCacheFunc("Id", DefaultFindOneCacheFunc()).Delete()
if err != nil {
panic(err)
}
logrus.Infof("get ttl: %+v", m.cache.ttl)
}
func TestFindOneScope_Replace(t *testing.T) {
var m = NewModel(&ModelSchedTask{})
var record ModelSchedTask
err := m.FindOne().
SetFilter(bson.M{"_id": "0e63f5962e18a8da331289caaa3fa224"}).Get(&record)
record.RspJson = "hahahahah"
err = m.FindOne().SetFilter(bson.M{"_id": record.Id}).Replace(&record)
if err != nil {
panic(err)
}
logrus.Infof("get ttl: %+v", m.cache.ttl)
}

445
find_scope.go Normal file
View File

@ -0,0 +1,445 @@
package mdbc
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"time"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type FindScope struct {
scope *Scope // 向上提权
cw *ctxWrap // 查询上下文环境
cursor *mongo.Cursor // 查询指针
err error // 操作是否有错误
limit *int64 // 限制查询数量
skip *int64 // 偏移量
count *int64 // 计数统计结果
withCount bool // 是否开启计数统计 这将只基于filter进行计数而忽略 limit skip
selects interface{} //
sort bson.D // 排序
filter interface{} // 过滤条件
opts *options.FindOptions // 查询条件
enableCache bool // 是否开启缓存
cacheKey string // 缓存key
cacheFunc FindCacheFunc // 基于什么缓存函数进行缓存
}
// FindCacheFunc 缓存Find结果集
// field: 若无作用 可置空 DefaultFindCacheFunc将使用其反射出field字段对应的value并将其设为key
// obj: 这是一个slice 只能在GetList是使用GetMap使用将无效果
type FindCacheFunc func(field string, obj interface{}) (*CacheObject, error)
// DefaultFindCacheFunc 按照第一个结果的field字段作为key缓存list数据
// field字段需要满足所有数据均有该属性 否则有概率导致缓存失败
var DefaultFindCacheFunc = func() FindCacheFunc {
return func(field string, obj interface{}) (*CacheObject, error) {
// 建议先断言obj 再操作 若obj断言失败 请返回nil 这样缓存将不会执行
// 示例: res,ok := obj.(*[]*model.ModelUser) 然后操作res 将 key,value 写入 CacheObject 既可
// 下面使用反射进行操作 保证兼容
v := reflect.ValueOf(obj)
if v.Type().Kind() != reflect.Ptr {
return nil, fmt.Errorf("invalid list type, not ptr")
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
return nil, fmt.Errorf("invalid list type, not ptr to slice")
}
// 空slice 无需缓存
if v.Len() == 0 {
return nil, nil
}
firstNode := v.Index(0)
// 判断过长度 无需判断nil
if firstNode.Kind() == reflect.Ptr {
firstNode = firstNode.Elem()
}
idVal := firstNode.FieldByName(field)
if idVal.Kind() == reflect.Invalid {
return nil, fmt.Errorf("first node %s field not found", idVal)
}
b, _ := json.Marshal(v.Interface())
co := &CacheObject{
Key: idVal.String(),
Value: string(b),
}
return co, nil
}
}
// SetLimit 设置获取的条数
func (fs *FindScope) SetLimit(limit int64) *FindScope {
fs.limit = &limit
return fs
}
// SetSkip 设置跳过的条数
func (fs *FindScope) SetSkip(skip int64) *FindScope {
fs.skip = &skip
return fs
}
// SetSort 设置排序
func (fs *FindScope) SetSort(sort bson.D) *FindScope {
fs.sort = sort
return fs
}
// SetContext 设置上下文
func (fs *FindScope) SetContext(ctx context.Context) *FindScope {
if fs.cw == nil {
fs.cw = &ctxWrap{}
}
fs.cw.ctx = ctx
return fs
}
// SetSelect 选择查询字段 格式 {key: 1, key: 0}
// 你可以传入 bson.M, map[string]integer_interface 其他类型将导致该功能失效
// 1显示 0不显示
func (fs *FindScope) SetSelect(selects interface{}) *FindScope {
vo := reflect.ValueOf(selects)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
// 不是map 提前返回
if vo.Kind() != reflect.Map {
return fs
}
fs.selects = selects
return fs
}
func (fs *FindScope) getContext() context.Context {
if fs.cw == nil {
fs.cw = &ctxWrap{
ctx: context.Background(),
}
}
return fs.cw.ctx
}
// SetFilter 设置过滤条件
func (fs *FindScope) SetFilter(filter interface{}) *FindScope {
if filter == nil {
fs.filter = bson.M{}
return fs
}
v := reflect.ValueOf(filter)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
fs.filter = bson.M{}
}
}
fs.filter = filter
return fs
}
// SetFindOption 设置FindOption 优先级最低 sort/skip/limit 会被 set 函数重写
func (fs *FindScope) SetFindOption(opts options.FindOptions) *FindScope {
fs.opts = &opts
return fs
}
// WithCount 查询的同时获取总数量
// 将按照 SetFilter 的条件进行查询 未设置是获取所有文档数量
// 如果在该步骤出错 将不会进行余下操作
func (fs *FindScope) WithCount(count *int64) *FindScope {
fs.withCount = true
fs.count = count
return fs
}
func (fs *FindScope) optionAssembled() {
// 配置项被直接调用重写过
if fs.opts == nil {
fs.opts = new(options.FindOptions)
}
if fs.sort != nil {
fs.opts.Sort = fs.sort
}
if fs.skip != nil {
fs.opts.Skip = fs.skip
}
if fs.limit != nil {
fs.opts.Limit = fs.limit
}
if fs.selects != nil {
fs.opts.Projection = fs.selects
}
}
func (fs *FindScope) preCheck() {
if fs.filter == nil {
fs.filter = bson.M{}
}
var breakerTTL time.Duration
if fs.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if fs.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = fs.scope.breaker.ttl
}
if fs.cw == nil {
fs.cw = &ctxWrap{}
}
if fs.cw.ctx == nil {
fs.cw.ctx, fs.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
} else {
fs.cw.ctx, fs.cw.cancel = context.WithTimeout(fs.cw.ctx, breakerTTL)
}
}
func (fs *FindScope) assertErr() {
if fs.err == nil {
return
}
if errors.Is(fs.err, mongo.ErrNoDocuments) || errors.Is(fs.err, mongo.ErrNilDocument) {
fs.err = &ErrRecordNotFound
return
}
if errors.Is(fs.err, context.DeadlineExceeded) {
fs.err = &ErrRequestBroken
return
}
err, ok := fs.err.(mongo.CommandError)
if ok && err.HasErrorMessage(context.DeadlineExceeded.Error()) {
fs.err = &ErrRequestBroken
return
}
fs.err = err
}
func (fs *FindScope) doString() string {
builder := RegisterTimestampCodec(nil).Build()
filter, _ := bson.MarshalExtJSONWithRegistry(builder, fs.filter, true, true)
query := fmt.Sprintf("find(%s)", string(filter))
if fs.skip != nil {
query = fmt.Sprintf("%s.skip(%d)", query, *fs.skip)
}
if fs.limit != nil {
query = fmt.Sprintf("%s.limit(%d)", query, *fs.limit)
}
if fs.sort != nil {
sort, _ := bson.MarshalExtJSON(fs.sort, true, true)
query = fmt.Sprintf("%s.sort(%s)", query, string(sort))
}
return query
}
func (fs *FindScope) debug() {
if !fs.scope.debug && !fs.scope.debugWhenError {
return
}
debugger := &Debugger{
collection: fs.scope.tableName,
execT: fs.scope.execT,
action: fs,
}
// 当错误时优先输出
if fs.scope.debugWhenError {
if fs.err != nil {
debugger.errMsg = fs.err.Error()
debugger.ErrorString()
}
return
}
// 所有bug输出
if fs.scope.debug {
debugger.String()
}
}
func (fs *FindScope) doClear() {
if fs.cw != nil && fs.cw.cancel != nil {
fs.cw.cancel()
}
fs.scope.debug = false
fs.scope.execT = 0
}
func (fs *FindScope) doSearch() {
var starTime time.Time
if fs.scope.debug {
starTime = time.Now()
}
fs.cursor, fs.err = db.Collection(fs.scope.tableName).Find(fs.getContext(), fs.filter, fs.opts)
fs.assertErr() // 断言错误
if fs.scope.debug {
fs.scope.execT = time.Since(starTime)
fs.debug()
}
// 有检测数量的
if fs.withCount {
var res int64
res, fs.err = fs.scope.Count().SetContext(fs.getContext()).SetFilter(fs.filter).Count()
*fs.count = res
}
fs.assertErr()
}
// GetList 获取列表
// list: 需要一个 *[]*struct
func (fs *FindScope) GetList(list interface{}) error {
defer fs.doClear()
fs.optionAssembled()
fs.preCheck()
if fs.err != nil {
return fs.err
}
fs.doSearch()
if fs.err != nil {
return fs.err
}
v := reflect.ValueOf(list)
if v.Type().Kind() != reflect.Ptr {
fs.err = fmt.Errorf("invalid list type, not ptr")
return fs.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Slice {
fs.err = fmt.Errorf("invalid list type, not ptr to slice")
return fs.err
}
fs.err = fs.cursor.All(fs.getContext(), list)
if fs.enableCache {
fs.doCache(list)
}
return fs.err
}
// GetMap 基于结果的某个字段为Key 获取Map
// m: 传递一个 *map[string]*Struct
// field: struct的一个字段名称 需要是公开可访问的大写
func (fs *FindScope) GetMap(m interface{}, field string) error {
defer fs.doClear()
fs.optionAssembled()
fs.preCheck()
if fs.err != nil {
return fs.err
}
fs.doSearch()
if fs.err != nil {
return fs.err
}
v := reflect.ValueOf(m)
if v.Type().Kind() != reflect.Ptr {
fs.err = fmt.Errorf("invalid map type, not ptr")
return fs.err
}
v = v.Elem()
if v.Type().Kind() != reflect.Map {
fs.err = fmt.Errorf("invalid map type, not map")
return fs.err
}
mapType := v.Type()
valueType := mapType.Elem()
keyType := mapType.Key()
if valueType.Kind() != reflect.Ptr {
fs.err = fmt.Errorf("invalid map value type, not prt")
return fs.err
}
if !v.CanSet() {
fs.err = fmt.Errorf("invalid map value type, not addressable or obtained by the use of unexported struct fields")
return fs.err
}
v.Set(reflect.MakeMap(reflect.MapOf(keyType, valueType)))
for fs.cursor.Next(fs.getContext()) {
t := reflect.New(valueType.Elem())
if fs.err = fs.cursor.Decode(t.Interface()); fs.err != nil {
logrus.Errorf("err: %+v", fs.err)
return fs.err
}
fieldNode := t.Elem().FieldByName(field)
if fieldNode.Kind() == reflect.Invalid {
fs.err = fmt.Errorf("invalid model key: %s", field)
return fs.err
}
v.SetMapIndex(fieldNode, t)
}
return fs.err
}
func (fs *FindScope) Error() error {
return fs.err
}
// SetCacheFunc 传递一个函数 处理查询操作的结果进行缓存 还没有实现
func (fs *FindScope) SetCacheFunc(key string, cb FindCacheFunc) *FindScope {
fs.enableCache = true
fs.cacheFunc = cb
fs.cacheKey = key
return fs
}
// doCache 执行缓存
func (fs *FindScope) doCache(obj interface{}) *FindScope {
// redis句柄不存在
if fs.scope.cache == nil {
return nil
}
cacheObj, err := fs.cacheFunc(fs.cacheKey, obj)
if err != nil {
fs.err = err
return fs
}
if cacheObj == nil {
fs.err = fmt.Errorf("cache object nil")
return fs
}
ttl := fs.scope.cache.ttl
if ttl == 0 {
ttl = time.Hour
} else if ttl == -1 {
ttl = 0
}
if fs.getContext().Err() != nil {
fs.err = fs.getContext().Err()
fs.assertErr()
return fs
}
fs.err = fs.scope.cache.client.Set(fs.getContext(), cacheObj.Key, cacheObj.Value, ttl).Err()
// 这里也要断言错误 看是否是熔断错误
return fs
}

98
find_scope_test.go Normal file
View File

@ -0,0 +1,98 @@
package mdbc
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"github.com/sirupsen/logrus"
"gitlab.com/gotk/gotk/utils"
"go.mongodb.org/mongo-driver/bson"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestFindScope(t *testing.T) {
codec := bson.NewRegistryBuilder()
codec.RegisterCodec(reflect.TypeOf(&timestamppb.Timestamp{}), &TimestampCodec{})
client, err := ConnInit(&Config{
URI: "mongodb://mdbc:mdbc@10.0.0.135:27117/admin",
MinPoolSize: 32,
ConnTimeout: 10,
RegistryBuilder: codec,
})
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
InitDB(client.Database("heywoods_golang_jingliao_crm_dev"))
var m = NewModel(&ModelRobot{})
var record []*ModelRobot
var count int64
err = m.Find().SetFilter(bson.M{
ModelRobotField.GetStatusField(): 11,
}).SetSort(bson.D{
{
Key: ModelRobotField.GetWechatIdField(),
Value: -1,
},
}).WithCount(&count).SetLimit(10).GetList(&record)
if err != nil {
panic(err)
}
logrus.Infof("count: %+v", count)
logrus.Infof("list: %+v", utils.PluckString(record, "NickName"))
//p1 := utils.PluckString(record, "NickName")
//logrus.Infof("all: %+v", p1)
//
//err = m.Find().SetFilter(nil).SetSort(bson.D{
// {
// Key: ModelRobotField.GetWechatIdField(),
// Value: -1,
// },
//}).SetSkip(0).SetLimit(10).GetList(&record)
//if err != nil {
// panic(err)
//}
//p1 = utils.PluckString(record, "NickName")
//logrus.Infof("p1: %+v", p1)
//
//err = m.Find().SetFilter(nil).SetSort(bson.D{
// {
// Key: ModelRobotField.GetWechatIdField(),
// Value: -1,
// },
//}).SetSkip(10).SetLimit(10).GetList(&record)
//if err != nil {
// panic(err)
//}
//p2 := utils.PluckString(record, "NickName")
//logrus.Infof("p2: %+v", p2)
}
func TestFindScope_GetMap(t *testing.T) {
var m = NewModel(&ModelSchedTask{})
var record = make(map[string]*ModelSchedTask)
err := m.Find().
SetFilter(nil).SetLimit(2).SetCacheFunc("Id", DefaultFindCacheFunc()).GetMap(&record, "Id")
if err != nil {
panic(err)
}
marshal, err := json.Marshal(record)
if err != nil {
return
}
fmt.Printf("%+v\n", string(marshal))
}
func TestFindScope_SetCacheFunc(t *testing.T) {
var record []*ModelSchedTask
var m = NewModel(&ModelSchedTask{})
err := m.Find().SetCacheFunc("Id", DefaultFindCacheFunc()).GetList(&record)
if err != nil {
return
}
}

3
gen.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/bash
gotker gen --path . --out . --no-scope

15
go.mod Normal file
View File

@ -0,0 +1,15 @@
module gitlab.com/gotk/mdbc
go 1.16
require (
github.com/antonfisher/nested-logrus-formatter v1.3.1
github.com/go-redis/redis/v8 v8.11.4
github.com/google/uuid v1.3.0
github.com/json-iterator/go v1.1.12
github.com/sirupsen/logrus v1.8.1
gitlab.com/gotk/gotk v0.0.0-20220223083201-5d05a06943c3
go.mongodb.org/mongo-driver v1.8.3
golang.org/x/sys v0.0.0-20220222200937-f2425489ef4c // indirect
google.golang.org/protobuf v1.27.1
)

12
hook.go Normal file
View File

@ -0,0 +1,12 @@
package mdbc
// Hook mongodb的指令hook
type Hook interface {
Before() error
After() error
}
// UpdateHook 更新指令的hook
type UpdateHook interface {
Hook
}

BIN
icon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

7
index_scope.go Normal file
View File

@ -0,0 +1,7 @@
package mdbc
import (
"errors"
)
var ErrorCursorIsNil = errors.New("cursor is nil")

97
index_scope_test.go Normal file
View File

@ -0,0 +1,97 @@
package mdbc
import (
"context"
"fmt"
"testing"
"time"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func TestIndexScope_GetList(t *testing.T) {
client, err := ConnInit(&Config{
URI: "mongodb://mdbc:mdbc@10.0.0.135:27117/admin",
MinPoolSize: 32,
ConnTimeout: 10,
ReadPreference: readpref.Nearest(),
RegistryBuilder: RegisterTimestampCodec(nil),
})
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
InitDB(client.Database("mdbc"))
var list IndexList
var m = NewModel(&ModelSchedTask{})
err = m.Index().SetContext(context.Background()).GetIndexList(&list)
if err != nil {
panic(err)
}
for _, card := range list {
logrus.Infof("%+v\n", card)
}
}
func TestIndexScope_DropAll(t *testing.T) {
var m = NewModel(&ModelSchedTask{})
err := m.Index().SetContext(context.Background()).DropAll()
if err != nil {
panic(err)
}
}
func TestIndexScope_DropOne(t *testing.T) {
var m = NewModel(&ModelSchedTask{})
err := m.Index().SetContext(context.Background()).DropOne("expire_time_2")
if err != nil {
panic(err)
}
}
func TestIndexScope_AddIndexModel(t *testing.T) {
var err error
var m = NewModel(&ModelSchedTask{})
// 创建多个索引
_, err = m.Index().SetContext(context.Background()).AddIndexModels(mongo.IndexModel{
Keys: bson.D{{Key: "expire_time", Value: 2}}, // 设置索引列
Options: options.Index().SetExpireAfterSeconds(0), // 设置过期时间
}).AddIndexModels(mongo.IndexModel{
Keys: bson.D{{Key: "created_at", Value: 1}}, // 设置索引列
}).CreateMany()
// 创建单个索引
_, err = m.Index().SetContext(context.Background()).AddIndexModels(mongo.IndexModel{
Keys: bson.D{{Key: "expire_time", Value: 2}}, // 设置索引列
Options: options.Index().SetExpireAfterSeconds(0), // 设置过期时间
}).CreateOne()
if err != nil {
panic(err)
}
}
func TestIndexScope_SetListIndexesOption(t *testing.T) {
var list IndexList
s := int32(2)
mt := 10 * time.Second
var m = NewModel(&ModelSchedTask{})
//设定ListIndexesOptions
err := m.Index().SetListIndexesOption(options.ListIndexesOptions{
BatchSize: &s,
MaxTime: &mt,
}).GetIndexList(&list)
if err != nil {
panic(err)
}
fmt.Printf("%+v\n", list)
}

273
index_scope_v2.go Normal file
View File

@ -0,0 +1,273 @@
package mdbc
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// KeySort 排序方式 类型别名
type KeySort = int
const (
ASC KeySort = 1 // 正序排列
DESC KeySort = -1 // 倒序排列
)
// IndexCard 索引信息
type IndexCard struct {
Version int32 `bson:"v"` // 版本信息
Name string `bson:"name"` // 索引名
Key map[string]KeySort `bson:"key"` // 索引方式
}
// IndexList 索引列表
type IndexList []*IndexCard
type IndexScope struct {
scope *Scope
cw *ctxWrap
iv mongo.IndexView
cursor *mongo.Cursor
ims []mongo.IndexModel
err error
lOpts *options.ListIndexesOptions
dOpts *options.DropIndexesOptions
cOpts *options.CreateIndexesOptions
execResult interface{}
}
func (is *IndexScope) assertErr() {
if is.err == nil {
return
}
err, ok := is.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
is.err = &ErrRequestBroken
}
}
func (is *IndexScope) doClear() {
if is.cw != nil && is.cw.cancel != nil {
is.cw.cancel()
}
is.scope.debug = false
is.scope.execT = 0
}
// doGet 拿到 iv 后续的 增删改查都基于 iv 对象
func (is *IndexScope) doGet() {
is.iv = db.Collection(is.scope.tableName).Indexes(is.getContext())
}
// doCursor 拿到 cursor
func (is *IndexScope) doCursor() {
if is.lOpts == nil {
is.lOpts = &options.ListIndexesOptions{}
}
is.cursor, is.err = is.iv.List(is.getContext(), is.lOpts)
is.assertErr()
}
// SetContext 设定 Context
func (is *IndexScope) SetContext(ctx context.Context) *IndexScope {
if is.cw == nil {
is.cw = &ctxWrap{}
}
is.cw.ctx = ctx
return is
}
// SetListIndexesOption 设定 ListIndexesOptions
func (is *IndexScope) SetListIndexesOption(opts options.ListIndexesOptions) *IndexScope {
is.lOpts = &opts
return is
}
// SetDropIndexesOption 设定 DropIndexesOptions
func (is *IndexScope) SetDropIndexesOption(opts options.DropIndexesOptions) *IndexScope {
is.dOpts = &opts
return is
}
// SetCreateIndexesOption 设定 CreateIndexesOptions
func (is *IndexScope) SetCreateIndexesOption(opts options.CreateIndexesOptions) *IndexScope {
is.cOpts = &opts
return is
}
// getContext 获取ctx
func (is *IndexScope) getContext() context.Context {
return is.cw.ctx
}
// GetIndexList 获取索引列表
func (is *IndexScope) GetIndexList(data *IndexList) error {
defer is.doClear()
is.doGet()
is.doCursor()
if is.err != nil {
return is.err
}
is.err = is.cursor.All(is.getContext(), data)
is.assertErr()
if is.err != nil {
return is.err
}
return nil
}
// GetListCursor 获取结果集的指针
func (is *IndexScope) GetListCursor() (*mongo.Cursor, error) {
defer is.doClear()
is.doGet()
is.doCursor()
if is.err != nil {
return nil, is.err
}
return is.cursor, nil
}
// DropOne 删除一个索引 name 为索引名称
func (is *IndexScope) DropOne(name string) error {
defer is.doClear()
is.doGet()
is.doDropOne(name)
if is.err != nil {
return is.err
}
return nil
}
// DropAll 删除所有的索引
func (is *IndexScope) DropAll() error {
defer is.doClear()
is.doGet()
is.doDropAll()
if is.err != nil {
return is.err
}
return nil
}
func (is *IndexScope) doDropAll() {
if is.dOpts == nil {
is.dOpts = &options.DropIndexesOptions{}
}
is.execResult, is.err = is.iv.DropAll(is.getContext(), is.dOpts)
is.assertErr()
}
// doDropOne 删除一个索引
func (is *IndexScope) doDropOne(name string) {
if is.dOpts == nil {
is.dOpts = &options.DropIndexesOptions{}
}
is.execResult, is.err = is.iv.DropOne(is.getContext(), name, is.dOpts)
is.assertErr()
}
// CreateOne 创建单个索引
// 在调用本方法前,需要调用 AddIndexModel() 添加
// 当添加了多个 indexModel 后仅会创建第一个索引
// 返回 res 索引名称 err 创建时错误信息
func (is *IndexScope) CreateOne() (res string, err error) {
defer is.doClear()
is.doGet()
is.doCreateOne()
if is.err != nil {
return "", is.err
}
res, ok := is.execResult.(string)
if !ok {
return "", fmt.Errorf("create success but get index name empty")
}
return res, nil
}
// CreateMany 创建多个索引,在调用本方法前,需要调用 AddIndexModel() 添加
func (is *IndexScope) CreateMany() (res []string, err error) {
defer is.doClear()
is.doGet()
is.doCreateMany()
if is.err != nil {
return nil, is.err
}
res, ok := is.execResult.([]string)
if !ok {
return nil, fmt.Errorf("create success but get index names empty")
}
return res, nil
}
func (is *IndexScope) doCreateOne() {
if len(is.ims) == 0 {
is.err = fmt.Errorf("no such index model, you need to use AddIndexModel() to add one")
return
}
if is.cOpts == nil {
is.cOpts = &options.CreateIndexesOptions{}
}
is.execResult, is.err = is.iv.CreateOne(is.getContext(), is.ims[0], is.cOpts)
is.assertErr()
}
func (is *IndexScope) doCreateMany() {
if len(is.ims) == 0 {
is.err = fmt.Errorf("no such index model, you need to use AddIndexModel() to add some")
return
}
if is.cOpts == nil {
is.cOpts = &options.CreateIndexesOptions{}
}
is.execResult, is.err = is.iv.CreateMany(is.getContext(), is.ims, is.cOpts)
is.assertErr()
}
// AddIndexModels 添加 indexModel
func (is *IndexScope) AddIndexModels(obj ...mongo.IndexModel) *IndexScope {
if len(obj) == 0 {
is.err = fmt.Errorf("add index model empty")
return is
}
is.ims = append(is.ims, obj...)
return is
}

281
insert_scope.go Normal file
View File

@ -0,0 +1,281 @@
package mdbc
import (
"context"
"errors"
"fmt"
"reflect"
"time"
jsoniter "github.com/json-iterator/go"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type insertAction string
const (
insertOne insertAction = "insertOne"
insertMany insertAction = "insertMany"
)
type InsertResult struct {
id interface{}
}
// GetString 获取单条插入的字符串ID
func (i *InsertResult) GetString() string {
vo := reflect.ValueOf(i.id)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() == reflect.String {
return vo.String()
}
return ""
}
// GetListString 获取多条插入的字符串ID
func (i *InsertResult) GetListString() []string {
var list []string
vo := reflect.ValueOf(i.id)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() == reflect.Slice || vo.Kind() == reflect.Array {
for i := 0; i < vo.Len(); i++ {
val := vo.Index(i).Interface()
valVo := reflect.ValueOf(val)
if valVo.Kind() == reflect.Ptr {
valVo = valVo.Elem()
}
if valVo.Kind() == reflect.String {
list = append(list, valVo.String())
}
}
return list
}
return nil
}
type InsertScope struct {
scope *Scope
cw *ctxWrap
err error
filter interface{}
action insertAction
insertObject interface{}
oopts *options.InsertOneOptions
mopts *options.InsertManyOptions
oResult *mongo.InsertOneResult
mResult *mongo.InsertManyResult
}
// SetContext 设置上下文
func (is *InsertScope) SetContext(ctx context.Context) *InsertScope {
if is.cw == nil {
is.cw = &ctxWrap{}
}
return is
}
func (is *InsertScope) getContext() context.Context {
return is.cw.ctx
}
func (is *InsertScope) doClear() {
if is.cw != nil && is.cw.cancel != nil {
is.cw.cancel()
}
is.scope.debug = false
is.scope.execT = 0
}
// SetInsertOneOption 设置插InsertOneOption
func (is *InsertScope) SetInsertOneOption(opts options.InsertOneOptions) *InsertScope {
is.oopts = &opts
return is
}
// SetInsertManyOption 设置InsertManyOption
func (is *InsertScope) SetInsertManyOption(opts options.InsertManyOptions) *InsertScope {
is.mopts = &opts
return is
}
// SetFilter 设置过滤条件
func (is *InsertScope) SetFilter(filter interface{}) *InsertScope {
if filter == nil {
is.filter = bson.M{}
return is
}
v := reflect.ValueOf(filter)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
is.filter = bson.M{}
}
}
is.filter = filter
return is
}
// preCheck 预检查
func (is *InsertScope) preCheck() {
var breakerTTL time.Duration
if is.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if is.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = is.scope.breaker.ttl
}
if is.cw == nil {
is.cw = &ctxWrap{}
}
if is.cw.ctx == nil {
is.cw.ctx, is.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}
func (is *InsertScope) assertErr() {
if is.err == nil {
return
}
if errors.Is(is.err, context.DeadlineExceeded) {
is.err = &ErrRequestBroken
return
}
err, ok := is.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
is.err = &ErrRequestBroken
}
}
func (is *InsertScope) debug() {
if !is.scope.debug && !is.scope.debugWhenError {
return
}
debugger := &Debugger{
collection: is.scope.tableName,
execT: is.scope.execT,
action: is,
}
// 当错误时优先输出
if is.scope.debugWhenError {
if is.err != nil {
debugger.errMsg = is.err.Error()
debugger.ErrorString()
}
return
}
// 所有bug输出
if is.scope.debug {
debugger.String()
}
}
func (is *InsertScope) doString() string {
builder := RegisterTimestampCodec(nil).Build()
switch is.action {
case insertOne:
b, _ := bson.MarshalExtJSONWithRegistry(builder, is.insertObject, true, true)
return fmt.Sprintf("insertOne(%s)", string(b))
case insertMany:
vo := reflect.ValueOf(is.insertObject)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() != reflect.Slice {
panic("insertMany object is not slice")
}
var data []interface{}
for i := 0; i < vo.Len(); i++ {
var body interface{}
b, _ := bson.MarshalExtJSONWithRegistry(builder, vo.Index(i).Interface(), true, true)
_ = jsoniter.Unmarshal(b, &body)
data = append(data, body)
}
b, _ := jsoniter.Marshal(data)
return fmt.Sprintf("insertMany(%s)", string(b))
default:
panic("not support insert type")
}
}
func (is *InsertScope) doInsert(obj interface{}) {
var starTime time.Time
if is.scope.debug {
starTime = time.Now()
}
is.oResult, is.err = db.Collection(is.scope.tableName).InsertOne(is.getContext(), obj, is.oopts)
is.assertErr()
if is.scope.debug {
is.action = insertOne
is.insertObject = obj
is.scope.execT = time.Since(starTime)
is.debug()
}
}
func (is *InsertScope) doManyInsert(obj []interface{}) {
defer is.assertErr()
var starTime time.Time
if is.scope.debug {
starTime = time.Now()
}
is.mResult, is.err = db.Collection(is.scope.tableName).InsertMany(is.getContext(), obj, is.mopts)
is.assertErr()
if is.scope.debug {
is.action = insertMany
is.insertObject = obj
is.scope.execT = time.Since(starTime)
is.debug()
}
}
// One 插入一个对象 返回的id通过InsertResult中的方法进行获取
func (is *InsertScope) One(obj interface{}) (id *InsertResult, err error) {
defer is.doClear()
is.preCheck()
is.doInsert(obj)
if is.err != nil {
return nil, is.err
}
if is.oResult == nil {
return nil, fmt.Errorf("insert success but result empty")
}
return &InsertResult{id: is.oResult.InsertedID}, nil
}
// Many 插入多个对象 返回的id通过InsertResult中的方法进行获取
func (is *InsertScope) Many(obj []interface{}) (ids *InsertResult, err error) {
defer is.doClear()
is.preCheck()
is.doManyInsert(obj)
if is.err != nil {
return nil, is.err
}
if is.mResult == nil {
return nil, fmt.Errorf("insert success but result empty")
}
return &InsertResult{id: is.mResult.InsertedIDs}, nil
}

67
insert_scope_test.go Normal file
View File

@ -0,0 +1,67 @@
package mdbc
import (
"testing"
"github.com/sirupsen/logrus"
)
func TestInsertScope_One(t *testing.T) {
cfg := &Config{
URI: "mongodb://mdbc:mdbc@10.0.0.135:27117/admin",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
Init(client).InitDB(client.Database("mdbc"))
var m = NewModel(&ModelSchedTask{})
var record = ModelSchedTask{}
record.TaskState = uint32(TaskState_TaskStateCompleted)
record.Id = "insertffddddfdfdsdnkodsanfkasdf"
res, err := m.SetDebug(true).Insert().One(&record)
if err != nil {
panic(err)
}
logrus.Infof("res: %+v", res.GetString())
}
func TestInsertScope_Many(t *testing.T) {
cfg := &Config{
URI: "mongodb://admin:admin@10.0.0.135:27017/admin",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
Init(client).InitDB(client.Database("mdbc"))
var m = NewModel(&ModelSchedTask{})
var list = []interface{}{
&ModelSchedTask{
Id: "insertManyaaaajkfnjaksdf",
},
&ModelSchedTask{
Id: "insertManybbbwdsfsadnasjkf",
},
}
res, err := m.SetDebug(true).Insert().Many(list)
if err != nil {
panic(err)
}
logrus.Infof("res: %+v", res.GetListString())
}

3086
mdbc.pb.go Normal file

File diff suppressed because it is too large Load Diff

278
mdbc.proto Normal file
View File

@ -0,0 +1,278 @@
syntax = "proto3";
import "google/protobuf/timestamp.proto";
package mdbc;
option go_package = "./;mdbc";
// @table_name: tb_friends_info
message ModelFriendInfo {
string id = 1; // ID wxid md5
string wechat_id = 2; // ID
// @bson: nick_name
string nickname = 3; //
string wechat_alias = 4; //
string avatar_url = 5; //
string phone = 6; //
string country = 7; //
string province = 8; //
string city = 9; //
int32 sex = 10; // 0 1 2
int64 create_time = 12; //
int64 update_time = 13; //
}
enum AdminType {
// @desc:
AdminTypeNil = 0;
// @desc:
AdminTypeAdmin = 1;
// @desc:
AdminTypeOwner = 2;
}
// @table_name: tb_crm_group_chat
message ModelGroupChat {
string id = 1; // ID
int64 created_at = 2; //
int64 updated_at = 3; //
int64 deleted_at = 4; // 0
string robot_wx_id = 6; // id
string group_wx_id = 7; // id
string owner_wx_id = 8; // id
string group_name = 9; //
uint32 member_count = 10; //
string owner_name = 11; //
string group_avatar_url = 12; //
bool is_watch = 13; //
bool has_been_watch = 14; //
bool is_default_group_name = 15; //
bool in_contact = 16; //
bool disable_invite = 17; // true false
int64 last_sync_at = 20; //
int64 last_sync_member_at = 21; //
string notice = 22; //
int64 qrcode_updated_at = 23; //
string qrcode_url = 24; //
AdminType admin_type = 25; //
}
// @table_name: tb_crm_group_chat_member
message ModelGroupChatMember {
string id = 1; // id
int64 created_at = 2; //
int64 updated_at = 3; //
int64 deleted_at = 4; //
string group_chat_id = 5; // ModelGroupChat ID
string member_wx_id = 6; // id
string member_name = 7; //
string member_avatar = 8; //
string member_alias = 9; //
uint32 member_sex = 10; //
bool is_robot = 11; //
AdminType admin_type = 12; //
int64 last_sync_at = 13; //
}
// @table_name: tb_crm_private_msg_session
message ModelTbPrivateMsgSession {
string id = 1; //ID (md5(id+id))
int32 all = 2; //(:)
int32 read = 3; //
int32 unread = 4; //
int64 last_msg_at = 5; //
int64 last_friend_msg_at = 6; //
string robot_wx_id = 7; //id
string user_wx_id = 8; //id
string last_msg_id = 9; //id
string last_friend_msg_id = 10; //id
}
// @table_name: tb_crm_group_msg_session
message ModelTbGroupMsgSession {
string id = 1; //ID (md5(id+id))
int32 all = 2; //(:)
int32 read = 3; //
int32 unread = 4; //
int64 last_msg_at = 5; //
int64 last_friend_msg_at = 6; //
string robot_wx_id = 7; //id
string user_wx_id = 8; //id
string last_msg_id = 9; //id
string last_friend_msg_id = 10; //id
string last_member_wx_id = 11; //id
}
// @table_name: tb_crm_robot_private_msg
message ModelTbRobotPrivateMsg {
string id = 1; // ID
string bind_id = 3; // id
string robot_wx_id = 4; // id
string user_wx_id = 5; // id
string msg_id = 6; // idid
int32 msg_type = 7; //
int32 send_status = 8; // 012343
int32 direct = 9; // 12
int32 send_error_code = 10; // :-1 ; -2 ; -3
// ; -4 ;
bool content_read = 12; //
int64 created_at = 13; //
int64 updated_at = 14; //
string fail_reason = 15; //
int64 call_back_at = 16; //
int64 cursor = 17; // session的all
int64 send_at = 18; //
int64 expire_at = 19; //
ContentData content_data = 20; //
}
// @table_name: tb_crm_robot_group_msg
message ModelTbRobotGroupMsg {
string id = 1; // ID
string bind_id = 3; // id
string robot_wx_id = 4; // id
string user_wx_id = 5; // id
string msg_id = 6; // idid
int32 msg_type = 7; //
int32 send_status = 8; // 012343
int32 direct = 9; // 12
int32 send_error_code = 10; // :-1 ; -2 ; -3
// ; -4 ;
bool content_read = 12; //
int64 created_at = 13; //
int64 updated_at = 14; //
string fail_reason = 15; //
int64 call_back_at = 16; //
int64 cursor = 17; // session的all
int64 send_at = 18; //
int64 expire_at = 19; //
ContentData content_data = 20; //
string sender_wx_id = 21; // id
}
message ContentData {
string raw_content = 1; // xml数据
string content = 2; // 12 urlamr格式6xml
string share_title = 3; // 5
string share_desc = 4; // 5
string share_url = 5; // 5URL
string file_url = 6; // 3url4Url58urlgif9url
string share_user_name = 7; // 7id
string share_nick_name = 8; // 7
repeated AtMsgItem at_msg_item = 9; // @
int32 wx_msg_type = 10; // : 1 2 3 4 5 6 7
// 8 9 10 11 12
// 13
double file_size = 11; // KB单位
int32 resource_duration = 12; // s
repeated string at_user_name = 13; // at消息
bool is_at_myself = 14; // at我自己 便
}
message AtMsgItem {
int32 SubType = 1; // 01@
string Content = 2; //
string UserName = 3; // @(wx_id)
string NickName = 4; // @
}
enum TaskState {
TaskStateNil = 0; //
TaskStateRunning = 1; //
TaskStateFailed = 2; //退
TaskStateCompleted = 3; //
}
//
// @table_name: tb_crm_sched_task
message ModelSchedTask {
string id = 1; //id
int64 created_at = 2; //
int64 updated_at = 3; //
uint32 task_state = 4; // TaskState
string task_type = 5; //
string req_id = 6; //便 id[: id来查询该记录]
string req_json = 7; //
string rsp_json = 8; // []
string robot_wx_id = 9; //id
google.protobuf.Timestamp expired_at = 10; //
}
// @table_name: tb_robot_friend
message ModelRobotFriend {
string id = 1; // ID id+id md5
string robot_wechat_id = 2; // :ID
string user_wechat_id = 3; // ID,
int64 deleted = 4; // 0 1 2 3
int64 offline_add = 5; // 线
string remark_name = 6; //
string pinyin = 7; //
string pinyin_head = 8; //
int64 delete_time = 9; //
int64 create_time = 10; // :
int64 update_time = 11; //
int64 add_at = 12; //
string crm_phone = 13; // CRM自己设置的好友手机号
}
// @table_name: tb_ws_connect_record
message ModelWsConnectRecord {
string id = 1; // ID wxid md5
string user_id = 2; // id
int64 created_at = 3; //
int64 login_at = 4; //
int64 logout_at = 5; //
string bind_id = 6; // ws绑定的id
google.protobuf.Timestamp expired_at = 10; //
}
// @table_name: tb_robot
message ModelRobot {
// @json: _id
string id = 1; // ID wxid md5
string user_id = 2; // id
string crm_shop_id = 3; // id
string alias_name = 4; //
string nick_name = 5; //
string wechat_id = 6; // ID (wxidxxxxxx)
string wechat_alias = 7; // ID ()
string avatar_url = 8; //
int32 sex = 9; // 0 1 2
string mobile = 10; //
string qrcode = 11; //
int64 status = 12; // PC是否在线 10线 11线 (pc登录流程和其他接口,)
int64 limited = 13; // 0 1
int64 ability_limit = 14; //
int64 init_friend = 15; //
int64 now_friend = 16; //
int64 auto_add_friend = 17; // 0 1
int64 last_login_time = 18; //
int64 last_log_out_time = 19; //
string last_region_code = 20; //
string last_city = 21; //
int64 today_require_time = 22; //
int64 last_require_add_friend_time = 23; //
int64 crm_auto_add_friend = 24; // crm系统自动通过好友 1 0
int64 delete_time = 25; //
int64 create_time = 26; //
int64 update_time = 27; //
int64 log_and_out_time = 28; //
int64 android_status = 29; // Android是否在线 10线 11线
string greet_id = 30; // id
string android_wechat_version = 31; //
uint32 risk_control_group = 33; //
int64 last_pc_login_at = 34; // PC登录时间
int64 last_pc_logout_at = 35; // PC登出时间
int64 last_android_login_at = 36; //
int64 last_android_logout_at = 37; //
string risk_control_task = 38; // 0123456 7 1,2,3
bool open_for_stranger = 39; //
int32 moment_privacy_type = 40; //
string cover_url = 41; // url
string country = 42; //
string province = 43; //
string city = 44; //
string signature = 45; //
}

200
mongo.go Normal file
View File

@ -0,0 +1,200 @@
package mdbc
import (
"context"
"fmt"
"time"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
)
type ClientInit struct {
*mongo.Client
}
var (
ci *ClientInit
)
type Database struct {
*mongo.Database
dbname string
}
type Collection struct {
*mongo.Collection
dbname string
colname string
}
//ConnInit 初始化mongo
func ConnInit(config *Config) (*ClientInit, error) {
if config == nil {
return nil, fmt.Errorf("config nil")
}
if config.URI == "" {
return nil, fmt.Errorf("empty uri")
}
if config.MinPoolSize == 0 {
config.MinPoolSize = 1
}
if config.MaxPoolSize == 0 {
config.MaxPoolSize = 32
}
var timeout time.Duration
if config.ConnTimeout == 0 {
config.ConnTimeout = 10
}
timeout = time.Duration(config.ConnTimeout) * time.Second
if config.ReadPreference == nil {
config.ReadPreference = readpref.PrimaryPreferred()
}
op := options.Client().ApplyURI(config.URI).SetMinPoolSize(config.MinPoolSize).
SetMaxPoolSize(config.MaxPoolSize).SetConnectTimeout(timeout).
SetReadPreference(config.ReadPreference)
if config.RegistryBuilder != nil {
op.SetRegistry(config.RegistryBuilder.Build())
}
c, err := mongo.NewClient(op)
if err != nil {
return nil, err
}
var ctx = context.Background()
err = c.Connect(ctx)
if err != nil {
return nil, err
}
err = c.Ping(ctx, readpref.Primary())
if err != nil {
return nil, err
}
ci = &ClientInit{c}
return ci, nil
}
func (c *ClientInit) Database(dbname string, opts ...*options.DatabaseOptions) *Database {
db := c.Client.Database(dbname, opts...)
return &Database{db, dbname}
}
func (db *Database) Collection(collection string, opts ...*options.CollectionOptions) *Collection {
col := db.Database.Collection(collection, opts...)
return &Collection{col, db.dbname, collection}
}
func (col *Collection) InsertOne(ctx context.Context, document interface{},
opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) {
res, err := col.Collection.InsertOne(ctx, document, opts...)
return res, err
}
func (col *Collection) InsertMany(ctx context.Context, documents []interface{},
opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) {
res, err := col.Collection.InsertMany(ctx, documents, opts...)
return res, err
}
func (col *Collection) DeleteOne(ctx context.Context, filter interface{},
opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) {
res, err := col.Collection.DeleteOne(ctx, filter, opts...)
return res, err
}
func (col *Collection) DeleteMany(ctx context.Context, filter interface{},
opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) {
res, err := col.Collection.DeleteMany(ctx, filter, opts...)
return res, err
}
func (col *Collection) UpdateOne(ctx context.Context, filter interface{}, update interface{},
opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
res, err := col.Collection.UpdateOne(ctx, filter, update, opts...)
return res, err
}
func (col *Collection) UpdateMany(ctx context.Context, filter interface{}, update interface{},
opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
res, err := col.Collection.UpdateMany(ctx, filter, update, opts...)
return res, err
}
func (col *Collection) ReplaceOne(ctx context.Context, filter interface{},
replacement interface{}, opts ...*options.ReplaceOptions) (*mongo.UpdateResult, error) {
res, err := col.Collection.ReplaceOne(ctx, filter, replacement, opts...)
return res, err
}
func (col *Collection) Aggregate(ctx context.Context, pipeline interface{},
opts ...*options.AggregateOptions) (*mongo.Cursor, error) {
res, err := col.Collection.Aggregate(ctx, pipeline, opts...)
return res, err
}
func (col *Collection) CountDocuments(ctx context.Context, filter interface{},
opts ...*options.CountOptions) (int64, error) {
res, err := col.Collection.CountDocuments(ctx, filter, opts...)
return res, err
}
func (col *Collection) Distinct(ctx context.Context, fieldName string, filter interface{},
opts ...*options.DistinctOptions) ([]interface{}, error) {
res, err := col.Collection.Distinct(ctx, fieldName, filter, opts...)
return res, err
}
func (col *Collection) Find(ctx context.Context, filter interface{},
opts ...*options.FindOptions) (*mongo.Cursor, error) {
res, err := col.Collection.Find(ctx, filter, opts...)
return res, err
}
func (col *Collection) FindOne(ctx context.Context, filter interface{},
opts ...*options.FindOneOptions) *mongo.SingleResult {
res := col.Collection.FindOne(ctx, filter, opts...)
return res
}
func (col *Collection) FindOneAndDelete(ctx context.Context, filter interface{},
opts ...*options.FindOneAndDeleteOptions) *mongo.SingleResult {
res := col.Collection.FindOneAndDelete(ctx, filter, opts...)
return res
}
func (col *Collection) FindOneAndReplace(ctx context.Context, filter interface{},
replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *mongo.SingleResult {
res := col.Collection.FindOneAndReplace(ctx, filter, replacement, opts...)
return res
}
func (col *Collection) FindOneAndUpdate(ctx context.Context, filter interface{},
update interface{}, opts ...*options.FindOneAndUpdateOptions) *mongo.SingleResult {
res := col.Collection.FindOneAndUpdate(ctx, filter, update, opts...)
return res
}
func (col *Collection) Watch(ctx context.Context, pipeline interface{},
opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) {
res, err := col.Collection.Watch(ctx, pipeline, opts...)
return res, err
}
func (col *Collection) Indexes(ctx context.Context) mongo.IndexView {
res := col.Collection.Indexes()
return res
}
func (col *Collection) Drop(ctx context.Context) error {
err := col.Collection.Drop(ctx)
return err
}
func (col *Collection) BulkWrite(ctx context.Context, models []mongo.WriteModel,
opts ...*options.BulkWriteOptions) (*mongo.BulkWriteResult, error) {
res, err := col.Collection.BulkWrite(ctx, models, opts...)
return res, err
}

118
new.go Normal file
View File

@ -0,0 +1,118 @@
package mdbc
import (
"errors"
"fmt"
"os"
"reflect"
"runtime"
"strings"
"google.golang.org/protobuf/proto"
nested "github.com/antonfisher/nested-logrus-formatter"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/mongo"
)
func init() {
var getServeDir = func(path string) string {
var run, _ = os.Getwd()
return strings.Replace(path, run, ".", -1)
}
var formatter = &nested.Formatter{
NoColors: false,
HideKeys: true,
TimestampFormat: "2006-01-02 15:04:05",
CallerFirst: true,
CustomCallerFormatter: func(f *runtime.Frame) string {
s := strings.Split(f.Function, ".")
funcName := s[len(s)-1]
return fmt.Sprintf(" [%s:%d][%s()]", getServeDir(f.File), f.Line, funcName)
},
}
logrus.SetFormatter(formatter)
logrus.SetReportCaller(true)
// 注册错误码
register()
}
var (
// db 直连数据库连接句柄
db *Database
// rawClient mongo的client客户端 是db的上级
rawClient *mongo.Client
)
func GetDatabase() *Database {
return db
}
func GetMongodbClient() *mongo.Client {
return rawClient
}
type Model struct {
Type proto.Message // 模型原形
modelKind reflect.Type // 模型类型
modelName string // 模型名称
tableName string // 绑定表名
}
// NewModel 实例化model
func NewModel(msg interface{}) *Scope {
if db == nil {
panic("database nil")
}
if msg == nil {
panic("proto message nil")
}
remsg, ok := msg.(proto.Message)
if !ok {
panic("msg no match proto message")
}
m := &Model{Type: remsg}
typ := reflect.TypeOf(m.Type)
for typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
m.modelKind = typ
m.modelName = string(m.Type.ProtoReflect().Descriptor().FullName())
m.tableName = GetTableName(m.Type)
s := &Scope{
Model: m,
}
return s
}
// InitC 初始化client
type InitC struct{}
// Init 初始化client
func Init(c *ClientInit) *InitC {
rawClient = c.Client
return &InitC{}
}
// InitDB 初始化db
func (i *InitC) InitDB(c *Database) *InitC {
db = c
return &InitC{}
}
// InitDB 兼容以前的初始化方式
func InitDB(c *Database) {
db = c
rawClient = c.Client()
}
// IsRecordNotFound 检测记录是否存在
func IsRecordNotFound(err error) bool {
return errors.Is(err, &ErrRecordNotFound)
}

221
scope.go Normal file
View File

@ -0,0 +1,221 @@
package mdbc
import (
"context"
"time"
"github.com/go-redis/redis/v8"
)
type cacheConfig struct {
enable bool // 是否缓存
ttl time.Duration // 缓存时间
client redis.UniversalClient // 全局缓存redis句柄 [缓存、熔断器]
}
type breakerConfig struct {
ttl time.Duration // 熔断时间
reporter bool // 是否开启熔断告警 开启熔断告警 需要配置全局的redis 否则开启无效
}
// CacheObject 生成的缓存对象
type CacheObject struct {
// Key redis的key
Key string
// Value redis的value 请自行marshal对象value
// 如果value是一个对象 请将该对象实现 encoding.BinaryMarshaler
// 你可以参考 DefaultFindOneCacheFunc 其利用 json.Marshal 实现了对 value 的序列化
Value interface{}
}
// Scope 所有操作基于此结构 自定义结构体 将其匿名以继承其子scope
type Scope struct {
*Model // 模型
cache *cacheConfig // 缓存配置项
breaker *breakerConfig // 熔断配置项
execT time.Duration // 执行时间 后续debug和熔断告警用到
debug bool // 是否开启debug
debugWhenError bool // 当错误时才输出debug信息
}
type ctxWrap struct {
ctx context.Context // 上下文
cancel context.CancelFunc // 上下文取消
}
func (s *Scope) newDefaultCtxWrap() *ctxWrap {
return &ctxWrap{}
}
// SetDebug 是否开启执行语句的日志输出 目前仅支持终端输出
func (s *Scope) SetDebug(debug bool) *Scope {
s.debug = debug
return s
}
// SetDebugError 是否开启仅当错误时打印执行语句的日志 目前仅支持终端输出
func (s *Scope) SetDebugError(debug bool) *Scope {
s.debugWhenError = debug
return s
}
// SetBreaker 设置熔断时间 不配置默认5s熔断
// 此处建议在全局进行配置一次 否则会有覆盖后失效风险
func (s *Scope) SetBreaker(duration time.Duration) *Scope {
if s.breaker == nil {
s.breaker = &breakerConfig{}
}
s.breaker.ttl = duration
return s
}
// SetBreakerReporter 设置开启熔断告警
func (s *Scope) SetBreakerReporter(br bool) *Scope {
if s.breaker == nil {
s.breaker = &breakerConfig{}
}
s.breaker.reporter = br
return s
}
// SetCacheIdle 设置redis句柄 方便查询操作进行缓存
// 此处建议在全局进行配置一次 否则会有覆盖后失效风险
func (s *Scope) SetCacheIdle(cli redis.UniversalClient) *Scope {
if s.cache == nil {
s.cache = &cacheConfig{}
}
s.cache.client = cli
return s
}
// SetCacheExpiredAt 设置redis缓存过期时间 不设置 则不缓存 单位time.Duration
// 此处建议在全局进行配置一次 否则会有覆盖后失效风险
func (s *Scope) SetCacheExpiredAt(t time.Duration) *Scope {
if s.cache == nil {
s.cache = &cacheConfig{}
}
s.cache.ttl = t
s.cache.enable = true
return s
}
// SetCacheForever 设置key永不过期 true永不过期
// 此处建议在全局进行配置一次 否则会有覆盖后失效风险
func (s *Scope) SetCacheForever(b bool) *Scope {
if s.cache == nil {
s.cache = &cacheConfig{}
}
if b {
s.cache.ttl = -1
s.cache.enable = true
}
return s
}
// GetTableName 获取当前连接的表名
func (s *Scope) GetTableName() string {
return s.tableName
}
// check ctx检测和初始化
func (s *Scope) check() {
if s.breaker == nil {
s.breaker = &breakerConfig{}
}
// 设置默认熔断时间 5秒
if s.breaker.ttl == 0 {
s.breaker.ttl = 5 * time.Second
}
}
// RawConn 返回原始的mongoDB 操作资源句柄
func (s *Scope) RawConn() *Collection {
return db.Collection(s.tableName)
}
// Aggregate 聚合操作
func (s *Scope) Aggregate() *AggregateScope {
s.check()
as := &AggregateScope{cw: s.newDefaultCtxWrap(), scope: s}
return as
}
// Find 查询列表操作
func (s *Scope) Find() *FindScope {
s.check()
fs := &FindScope{cw: s.newDefaultCtxWrap(), scope: s}
return fs
}
// FindOne 查询一条操作
func (s *Scope) FindOne() *FindOneScope {
s.check()
fs := &FindOneScope{cw: s.newDefaultCtxWrap(), scope: s}
return fs
}
// Update 更新文档操作
func (s *Scope) Update() *UpdateScope {
s.check()
us := &UpdateScope{cw: s.newDefaultCtxWrap(), scope: s}
return us
}
// Delete 删除文档操作
func (s *Scope) Delete() *DeleteScope {
s.check()
ds := &DeleteScope{cw: s.newDefaultCtxWrap(), scope: s}
return ds
}
// Insert 插入文档操作
func (s *Scope) Insert() *InsertScope {
s.check()
is := &InsertScope{cw: s.newDefaultCtxWrap(), scope: s}
return is
}
// Distinct Distinct操作
func (s *Scope) Distinct() *DistinctScope {
s.check()
ds := &DistinctScope{cw: s.newDefaultCtxWrap(), scope: s}
return ds
}
// Index 索引操作
func (s *Scope) Index() *IndexScope {
s.check()
is := &IndexScope{cw: s.newDefaultCtxWrap(), scope: s}
return is
}
// BulkWrite 批量写操作
func (s *Scope) BulkWrite() *BulkWriteScope {
s.check()
bws := &BulkWriteScope{cw: s.newDefaultCtxWrap(), scope: s}
return bws
}
// Count 计数操作
func (s *Scope) Count() *CountScope {
s.check()
cs := &CountScope{cw: s.newDefaultCtxWrap(), scope: s}
return cs
}
// Drop 集合删除操作
func (s *Scope) Drop() *DropScope {
s.check()
ds := &DropScope{cw: s.newDefaultCtxWrap(), scope: s}
return ds
}
// Transaction 事务操作
func (s *Scope) Transaction() *TransactionScope {
s.check()
tx := &TransactionScope{cw: s.newDefaultCtxWrap(), scope: s}
return tx
}
// Retry mongo失败后的重试机制
func (s *Scope) Retry(since int64) {}

32
scope_test.go Normal file
View File

@ -0,0 +1,32 @@
package mdbc
import (
"context"
"fmt"
"testing"
"time"
)
func TestScope_check(t *testing.T) {
var a Scope
fmt.Println(a.execT)
}
func TestScopeCtx(t *testing.T) {
//var ctx, cancel = context.WithTimeout(context.Background(), time.Second)
//defer cancel()
//time.Sleep(time.Second * 2)
//fmt.Println(ctx.Err())
var c1 = context.Background()
var c2, cancel = context.WithTimeout(c1, time.Second)
fmt.Println(c1, c2)
cancel()
cancel()
fmt.Println("------")
//fmt.Printf("c1: %p, c2: %p", &c1, &c2)
fmt.Println(c2.Err())
}

144
transaction_scope.go Normal file
View File

@ -0,0 +1,144 @@
package mdbc
//
//// trMap 事务集合 key:string(trid) value: mongo.Session
//var trMap = sync.Map{}
//
//// ctxTrIDKey 后续用于检测session与ctx是否一致
//const ctxTrIDKey = "mdbc-tx-id"
//
//// opentracingKey 用于链路追踪时span的存储
//const opentracingKey = "mdbc-tx-span"
//
//// trid 事务ID
//type trid = bson.Raw
//
//func getTrID(id trid) (string, error) {
// raw, err := id.LookupErr("id")
// if err != nil {
// return "", err
// }
// _, uuid := raw.Binary()
// return hex.EncodeToString(uuid), nil
//}
//
//// equalTrID 事务ID校验 检测两个id是否相同
//func equalTrID(a, b trid) bool {
// as, aerr := a.LookupErr("id")
// if aerr != nil {
// return false
// }
// bs, berr := b.LookupErr("id")
// if berr != nil {
// return false
// }
//
// _, auuid := as.Binary()
// _, buuid := bs.Binary()
// return bytes.Equal(auuid, buuid)
//}
//
//// checkSession 检测session是否配置 未配置 将分配一次
//func (tr *TransactionScope) checkSession() {
//
//}
//
//// getSession 设置session
//func (tr *TransactionScope) setSession(s mongo.Session) {
// tr.session.s = s
//}
//
//// checkTransaction 检测当前事务 未开启事务而操作事务是否将触发panic
//func (tr *TransactionScope) checkTransaction() {
//
//}
//
//// setSessionID 设置session id
//func (tr *TransactionScope) setSessionID() {
// if tr.err != nil {
// return
// }
// tr.trID = tr.getSession().ID()
// trMap.Store(tr.GetTxID(), tr.session)
//}
//
//// deleteSessionID 删除session id
//func (tr *TransactionScope) deleteSessionID() {
// if tr.err != nil {
// return
// }
//
// trMap.Delete(tr.GetTxID())
//}
//
//// getSessionID 获取session id
//func (tr *TransactionScope) getSessionID() (mongo.Session, bool) {
// if tr.err != nil {
// return nil, false
// }
//
// res, exist := trMap.Load(tr.GetTxID())
// if !exist {
// return nil, false
// }
//
// return res.(mongo.Session), true
//}
//
//// EndSession 关闭会话
//func (tr *TransactionScope) EndSession() *TransactionScope {
// tr.getSession().EndSession(tr.cw.ctx)
// tr.deleteSessionID()
// return tr
//}
//
//// StartTransaction 开启事务
//// 如果session未设置或未开启 将panic
//func (tr *TransactionScope) StartTransaction() *TransactionScope {
// tr.checkSession()
// tr.err = tr.getSession().StartTransaction(tr.trOpts...)
// tr.hasStartTx = true
// return tr
//}
//
//// GetSession 获取会话 如果未设置 session 将返回 nil
//func (tr *TransactionScope) GetSession() mongo.Session {
// tr.checkSession()
// return tr.getSession()
//}
//
//// GetSessionContext 获取当前会话的上下文
//func (tr *TransactionScope) GetSessionContext() mongo.SessionContext {
// tr.checkSession()
// return tr.session.ctx
//}
//
//// GetCollection 获取当前会话的collection
//func (tr *TransactionScope) GetCollection() *Scope {
// return tr.scope
//}
//
//// GetTxID 获取事务ID的字符串
//func (tr *TransactionScope) GetTxID() string {
// tr.checkSession()
// if tr.trID == nil {
// tr.err = fmt.Errorf("session id empty")
// return ""
// }
//
// id, err := getTrID(tr.trID)
// if err != nil {
// tr.err = err
// return ""
// }
//
// return id
//}
//
//// Error 获取执行的error
//func (tr *TransactionScope) Error() error {
// if tr.err != nil {
// return tr.err
// }
// return nil
//}

61
transaction_scope_hook.go Normal file
View File

@ -0,0 +1,61 @@
package mdbc
//import (
// "context"
// "github.com/opentracing/opentracing-go"
// "github.com/opentracing/opentracing-go/ext"
//)
//// TrFunc 用于执行一次事务用到的执行函数
//// 在该回调函数内不要进行 提交/回滚
//type TrFunc func(tr *TransactionScope) error
//
//// TrHook 事务钩子 对每一个事务都执行该方法而不是对每一个会话
//type TrHook interface {
// // BeforeTransaction 执行事务之前 需要做的操作 当error时不执行事务
// // 其执行时机是在会话创建之后 事务开启之后尚为执行之前
// BeforeTransaction(ctx context.Context, ts *TransactionScope) error
// // AfterTransaction 执行事务之后 需要做的操作 error需要自行处理 不会对提交的事务构成变更
// AfterTransaction(ctx context.Context, ts *TransactionScope) error
//}
//
//// AddHook 添加钩子
//func (tr *TransactionScope) AddHook(th TrHook) *TransactionScope {
// tr.hooks = append(tr.hooks, th)
// return tr
//}
//
//// OpentracingHook 链路追踪钩子 将该事务的id和操作记录起来
//// 当然你可以自己实现 但这里提供一个常规的同用的案例
//type OpentracingHook struct{}
//
//// BeforeTransaction 执行事务之前 需要做的操作 当error时不执行事务
//// 其执行时机是在会话创建之后 事务开启之后尚为执行之前
//func (oh *OpentracingHook) BeforeTransaction(ctx context.Context, ts *TransactionScope) error {
// txID := ts.GetTrID()
// span := trace.ObtainChildSpan(ctx, "mdbc::transaction")
// span.SetTag("txid", txID)
// ctx = context.WithValue(ctx, opentracingKey, span)
// return nil
//}
//
//// AfterTransaction 执行事务之后 需要做的操作 error需要自行处理 不会对提交的事务构成变更
//func (oh *OpentracingHook) AfterTransaction(ctx context.Context, ts *TransactionScope) error {
// spanRaw := ctx.Value(opentracingKey)
// if spanRaw == nil {
// return nil
// }
//
// span, ok := spanRaw.(opentracing.Span)
// if !ok {
// return nil
// }
//
// // 失败了 标记一下
// if ts.Error() != nil {
// ext.Error.Set(span, true)
// }
//
// span.Finish()
// return nil
//}

150
transaction_scope_test.go Normal file
View File

@ -0,0 +1,150 @@
package mdbc
//func TestTransaction(t *testing.T) {
// cfg := &Config{
// URI: "mongodb://10.0.0.135:27117/mdbc",
// MinPoolSize: 32,
// ConnTimeout: 10,
// }
// cfg.RegistryBuilder = RegisterTimestampCodec(nil)
// client, err := ConnInit(cfg)
//
// if err != nil {
// logrus.Fatalf("get err: %+v", err)
// }
// var m = NewModel(&ModelWsConnectRecord{})
//
// ctx := context.Background()
//
// cli := client.Client
// stx, err := cli.StartSession()
// if err != nil {
// panic(err)
// }
//
// // ef6e39c011ee4ed8bd0f82efb0fbccee
// // 80a850d0d8584668b89809e1308a9878
// // 457188bf227341e5a4a6c3d0242d425a
// _, uuid := stx.ID().Lookup("id").Binary()
// logrus.Infof("id: %+v", hex.EncodeToString(uuid))
//
// col := client.Database("mdbc").Collection(m.GetTableName())
//
// if err := stx.StartTransaction(); err != nil {
// panic(err)
// }
//
// logrus.Infof("id: %+v", stx.ID().Lookup("id").String())
//
// err = mongo.WithSession(ctx, stx, func(tx mongo.SessionContext) error {
// logrus.Infof("id: %+v", tx.ID())
// var record ModelWsConnectRecord
// if err := col.FindOne(tx, bson.M{"_id": "87c9c841b078bd213c79b682ffbdbec7"}).Decode(&record); err != nil {
// logrus.Errorf("err: %+v", err)
// panic(err)
// }
//
// logrus.Infof("record: %+v", &record)
//
// record.BindId = "myidsssssssssss"
// if err := col.FindOneAndUpdate(tx, bson.M{"_id": "87c9c841b078bd213c79b682ffbdbec7"}, bson.M{
// "$set": &record,
// }).Err(); err != nil {
// if err := stx.AbortTransaction(tx); err != nil {
// panic(err)
// }
//
// return err
// }
//
// if err := stx.CommitTransaction(tx); err != nil {
// panic(err)
// }
//
// logrus.Infof("id: %+v", tx.ID().Lookup("id").String())
//
// return nil
// })
//
// if err != nil {
// panic(err)
// }
//
// stx.EndSession(ctx)
//}
//
//func TestNewTransaction(t *testing.T) {
// cfg := &Config{
// URI: "mongodb://mdbc:mdbc@10.0.0.135:27117/mdbc",
// MinPoolSize: 32,
// ConnTimeout: 10,
// }
// cfg.RegistryBuilder = RegisterTimestampCodec(nil)
// client, err := ConnInit(cfg)
// client.Database("")
//
// if err != nil {
// logrus.Fatalf("get err: %+v", err)
// }
// var m = NewModel(&ModelWsConnectRecord{})
//
// err = m.Transaction().StartSession().AddHook(&OpentracingHook{}).RollbackOnError(func(tx *TransactionScope) error {
// ctx := tx.getSessionContext()
// coll := tx.scope
// var record ModelWsConnectRecord
// if err := coll.FindOne().SetContext(ctx).SetFilter(bson.M{"_id": "d1ef466f75b9a03256a8c2a3da0409a3"}).Get(&record); err != nil {
// logrus.Errorf("update err: %+v", err)
// return err
// }
// record.BindId = uuid.New().String()
// logrus.Infof("bind_id: %+v", record.BindId)
// if _, err := coll.Update().SetID("d1ef466f75b9a03256a8c2a3da0409a3").One(&record); err != nil {
// logrus.Errorf("update err: %+v", err)
// return err
// }
// return nil
// })
// if err != nil {
// panic(err)
// }
//
// time.Sleep(3 * time.Second)
//}
//
//func TestNewWithTransaction(t *testing.T) {
// cfg := &Config{
// URI: "mongodb://10.0.0.135:27117/mdbc",
// MinPoolSize: 32,
// ConnTimeout: 10,
// }
// cfg.RegistryBuilder = RegisterTimestampCodec(nil)
// client, err := ConnInit(cfg)
// client.Database("")
//
// if err != nil {
// logrus.Fatalf("get err: %+v", err)
// }
// var m = NewModel(&ModelWsConnectRecord{})
// err = m.Transaction().SetContext(context.Background()).StartSession().SingleRollbackOnError(func(tx *TransactionScope) error {
// ctx := tx.getSessionContext()
// coll := tx.scope
// var record ModelWsConnectRecord
// if err := coll.FindOne().SetContext(ctx).SetFilter(bson.M{"_id": "697022b263e2c1528f32a26e704e63d6"}).Get(&record); err != nil {
// logrus.Errorf("get err: %+v", err)
// return err
// }
// record.BindId = uuid.New().String()
// logrus.Infof("bind_id: %+v", record.BindId)
// if res, err := coll.Update().SetContext(ctx).SetID("697022b263e2c1528f32a26e704e63d6").One(&record); err != nil {
// logrus.Errorf("update err: %+v", err)
// return err
// } else if res.MatchedCount != 2 {
// //return errors.New("no match")
// return nil
// }
// return nil
// }).CloseSession().Error()
// if err != nil {
// panic(err)
// }
//}

295
transaction_scope_v2.go Normal file
View File

@ -0,0 +1,295 @@
package mdbc
import (
"context"
"fmt"
"sync"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// trMap 事务集合 key:string(trid) value: mongo.Session
var trMap = sync.Map{}
// trid 事务ID
type trid = bson.Raw
// ctxTrIDKey 后续用于检测session与ctx是否一致
const ctxTrIDKey = "mdbc-tx-id"
// opentracingKey 用于链路追踪时span的存储
const opentracingKey = "mdbc-tx-span"
// TransactionScope 事务支持
// 由于事务不太好从gotk.driver.mongodb里面实现链路追踪 所以在这里单独实现一下opentracing
type TransactionScope struct {
scope *Scope // 能力范围
cw *ctxWrap // 操作上下文 wrap
err error // 错误
trID trid // 事务ID
session *sessionC // 会话
hasStartTx bool // 是否开启事务
sOpts []*options.SessionOptions // 有关会话的隔离级别配置项
trOpts []*options.TransactionOptions // 有关事务的隔离级别配置项
//hooks []TrHook // 事务钩子
}
type sessionC struct {
s mongo.Session // 会话
ctx mongo.SessionContext // 会话上下文
}
// getContext 获取操作上下文而非会话上下文
func (tr *TransactionScope) getContext() context.Context {
return tr.cw.ctx
}
// setContext 设置操作上下文
func (tr *TransactionScope) setContext(ctx context.Context) {
tr.cw.ctx = ctx
}
// checkContext 检测操作上下文是否配置
func (tr *TransactionScope) checkContext() {
if tr.cw == nil || tr.cw.ctx == nil {
tr.err = fmt.Errorf("op context empty")
panic(tr.err)
}
}
func (tr *TransactionScope) getSessionContext() mongo.SessionContext {
return tr.session.ctx
}
// setSessionContext 设置会话session上下文
func (tr *TransactionScope) setSessionContext(ctx mongo.SessionContext) {
tr.session.ctx = ctx
}
// getSession 获取内部session
func (tr *TransactionScope) getSession() mongo.Session {
return tr.session.s
}
func (tr *TransactionScope) getSessionC() *sessionC {
return tr.session
}
func (tr *TransactionScope) checkSessionC() {
if tr.session == nil {
tr.session = new(sessionC)
}
}
func (tr *TransactionScope) checkSessionPanic() {
if tr.session == nil || tr.session.ctx == nil || tr.session.s == nil {
tr.err = fmt.Errorf("session context empty")
panic(tr.err)
}
}
// GetTrID 获取事务ID的字符串
func (tr *TransactionScope) GetTrID() string {
tr.checkSessionPanic()
//if tr.trID == nil {
// tr.err = fmt.Errorf("session id empty")
// return ""
//}
//
//id, err := getTrID(tr.trID)
//if err != nil {
// tr.err = err
// return ""
//}
return "id"
}
// SetContext 设置操作的上下文
func (tr *TransactionScope) SetContext(ctx context.Context) *TransactionScope {
tr.cw.ctx = ctx
return tr
}
// SetSessionContext 设置会话的上下文
func (tr *TransactionScope) SetSessionContext(ctx mongo.SessionContext) *TransactionScope {
tr.checkSessionC()
tr.setSessionContext(ctx)
return tr
}
// NewSessionContext 生成 [会话/事务] 对应的ctx
// 如果 ctx 传递空 将panic
func (tr *TransactionScope) NewSessionContext(ctx context.Context) mongo.SessionContext {
tr.checkSessionC()
if ctx == nil {
panic("ctx nil")
}
sctx := mongo.NewSessionContext(ctx, tr.getSession())
tr.setSessionContext(sctx)
if ctx.Value(ctxTrIDKey) == nil {
ctx = context.WithValue(ctx, ctxTrIDKey, tr.GetTrID())
}
return sctx
}
// SetSessionOptions 设置事务的各项机制 [读写关注 一致性]
func (tr *TransactionScope) SetSessionOptions(opts ...*options.SessionOptions) *TransactionScope {
tr.sOpts = opts
return tr
}
// SetTransactionOptions 设置事务的各项机制 [读写关注 一致性]
func (tr *TransactionScope) SetTransactionOptions(opts ...*options.TransactionOptions) *TransactionScope {
tr.trOpts = opts
return tr
}
// StartSession 开启一个新的session 一个session对应一个sessionID
func (tr *TransactionScope) StartSession() *TransactionScope {
// 先检测 ctx 没有设置外部的ctx 将直接PANIC
tr.checkContext()
// 检测 sessionC 不能让他为空 否则空指针异常
tr.checkSessionC()
// 开启会话
tr.session.s, tr.err = rawClient.StartSession(tr.sOpts...)
if tr.err != nil {
return tr
}
// 检测 session 上下文 未设置将默认指定一个
if tr.getSessionContext() == nil {
tr.session.ctx = tr.NewSessionContext(tr.getContext())
}
// 会话ID 先不用理会
//tr.setSessionID()
return tr
}
// CloseSession 关闭会话
func (tr *TransactionScope) CloseSession() *TransactionScope {
tr.getSession().EndSession(tr.cw.ctx)
//tr.deleteSessionID()
return tr
}
// Commit 提交事务 如果事务不存在、已提交或者已回滚 会触发error
func (tr *TransactionScope) Commit() error {
tr.checkSessionPanic()
tr.err = tr.getSession().CommitTransaction(tr.getSessionContext())
return tr.err
}
// Rollback 回滚事务 如果事务不存在、已提交或者已回滚 会触发error
func (tr *TransactionScope) Rollback() error {
tr.checkSessionPanic()
tr.err = tr.getSession().AbortTransaction(tr.getSessionContext())
return tr.err
}
// StartTransaction 开启事务
// 如果session未设置或未开启 将panic
func (tr *TransactionScope) StartTransaction() *TransactionScope {
tr.checkSessionPanic()
tr.err = tr.getSession().StartTransaction(tr.trOpts...)
tr.hasStartTx = true
return tr
}
// RollbackOnError 基于回调函数执行并提交事务 失败则自动回滚
// 该方法是一个会话对应一个事务 若希望一个会话对应多个事务 请使用 SingleRollbackOnError
// 不要在 回调函数内进行 rollback 和 commit 否则该方法恒为 error
// 注:该方法会在结束后自动结束会话
// 若希望手动结束会话 请使用 SingleRollbackOnError
//func (tr *TransactionScope) RollbackOnError(cb TrFunc) error {
// defer tr.CloseSession()
// if tr.err != nil {
// return tr.err
// }
//
// if !tr.hasStartTx {
// tr.StartTransaction()
// }
// if tr.err != nil {
// return tr.err
// }
//
// // 执行事务前钩子
// var hookIndex int
// var retErr error
// for ; hookIndex < len(tr.hooks); hookIndex++ {
// retErr = tr.hooks[hookIndex].BeforeTransaction(tr.getContext(), tr)
// if retErr != nil {
// return retErr
// }
// }
//
// tr.err = cb(tr)
// if tr.err != nil {
// if err := tr.Rollback(); err != nil {
// return &ErrRollbackTransaction
// }
// return core.CreateErrorWithMsg(ErrOpTransaction, fmt.Sprintf("tx rollback: %+v", tr.err))
// }
//
// tr.err = tr.Commit()
// if tr.err != nil {
// return core.CreateErrorWithMsg(ErrOpTransaction, fmt.Sprintf("commit tx error: %+v", tr.err))
// }
//
// // 执行事务后钩子
// var aHookIndex int
// for ; aHookIndex < len(tr.hooks); aHookIndex++ {
// tr.err = tr.hooks[aHookIndex].AfterTransaction(tr.getContext(), tr)
// if tr.err != nil {
// return tr.err
// }
// }
//
// return nil
//}
// SingleRollbackOnError 基于已有会话 通过链式调用 可以多次单个事务执行
// 注:请手动结束会话 EndSession
//func (tr *TransactionScope) SingleRollbackOnError(cb TrFunc) *TransactionScope {
// if tr.err != nil {
// return tr
// }
//
// tr.checkSessionPanic()
//
// // 执行事务前钩子
// var bHookIndex int
// for ; bHookIndex < len(tr.hooks); bHookIndex++ {
// tr.err = tr.hooks[bHookIndex].BeforeTransaction(tr.getContext(), tr)
// if tr.err != nil {
// return tr
// }
// }
//
// // 该方法会自动提交 失败自动回滚
// _, tr.err = tr.getSession().WithTransaction(tr.getSessionContext(),
// func(sessCtx mongo.SessionContext) (interface{}, error) {
// return nil, cb(tr)
// })
//
// // 执行事务后钩子
// var aHookIndex int
// for ; aHookIndex < len(tr.hooks); aHookIndex++ {
// tr.err = tr.hooks[aHookIndex].AfterTransaction(tr.getContext(), tr)
// if tr.err != nil {
// return tr
// }
// }
//
// return tr
//}
// Error 获取执行的error
func (tr *TransactionScope) Error() error {
if tr.err != nil {
return tr.err
}
return nil
}

9
update_module.sh Executable file
View File

@ -0,0 +1,9 @@
#!/bin/bash
# 该脚本用来自动更新module
go mod tidy
match_required=$(cat go.mod | grep -zoE "\((.*?)\)" | awk -F ' ' '{print $1}' | awk '{if($1>1){print $1}}')
for i in $match_required;do go get -u "$i";done

404
update_scope.go Normal file
View File

@ -0,0 +1,404 @@
package mdbc
import (
"context"
"errors"
"fmt"
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type updateAction string
const (
updateOne updateAction = "updateOne"
updateMany updateAction = "updateMany"
)
type UpdateScope struct {
scope *Scope
cw *ctxWrap
err error
upsert *bool
id interface{}
filter interface{}
action updateAction
updateData interface{}
opts *options.UpdateOptions
result *mongo.UpdateResult
}
// SetContext 设置上下文
func (us *UpdateScope) SetContext(ctx context.Context) *UpdateScope {
if us.cw == nil {
us.cw = &ctxWrap{}
}
us.cw.ctx = ctx
return us
}
func (us *UpdateScope) getContext() context.Context {
return us.cw.ctx
}
// SetUpsert 设置upsert属性
func (us *UpdateScope) SetUpsert(upsert bool) *UpdateScope {
us.upsert = &upsert
return us
}
// SetID 基于ID进行文档更新 这将忽略 filter 只按照ID过滤
func (us *UpdateScope) SetID(id interface{}) *UpdateScope {
us.id = id
return us
}
// SetUpdateOption 设置UpdateOption
func (us *UpdateScope) SetUpdateOption(opts options.UpdateOptions) *UpdateScope {
us.opts = &opts
return us
}
// SetFilter 设置过滤条件
func (us *UpdateScope) SetFilter(filter interface{}) *UpdateScope {
if filter == nil {
us.filter = bson.M{}
return us
}
v := reflect.ValueOf(filter)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
us.filter = bson.M{}
}
}
us.filter = filter
return us
}
func (us *UpdateScope) optionAssembled() {
// 配置项被直接调用重写过
if us.opts != nil {
return
}
us.opts = new(options.UpdateOptions)
if us.upsert != nil {
us.opts.Upsert = us.upsert
}
}
func (us *UpdateScope) assertErr() {
if us.err == nil {
return
}
if errors.Is(us.err, context.DeadlineExceeded) {
us.err = &ErrRequestBroken
return
}
err, ok := us.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
us.err = &ErrRequestBroken
return
}
}
func (us *UpdateScope) debug() {
if !us.scope.debug && !us.scope.debugWhenError {
return
}
debugger := &Debugger{
collection: us.scope.tableName,
execT: us.scope.execT,
action: us,
}
// 当错误时优先输出
if us.scope.debugWhenError {
if us.err != nil {
debugger.errMsg = us.err.Error()
debugger.ErrorString()
}
return
}
// 所有bug输出
if us.scope.debug {
debugger.String()
}
}
func (us *UpdateScope) doReporter() string {
debugger := &Debugger{
collection: us.scope.tableName,
execT: us.scope.execT,
action: us,
}
return debugger.Echo()
}
func (us *UpdateScope) reporter() {
br := &BreakerReporter{
reportTitle: "mdbc::update",
reportMsg: "熔断错误",
reportErrorWrap: us.err,
breakerDo: us,
}
if us.err == &ErrRequestBroken {
br.Report()
}
}
func (us *UpdateScope) doString() string {
builder := RegisterTimestampCodec(nil).Build()
switch us.action {
case updateOne:
filter, _ := bson.MarshalExtJSONWithRegistry(builder, us.filter, true, true)
updateData, _ := bson.MarshalExtJSONWithRegistry(builder, us.updateData, true, true)
return fmt.Sprintf("updateOne(%s, %s)", string(filter), string(updateData))
case updateMany:
filter, _ := bson.MarshalExtJSONWithRegistry(builder, us.filter, true, true)
updateData, _ := bson.MarshalExtJSONWithRegistry(builder, us.updateData, true, true)
return fmt.Sprintf("updateMany(%s, %s)", string(filter), string(updateData))
default:
panic("not support update type")
}
}
// preCheck 预检查
func (us *UpdateScope) preCheck() {
var breakerTTL time.Duration
if us.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if us.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = us.scope.breaker.ttl
}
if us.cw == nil {
us.cw = &ctxWrap{}
}
if us.cw.ctx == nil {
us.cw.ctx, us.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}
func (us *UpdateScope) doUpdate(isMany bool) {
if isMany {
var starTime time.Time
if us.scope.debug {
starTime = time.Now()
}
us.result, us.err = db.Collection(us.scope.tableName).UpdateMany(us.getContext(), us.filter, us.updateData, us.opts)
us.assertErr()
if us.scope.debug {
us.scope.execT = time.Since(starTime)
us.action = updateMany
us.debug()
}
return
}
if us.id != nil {
var starTime time.Time
if us.scope.debug {
starTime = time.Now()
}
us.result, us.err = db.Collection(us.scope.tableName).UpdateByID(us.getContext(), us.id, us.updateData, us.opts)
us.assertErr()
if us.scope.debug {
us.scope.execT = time.Since(starTime)
us.filter = bson.M{"_id": us.id}
us.action = updateOne
us.debug()
}
return
}
var starTime time.Time
if us.scope.debug {
starTime = time.Now()
}
us.result, us.err = db.Collection(us.scope.tableName).UpdateOne(us.getContext(), us.filter, us.updateData, us.opts)
us.assertErr()
if us.scope.debug {
us.scope.execT = time.Since(starTime)
us.action = updateOne
us.debug()
}
}
func (us *UpdateScope) doClear() {
if us.cw != nil && us.cw.cancel != nil {
us.cw.cancel()
}
us.scope.execT = 0
us.scope.debug = false
}
// checkObj 检测需要更新的obj是否正常 没有$set将被注入
func (us *UpdateScope) checkObj(obj interface{}) {
if obj == nil {
us.err = &ErrObjectEmpty
return
}
us.updateData = obj
vo := reflect.ValueOf(obj)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() == reflect.Struct {
us.updateData = bson.M{
"$set": obj,
}
return
}
if vo.Kind() == reflect.Map {
setKey := vo.MapIndex(reflect.ValueOf("$set"))
if !setKey.IsValid() {
us.updateData = bson.M{
"$set": obj,
}
}
return
}
us.err = &ErrUpdateObjectTypeNotSupport
}
// checkMapSetObj 检测需要更新的obj是否正常
func (us *UpdateScope) checkMustMapObj(obj interface{}) {
if obj == nil {
us.err = &ErrObjectEmpty
return
}
us.updateData = obj
vo := reflect.ValueOf(obj)
if vo.Kind() == reflect.Ptr {
vo = vo.Elem()
}
if vo.Kind() == reflect.Map {
return
}
us.err = &ErrUpdateObjectTypeNotSupport
}
// One 单个更新
// obj: 为*struct时将被自动转换为bson.M并注入$set
// 当指定为map时 若未指定$set 会自动注入$set
// 当指定 非 *struct/ map 类型时 将报错
func (us *UpdateScope) One(obj interface{}) (*mongo.UpdateResult, error) {
defer us.doClear()
us.optionAssembled()
us.preCheck()
us.checkObj(obj)
if us.err != nil {
return nil, us.err
}
us.doUpdate(false)
if us.err != nil {
return nil, us.err
}
return us.result, nil
}
// MustMapOne 单个更新 必须传入map并强制更新这个map
func (us *UpdateScope) MustMapOne(obj interface{}) (*mongo.UpdateResult, error) {
defer us.doClear()
us.optionAssembled()
us.preCheck()
us.checkMustMapObj(obj)
if us.err != nil {
return nil, us.err
}
us.doUpdate(false)
if us.err != nil {
return nil, us.err
}
return us.result, nil
}
// RawOne 单个更新 不对更新内容进行校验
func (us *UpdateScope) RawOne(obj interface{}) (*mongo.UpdateResult, error) {
defer us.doClear()
us.optionAssembled()
us.preCheck()
us.updateData = obj
us.doUpdate(false)
if us.err != nil {
return nil, us.err
}
return us.result, nil
}
// Many 批量更新
// obj: 为*struct时将被自动转换为bson.M并注入$set
// 当指定为map时 若未指定$set 会自动注入$set
// 当指定 非 *struct/ map 类型时 将报错
func (us *UpdateScope) Many(obj interface{}) (*mongo.UpdateResult, error) {
defer us.doClear()
us.optionAssembled()
us.checkObj(obj)
if us.err != nil {
return nil, us.err
}
us.doUpdate(true)
if us.err != nil {
return nil, us.err
}
return us.result, nil
}
// MustMapMany 批量更新 必须传入map并强制更新这个map
func (us *UpdateScope) MustMapMany(obj interface{}) (*mongo.UpdateResult, error) {
defer us.doClear()
us.optionAssembled()
us.checkMustMapObj(obj)
if us.err != nil {
return nil, us.err
}
us.doUpdate(true)
if us.err != nil {
return nil, us.err
}
return us.result, nil
}
// RawMany 批量更新 不对更新内容进行校验
func (us *UpdateScope) RawMany(obj interface{}) (*mongo.UpdateResult, error) {
defer us.doClear()
us.optionAssembled()
us.preCheck()
us.updateData = obj
us.doUpdate(true)
if us.err != nil {
return nil, us.err
}
return us.result, nil
}

82
update_scope_test.go Normal file
View File

@ -0,0 +1,82 @@
package mdbc
import (
"context"
"fmt"
"testing"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson"
)
func TestUpdateScope_One(t *testing.T) {
cfg := &Config{
URI: "mongodb://admin:admin@10.0.0.135:27017/admin",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
Init(client).InitDB(client.Database("mdbc"))
var m = NewModel(&ModelSchedTask{})
var record = ModelSchedTask{Id: "0e434c7d5c9eb8b129ce42ce07afef09"}
if err := m.SetDebug(true).FindOne().SetFilter(bson.M{"_id": record.Id}).Get(&record); err != nil {
logrus.Errorf("get err: %+v", err)
}
fmt.Printf("record %+v\n", record)
var ctx = context.Background()
record.TaskState = uint32(TaskState_TaskStateCompleted)
updateData := bson.M{
"task_state": uint32(TaskState_TaskStateFailed),
}
_, err = m.SetDebug(true).Update().SetContext(ctx).SetFilter(bson.M{"_id": record.Id}).One(updateData)
if err != nil {
logrus.Errorf("get err: %+v", err)
}
updateData = bson.M{
"task_state": uint32(TaskState_TaskStateFailed),
}
_, err = m.SetDebug(true).Update().SetFilter(bson.M{"_id": record.Id}).One(updateData)
if err != nil {
logrus.Errorf("get err: %+v", err)
}
}
func TestUpdateScope_Many(t *testing.T) {
cfg := &Config{
URI: "mongodb://mdbc:mdbc@10.0.0.135:27117/admin",
MinPoolSize: 32,
ConnTimeout: 10,
}
cfg.RegistryBuilder = RegisterTimestampCodec(nil)
client, err := ConnInit(cfg)
if err != nil {
logrus.Fatalf("get err: %+v", err)
}
Init(client).InitDB(client.Database("mdbc"))
var m = NewModel(&ModelSchedTask{})
//updateData := bson.M{
// "$inc": bson.M{
// "set_group_name": 1,
// },
//}
res, err := m.SetDebug(true).Update().SetFilter(bson.M{"task_type": "set_group_name"}).MustMapMany(m)
if err != nil {
logrus.Errorf("get err: %+v", err)
panic(err)
}
logrus.Infof("%+v", res)
}

43
utils.go Normal file
View File

@ -0,0 +1,43 @@
package mdbc
import "reflect"
type StructField struct {
StructFieldName string
DbFieldName string
}
type tabler interface {
TableName() string
}
// GetTableName 获取表名
func GetTableName(i interface{}) string {
if p, ok := i.(tabler); ok {
return p.TableName()
}
return ""
}
// isReflectNumber 检测是否是数值类型
func isReflectNumber(obj reflect.Value) bool {
var k = obj.Kind()
if k == reflect.Int || k == reflect.Int8 || k == reflect.Int16 || k == reflect.Int32 ||
k == reflect.Int64 || k == reflect.Uint || k == reflect.Uint8 || k == reflect.Uint16 ||
k == reflect.Uint32 || k == reflect.Uint64 || k == reflect.Float32 || k == reflect.Float64 {
return true
}
return false
}
// getReflectTypeField 获取j指定的字段的field 没有获取到则返回原字段名称
func getBsonTagByReflectTypeField(rt reflect.Type, fieldName string) string {
for i := 0; i < rt.NumField(); i++ {
field := rt.Field(i)
if field.Name != fieldName {
continue
}
return field.Tag.Get("bson")
}
return fieldName
}