package mdbc import ( "context" "errors" "fmt" "reflect" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) type CountScope struct { scope *Scope cw *ctxWrap err error limit *int64 skip *int64 hint interface{} filter interface{} maxTimeMS *time.Duration collation *options.Collation opts *options.CountOptions result int64 } // SetLimit 设置获取数量 func (cs *CountScope) SetLimit(limit int64) *CountScope { cs.limit = &limit return cs } // SetSkip 设置跳过的数量 func (cs *CountScope) SetSkip(skip int64) *CountScope { cs.skip = &skip return cs } // SetHint 设置hit func (cs *CountScope) SetHint(hint interface{}) *CountScope { if hint == nil { cs.hint = bson.M{} return cs } v := reflect.ValueOf(hint) if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice { if v.IsNil() { cs.hint = bson.M{} } } cs.hint = hint return cs } // SetFilter 设置过滤条件 func (cs *CountScope) SetFilter(filter interface{}) *CountScope { if filter == nil { cs.filter = bson.M{} return cs } v := reflect.ValueOf(filter) if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice { if v.IsNil() { cs.filter = bson.M{} } } cs.filter = filter return cs } // SetMaxTime 设置MaxTime func (cs *CountScope) SetMaxTime(maxTime time.Duration) *CountScope { cs.maxTimeMS = &maxTime return cs } // SetCollation 设置文档 func (cs *CountScope) SetCollation(collation options.Collation) *CountScope { cs.collation = &collation return cs } // SetContext 设置上下文 func (cs *CountScope) SetContext(ctx context.Context) *CountScope { if cs.cw == nil { cs.cw = &ctxWrap{} } cs.cw.ctx = ctx return cs } func (cs *CountScope) doClear() { if cs.cw != nil && cs.cw.cancel != nil { cs.cw.cancel() } cs.scope.execT = 0 cs.scope.debug = false } func (cs *CountScope) getContext() context.Context { return cs.cw.ctx } // Count 执行计数 func (cs *CountScope) Count() (int64, error) { defer cs.doClear() cs.doCount() if cs.err != nil { return 0, cs.err } return cs.result, nil } func (cs *CountScope) optionAssembled() { // 配置项被直接调用重写过 if cs.opts != nil { return } cs.opts = new(options.CountOptions) if cs.skip != nil { cs.opts.Skip = cs.skip } if cs.limit != nil { cs.opts.Limit = cs.limit } if cs.collation != nil { cs.opts.Collation = cs.collation } if cs.hint != nil { cs.opts.Hint = cs.hint } if cs.maxTimeMS != nil { cs.opts.MaxTime = cs.maxTimeMS } } func (cs *CountScope) assertErr() { if cs.err == nil { return } if errors.Is(cs.err, context.DeadlineExceeded) { cs.err = &ErrRequestBroken return } err, ok := cs.err.(mongo.CommandError) if !ok { return } if err.HasErrorMessage(context.DeadlineExceeded.Error()) { cs.err = &ErrRequestBroken return } } func (cs *CountScope) debug() { if !cs.scope.debug { return } debugger := &Debugger{ collection: cs.scope.tableName, execT: cs.scope.execT, action: cs, } debugger.String() } func (cs *CountScope) doString() string { builder := RegisterTimestampCodec(nil).Build() filter, _ := bson.MarshalExtJSONWithRegistry(builder, cs.filter, true, true) return fmt.Sprintf("count(%s)", string(filter)) } func (cs *CountScope) doCount() { defer cs.assertErr() cs.optionAssembled() cs.preCheck() var starTime time.Time if cs.scope.debug { starTime = time.Now() } cs.result, cs.err = db.Collection(cs.scope.tableName).CountDocuments(cs.getContext(), cs.filter, cs.opts) if cs.scope.debug { cs.scope.execT = time.Since(starTime) cs.debug() } } func (cs *CountScope) preCheck() { if cs.filter == nil { cs.filter = bson.M{} } var breakerTTL time.Duration if cs.scope.breaker == nil { breakerTTL = defaultBreakerTime } else if cs.scope.breaker.ttl == 0 { breakerTTL = defaultBreakerTime } else { breakerTTL = cs.scope.breaker.ttl } if cs.cw == nil { cs.cw = &ctxWrap{} } if cs.cw.ctx == nil { cs.cw.ctx, cs.cw.cancel = context.WithTimeout(context.Background(), breakerTTL) } }