mdbc/count_scope.go
2022-02-23 16:59:45 +08:00

218 lines
4.1 KiB
Go

package mdbc
import (
"context"
"errors"
"fmt"
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type CountScope struct {
scope *Scope
cw *ctxWrap
err error
limit *int64
skip *int64
hint interface{}
filter interface{}
maxTimeMS *time.Duration
collation *options.Collation
opts *options.CountOptions
result int64
}
// SetLimit 设置获取数量
func (cs *CountScope) SetLimit(limit int64) *CountScope {
cs.limit = &limit
return cs
}
// SetSkip 设置跳过的数量
func (cs *CountScope) SetSkip(skip int64) *CountScope {
cs.skip = &skip
return cs
}
// SetHint 设置hit
func (cs *CountScope) SetHint(hint interface{}) *CountScope {
if hint == nil {
cs.hint = bson.M{}
return cs
}
v := reflect.ValueOf(hint)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
cs.hint = bson.M{}
}
}
cs.hint = hint
return cs
}
// SetFilter 设置过滤条件
func (cs *CountScope) SetFilter(filter interface{}) *CountScope {
if filter == nil {
cs.filter = bson.M{}
return cs
}
v := reflect.ValueOf(filter)
if v.Kind() == reflect.Ptr || v.Kind() == reflect.Map || v.Kind() == reflect.Slice {
if v.IsNil() {
cs.filter = bson.M{}
}
}
cs.filter = filter
return cs
}
// SetMaxTime 设置MaxTime
func (cs *CountScope) SetMaxTime(maxTime time.Duration) *CountScope {
cs.maxTimeMS = &maxTime
return cs
}
// SetCollation 设置文档
func (cs *CountScope) SetCollation(collation options.Collation) *CountScope {
cs.collation = &collation
return cs
}
// SetContext 设置上下文
func (cs *CountScope) SetContext(ctx context.Context) *CountScope {
if cs.cw == nil {
cs.cw = &ctxWrap{}
}
cs.cw.ctx = ctx
return cs
}
func (cs *CountScope) doClear() {
if cs.cw != nil && cs.cw.cancel != nil {
cs.cw.cancel()
}
cs.scope.execT = 0
cs.scope.debug = false
}
func (cs *CountScope) getContext() context.Context {
return cs.cw.ctx
}
// Count 执行计数
func (cs *CountScope) Count() (int64, error) {
defer cs.doClear()
cs.doCount()
if cs.err != nil {
return 0, cs.err
}
return cs.result, nil
}
func (cs *CountScope) optionAssembled() {
// 配置项被直接调用重写过
if cs.opts != nil {
return
}
cs.opts = new(options.CountOptions)
if cs.skip != nil {
cs.opts.Skip = cs.skip
}
if cs.limit != nil {
cs.opts.Limit = cs.limit
}
if cs.collation != nil {
cs.opts.Collation = cs.collation
}
if cs.hint != nil {
cs.opts.Hint = cs.hint
}
if cs.maxTimeMS != nil {
cs.opts.MaxTime = cs.maxTimeMS
}
}
func (cs *CountScope) assertErr() {
if cs.err == nil {
return
}
if errors.Is(cs.err, context.DeadlineExceeded) {
cs.err = &ErrRequestBroken
return
}
err, ok := cs.err.(mongo.CommandError)
if !ok {
return
}
if err.HasErrorMessage(context.DeadlineExceeded.Error()) {
cs.err = &ErrRequestBroken
return
}
}
func (cs *CountScope) debug() {
if !cs.scope.debug {
return
}
debugger := &Debugger{
collection: cs.scope.tableName,
execT: cs.scope.execT,
action: cs,
}
debugger.String()
}
func (cs *CountScope) doString() string {
builder := RegisterTimestampCodec(nil).Build()
filter, _ := bson.MarshalExtJSONWithRegistry(builder, cs.filter, true, true)
return fmt.Sprintf("count(%s)", string(filter))
}
func (cs *CountScope) doCount() {
defer cs.assertErr()
cs.optionAssembled()
cs.preCheck()
var starTime time.Time
if cs.scope.debug {
starTime = time.Now()
}
cs.result, cs.err = db.Collection(cs.scope.tableName).CountDocuments(cs.getContext(), cs.filter, cs.opts)
if cs.scope.debug {
cs.scope.execT = time.Since(starTime)
cs.debug()
}
}
func (cs *CountScope) preCheck() {
if cs.filter == nil {
cs.filter = bson.M{}
}
var breakerTTL time.Duration
if cs.scope.breaker == nil {
breakerTTL = defaultBreakerTime
} else if cs.scope.breaker.ttl == 0 {
breakerTTL = defaultBreakerTime
} else {
breakerTTL = cs.scope.breaker.ttl
}
if cs.cw == nil {
cs.cw = &ctxWrap{}
}
if cs.cw.ctx == nil {
cs.cw.ctx, cs.cw.cancel = context.WithTimeout(context.Background(), breakerTTL)
}
}