From 108f62aa4d61008fa02e5bafbc470029a6fa8eae Mon Sep 17 00:00:00 2001 From: xuthus5 Date: Sat, 16 Sep 2023 18:09:34 +0800 Subject: [PATCH] feat: support crud --- delete_scope.go | 59 +++++++++++++++++++++ delete_scope_test.go | 36 +++++++++++++ errors.go | 7 +++ find_scope.go | 121 +++++++++++++++++++++++++++++++++++++++++++ find_scope_test.go | 52 +++++++++++++++++++ go.mod | 1 + go.sum | 2 + insert_scope.go | 34 ++++++++---- insert_scope_test.go | 3 +- scope.go | 42 +++++++++------ sdbc.go | 4 ++ update_scope.go | 70 +++++++++++++++++++++++++ update_scope_test.go | 74 ++++++++++++++++++++++++++ 13 files changed, 478 insertions(+), 27 deletions(-) create mode 100644 delete_scope.go create mode 100644 delete_scope_test.go create mode 100644 errors.go create mode 100644 find_scope.go create mode 100644 find_scope_test.go create mode 100644 update_scope.go create mode 100644 update_scope_test.go diff --git a/delete_scope.go b/delete_scope.go new file mode 100644 index 0000000..07d32ef --- /dev/null +++ b/delete_scope.go @@ -0,0 +1,59 @@ +package sdbc + +import ( + "context" + "gorm.io/gorm" +) + +type _delete struct { + scope *scope + sdbc *sdbc + ctx context.Context + where []any + isHard bool +} + +// SetWhere 设置删除条件 +func (d *_delete) SetWhere(query any, args ...any) *_delete { + d.where = append(d.where, query, args) + return d +} + +// SetHardDelete 设置硬删除删除 +func (d *_delete) SetHardDelete(hard bool) *_delete { + d.isHard = hard + return d +} + +// DeleteByID 通过主键删除记录 +func (d *_delete) DeleteByID(id any) error { + db := d.sdbc.client + if d.isHard { + db = db.Unscoped() + } + return db.Delete(d.scope.model.ptr(), id).Error +} + +// Delete 基于条件做删除 无条件删除使用 MustDelete +func (d *_delete) Delete() error { + db := d.sdbc.client + if d.isHard { + db = db.Unscoped() + } + if len(d.where) != 0 { + db = db.Where(d.where[0], d.where[1:]...) + } + return db.Delete(d.scope.model.ptr()).Error +} + +// MustDelete 基于条件做删除 无条件则全部删除 +func (d *_delete) MustDelete() error { + db := d.sdbc.client + if d.isHard { + db = db.Unscoped() + } + if len(d.where) != 0 { + db = db.Where(d.where[0], d.where[1:]...) + } + return db.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(d.scope.model.ptr()).Error +} diff --git a/delete_scope_test.go b/delete_scope_test.go new file mode 100644 index 0000000..f8d1044 --- /dev/null +++ b/delete_scope_test.go @@ -0,0 +1,36 @@ +package sdbc + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func Test__delete_DeleteByID(t *testing.T) { + var driver = NewSDBC(&Config{ + Dbname: "test.db", + MaxIdleConn: 10, + MaxOpenConn: 100, + MaxLifetime: time.Hour, + Debug: true, + }).BindModel(&ModelArticles{}) + var doc = &ModelArticles{ + Title: "hello world1", + CreateTime: time.Now().Unix(), + } + err := driver.Delete().DeleteByID(1) + assert.NoError(t, err) + t.Logf("doc id: %v", doc.Id) +} + +func Test__delete_Delete(t *testing.T) { + var driver = NewSDBC(&Config{ + Dbname: "test.db", + MaxIdleConn: 10, + MaxOpenConn: 100, + MaxLifetime: time.Hour, + Debug: true, + }).BindModel(&ModelArticles{}) + err := driver.Delete().SetWhere("id = ?", 2).Delete() + assert.NoError(t, err) +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..a2ba0d5 --- /dev/null +++ b/errors.go @@ -0,0 +1,7 @@ +package sdbc + +import "errors" + +var ( + ErrorUnavailableType error = errors.New("unavailable type") +) diff --git a/find_scope.go b/find_scope.go new file mode 100644 index 0000000..1297e62 --- /dev/null +++ b/find_scope.go @@ -0,0 +1,121 @@ +package sdbc + +import ( + "context" + "gitter.top/common/goref" + "reflect" +) + +type _find struct { + scope *scope + sdbc *sdbc + ctx context.Context + where []any + selectOpts []interface{} + omitOpts []string + notOpts []interface{} + orOpts []interface{} + orderOpts string + limit int + offset int +} + +// SetWhere 设置更新条件 +func (f *_find) SetWhere(query any, args ...any) *_find { + f.where = append(f.where, query, args) + return f +} + +// SetOr 设置或条件 +func (f *_find) SetOr(query any, args ...any) *_find { + f.orOpts = append(f.orOpts, query, args) + return f +} + +// SetNot 设置非条件 +func (f *_find) SetNot(query any, args ...any) *_find { + f.notOpts = append(f.notOpts, query, args) + return f +} + +// SetSelect 只更新字段 对Updates有效 +func (f *_find) SetSelect(selects ...any) *_find { + f.selectOpts = selects + return f +} + +// SetOmit 忽略字段 对Updates有效 +func (f *_find) SetOmit(omits ...string) *_find { + f.omitOpts = omits + return f +} + +// SetLimit 限制条数 +func (f *_find) SetLimit(limit int) *_find { + f.limit = limit + return f +} + +// SetOffset 跳跃条数 +func (f *_find) SetOffset(offset int) *_find { + f.offset = offset + return f +} + +// SetOrder 跳跃条数 +func (f *_find) SetOrder(order string) *_find { + f.orderOpts = order + return f +} + +// FindById 主键检索 +func (f *_find) FindById(id any, bind any) error { + // 不是结构体指针 返回错误 + if !goref.IsPointer(bind) && !goref.IsBaseStruct(bind) { + return ErrorUnavailableType + } + db := f.sdbc.client + return db.Model(f.scope.model.ptr()).First(bind, id).Error +} + +// FindByIds 主键检索 +func (f *_find) FindByIds(ids any, binds any) error { + // 不是一个id列表 + if !goref.IsBaseList(ids) { + return ErrorUnavailableType + } + // 不是结构体指针 返回错误 + if !goref.IsBaseList(binds) && goref.GetBaseType(binds).Kind() != reflect.Struct { + return ErrorUnavailableType + } + db := f.sdbc.client + return db.Model(f.scope.model.ptr()).Where(ids).Find(binds).Error +} + +// Find 使用where条件查询 +func (f *_find) Find(binds any) error { + // 不是结构体指针 返回错误 + if !goref.IsPointer(binds) || !goref.IsBaseList(binds) || goref.GetBaseType(binds).Kind() != reflect.Struct { + return ErrorUnavailableType + } + db := f.sdbc.client + if len(f.where) != 0 { + db = db.Where(f.where[0], f.where[1:]...) + } + if f.limit != 0 { + db = db.Limit(f.limit) + } + if f.offset != 0 { + db = db.Offset(f.offset) + } + if f.orderOpts != "" { + db = db.Order(f.orderOpts) + } + if len(f.orOpts) != 0 { + db = db.Or(f.orOpts[0], f.orOpts[1:]...) + } + if len(f.notOpts) != 0 { + db = db.Not(f.notOpts[0], f.notOpts[1:]...) + } + return db.Model(f.scope.model.ptr()).Find(binds).Error +} diff --git a/find_scope_test.go b/find_scope_test.go new file mode 100644 index 0000000..e1dcf4e --- /dev/null +++ b/find_scope_test.go @@ -0,0 +1,52 @@ +package sdbc + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func Test__find_FindById(t *testing.T) { + var driver = NewSDBC(&Config{ + Dbname: "test.db", + MaxIdleConn: 10, + MaxOpenConn: 100, + MaxLifetime: time.Hour, + Debug: true, + }).BindModel(&ModelArticles{}) + var doc = &ModelArticles{ + Title: "hello world1", + CreateTime: time.Now().Unix(), + } + err := driver.Find().FindById(5, doc) + assert.NoError(t, err) + t.Logf("doc: %v", doc) +} + +func Test__find_FindByIds(t *testing.T) { + var driver = NewSDBC(&Config{ + Dbname: "test.db", + MaxIdleConn: 10, + MaxOpenConn: 100, + MaxLifetime: time.Hour, + Debug: true, + }).BindModel(&ModelArticles{}) + var docs []ModelArticles + err := driver.Find().FindByIds([]int{4, 5, 6}, &docs) + assert.NoError(t, err) + t.Logf("doc: %v", docs) +} + +func Test__find_Find(t *testing.T) { + var driver = NewSDBC(&Config{ + Dbname: "test.db", + MaxIdleConn: 10, + MaxOpenConn: 100, + MaxLifetime: time.Hour, + Debug: true, + }).BindModel(&ModelArticles{}) + var docs []ModelArticles + err := driver.Find().SetNot("id = ?", 1).SetWhere("create_time = ?", 0).SetOr("update_time = ?", 0).Find(&docs) + assert.NoError(t, err) + t.Logf("doc: %v", docs) +} diff --git a/go.mod b/go.mod index 8ee1eaa..96e29bf 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/glebarez/sqlite v1.9.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 + gitter.top/common/goref v0.0.0-20230916075900-7b64840146ae gitter.top/common/lormatter v0.0.0-20230910075849-28d49dccd03a google.golang.org/protobuf v1.31.0 gorm.io/gorm v1.25.4 diff --git a/go.sum b/go.sum index c966244..cbf936d 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gitter.top/common/goref v0.0.0-20230916075900-7b64840146ae h1:XvVCCs6tnDQTl8JFgjO6tVRX1l5Xsp1LSPPG0TYlVss= +gitter.top/common/goref v0.0.0-20230916075900-7b64840146ae/go.mod h1:9ZCvSyMgyJ6ODKdgvHgnNuRlBhvlzIOBcwhP3Buz5SA= gitter.top/common/lormatter v0.0.0-20230910075849-28d49dccd03a h1:bOn73Ju5BmAn+MXT3bukh7fsn9FXGwqHLM03QixCavY= gitter.top/common/lormatter v0.0.0-20230910075849-28d49dccd03a/go.mod h1:/Zue/gLVDDSvCCRJKytEfpX0LP/JHkIeDzwVE5cA254= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/insert_scope.go b/insert_scope.go index ecb98d2..8e073c1 100644 --- a/insert_scope.go +++ b/insert_scope.go @@ -2,7 +2,7 @@ package sdbc import "context" -type insert struct { +type _insert struct { scope *scope sdbc *sdbc ctx context.Context @@ -10,29 +10,43 @@ type insert struct { omitOpts []string } -func (i *insert) SetContext(ctx context.Context) { +func (i *_insert) SetContext(ctx context.Context) *_insert { i.ctx = ctx + return i } -func (i *insert) SetSelect(selects ...any) { +// SetSelect 只插入字段 +func (i *_insert) SetSelect(selects ...any) *_insert { i.selectOpts = selects + return i } -func (i *insert) SetOmit(omits ...string) { +// SetOmit 忽略字段 +func (i *_insert) SetOmit(omits ...string) *_insert { i.omitOpts = omits + return i } -func (i *insert) InsertOne(doc any) error { +// Insert 插入数据 +func (i *_insert) Insert(docs any) error { db := i.sdbc.client if len(i.selectOpts) != 0 { - db.Select(i.selectOpts[0], i.selectOpts[1:]...) + db = db.Select(i.selectOpts[0], i.selectOpts[1:]...) } if len(i.omitOpts) != 0 { - db.Omit(i.omitOpts...) + db = db.Omit(i.omitOpts...) } - return db.Create(doc).Error + return db.Model(i.scope.model.ptr()).Create(docs).Error } -func (i *insert) InsertMany(docs any) error { - return nil +// InsertBatch 按批插入数据 +func (i *_insert) InsertBatch(docs any) error { + db := i.sdbc.client + if len(i.selectOpts) != 0 { + db = db.Select(i.selectOpts[0], i.selectOpts[1:]...) + } + if len(i.omitOpts) != 0 { + db = db.Omit(i.omitOpts...) + } + return db.Model(i.scope.model.ptr()).CreateInBatches(docs, 1000).Error } diff --git a/insert_scope_test.go b/insert_scope_test.go index 304ba80..682db20 100644 --- a/insert_scope_test.go +++ b/insert_scope_test.go @@ -12,12 +12,13 @@ func Test_insert_InsertOne(t *testing.T) { MaxIdleConn: 10, MaxOpenConn: 100, MaxLifetime: time.Hour, + Debug: true, }).BindModel(&ModelArticles{}) var doc = &ModelArticles{ Title: "hello world1", CreateTime: time.Now().Unix(), } - err := driver.Insert().InsertOne(doc) + err := driver.Insert().Insert(doc) assert.NoError(t, err) t.Logf("doc id: %v", doc.Id) } diff --git a/scope.go b/scope.go index e9881a6..5914a76 100644 --- a/scope.go +++ b/scope.go @@ -14,37 +14,47 @@ type model struct { tableName string // model mapping table name } +// ptr 基于类型获取实例指针 +func (m *model) ptr() any { + return reflect.New(m.modelKind).Interface() +} + type scope struct { *model *sdbc } -func (s *scope) Insert() *insert { - return &insert{ +func (s *scope) Insert() *_insert { + return &_insert{ scope: s, sdbc: s.sdbc, } - } -func (s *scope) Find() { - //TODO implement me - panic("implement me") +func (s *scope) Find() *_find { + return &_find{ + scope: s, + sdbc: s.sdbc, + } } -func (s *scope) Update() { - //TODO implement me - panic("implement me") +func (s *scope) Update() *_update { + return &_update{ + scope: s, + sdbc: s.sdbc, + } } -func (s *scope) Delete() { - //TODO implement me - panic("implement me") +func (s *scope) Delete() *_delete { + return &_delete{ + scope: s, + sdbc: s.sdbc, + } } type Operator interface { - Insert() *insert - Find() - Update() - Delete() + Insert() *_insert + Find() *_find + Update() *_update + Delete() *_delete } diff --git a/sdbc.go b/sdbc.go index 0cbb461..da677e9 100644 --- a/sdbc.go +++ b/sdbc.go @@ -26,6 +26,7 @@ type Config struct { MaxIdleConn int // 最大空闲连接 MaxOpenConn int // 最大连接 MaxLifetime time.Duration // 最大生存时长 + Debug bool // 是否开启debug } type sdbc struct { @@ -45,6 +46,9 @@ func NewSDBC(cfg *Config) *sdbc { if err != nil { logrus.Fatalf("open sqlite file failed: %v", err) } + if cfg.Debug { + driver.client = driver.client.Debug() + } rawDB, err := driver.client.DB() if err != nil { logrus.Fatalf("get sqlite instance failed: %v", err) diff --git a/update_scope.go b/update_scope.go new file mode 100644 index 0000000..ed65db8 --- /dev/null +++ b/update_scope.go @@ -0,0 +1,70 @@ +package sdbc + +import ( + "context" + "gitter.top/common/goref" +) + +type _update struct { + scope *scope + sdbc *sdbc + ctx context.Context + where []any + selectOpts []interface{} + omitOpts []string +} + +// SetWhere 设置更新条件 +func (u *_update) SetWhere(query any, args ...any) *_update { + u.where = append(u.where, query, args) + return u +} + +// SetSelect 只更新字段 对Updates有效 +func (u *_update) SetSelect(selects ...any) *_update { + u.selectOpts = selects + return u +} + +// SetOmit 忽略字段 对Updates有效 +func (u *_update) SetOmit(omits ...string) *_update { + u.omitOpts = omits + return u +} + +// Replace 更新保存所有的字段,即使字段是零值 +func (u *_update) Replace(doc any) error { + // 不是结构体 返回错误 + if !goref.IsBaseStruct(doc) { + return ErrorUnavailableType + } + db := u.sdbc.client + if len(u.where) != 0 { + db = db.Where(u.where[0], u.where[1:]...) + } + return db.Save(doc).Error +} + +// Update 更新单个字段 +func (u *_update) Update(column string, value any) error { + db := u.sdbc.client + if len(u.where) != 0 { + db = db.Where(u.where[0], u.where[1:]...) + } + return db.Model(u.scope.model.ptr()).Update(column, value).Error +} + +// Updates 更新多个字段 +func (u *_update) Updates(m map[string]interface{}) error { + db := u.sdbc.client + if len(u.where) != 0 { + db = db.Where(u.where[0], u.where[1:]...) + } + if len(u.selectOpts) != 0 { + db = db.Select(u.selectOpts[0], u.selectOpts[1:]...) + } + if len(u.omitOpts) != 0 { + db = db.Omit(u.omitOpts...) + } + return db.Model(u.scope.model.ptr()).Updates(m).Error +} diff --git a/update_scope_test.go b/update_scope_test.go new file mode 100644 index 0000000..a0a5d1d --- /dev/null +++ b/update_scope_test.go @@ -0,0 +1,74 @@ +package sdbc + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func Test__update_Replace(t *testing.T) { + var driver = NewSDBC(&Config{ + Dbname: "test.db", + MaxIdleConn: 10, + MaxOpenConn: 100, + MaxLifetime: time.Hour, + Debug: true, + }).BindModel(&ModelArticles{}) + var doc = &ModelArticles{ + Id: 2, + Title: "hello world1", + AvatarUrl: "https://github.com/xuthus5/profile", + Phone: "10086", + CreateTime: time.Now().Unix() - 10000000, + UpdateTime: time.Now().Unix(), + } + err := driver.Update().Replace(doc) + assert.NoError(t, err) + t.Logf("doc id: %v", doc.Id) +} + +func Test__update_Update(t *testing.T) { + var driver = NewSDBC(&Config{ + Dbname: "test.db", + MaxIdleConn: 10, + MaxOpenConn: 100, + MaxLifetime: time.Hour, + Debug: true, + }).BindModel(&ModelArticles{}) + var doc = &ModelArticles{ + Id: 2, + Title: "hello world1", + AvatarUrl: "https://github.com/xuthus5/profile", + Phone: "10086", + CreateTime: time.Now().Unix() - 10000000, + UpdateTime: time.Now().Unix(), + } + err := driver.Update().SetWhere("id = ?", doc.Id).Update("phone", "10010") + assert.NoError(t, err) + t.Logf("doc id: %v", doc.Id) +} + +func Test__update_Updates(t *testing.T) { + var driver = NewSDBC(&Config{ + Dbname: "test.db", + MaxIdleConn: 10, + MaxOpenConn: 100, + MaxLifetime: time.Hour, + Debug: true, + }).BindModel(&ModelArticles{}) + var doc = &ModelArticles{ + Id: 2, + Title: "hello world1", + AvatarUrl: "https://github.com/xuthus5/profile", + Phone: "10086", + CreateTime: time.Now().Unix() - 10000000, + UpdateTime: time.Now().Unix(), + } + err := driver.Update().SetWhere("update_time = ?", 0).SetOmit("create_time", "update_time"). + Updates(map[string]interface{}{ + "title": "hello", + "phone": "110", + }) + assert.NoError(t, err) + t.Logf("doc id: %v", doc.Id) +}