mdbc/transaction_scope_v2.go
2022-02-23 16:59:45 +08:00

296 lines
8.0 KiB
Go

package mdbc
import (
"context"
"fmt"
"sync"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// trMap 事务集合 key:string(trid) value: mongo.Session
var trMap = sync.Map{}
// trid 事务ID
type trid = bson.Raw
// ctxTrIDKey 后续用于检测session与ctx是否一致
const ctxTrIDKey = "mdbc-tx-id"
// opentracingKey 用于链路追踪时span的存储
const opentracingKey = "mdbc-tx-span"
// TransactionScope 事务支持
// 由于事务不太好从gotk.driver.mongodb里面实现链路追踪 所以在这里单独实现一下opentracing
type TransactionScope struct {
scope *Scope // 能力范围
cw *ctxWrap // 操作上下文 wrap
err error // 错误
trID trid // 事务ID
session *sessionC // 会话
hasStartTx bool // 是否开启事务
sOpts []*options.SessionOptions // 有关会话的隔离级别配置项
trOpts []*options.TransactionOptions // 有关事务的隔离级别配置项
//hooks []TrHook // 事务钩子
}
type sessionC struct {
s mongo.Session // 会话
ctx mongo.SessionContext // 会话上下文
}
// getContext 获取操作上下文而非会话上下文
func (tr *TransactionScope) getContext() context.Context {
return tr.cw.ctx
}
// setContext 设置操作上下文
func (tr *TransactionScope) setContext(ctx context.Context) {
tr.cw.ctx = ctx
}
// checkContext 检测操作上下文是否配置
func (tr *TransactionScope) checkContext() {
if tr.cw == nil || tr.cw.ctx == nil {
tr.err = fmt.Errorf("op context empty")
panic(tr.err)
}
}
func (tr *TransactionScope) getSessionContext() mongo.SessionContext {
return tr.session.ctx
}
// setSessionContext 设置会话session上下文
func (tr *TransactionScope) setSessionContext(ctx mongo.SessionContext) {
tr.session.ctx = ctx
}
// getSession 获取内部session
func (tr *TransactionScope) getSession() mongo.Session {
return tr.session.s
}
func (tr *TransactionScope) getSessionC() *sessionC {
return tr.session
}
func (tr *TransactionScope) checkSessionC() {
if tr.session == nil {
tr.session = new(sessionC)
}
}
func (tr *TransactionScope) checkSessionPanic() {
if tr.session == nil || tr.session.ctx == nil || tr.session.s == nil {
tr.err = fmt.Errorf("session context empty")
panic(tr.err)
}
}
// GetTrID 获取事务ID的字符串
func (tr *TransactionScope) GetTrID() string {
tr.checkSessionPanic()
//if tr.trID == nil {
// tr.err = fmt.Errorf("session id empty")
// return ""
//}
//
//id, err := getTrID(tr.trID)
//if err != nil {
// tr.err = err
// return ""
//}
return "id"
}
// SetContext 设置操作的上下文
func (tr *TransactionScope) SetContext(ctx context.Context) *TransactionScope {
tr.cw.ctx = ctx
return tr
}
// SetSessionContext 设置会话的上下文
func (tr *TransactionScope) SetSessionContext(ctx mongo.SessionContext) *TransactionScope {
tr.checkSessionC()
tr.setSessionContext(ctx)
return tr
}
// NewSessionContext 生成 [会话/事务] 对应的ctx
// 如果 ctx 传递空 将panic
func (tr *TransactionScope) NewSessionContext(ctx context.Context) mongo.SessionContext {
tr.checkSessionC()
if ctx == nil {
panic("ctx nil")
}
sctx := mongo.NewSessionContext(ctx, tr.getSession())
tr.setSessionContext(sctx)
if ctx.Value(ctxTrIDKey) == nil {
ctx = context.WithValue(ctx, ctxTrIDKey, tr.GetTrID())
}
return sctx
}
// SetSessionOptions 设置事务的各项机制 [读写关注 一致性]
func (tr *TransactionScope) SetSessionOptions(opts ...*options.SessionOptions) *TransactionScope {
tr.sOpts = opts
return tr
}
// SetTransactionOptions 设置事务的各项机制 [读写关注 一致性]
func (tr *TransactionScope) SetTransactionOptions(opts ...*options.TransactionOptions) *TransactionScope {
tr.trOpts = opts
return tr
}
// StartSession 开启一个新的session 一个session对应一个sessionID
func (tr *TransactionScope) StartSession() *TransactionScope {
// 先检测 ctx 没有设置外部的ctx 将直接PANIC
tr.checkContext()
// 检测 sessionC 不能让他为空 否则空指针异常
tr.checkSessionC()
// 开启会话
tr.session.s, tr.err = rawClient.StartSession(tr.sOpts...)
if tr.err != nil {
return tr
}
// 检测 session 上下文 未设置将默认指定一个
if tr.getSessionContext() == nil {
tr.session.ctx = tr.NewSessionContext(tr.getContext())
}
// 会话ID 先不用理会
//tr.setSessionID()
return tr
}
// CloseSession 关闭会话
func (tr *TransactionScope) CloseSession() *TransactionScope {
tr.getSession().EndSession(tr.cw.ctx)
//tr.deleteSessionID()
return tr
}
// Commit 提交事务 如果事务不存在、已提交或者已回滚 会触发error
func (tr *TransactionScope) Commit() error {
tr.checkSessionPanic()
tr.err = tr.getSession().CommitTransaction(tr.getSessionContext())
return tr.err
}
// Rollback 回滚事务 如果事务不存在、已提交或者已回滚 会触发error
func (tr *TransactionScope) Rollback() error {
tr.checkSessionPanic()
tr.err = tr.getSession().AbortTransaction(tr.getSessionContext())
return tr.err
}
// StartTransaction 开启事务
// 如果session未设置或未开启 将panic
func (tr *TransactionScope) StartTransaction() *TransactionScope {
tr.checkSessionPanic()
tr.err = tr.getSession().StartTransaction(tr.trOpts...)
tr.hasStartTx = true
return tr
}
// RollbackOnError 基于回调函数执行并提交事务 失败则自动回滚
// 该方法是一个会话对应一个事务 若希望一个会话对应多个事务 请使用 SingleRollbackOnError
// 不要在 回调函数内进行 rollback 和 commit 否则该方法恒为 error
// 注:该方法会在结束后自动结束会话
// 若希望手动结束会话 请使用 SingleRollbackOnError
//func (tr *TransactionScope) RollbackOnError(cb TrFunc) error {
// defer tr.CloseSession()
// if tr.err != nil {
// return tr.err
// }
//
// if !tr.hasStartTx {
// tr.StartTransaction()
// }
// if tr.err != nil {
// return tr.err
// }
//
// // 执行事务前钩子
// var hookIndex int
// var retErr error
// for ; hookIndex < len(tr.hooks); hookIndex++ {
// retErr = tr.hooks[hookIndex].BeforeTransaction(tr.getContext(), tr)
// if retErr != nil {
// return retErr
// }
// }
//
// tr.err = cb(tr)
// if tr.err != nil {
// if err := tr.Rollback(); err != nil {
// return &ErrRollbackTransaction
// }
// return core.CreateErrorWithMsg(ErrOpTransaction, fmt.Sprintf("tx rollback: %+v", tr.err))
// }
//
// tr.err = tr.Commit()
// if tr.err != nil {
// return core.CreateErrorWithMsg(ErrOpTransaction, fmt.Sprintf("commit tx error: %+v", tr.err))
// }
//
// // 执行事务后钩子
// var aHookIndex int
// for ; aHookIndex < len(tr.hooks); aHookIndex++ {
// tr.err = tr.hooks[aHookIndex].AfterTransaction(tr.getContext(), tr)
// if tr.err != nil {
// return tr.err
// }
// }
//
// return nil
//}
// SingleRollbackOnError 基于已有会话 通过链式调用 可以多次单个事务执行
// 注:请手动结束会话 EndSession
//func (tr *TransactionScope) SingleRollbackOnError(cb TrFunc) *TransactionScope {
// if tr.err != nil {
// return tr
// }
//
// tr.checkSessionPanic()
//
// // 执行事务前钩子
// var bHookIndex int
// for ; bHookIndex < len(tr.hooks); bHookIndex++ {
// tr.err = tr.hooks[bHookIndex].BeforeTransaction(tr.getContext(), tr)
// if tr.err != nil {
// return tr
// }
// }
//
// // 该方法会自动提交 失败自动回滚
// _, tr.err = tr.getSession().WithTransaction(tr.getSessionContext(),
// func(sessCtx mongo.SessionContext) (interface{}, error) {
// return nil, cb(tr)
// })
//
// // 执行事务后钩子
// var aHookIndex int
// for ; aHookIndex < len(tr.hooks); aHookIndex++ {
// tr.err = tr.hooks[aHookIndex].AfterTransaction(tr.getContext(), tr)
// if tr.err != nil {
// return tr
// }
// }
//
// return tr
//}
// Error 获取执行的error
func (tr *TransactionScope) Error() error {
if tr.err != nil {
return tr.err
}
return nil
}