296 lines
8.0 KiB
Go
296 lines
8.0 KiB
Go
package mdbc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
"go.mongodb.org/mongo-driver/mongo/options"
|
|
)
|
|
|
|
// trMap 事务集合 key:string(trid) value: mongo.Session
|
|
var trMap = sync.Map{}
|
|
|
|
// trid 事务ID
|
|
type trid = bson.Raw
|
|
|
|
// ctxTrIDKey 后续用于检测session与ctx是否一致
|
|
const ctxTrIDKey = "mdbc-tx-id"
|
|
|
|
// opentracingKey 用于链路追踪时span的存储
|
|
const opentracingKey = "mdbc-tx-span"
|
|
|
|
// TransactionScope 事务支持
|
|
// 由于事务不太好从gotk.driver.mongodb里面实现链路追踪 所以在这里单独实现一下opentracing
|
|
type TransactionScope struct {
|
|
scope *Scope // 能力范围
|
|
cw *ctxWrap // 操作上下文 wrap
|
|
err error // 错误
|
|
trID trid // 事务ID
|
|
session *sessionC // 会话
|
|
hasStartTx bool // 是否开启事务
|
|
sOpts []*options.SessionOptions // 有关会话的隔离级别配置项
|
|
trOpts []*options.TransactionOptions // 有关事务的隔离级别配置项
|
|
//hooks []TrHook // 事务钩子
|
|
}
|
|
|
|
type sessionC struct {
|
|
s mongo.Session // 会话
|
|
ctx mongo.SessionContext // 会话上下文
|
|
}
|
|
|
|
// getContext 获取操作上下文而非会话上下文
|
|
func (tr *TransactionScope) getContext() context.Context {
|
|
return tr.cw.ctx
|
|
}
|
|
|
|
// setContext 设置操作上下文
|
|
func (tr *TransactionScope) setContext(ctx context.Context) {
|
|
tr.cw.ctx = ctx
|
|
}
|
|
|
|
// checkContext 检测操作上下文是否配置
|
|
func (tr *TransactionScope) checkContext() {
|
|
if tr.cw == nil || tr.cw.ctx == nil {
|
|
tr.err = fmt.Errorf("op context empty")
|
|
panic(tr.err)
|
|
}
|
|
}
|
|
|
|
func (tr *TransactionScope) getSessionContext() mongo.SessionContext {
|
|
return tr.session.ctx
|
|
}
|
|
|
|
// setSessionContext 设置会话session上下文
|
|
func (tr *TransactionScope) setSessionContext(ctx mongo.SessionContext) {
|
|
tr.session.ctx = ctx
|
|
}
|
|
|
|
// getSession 获取内部session
|
|
func (tr *TransactionScope) getSession() mongo.Session {
|
|
return tr.session.s
|
|
}
|
|
|
|
func (tr *TransactionScope) getSessionC() *sessionC {
|
|
return tr.session
|
|
}
|
|
|
|
func (tr *TransactionScope) checkSessionC() {
|
|
if tr.session == nil {
|
|
tr.session = new(sessionC)
|
|
}
|
|
}
|
|
|
|
func (tr *TransactionScope) checkSessionPanic() {
|
|
if tr.session == nil || tr.session.ctx == nil || tr.session.s == nil {
|
|
tr.err = fmt.Errorf("session context empty")
|
|
panic(tr.err)
|
|
}
|
|
}
|
|
|
|
// GetTrID 获取事务ID的字符串
|
|
func (tr *TransactionScope) GetTrID() string {
|
|
tr.checkSessionPanic()
|
|
//if tr.trID == nil {
|
|
// tr.err = fmt.Errorf("session id empty")
|
|
// return ""
|
|
//}
|
|
//
|
|
//id, err := getTrID(tr.trID)
|
|
//if err != nil {
|
|
// tr.err = err
|
|
// return ""
|
|
//}
|
|
|
|
return "id"
|
|
}
|
|
|
|
// SetContext 设置操作的上下文
|
|
func (tr *TransactionScope) SetContext(ctx context.Context) *TransactionScope {
|
|
tr.cw.ctx = ctx
|
|
return tr
|
|
}
|
|
|
|
// SetSessionContext 设置会话的上下文
|
|
func (tr *TransactionScope) SetSessionContext(ctx mongo.SessionContext) *TransactionScope {
|
|
tr.checkSessionC()
|
|
tr.setSessionContext(ctx)
|
|
return tr
|
|
}
|
|
|
|
// NewSessionContext 生成 [会话/事务] 对应的ctx
|
|
// 如果 ctx 传递空 将panic
|
|
func (tr *TransactionScope) NewSessionContext(ctx context.Context) mongo.SessionContext {
|
|
tr.checkSessionC()
|
|
if ctx == nil {
|
|
panic("ctx nil")
|
|
}
|
|
sctx := mongo.NewSessionContext(ctx, tr.getSession())
|
|
tr.setSessionContext(sctx)
|
|
if ctx.Value(ctxTrIDKey) == nil {
|
|
ctx = context.WithValue(ctx, ctxTrIDKey, tr.GetTrID())
|
|
}
|
|
return sctx
|
|
}
|
|
|
|
// SetSessionOptions 设置事务的各项机制 [读写关注 一致性]
|
|
func (tr *TransactionScope) SetSessionOptions(opts ...*options.SessionOptions) *TransactionScope {
|
|
tr.sOpts = opts
|
|
return tr
|
|
}
|
|
|
|
// SetTransactionOptions 设置事务的各项机制 [读写关注 一致性]
|
|
func (tr *TransactionScope) SetTransactionOptions(opts ...*options.TransactionOptions) *TransactionScope {
|
|
tr.trOpts = opts
|
|
return tr
|
|
}
|
|
|
|
// StartSession 开启一个新的session 一个session对应一个sessionID
|
|
func (tr *TransactionScope) StartSession() *TransactionScope {
|
|
// 先检测 ctx 没有设置外部的ctx 将直接PANIC
|
|
tr.checkContext()
|
|
// 检测 sessionC 不能让他为空 否则空指针异常
|
|
tr.checkSessionC()
|
|
// 开启会话
|
|
tr.session.s, tr.err = rawClient.StartSession(tr.sOpts...)
|
|
if tr.err != nil {
|
|
return tr
|
|
}
|
|
// 检测 session 上下文 未设置将默认指定一个
|
|
if tr.getSessionContext() == nil {
|
|
tr.session.ctx = tr.NewSessionContext(tr.getContext())
|
|
}
|
|
// 会话ID 先不用理会
|
|
//tr.setSessionID()
|
|
return tr
|
|
}
|
|
|
|
// CloseSession 关闭会话
|
|
func (tr *TransactionScope) CloseSession() *TransactionScope {
|
|
tr.getSession().EndSession(tr.cw.ctx)
|
|
//tr.deleteSessionID()
|
|
return tr
|
|
}
|
|
|
|
// Commit 提交事务 如果事务不存在、已提交或者已回滚 会触发error
|
|
func (tr *TransactionScope) Commit() error {
|
|
tr.checkSessionPanic()
|
|
tr.err = tr.getSession().CommitTransaction(tr.getSessionContext())
|
|
return tr.err
|
|
}
|
|
|
|
// Rollback 回滚事务 如果事务不存在、已提交或者已回滚 会触发error
|
|
func (tr *TransactionScope) Rollback() error {
|
|
tr.checkSessionPanic()
|
|
tr.err = tr.getSession().AbortTransaction(tr.getSessionContext())
|
|
return tr.err
|
|
}
|
|
|
|
// StartTransaction 开启事务
|
|
// 如果session未设置或未开启 将panic
|
|
func (tr *TransactionScope) StartTransaction() *TransactionScope {
|
|
tr.checkSessionPanic()
|
|
tr.err = tr.getSession().StartTransaction(tr.trOpts...)
|
|
tr.hasStartTx = true
|
|
return tr
|
|
}
|
|
|
|
// RollbackOnError 基于回调函数执行并提交事务 失败则自动回滚
|
|
// 该方法是一个会话对应一个事务 若希望一个会话对应多个事务 请使用 SingleRollbackOnError
|
|
// 不要在 回调函数内进行 rollback 和 commit 否则该方法恒为 error
|
|
// 注:该方法会在结束后自动结束会话
|
|
// 若希望手动结束会话 请使用 SingleRollbackOnError
|
|
//func (tr *TransactionScope) RollbackOnError(cb TrFunc) error {
|
|
// defer tr.CloseSession()
|
|
// if tr.err != nil {
|
|
// return tr.err
|
|
// }
|
|
//
|
|
// if !tr.hasStartTx {
|
|
// tr.StartTransaction()
|
|
// }
|
|
// if tr.err != nil {
|
|
// return tr.err
|
|
// }
|
|
//
|
|
// // 执行事务前钩子
|
|
// var hookIndex int
|
|
// var retErr error
|
|
// for ; hookIndex < len(tr.hooks); hookIndex++ {
|
|
// retErr = tr.hooks[hookIndex].BeforeTransaction(tr.getContext(), tr)
|
|
// if retErr != nil {
|
|
// return retErr
|
|
// }
|
|
// }
|
|
//
|
|
// tr.err = cb(tr)
|
|
// if tr.err != nil {
|
|
// if err := tr.Rollback(); err != nil {
|
|
// return &ErrRollbackTransaction
|
|
// }
|
|
// return core.CreateErrorWithMsg(ErrOpTransaction, fmt.Sprintf("tx rollback: %+v", tr.err))
|
|
// }
|
|
//
|
|
// tr.err = tr.Commit()
|
|
// if tr.err != nil {
|
|
// return core.CreateErrorWithMsg(ErrOpTransaction, fmt.Sprintf("commit tx error: %+v", tr.err))
|
|
// }
|
|
//
|
|
// // 执行事务后钩子
|
|
// var aHookIndex int
|
|
// for ; aHookIndex < len(tr.hooks); aHookIndex++ {
|
|
// tr.err = tr.hooks[aHookIndex].AfterTransaction(tr.getContext(), tr)
|
|
// if tr.err != nil {
|
|
// return tr.err
|
|
// }
|
|
// }
|
|
//
|
|
// return nil
|
|
//}
|
|
|
|
// SingleRollbackOnError 基于已有会话 通过链式调用 可以多次单个事务执行
|
|
// 注:请手动结束会话 EndSession
|
|
//func (tr *TransactionScope) SingleRollbackOnError(cb TrFunc) *TransactionScope {
|
|
// if tr.err != nil {
|
|
// return tr
|
|
// }
|
|
//
|
|
// tr.checkSessionPanic()
|
|
//
|
|
// // 执行事务前钩子
|
|
// var bHookIndex int
|
|
// for ; bHookIndex < len(tr.hooks); bHookIndex++ {
|
|
// tr.err = tr.hooks[bHookIndex].BeforeTransaction(tr.getContext(), tr)
|
|
// if tr.err != nil {
|
|
// return tr
|
|
// }
|
|
// }
|
|
//
|
|
// // 该方法会自动提交 失败自动回滚
|
|
// _, tr.err = tr.getSession().WithTransaction(tr.getSessionContext(),
|
|
// func(sessCtx mongo.SessionContext) (interface{}, error) {
|
|
// return nil, cb(tr)
|
|
// })
|
|
//
|
|
// // 执行事务后钩子
|
|
// var aHookIndex int
|
|
// for ; aHookIndex < len(tr.hooks); aHookIndex++ {
|
|
// tr.err = tr.hooks[aHookIndex].AfterTransaction(tr.getContext(), tr)
|
|
// if tr.err != nil {
|
|
// return tr
|
|
// }
|
|
// }
|
|
//
|
|
// return tr
|
|
//}
|
|
|
|
// Error 获取执行的error
|
|
func (tr *TransactionScope) Error() error {
|
|
if tr.err != nil {
|
|
return tr.err
|
|
}
|
|
return nil
|
|
}
|