package mdbc import ( "context" "errors" "fmt" "reflect" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) type updateAction string const ( updateOne updateAction = "updateOne" updateMany updateAction = "updateMany" ) type UpdateScope struct { scope *Scope cw *ctxWrap err error upsert *bool id interface{} filter interface{} action updateAction updateData interface{} opts *options.UpdateOptions result *mongo.UpdateResult } // SetContext 设置上下文 func (us *UpdateScope) SetContext(ctx context.Context) *UpdateScope { if us.cw == nil { us.cw = &ctxWrap{} } us.cw.ctx = ctx return us } func (us *UpdateScope) getContext() context.Context { return us.cw.ctx } // SetUpsert 设置upsert属性 func (us *UpdateScope) SetUpsert(upsert bool) *UpdateScope { us.upsert = &upsert return us } // SetID 基于ID进行文档更新 这将忽略 filter 只按照ID过滤 func (us *UpdateScope) SetID(id interface{}) *UpdateScope { us.id = id return us } // SetUpdateOption 设置UpdateOption func (us *UpdateScope) SetUpdateOption(opts options.UpdateOptions) *UpdateScope { us.opts = &opts return us } // SetFilter 设置过滤条件 func (us *UpdateScope) SetFilter(filter interface{}) *UpdateScope { if filter == nil { us.filter = bson.M{} return us } v := reflect.ValueOf(filter) if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice { if v.IsNil() { us.filter = bson.M{} } } us.filter = filter return us } func (us *UpdateScope) optionAssembled() { // 配置项被直接调用重写过 if us.opts != nil { return } us.opts = new(options.UpdateOptions) if us.upsert != nil { us.opts.Upsert = us.upsert } } func (us *UpdateScope) assertErr() { if us.err == nil { return } if errors.Is(us.err, context.DeadlineExceeded) { us.err = &ErrRequestBroken return } err, ok := us.err.(mongo.CommandError) if !ok { return } if err.HasErrorMessage(context.DeadlineExceeded.Error()) { us.err = &ErrRequestBroken return } } func (us *UpdateScope) debug() { if !us.scope.debug && !us.scope.debugWhenError { return } debugger := &Debugger{ collection: us.scope.tableName, execT: us.scope.execT, action: us, } // 当错误时优先输出 if us.scope.debugWhenError { if us.err != nil { debugger.errMsg = us.err.Error() debugger.ErrorString() } return } // 所有bug输出 if us.scope.debug { debugger.String() } } func (us *UpdateScope) doReporter() string { debugger := &Debugger{ collection: us.scope.tableName, execT: us.scope.execT, action: us, } return debugger.Echo() } func (us *UpdateScope) reporter() { br := &BreakerReporter{ reportTitle: "mdbc::update", reportMsg: "熔断错误", reportErrorWrap: us.err, breakerDo: us, } if us.err == &ErrRequestBroken { br.Report() } } func (us *UpdateScope) doString() string { builder := RegisterTimestampCodec(nil).Build() switch us.action { case updateOne: filter, _ := bson.MarshalExtJSONWithRegistry(builder, us.filter, true, true) updateData, _ := bson.MarshalExtJSONWithRegistry(builder, us.updateData, true, true) return fmt.Sprintf("updateOne(%s, %s)", string(filter), string(updateData)) case updateMany: filter, _ := bson.MarshalExtJSONWithRegistry(builder, us.filter, true, true) updateData, _ := bson.MarshalExtJSONWithRegistry(builder, us.updateData, true, true) return fmt.Sprintf("updateMany(%s, %s)", string(filter), string(updateData)) default: panic("not support update type") } } // preCheck 预检查 func (us *UpdateScope) preCheck() { var breakerTTL time.Duration if us.scope.breaker == nil { breakerTTL = defaultBreakerTime } else if us.scope.breaker.ttl == 0 { breakerTTL = defaultBreakerTime } else { breakerTTL = us.scope.breaker.ttl } if us.cw == nil { us.cw = &ctxWrap{} } if us.cw.ctx == nil { us.cw.ctx, us.cw.cancel = context.WithTimeout(context.Background(), breakerTTL) } } func (us *UpdateScope) doUpdate(isMany bool) { if isMany { var starTime time.Time if us.scope.debug { starTime = time.Now() } us.result, us.err = db.Collection(us.scope.tableName).UpdateMany(us.getContext(), us.filter, us.updateData, us.opts) us.assertErr() if us.scope.debug { us.scope.execT = time.Since(starTime) us.action = updateMany us.debug() } return } if us.id != nil { var starTime time.Time if us.scope.debug { starTime = time.Now() } us.result, us.err = db.Collection(us.scope.tableName).UpdateByID(us.getContext(), us.id, us.updateData, us.opts) us.assertErr() if us.scope.debug { us.scope.execT = time.Since(starTime) us.filter = bson.M{"_id": us.id} us.action = updateOne us.debug() } return } var starTime time.Time if us.scope.debug { starTime = time.Now() } us.result, us.err = db.Collection(us.scope.tableName).UpdateOne(us.getContext(), us.filter, us.updateData, us.opts) us.assertErr() if us.scope.debug { us.scope.execT = time.Since(starTime) us.action = updateOne us.debug() } } func (us *UpdateScope) doClear() { if us.cw != nil && us.cw.cancel != nil { us.cw.cancel() } us.scope.execT = 0 us.scope.debug = false } // checkObj 检测需要更新的obj是否正常 没有$set将被注入 func (us *UpdateScope) checkObj(obj interface{}) { if obj == nil { us.err = &ErrObjectEmpty return } us.updateData = obj vo := reflect.ValueOf(obj) if vo.Kind() == reflect.Ptr { vo = vo.Elem() } if vo.Kind() == reflect.Struct { us.updateData = bson.M{ "$set": obj, } return } if vo.Kind() == reflect.Map { setKey := vo.MapIndex(reflect.ValueOf("$set")) if !setKey.IsValid() { us.updateData = bson.M{ "$set": obj, } } return } us.err = &ErrUpdateObjectTypeNotSupport } // checkMapSetObj 检测需要更新的obj是否正常 func (us *UpdateScope) checkMustMapObj(obj interface{}) { if obj == nil { us.err = &ErrObjectEmpty return } us.updateData = obj vo := reflect.ValueOf(obj) if vo.Kind() == reflect.Ptr { vo = vo.Elem() } if vo.Kind() == reflect.Map { return } us.err = &ErrUpdateObjectTypeNotSupport } // One 单个更新 // obj: 为*struct时将被自动转换为bson.M并注入$set // 当指定为map时 若未指定$set 会自动注入$set // 当指定 非 *struct/ map 类型时 将报错 func (us *UpdateScope) One(obj interface{}) (*mongo.UpdateResult, error) { defer us.doClear() us.optionAssembled() us.preCheck() us.checkObj(obj) if us.err != nil { return nil, us.err } us.doUpdate(false) if us.err != nil { return nil, us.err } return us.result, nil } // MustMapOne 单个更新 必须传入map并强制更新这个map func (us *UpdateScope) MustMapOne(obj interface{}) (*mongo.UpdateResult, error) { defer us.doClear() us.optionAssembled() us.preCheck() us.checkMustMapObj(obj) if us.err != nil { return nil, us.err } us.doUpdate(false) if us.err != nil { return nil, us.err } return us.result, nil } // RawOne 单个更新 不对更新内容进行校验 func (us *UpdateScope) RawOne(obj interface{}) (*mongo.UpdateResult, error) { defer us.doClear() us.optionAssembled() us.preCheck() us.updateData = obj us.doUpdate(false) if us.err != nil { return nil, us.err } return us.result, nil } // Many 批量更新 // obj: 为*struct时将被自动转换为bson.M并注入$set // 当指定为map时 若未指定$set 会自动注入$set // 当指定 非 *struct/ map 类型时 将报错 func (us *UpdateScope) Many(obj interface{}) (*mongo.UpdateResult, error) { defer us.doClear() us.optionAssembled() us.checkObj(obj) if us.err != nil { return nil, us.err } us.doUpdate(true) if us.err != nil { return nil, us.err } return us.result, nil } // MustMapMany 批量更新 必须传入map并强制更新这个map func (us *UpdateScope) MustMapMany(obj interface{}) (*mongo.UpdateResult, error) { defer us.doClear() us.optionAssembled() us.checkMustMapObj(obj) if us.err != nil { return nil, us.err } us.doUpdate(true) if us.err != nil { return nil, us.err } return us.result, nil } // RawMany 批量更新 不对更新内容进行校验 func (us *UpdateScope) RawMany(obj interface{}) (*mongo.UpdateResult, error) { defer us.doClear() us.optionAssembled() us.preCheck() us.updateData = obj us.doUpdate(true) if us.err != nil { return nil, us.err } return us.result, nil }