package mdbc import ( "context" "errors" "fmt" "reflect" "time" jsoniter "github.com/json-iterator/go" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) type insertAction string const ( insertOne insertAction = "insertOne" insertMany insertAction = "insertMany" ) type InsertResult struct { id interface{} } // GetString 获取单条插入的字符串ID func (i *InsertResult) GetString() string { vo := reflect.ValueOf(i.id) if vo.Kind() == reflect.Ptr { vo = vo.Elem() } if vo.Kind() == reflect.String { return vo.String() } return "" } // GetListString 获取多条插入的字符串ID func (i *InsertResult) GetListString() []string { var list []string vo := reflect.ValueOf(i.id) if vo.Kind() == reflect.Ptr { vo = vo.Elem() } if vo.Kind() == reflect.Slice || vo.Kind() == reflect.Array { for i := 0; i < vo.Len(); i++ { val := vo.Index(i).Interface() valVo := reflect.ValueOf(val) if valVo.Kind() == reflect.Ptr { valVo = valVo.Elem() } if valVo.Kind() == reflect.String { list = append(list, valVo.String()) } } return list } return nil } type InsertScope struct { scope *Scope cw *ctxWrap err error filter interface{} action insertAction insertObject interface{} oopts *options.InsertOneOptions mopts *options.InsertManyOptions oResult *mongo.InsertOneResult mResult *mongo.InsertManyResult } // SetContext 设置上下文 func (is *InsertScope) SetContext(ctx context.Context) *InsertScope { if is.cw == nil { is.cw = &ctxWrap{} } return is } func (is *InsertScope) getContext() context.Context { return is.cw.ctx } func (is *InsertScope) doClear() { if is.cw != nil && is.cw.cancel != nil { is.cw.cancel() } is.scope.debug = false is.scope.execT = 0 } // SetInsertOneOption 设置插InsertOneOption func (is *InsertScope) SetInsertOneOption(opts options.InsertOneOptions) *InsertScope { is.oopts = &opts return is } // SetInsertManyOption 设置InsertManyOption func (is *InsertScope) SetInsertManyOption(opts options.InsertManyOptions) *InsertScope { is.mopts = &opts return is } // SetFilter 设置过滤条件 func (is *InsertScope) SetFilter(filter interface{}) *InsertScope { if filter == nil { is.filter = bson.M{} return is } v := reflect.ValueOf(filter) if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice { if v.IsNil() { is.filter = bson.M{} } } is.filter = filter return is } // preCheck 预检查 func (is *InsertScope) preCheck() { var breakerTTL time.Duration if is.scope.breaker == nil { breakerTTL = defaultBreakerTime } else if is.scope.breaker.ttl == 0 { breakerTTL = defaultBreakerTime } else { breakerTTL = is.scope.breaker.ttl } if is.cw == nil { is.cw = &ctxWrap{} } if is.cw.ctx == nil { is.cw.ctx, is.cw.cancel = context.WithTimeout(context.Background(), breakerTTL) } } func (is *InsertScope) assertErr() { if is.err == nil { return } if errors.Is(is.err, context.DeadlineExceeded) { is.err = &ErrRequestBroken return } err, ok := is.err.(mongo.CommandError) if !ok { return } if err.HasErrorMessage(context.DeadlineExceeded.Error()) { is.err = &ErrRequestBroken } } func (is *InsertScope) debug() { if !is.scope.debug && !is.scope.debugWhenError { return } debugger := &Debugger{ collection: is.scope.tableName, execT: is.scope.execT, action: is, } // 当错误时优先输出 if is.scope.debugWhenError { if is.err != nil { debugger.errMsg = is.err.Error() debugger.ErrorString() } return } // 所有bug输出 if is.scope.debug { debugger.String() } } func (is *InsertScope) doString() string { builder := RegisterTimestampCodec(nil).Build() switch is.action { case insertOne: b, _ := bson.MarshalExtJSONWithRegistry(builder, is.insertObject, true, true) return fmt.Sprintf("insertOne(%s)", string(b)) case insertMany: vo := reflect.ValueOf(is.insertObject) if vo.Kind() == reflect.Ptr { vo = vo.Elem() } if vo.Kind() != reflect.Slice { panic("insertMany object is not slice") } var data []interface{} for i := 0; i < vo.Len(); i++ { var body interface{} b, _ := bson.MarshalExtJSONWithRegistry(builder, vo.Index(i).Interface(), true, true) _ = jsoniter.Unmarshal(b, &body) data = append(data, body) } b, _ := jsoniter.Marshal(data) return fmt.Sprintf("insertMany(%s)", string(b)) default: panic("not support insert type") } } func (is *InsertScope) doInsert(obj interface{}) { var starTime time.Time if is.scope.debug { starTime = time.Now() } is.oResult, is.err = db.Collection(is.scope.tableName).InsertOne(is.getContext(), obj, is.oopts) is.assertErr() if is.scope.debug { is.action = insertOne is.insertObject = obj is.scope.execT = time.Since(starTime) is.debug() } } func (is *InsertScope) doManyInsert(obj []interface{}) { defer is.assertErr() var starTime time.Time if is.scope.debug { starTime = time.Now() } is.mResult, is.err = db.Collection(is.scope.tableName).InsertMany(is.getContext(), obj, is.mopts) is.assertErr() if is.scope.debug { is.action = insertMany is.insertObject = obj is.scope.execT = time.Since(starTime) is.debug() } } // One 插入一个对象 返回的id通过InsertResult中的方法进行获取 func (is *InsertScope) One(obj interface{}) (id *InsertResult, err error) { defer is.doClear() is.preCheck() is.doInsert(obj) if is.err != nil { return nil, is.err } if is.oResult == nil { return nil, fmt.Errorf("insert success but result empty") } return &InsertResult{id: is.oResult.InsertedID}, nil } // Many 插入多个对象 返回的id通过InsertResult中的方法进行获取 func (is *InsertScope) Many(obj []interface{}) (ids *InsertResult, err error) { defer is.doClear() is.preCheck() is.doManyInsert(obj) if is.err != nil { return nil, is.err } if is.mResult == nil { return nil, fmt.Errorf("insert success but result empty") } return &InsertResult{id: is.mResult.InsertedIDs}, nil }