chore: add check for sqlite cascades before changing user id

This commit is contained in:
Ferdinand Mütsch
2025-02-02 21:56:22 +01:00
parent 2fef990d96
commit 8bd23c99ae
18 changed files with 1072 additions and 873 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,6 @@ package main
import (
"embed"
"flag"
"github.com/muety/wakapi/models"
"io/fs"
"log"
"log/slog"

View File

@@ -6,6 +6,7 @@ import (
)
type AliasRepositoryMock struct {
BaseRepositoryMock
mock.Mock
}

24
mocks/base_repository.go Normal file
View File

@@ -0,0 +1,24 @@
package mocks
import (
"github.com/stretchr/testify/mock"
)
type BaseRepositoryMock struct {
mock.Mock
}
func (m *BaseRepositoryMock) GetDialector() string {
args := m.Called()
return args.Get(0).(string)
}
func (m *BaseRepositoryMock) GetTableDDLMysql(s string) (string, error) {
args := m.Called(s)
return args.Get(0).(string), args.Error(1)
}
func (m *BaseRepositoryMock) GetTableDDLSqlite(s string) (string, error) {
args := m.Called(s)
return args.Get(0).(string), args.Error(1)
}

View File

@@ -7,6 +7,7 @@ import (
)
type SummaryRepositoryMock struct {
BaseRepositoryMock
mock.Mock
}

View File

@@ -7,11 +7,11 @@ import (
)
type AliasRepository struct {
db *gorm.DB
BaseRepository
}
func NewAliasRepository(db *gorm.DB) *AliasRepository {
return &AliasRepository{db: db}
return &AliasRepository{BaseRepository: NewBaseRepository(db)}
}
func (r *AliasRepository) GetAll() ([]*models.Alias, error) {

40
repositories/base.go Normal file
View File

@@ -0,0 +1,40 @@
package repositories
import (
"errors"
"gorm.io/gorm"
)
type BaseRepository struct {
db *gorm.DB
}
func NewBaseRepository(db *gorm.DB) BaseRepository {
return BaseRepository{db: db}
}
func (r *BaseRepository) GetDialector() string {
return r.db.Dialector.Name()
}
func (r *BaseRepository) GetTableDDLMysql(tableName string) (result string, err error) {
if dialector := r.GetDialector(); dialector == "sqlite" || dialector == "sqlite3" {
err = r.db.Raw("show create table ?", tableName).Scan(&result).Error
} else {
err = errors.New("not a mysql database")
}
return result, err
}
func (r *BaseRepository) GetTableDDLSqlite(tableName string) (result string, err error) {
if dialector := r.GetDialector(); dialector == "sqlite" || dialector == "sqlite3" {
err = r.db.Table("sqlite_master").
Select("sql").
Where("type = ?", "table").
Where("name = ?", tableName).
Take(&result).Error
} else {
err = errors.New("not an sqlite database")
}
return result, err
}

View File

@@ -6,11 +6,11 @@ import (
)
type DiagnosticsRepository struct {
db *gorm.DB
BaseRepository
}
func NewDiagnosticsRepository(db *gorm.DB) *DiagnosticsRepository {
return &DiagnosticsRepository{db: db}
return &DiagnosticsRepository{BaseRepository: NewBaseRepository(db)}
}
func (r *DiagnosticsRepository) Insert(diagnostics *models.Diagnostics) (*models.Diagnostics, error) {

View File

@@ -15,12 +15,12 @@ import (
)
type HeartbeatRepository struct {
db *gorm.DB
BaseRepository
config *conf.Config
}
func NewHeartbeatRepository(db *gorm.DB) *HeartbeatRepository {
return &HeartbeatRepository{config: conf.Get(), db: db}
return &HeartbeatRepository{BaseRepository: NewBaseRepository(db), config: conf.Get()}
}
// Use with caution!!

View File

@@ -11,11 +11,11 @@ import (
)
type KeyValueRepository struct {
db *gorm.DB
BaseRepository
}
func NewKeyValueRepository(db *gorm.DB) *KeyValueRepository {
return &KeyValueRepository{db: db}
return &KeyValueRepository{BaseRepository: NewBaseRepository(db)}
}
func (r *KeyValueRepository) GetAll() ([]*models.KeyStringValue, error) {

View File

@@ -8,12 +8,12 @@ import (
)
type LanguageMappingRepository struct {
BaseRepository
config *config.Config
db *gorm.DB
}
func NewLanguageMappingRepository(db *gorm.DB) *LanguageMappingRepository {
return &LanguageMappingRepository{config: config.Get(), db: db}
return &LanguageMappingRepository{BaseRepository: NewBaseRepository(db), config: config.Get()}
}
func (r *LanguageMappingRepository) GetAll() ([]*models.LanguageMapping, error) {

View File

@@ -8,11 +8,11 @@ import (
)
type LeaderboardRepository struct {
db *gorm.DB
BaseRepository
}
func NewLeaderboardRepository(db *gorm.DB) *LeaderboardRepository {
return &LeaderboardRepository{db: db}
return &LeaderboardRepository{BaseRepository: NewBaseRepository(db)}
}
func (r *LeaderboardRepository) InsertBatch(items []*models.LeaderboardItem) error {

View File

@@ -6,8 +6,8 @@ import (
)
type MetricsRepository struct {
BaseRepository
config *config.Config
db *gorm.DB
}
const sizeTplMysql = `
@@ -23,7 +23,7 @@ SELECT page_count * page_size as size
FROM pragma_page_count(), pragma_page_size();`
func NewMetricsRepository(db *gorm.DB) *MetricsRepository {
return &MetricsRepository{config: config.Get(), db: db}
return &MetricsRepository{BaseRepository: NewBaseRepository(db), config: config.Get()}
}
func (srv *MetricsRepository) GetDatabaseSize() (size int64, err error) {

View File

@@ -8,12 +8,12 @@ import (
)
type ProjectLabelRepository struct {
BaseRepository
config *config.Config
db *gorm.DB
}
func NewProjectLabelRepository(db *gorm.DB) *ProjectLabelRepository {
return &ProjectLabelRepository{config: config.Get(), db: db}
return &ProjectLabelRepository{BaseRepository: NewBaseRepository(db), config: config.Get()}
}
func (r *ProjectLabelRepository) GetAll() ([]*models.ProjectLabel, error) {

View File

@@ -5,7 +5,14 @@ import (
"time"
)
type IBaseRepository interface {
GetDialector() string
GetTableDDLMysql(string) (string, error)
GetTableDDLSqlite(string) (string, error)
}
type IAliasRepository interface {
IBaseRepository
Insert(*models.Alias) (*models.Alias, error)
Delete(uint) error
DeleteBatch([]uint) error
@@ -17,6 +24,7 @@ type IAliasRepository interface {
}
type IHeartbeatRepository interface {
IBaseRepository
InsertBatch([]*models.Heartbeat) error
GetAll() ([]*models.Heartbeat, error)
GetAllWithin(time.Time, time.Time, *models.User) ([]*models.Heartbeat, error)
@@ -37,10 +45,12 @@ type IHeartbeatRepository interface {
}
type IDiagnosticsRepository interface {
IBaseRepository
Insert(diagnostics *models.Diagnostics) (*models.Diagnostics, error)
}
type IKeyValueRepository interface {
IBaseRepository
GetAll() ([]*models.KeyStringValue, error)
GetString(string) (*models.KeyStringValue, error)
PutString(*models.KeyStringValue) error
@@ -50,6 +60,7 @@ type IKeyValueRepository interface {
}
type ILanguageMappingRepository interface {
IBaseRepository
GetAll() ([]*models.LanguageMapping, error)
GetById(uint) (*models.LanguageMapping, error)
GetByUser(string) ([]*models.LanguageMapping, error)
@@ -58,6 +69,7 @@ type ILanguageMappingRepository interface {
}
type IProjectLabelRepository interface {
IBaseRepository
GetAll() ([]*models.ProjectLabel, error)
GetById(uint) (*models.ProjectLabel, error)
GetByUser(string) ([]*models.ProjectLabel, error)
@@ -66,6 +78,7 @@ type IProjectLabelRepository interface {
}
type ISummaryRepository interface {
IBaseRepository
Insert(*models.Summary) error
GetAll() ([]*models.Summary, error)
GetByUserWithin(*models.User, time.Time, time.Time) ([]*models.Summary, error)
@@ -75,6 +88,7 @@ type ISummaryRepository interface {
}
type IUserRepository interface {
IBaseRepository
FindOne(user models.User) (*models.User, error)
GetByIds([]string) ([]*models.User, error)
GetAll() ([]*models.User, error)
@@ -92,6 +106,7 @@ type IUserRepository interface {
}
type ILeaderboardRepository interface {
IBaseRepository
InsertBatch([]*models.LeaderboardItem) error
CountAllByUser(string) (int64, error)
CountUsers(bool) (int64, error)

View File

@@ -11,11 +11,11 @@ import (
)
type SummaryRepository struct {
db *gorm.DB
BaseRepository
}
func NewSummaryRepository(db *gorm.DB) *SummaryRepository {
return &SummaryRepository{db: db}
return &SummaryRepository{BaseRepository: NewBaseRepository(db)}
}
func (r *SummaryRepository) GetAll() ([]*models.Summary, error) {

View File

@@ -12,11 +12,11 @@ import (
)
type UserRepository struct {
db *gorm.DB
BaseRepository
}
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
return &UserRepository{BaseRepository: NewBaseRepository(db)}
}
func (r *UserRepository) FindOne(attributes models.User) (*models.User, error) {

View File

@@ -13,6 +13,7 @@ import (
"github.com/muety/wakapi/utils"
"github.com/patrickmn/go-cache"
"log/slog"
"strings"
"time"
)
@@ -200,6 +201,10 @@ func (srv *UserService) Update(user *models.User) (*models.User, error) {
}
func (srv *UserService) ChangeUserId(user *models.User, newUserId string) (*models.User, error) {
if !srv.checkUpdateCascade() {
return nil, errors.New("sqlite database too old to perform user id change consistently")
}
// https://github.com/muety/wakapi/issues/739
oldUserId := user.ID
defer srv.FlushUserCache(oldUserId)
@@ -288,3 +293,11 @@ func (srv *UserService) notifyDelete(user *models.User) {
Fields: map[string]interface{}{config.FieldPayload: user},
})
}
func (srv *UserService) checkUpdateCascade() bool {
if dialector := srv.repository.GetDialector(); dialector == "sqlite" || dialector == "sqlite3" {
ddl, _ := srv.repository.GetTableDDLSqlite("heartbeats")
return strings.Contains(ddl, "ON UPDATE CASCADE")
}
return true
}