mirror of
https://github.com/muety/wakapi.git
synced 2025-12-05 22:20:24 -08:00
chore: add check for sqlite cascades before changing user id
This commit is contained in:
File diff suppressed because it is too large
Load Diff
1
main.go
1
main.go
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"embed"
|
||||
"flag"
|
||||
"github.com/muety/wakapi/models"
|
||||
"io/fs"
|
||||
"log"
|
||||
"log/slog"
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
)
|
||||
|
||||
type AliasRepositoryMock struct {
|
||||
BaseRepositoryMock
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
|
||||
24
mocks/base_repository.go
Normal file
24
mocks/base_repository.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type BaseRepositoryMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *BaseRepositoryMock) GetDialector() string {
|
||||
args := m.Called()
|
||||
return args.Get(0).(string)
|
||||
}
|
||||
|
||||
func (m *BaseRepositoryMock) GetTableDDLMysql(s string) (string, error) {
|
||||
args := m.Called(s)
|
||||
return args.Get(0).(string), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *BaseRepositoryMock) GetTableDDLSqlite(s string) (string, error) {
|
||||
args := m.Called(s)
|
||||
return args.Get(0).(string), args.Error(1)
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type SummaryRepositoryMock struct {
|
||||
BaseRepositoryMock
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
)
|
||||
|
||||
type AliasRepository struct {
|
||||
db *gorm.DB
|
||||
BaseRepository
|
||||
}
|
||||
|
||||
func NewAliasRepository(db *gorm.DB) *AliasRepository {
|
||||
return &AliasRepository{db: db}
|
||||
return &AliasRepository{BaseRepository: NewBaseRepository(db)}
|
||||
}
|
||||
|
||||
func (r *AliasRepository) GetAll() ([]*models.Alias, error) {
|
||||
|
||||
40
repositories/base.go
Normal file
40
repositories/base.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type BaseRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewBaseRepository(db *gorm.DB) BaseRepository {
|
||||
return BaseRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *BaseRepository) GetDialector() string {
|
||||
return r.db.Dialector.Name()
|
||||
}
|
||||
|
||||
func (r *BaseRepository) GetTableDDLMysql(tableName string) (result string, err error) {
|
||||
if dialector := r.GetDialector(); dialector == "sqlite" || dialector == "sqlite3" {
|
||||
err = r.db.Raw("show create table ?", tableName).Scan(&result).Error
|
||||
} else {
|
||||
err = errors.New("not a mysql database")
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (r *BaseRepository) GetTableDDLSqlite(tableName string) (result string, err error) {
|
||||
if dialector := r.GetDialector(); dialector == "sqlite" || dialector == "sqlite3" {
|
||||
err = r.db.Table("sqlite_master").
|
||||
Select("sql").
|
||||
Where("type = ?", "table").
|
||||
Where("name = ?", tableName).
|
||||
Take(&result).Error
|
||||
} else {
|
||||
err = errors.New("not an sqlite database")
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
)
|
||||
|
||||
type DiagnosticsRepository struct {
|
||||
db *gorm.DB
|
||||
BaseRepository
|
||||
}
|
||||
|
||||
func NewDiagnosticsRepository(db *gorm.DB) *DiagnosticsRepository {
|
||||
return &DiagnosticsRepository{db: db}
|
||||
return &DiagnosticsRepository{BaseRepository: NewBaseRepository(db)}
|
||||
}
|
||||
|
||||
func (r *DiagnosticsRepository) Insert(diagnostics *models.Diagnostics) (*models.Diagnostics, error) {
|
||||
|
||||
@@ -15,12 +15,12 @@ import (
|
||||
)
|
||||
|
||||
type HeartbeatRepository struct {
|
||||
db *gorm.DB
|
||||
BaseRepository
|
||||
config *conf.Config
|
||||
}
|
||||
|
||||
func NewHeartbeatRepository(db *gorm.DB) *HeartbeatRepository {
|
||||
return &HeartbeatRepository{config: conf.Get(), db: db}
|
||||
return &HeartbeatRepository{BaseRepository: NewBaseRepository(db), config: conf.Get()}
|
||||
}
|
||||
|
||||
// Use with caution!!
|
||||
|
||||
@@ -11,11 +11,11 @@ import (
|
||||
)
|
||||
|
||||
type KeyValueRepository struct {
|
||||
db *gorm.DB
|
||||
BaseRepository
|
||||
}
|
||||
|
||||
func NewKeyValueRepository(db *gorm.DB) *KeyValueRepository {
|
||||
return &KeyValueRepository{db: db}
|
||||
return &KeyValueRepository{BaseRepository: NewBaseRepository(db)}
|
||||
}
|
||||
|
||||
func (r *KeyValueRepository) GetAll() ([]*models.KeyStringValue, error) {
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
)
|
||||
|
||||
type LanguageMappingRepository struct {
|
||||
BaseRepository
|
||||
config *config.Config
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewLanguageMappingRepository(db *gorm.DB) *LanguageMappingRepository {
|
||||
return &LanguageMappingRepository{config: config.Get(), db: db}
|
||||
return &LanguageMappingRepository{BaseRepository: NewBaseRepository(db), config: config.Get()}
|
||||
}
|
||||
|
||||
func (r *LanguageMappingRepository) GetAll() ([]*models.LanguageMapping, error) {
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
)
|
||||
|
||||
type LeaderboardRepository struct {
|
||||
db *gorm.DB
|
||||
BaseRepository
|
||||
}
|
||||
|
||||
func NewLeaderboardRepository(db *gorm.DB) *LeaderboardRepository {
|
||||
return &LeaderboardRepository{db: db}
|
||||
return &LeaderboardRepository{BaseRepository: NewBaseRepository(db)}
|
||||
}
|
||||
|
||||
func (r *LeaderboardRepository) InsertBatch(items []*models.LeaderboardItem) error {
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
)
|
||||
|
||||
type MetricsRepository struct {
|
||||
BaseRepository
|
||||
config *config.Config
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
const sizeTplMysql = `
|
||||
@@ -23,7 +23,7 @@ SELECT page_count * page_size as size
|
||||
FROM pragma_page_count(), pragma_page_size();`
|
||||
|
||||
func NewMetricsRepository(db *gorm.DB) *MetricsRepository {
|
||||
return &MetricsRepository{config: config.Get(), db: db}
|
||||
return &MetricsRepository{BaseRepository: NewBaseRepository(db), config: config.Get()}
|
||||
}
|
||||
|
||||
func (srv *MetricsRepository) GetDatabaseSize() (size int64, err error) {
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
)
|
||||
|
||||
type ProjectLabelRepository struct {
|
||||
BaseRepository
|
||||
config *config.Config
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewProjectLabelRepository(db *gorm.DB) *ProjectLabelRepository {
|
||||
return &ProjectLabelRepository{config: config.Get(), db: db}
|
||||
return &ProjectLabelRepository{BaseRepository: NewBaseRepository(db), config: config.Get()}
|
||||
}
|
||||
|
||||
func (r *ProjectLabelRepository) GetAll() ([]*models.ProjectLabel, error) {
|
||||
|
||||
@@ -5,7 +5,14 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type IBaseRepository interface {
|
||||
GetDialector() string
|
||||
GetTableDDLMysql(string) (string, error)
|
||||
GetTableDDLSqlite(string) (string, error)
|
||||
}
|
||||
|
||||
type IAliasRepository interface {
|
||||
IBaseRepository
|
||||
Insert(*models.Alias) (*models.Alias, error)
|
||||
Delete(uint) error
|
||||
DeleteBatch([]uint) error
|
||||
@@ -17,6 +24,7 @@ type IAliasRepository interface {
|
||||
}
|
||||
|
||||
type IHeartbeatRepository interface {
|
||||
IBaseRepository
|
||||
InsertBatch([]*models.Heartbeat) error
|
||||
GetAll() ([]*models.Heartbeat, error)
|
||||
GetAllWithin(time.Time, time.Time, *models.User) ([]*models.Heartbeat, error)
|
||||
@@ -37,10 +45,12 @@ type IHeartbeatRepository interface {
|
||||
}
|
||||
|
||||
type IDiagnosticsRepository interface {
|
||||
IBaseRepository
|
||||
Insert(diagnostics *models.Diagnostics) (*models.Diagnostics, error)
|
||||
}
|
||||
|
||||
type IKeyValueRepository interface {
|
||||
IBaseRepository
|
||||
GetAll() ([]*models.KeyStringValue, error)
|
||||
GetString(string) (*models.KeyStringValue, error)
|
||||
PutString(*models.KeyStringValue) error
|
||||
@@ -50,6 +60,7 @@ type IKeyValueRepository interface {
|
||||
}
|
||||
|
||||
type ILanguageMappingRepository interface {
|
||||
IBaseRepository
|
||||
GetAll() ([]*models.LanguageMapping, error)
|
||||
GetById(uint) (*models.LanguageMapping, error)
|
||||
GetByUser(string) ([]*models.LanguageMapping, error)
|
||||
@@ -58,6 +69,7 @@ type ILanguageMappingRepository interface {
|
||||
}
|
||||
|
||||
type IProjectLabelRepository interface {
|
||||
IBaseRepository
|
||||
GetAll() ([]*models.ProjectLabel, error)
|
||||
GetById(uint) (*models.ProjectLabel, error)
|
||||
GetByUser(string) ([]*models.ProjectLabel, error)
|
||||
@@ -66,6 +78,7 @@ type IProjectLabelRepository interface {
|
||||
}
|
||||
|
||||
type ISummaryRepository interface {
|
||||
IBaseRepository
|
||||
Insert(*models.Summary) error
|
||||
GetAll() ([]*models.Summary, error)
|
||||
GetByUserWithin(*models.User, time.Time, time.Time) ([]*models.Summary, error)
|
||||
@@ -75,6 +88,7 @@ type ISummaryRepository interface {
|
||||
}
|
||||
|
||||
type IUserRepository interface {
|
||||
IBaseRepository
|
||||
FindOne(user models.User) (*models.User, error)
|
||||
GetByIds([]string) ([]*models.User, error)
|
||||
GetAll() ([]*models.User, error)
|
||||
@@ -92,6 +106,7 @@ type IUserRepository interface {
|
||||
}
|
||||
|
||||
type ILeaderboardRepository interface {
|
||||
IBaseRepository
|
||||
InsertBatch([]*models.LeaderboardItem) error
|
||||
CountAllByUser(string) (int64, error)
|
||||
CountUsers(bool) (int64, error)
|
||||
|
||||
@@ -11,11 +11,11 @@ import (
|
||||
)
|
||||
|
||||
type SummaryRepository struct {
|
||||
db *gorm.DB
|
||||
BaseRepository
|
||||
}
|
||||
|
||||
func NewSummaryRepository(db *gorm.DB) *SummaryRepository {
|
||||
return &SummaryRepository{db: db}
|
||||
return &SummaryRepository{BaseRepository: NewBaseRepository(db)}
|
||||
}
|
||||
|
||||
func (r *SummaryRepository) GetAll() ([]*models.Summary, error) {
|
||||
|
||||
@@ -12,11 +12,11 @@ import (
|
||||
)
|
||||
|
||||
type UserRepository struct {
|
||||
db *gorm.DB
|
||||
BaseRepository
|
||||
}
|
||||
|
||||
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
return &UserRepository{BaseRepository: NewBaseRepository(db)}
|
||||
}
|
||||
|
||||
func (r *UserRepository) FindOne(attributes models.User) (*models.User, error) {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/muety/wakapi/utils"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -200,6 +201,10 @@ func (srv *UserService) Update(user *models.User) (*models.User, error) {
|
||||
}
|
||||
|
||||
func (srv *UserService) ChangeUserId(user *models.User, newUserId string) (*models.User, error) {
|
||||
if !srv.checkUpdateCascade() {
|
||||
return nil, errors.New("sqlite database too old to perform user id change consistently")
|
||||
}
|
||||
|
||||
// https://github.com/muety/wakapi/issues/739
|
||||
oldUserId := user.ID
|
||||
defer srv.FlushUserCache(oldUserId)
|
||||
@@ -288,3 +293,11 @@ func (srv *UserService) notifyDelete(user *models.User) {
|
||||
Fields: map[string]interface{}{config.FieldPayload: user},
|
||||
})
|
||||
}
|
||||
|
||||
func (srv *UserService) checkUpdateCascade() bool {
|
||||
if dialector := srv.repository.GetDialector(); dialector == "sqlite" || dialector == "sqlite3" {
|
||||
ddl, _ := srv.repository.GetTableDDLSqlite("heartbeats")
|
||||
return strings.Contains(ddl, "ON UPDATE CASCADE")
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user