add quotesql function to quote identifiers in sql

This commit is contained in:
Chen Junda
2024-01-11 23:02:36 +08:00
parent 9169560336
commit e296886d92
4 changed files with 23 additions and 17 deletions

View File

@@ -1,7 +1,6 @@
package repositories
import (
"fmt"
"time"
"github.com/duke-git/lancet/v2/slice"
@@ -17,10 +16,6 @@ type HeartbeatRepository struct {
config *conf.Config
}
func (r *HeartbeatRepository) QuoteDbIdentifier(id string) string {
return utils.QuoteDbIdentifier(r.db, id)
}
func NewHeartbeatRepository(db *gorm.DB) *HeartbeatRepository {
return &HeartbeatRepository{config: conf.Get(), db: db}
}
@@ -120,7 +115,7 @@ func (r *HeartbeatRepository) GetLatestByFilters(user *models.User, filterMap ma
func (r *HeartbeatRepository) GetFirstByUsers() ([]*models.TimeByUser, error) {
var result []*models.TimeByUser
r.db.Model(&models.User{}).
Select(fmt.Sprintf("users.id as %s, min(time) as %s", r.QuoteDbIdentifier("user"), r.QuoteDbIdentifier("time"))).
Select(utils.QuoteSql(r.db, "users.id as %s, min(time) as %s", "user", "time")).
Joins("left join heartbeats on users.id = heartbeats.user_id").
Group("users.id").
Scan(&result)
@@ -130,7 +125,7 @@ func (r *HeartbeatRepository) GetFirstByUsers() ([]*models.TimeByUser, error) {
func (r *HeartbeatRepository) GetLastByUsers() ([]*models.TimeByUser, error) {
var result []*models.TimeByUser
r.db.Model(&models.User{}).
Select(fmt.Sprintf("users.id as %s, max(time) as %s", r.QuoteDbIdentifier("user"), r.QuoteDbIdentifier("time"))).
Select(utils.QuoteSql(r.db, "users.id as %s, max(time) as %s", "user", "time")).
Joins("left join heartbeats on users.id = heartbeats.user_id").
Group("user").
Scan(&result)
@@ -179,7 +174,7 @@ func (r *HeartbeatRepository) CountByUsers(users []*models.User) ([]*models.Coun
if err := r.db.
Model(&models.Heartbeat{}).
Select(fmt.Sprintf("user_id as %s, count(id) as %s", r.QuoteDbIdentifier("user"), r.QuoteDbIdentifier("count"))).
Select(utils.QuoteSql(r.db, "user_id as %s, count(id) as %s", "user", "count")).
Where("user_id in ?", userIds).
Group("user").
Find(&counts).Error; err != nil {

View File

@@ -2,7 +2,6 @@ package repositories
import (
"errors"
"fmt"
"github.com/muety/wakapi/models"
"github.com/muety/wakapi/utils"
@@ -39,10 +38,8 @@ func (r *KeyValueRepository) GetString(key string) (*models.KeyStringValue, erro
func (r *KeyValueRepository) Search(like string) ([]*models.KeyStringValue, error) {
var keyValues []*models.KeyStringValue
condition := fmt.Sprintf("%s like ?", utils.QuoteDbIdentifier(r.db, "key"))
if err := r.db.Table("key_string_values").
Where(condition, like).
Where(utils.QuoteSql(r.db, "%s like ?", "key"), like).
Find(&keyValues).
Error; err != nil {
return nil, err

View File

@@ -1,11 +1,13 @@
package repositories
import (
"time"
"github.com/duke-git/lancet/v2/slice"
"github.com/muety/wakapi/models"
"github.com/muety/wakapi/utils"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"time"
)
type SummaryRepository struct {
@@ -70,9 +72,9 @@ func (r *SummaryRepository) GetByUserWithin(user *models.User, from, to time.Tim
func (r *SummaryRepository) GetLastByUser() ([]*models.TimeByUser, error) {
var result []*models.TimeByUser
r.db.Model(&models.User{}).
Select("users.id as user, max(to_time) as time").
Select(utils.QuoteSql(r.db, "users.id as %s, max(to_time) as time", "user")).
Joins("left join summaries on users.id = summaries.user_id").
Group("user").
Group("users.id").
Scan(&result)
return result, nil
}

View File

@@ -65,11 +65,23 @@ func (s stringWriter) WriteString(str string) (int, error) {
}
// QuoteDbIdentifier quotes a column name used in a query.
func QuoteDbIdentifier(query *gorm.DB, identifier string) string {
func QuoteDbIdentifier(db *gorm.DB, identifier string) string {
builder := stringWriter{Builder: &strings.Builder{}}
query.Dialector.QuoteTo(builder, identifier)
db.Dialector.QuoteTo(builder, identifier)
return builder.Builder.String()
}
// QuoteSql quotes a SQL statement with the given identifiers.
func QuoteSql(db *gorm.DB, queryTemplate string, identifiers ...string) string {
quotedIdentifiers := make([]interface{}, len(identifiers))
for i, identifier := range identifiers {
quotedIdentifiers[i] = QuoteDbIdentifier(db, identifier)
}
return fmt.Sprintf(queryTemplate, quotedIdentifiers...)
}