package mdbc import ( "context" "fmt" "sync" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) // trMap 事务集合 key:string(trid) value: mongo.Session var trMap = sync.Map{} // trid 事务ID type trid = bson.Raw // ctxTrIDKey 后续用于检测session与ctx是否一致 const ctxTrIDKey = "mdbc-tx-id" // opentracingKey 用于链路追踪时span的存储 const opentracingKey = "mdbc-tx-span" // TransactionScope 事务支持 // 由于事务不太好从gotk.driver.mongodb里面实现链路追踪 所以在这里单独实现一下opentracing type TransactionScope struct { scope *Scope // 能力范围 cw *ctxWrap // 操作上下文 wrap err error // 错误 trID trid // 事务ID session *sessionC // 会话 hasStartTx bool // 是否开启事务 sOpts []*options.SessionOptions // 有关会话的隔离级别配置项 trOpts []*options.TransactionOptions // 有关事务的隔离级别配置项 //hooks []TrHook // 事务钩子 } type sessionC struct { s mongo.Session // 会话 ctx mongo.SessionContext // 会话上下文 } // getContext 获取操作上下文而非会话上下文 func (tr *TransactionScope) getContext() context.Context { return tr.cw.ctx } // setContext 设置操作上下文 func (tr *TransactionScope) setContext(ctx context.Context) { tr.cw.ctx = ctx } // checkContext 检测操作上下文是否配置 func (tr *TransactionScope) checkContext() { if tr.cw == nil || tr.cw.ctx == nil { tr.err = fmt.Errorf("op context empty") panic(tr.err) } } func (tr *TransactionScope) getSessionContext() mongo.SessionContext { return tr.session.ctx } // setSessionContext 设置会话session上下文 func (tr *TransactionScope) setSessionContext(ctx mongo.SessionContext) { tr.session.ctx = ctx } // getSession 获取内部session func (tr *TransactionScope) getSession() mongo.Session { return tr.session.s } func (tr *TransactionScope) getSessionC() *sessionC { return tr.session } func (tr *TransactionScope) checkSessionC() { if tr.session == nil { tr.session = new(sessionC) } } func (tr *TransactionScope) checkSessionPanic() { if tr.session == nil || tr.session.ctx == nil || tr.session.s == nil { tr.err = fmt.Errorf("session context empty") panic(tr.err) } } // GetTrID 获取事务ID的字符串 func (tr *TransactionScope) GetTrID() string { tr.checkSessionPanic() //if tr.trID == nil { // tr.err = fmt.Errorf("session id empty") // return "" //} // //id, err := getTrID(tr.trID) //if err != nil { // tr.err = err // return "" //} return "id" } // SetContext 设置操作的上下文 func (tr *TransactionScope) SetContext(ctx context.Context) *TransactionScope { tr.cw.ctx = ctx return tr } // SetSessionContext 设置会话的上下文 func (tr *TransactionScope) SetSessionContext(ctx mongo.SessionContext) *TransactionScope { tr.checkSessionC() tr.setSessionContext(ctx) return tr } // NewSessionContext 生成 [会话/事务] 对应的ctx // 如果 ctx 传递空 将panic func (tr *TransactionScope) NewSessionContext(ctx context.Context) mongo.SessionContext { tr.checkSessionC() if ctx == nil { panic("ctx nil") } sctx := mongo.NewSessionContext(ctx, tr.getSession()) tr.setSessionContext(sctx) if ctx.Value(ctxTrIDKey) == nil { ctx = context.WithValue(ctx, ctxTrIDKey, tr.GetTrID()) } return sctx } // SetSessionOptions 设置事务的各项机制 [读写关注 一致性] func (tr *TransactionScope) SetSessionOptions(opts ...*options.SessionOptions) *TransactionScope { tr.sOpts = opts return tr } // SetTransactionOptions 设置事务的各项机制 [读写关注 一致性] func (tr *TransactionScope) SetTransactionOptions(opts ...*options.TransactionOptions) *TransactionScope { tr.trOpts = opts return tr } // StartSession 开启一个新的session 一个session对应一个sessionID func (tr *TransactionScope) StartSession() *TransactionScope { // 先检测 ctx 没有设置外部的ctx 将直接PANIC tr.checkContext() // 检测 sessionC 不能让他为空 否则空指针异常 tr.checkSessionC() // 开启会话 tr.session.s, tr.err = rawClient.StartSession(tr.sOpts...) if tr.err != nil { return tr } // 检测 session 上下文 未设置将默认指定一个 if tr.getSessionContext() == nil { tr.session.ctx = tr.NewSessionContext(tr.getContext()) } // 会话ID 先不用理会 //tr.setSessionID() return tr } // CloseSession 关闭会话 func (tr *TransactionScope) CloseSession() *TransactionScope { tr.getSession().EndSession(tr.cw.ctx) //tr.deleteSessionID() return tr } // Commit 提交事务 如果事务不存在、已提交或者已回滚 会触发error func (tr *TransactionScope) Commit() error { tr.checkSessionPanic() tr.err = tr.getSession().CommitTransaction(tr.getSessionContext()) return tr.err } // Rollback 回滚事务 如果事务不存在、已提交或者已回滚 会触发error func (tr *TransactionScope) Rollback() error { tr.checkSessionPanic() tr.err = tr.getSession().AbortTransaction(tr.getSessionContext()) return tr.err } // StartTransaction 开启事务 // 如果session未设置或未开启 将panic func (tr *TransactionScope) StartTransaction() *TransactionScope { tr.checkSessionPanic() tr.err = tr.getSession().StartTransaction(tr.trOpts...) tr.hasStartTx = true return tr } // RollbackOnError 基于回调函数执行并提交事务 失败则自动回滚 // 该方法是一个会话对应一个事务 若希望一个会话对应多个事务 请使用 SingleRollbackOnError // 不要在 回调函数内进行 rollback 和 commit 否则该方法恒为 error // 注:该方法会在结束后自动结束会话 // 若希望手动结束会话 请使用 SingleRollbackOnError //func (tr *TransactionScope) RollbackOnError(cb TrFunc) error { // defer tr.CloseSession() // if tr.err != nil { // return tr.err // } // // if !tr.hasStartTx { // tr.StartTransaction() // } // if tr.err != nil { // return tr.err // } // // // 执行事务前钩子 // var hookIndex int // var retErr error // for ; hookIndex < len(tr.hooks); hookIndex++ { // retErr = tr.hooks[hookIndex].BeforeTransaction(tr.getContext(), tr) // if retErr != nil { // return retErr // } // } // // tr.err = cb(tr) // if tr.err != nil { // if err := tr.Rollback(); err != nil { // return &ErrRollbackTransaction // } // return core.CreateErrorWithMsg(ErrOpTransaction, fmt.Sprintf("tx rollback: %+v", tr.err)) // } // // tr.err = tr.Commit() // if tr.err != nil { // return core.CreateErrorWithMsg(ErrOpTransaction, fmt.Sprintf("commit tx error: %+v", tr.err)) // } // // // 执行事务后钩子 // var aHookIndex int // for ; aHookIndex < len(tr.hooks); aHookIndex++ { // tr.err = tr.hooks[aHookIndex].AfterTransaction(tr.getContext(), tr) // if tr.err != nil { // return tr.err // } // } // // return nil //} // SingleRollbackOnError 基于已有会话 通过链式调用 可以多次单个事务执行 // 注:请手动结束会话 EndSession //func (tr *TransactionScope) SingleRollbackOnError(cb TrFunc) *TransactionScope { // if tr.err != nil { // return tr // } // // tr.checkSessionPanic() // // // 执行事务前钩子 // var bHookIndex int // for ; bHookIndex < len(tr.hooks); bHookIndex++ { // tr.err = tr.hooks[bHookIndex].BeforeTransaction(tr.getContext(), tr) // if tr.err != nil { // return tr // } // } // // // 该方法会自动提交 失败自动回滚 // _, tr.err = tr.getSession().WithTransaction(tr.getSessionContext(), // func(sessCtx mongo.SessionContext) (interface{}, error) { // return nil, cb(tr) // }) // // // 执行事务后钩子 // var aHookIndex int // for ; aHookIndex < len(tr.hooks); aHookIndex++ { // tr.err = tr.hooks[aHookIndex].AfterTransaction(tr.getContext(), tr) // if tr.err != nil { // return tr // } // } // // return tr //} // Error 获取执行的error func (tr *TransactionScope) Error() error { if tr.err != nil { return tr.err } return nil }