From e296886d92642fe5623a666fd1381fd3e7aa7f70 Mon Sep 17 00:00:00 2001 From: Chen Junda Date: Thu, 11 Jan 2024 23:02:36 +0800 Subject: [PATCH] add quotesql function to quote identifiers in sql --- repositories/heartbeat.go | 11 +++-------- repositories/key_value.go | 5 +---- repositories/summary.go | 8 +++++--- utils/db.go | 16 ++++++++++++++-- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/repositories/heartbeat.go b/repositories/heartbeat.go index 07a1a0a..5176db7 100644 --- a/repositories/heartbeat.go +++ b/repositories/heartbeat.go @@ -1,7 +1,6 @@ package repositories import ( - "fmt" "time" "github.com/duke-git/lancet/v2/slice" @@ -17,10 +16,6 @@ type HeartbeatRepository struct { config *conf.Config } -func (r *HeartbeatRepository) QuoteDbIdentifier(id string) string { - return utils.QuoteDbIdentifier(r.db, id) -} - func NewHeartbeatRepository(db *gorm.DB) *HeartbeatRepository { return &HeartbeatRepository{config: conf.Get(), db: db} } @@ -120,7 +115,7 @@ func (r *HeartbeatRepository) GetLatestByFilters(user *models.User, filterMap ma func (r *HeartbeatRepository) GetFirstByUsers() ([]*models.TimeByUser, error) { var result []*models.TimeByUser r.db.Model(&models.User{}). - Select(fmt.Sprintf("users.id as %s, min(time) as %s", r.QuoteDbIdentifier("user"), r.QuoteDbIdentifier("time"))). + Select(utils.QuoteSql(r.db, "users.id as %s, min(time) as %s", "user", "time")). Joins("left join heartbeats on users.id = heartbeats.user_id"). Group("users.id"). Scan(&result) @@ -130,7 +125,7 @@ func (r *HeartbeatRepository) GetFirstByUsers() ([]*models.TimeByUser, error) { func (r *HeartbeatRepository) GetLastByUsers() ([]*models.TimeByUser, error) { var result []*models.TimeByUser r.db.Model(&models.User{}). - Select(fmt.Sprintf("users.id as %s, max(time) as %s", r.QuoteDbIdentifier("user"), r.QuoteDbIdentifier("time"))). + Select(utils.QuoteSql(r.db, "users.id as %s, max(time) as %s", "user", "time")). Joins("left join heartbeats on users.id = heartbeats.user_id"). Group("user"). Scan(&result) @@ -179,7 +174,7 @@ func (r *HeartbeatRepository) CountByUsers(users []*models.User) ([]*models.Coun if err := r.db. Model(&models.Heartbeat{}). - Select(fmt.Sprintf("user_id as %s, count(id) as %s", r.QuoteDbIdentifier("user"), r.QuoteDbIdentifier("count"))). + Select(utils.QuoteSql(r.db, "user_id as %s, count(id) as %s", "user", "count")). Where("user_id in ?", userIds). Group("user"). Find(&counts).Error; err != nil { diff --git a/repositories/key_value.go b/repositories/key_value.go index ed6eef3..450aa92 100644 --- a/repositories/key_value.go +++ b/repositories/key_value.go @@ -2,7 +2,6 @@ package repositories import ( "errors" - "fmt" "github.com/muety/wakapi/models" "github.com/muety/wakapi/utils" @@ -39,10 +38,8 @@ func (r *KeyValueRepository) GetString(key string) (*models.KeyStringValue, erro func (r *KeyValueRepository) Search(like string) ([]*models.KeyStringValue, error) { var keyValues []*models.KeyStringValue - condition := fmt.Sprintf("%s like ?", utils.QuoteDbIdentifier(r.db, "key")) - if err := r.db.Table("key_string_values"). - Where(condition, like). + Where(utils.QuoteSql(r.db, "%s like ?", "key"), like). Find(&keyValues). Error; err != nil { return nil, err diff --git a/repositories/summary.go b/repositories/summary.go index e7d2d7f..04cacde 100644 --- a/repositories/summary.go +++ b/repositories/summary.go @@ -1,11 +1,13 @@ package repositories import ( + "time" + "github.com/duke-git/lancet/v2/slice" "github.com/muety/wakapi/models" + "github.com/muety/wakapi/utils" "gorm.io/gorm" "gorm.io/gorm/clause" - "time" ) type SummaryRepository struct { @@ -70,9 +72,9 @@ func (r *SummaryRepository) GetByUserWithin(user *models.User, from, to time.Tim func (r *SummaryRepository) GetLastByUser() ([]*models.TimeByUser, error) { var result []*models.TimeByUser r.db.Model(&models.User{}). - Select("users.id as user, max(to_time) as time"). + Select(utils.QuoteSql(r.db, "users.id as %s, max(to_time) as time", "user")). Joins("left join summaries on users.id = summaries.user_id"). - Group("user"). + Group("users.id"). Scan(&result) return result, nil } diff --git a/utils/db.go b/utils/db.go index 33c43bc..8528b29 100644 --- a/utils/db.go +++ b/utils/db.go @@ -65,11 +65,23 @@ func (s stringWriter) WriteString(str string) (int, error) { } // QuoteDbIdentifier quotes a column name used in a query. -func QuoteDbIdentifier(query *gorm.DB, identifier string) string { +func QuoteDbIdentifier(db *gorm.DB, identifier string) string { builder := stringWriter{Builder: &strings.Builder{}} - query.Dialector.QuoteTo(builder, identifier) + db.Dialector.QuoteTo(builder, identifier) return builder.Builder.String() } + +// QuoteSql quotes a SQL statement with the given identifiers. +func QuoteSql(db *gorm.DB, queryTemplate string, identifiers ...string) string { + + quotedIdentifiers := make([]interface{}, len(identifiers)) + + for i, identifier := range identifiers { + quotedIdentifiers[i] = QuoteDbIdentifier(db, identifier) + } + + return fmt.Sprintf(queryTemplate, quotedIdentifiers...) +}