Skip to content

Commit

Permalink
Fix sql identifier escaping in datastore feed (woodpecker-ci#4746)
Browse files Browse the repository at this point in the history
  • Loading branch information
xoxys authored Jan 19, 2025
1 parent 0b65723 commit 08021ca
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 7 deletions.
16 changes: 11 additions & 5 deletions server/store/datastore/feed.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@
package datastore

import (
"fmt"

"xorm.io/builder"

"go.woodpecker-ci.org/woodpecker/v3/server/model"
)

var feedItemSelect = `repos.id as repo_id,
func (s storage) getFeedSelect() string {
const feedTemplate = `repos.id as repo_id,
pipelines.id as pipeline_id,
pipelines.number as pipeline_number,
pipelines.event as pipeline_event,
pipelines.status as pipeline_status,
pipelines.created as pipeline_created,
pipelines.started as pipeline_started,
pipelines.finished as pipeline_finished,
'pipelines.commit' as pipeline_commit,
pipelines.%s as pipeline_commit,
pipelines.branch as pipeline_branch,
pipelines.ref as pipeline_ref,
pipelines.refspec as pipeline_refspec,
Expand All @@ -38,10 +41,13 @@ pipelines.author as pipeline_author,
pipelines.email as pipeline_email,
pipelines.avatar as pipeline_avatar`

return fmt.Sprintf(feedTemplate, s.quoteIdentifier("commit"))
}

func (s storage) GetPipelineQueue() ([]*model.Feed, error) {
feed := make([]*model.Feed, 0, perPage)
err := s.engine.Table("pipelines").
Select(feedItemSelect).
Select(s.getFeedSelect()).
Join("INNER", "repos", "pipelines.repo_id = repos.id").
In("pipelines.status", model.StatusPending, model.StatusRunning).
Find(&feed)
Expand All @@ -51,7 +57,7 @@ func (s storage) GetPipelineQueue() ([]*model.Feed, error) {
func (s storage) UserFeed(user *model.User) ([]*model.Feed, error) {
feed := make([]*model.Feed, 0, perPage)
err := s.engine.Table("repos").
Select(feedItemSelect).
Select(s.getFeedSelect()).
Join("INNER", "perms", "repos.id = perms.repo_id").
Join("INNER", "pipelines", "repos.id = pipelines.repo_id").
Where(userPushOrAdminCondition(user.ID)).
Expand All @@ -66,7 +72,7 @@ func (s storage) RepoListLatest(user *model.User) ([]*model.Feed, error) {
feed := make([]*model.Feed, 0, perPage)

err := s.engine.Table("repos").
Select(feedItemSelect).
Select(s.getFeedSelect()).
Join("INNER", "perms", "repos.id = perms.repo_id").
Join("LEFT", "pipelines", "pipelines.id = "+`(
SELECT pipelines.id FROM pipelines
Expand Down
27 changes: 25 additions & 2 deletions server/store/datastore/feed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,37 @@ func TestGetPipelineQueue(t *testing.T) {
assert.NoError(t, store.PermUpsert(perm))
}
pipeline1 := &model.Pipeline{
RepoID: repo1.ID,
Status: model.StatusPending,
RepoID: repo1.ID,
Status: model.StatusPending,
Number: 1,
Event: "push",
Commit: "abc123",
Branch: "main",
Ref: "refs/heads/main",
Message: "Initial commit",
Author: "joe",
Email: "foo@bar.com",
Title: "First pipeline",
}
assert.NoError(t, store.CreatePipeline(pipeline1))

feed, err := store.GetPipelineQueue()
assert.NoError(t, err)
assert.Len(t, feed, 1)

feedItem := feed[0]
assert.Equal(t, repo1.ID, feedItem.RepoID)
assert.Equal(t, pipeline1.ID, feedItem.ID)
assert.Equal(t, pipeline1.Number, feedItem.Number)
assert.EqualValues(t, pipeline1.Event, feedItem.Event)
assert.EqualValues(t, pipeline1.Status, feedItem.Status)
assert.Equal(t, pipeline1.Commit, feedItem.Commit)
assert.Equal(t, pipeline1.Branch, feedItem.Branch)
assert.Equal(t, pipeline1.Ref, feedItem.Ref)
assert.Equal(t, pipeline1.Title, feedItem.Title)
assert.Equal(t, pipeline1.Message, feedItem.Message)
assert.Equal(t, pipeline1.Author, feedItem.Author)
assert.Equal(t, pipeline1.Email, feedItem.Email)
}

func TestUserFeed(t *testing.T) {
Expand Down
12 changes: 12 additions & 0 deletions server/store/datastore/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,15 @@ func callerName(skip int) string {
}
return fnName
}

func (s storage) quoteIdentifier(identifier string) string {
driver := s.engine.DriverName()
switch driver {
case DriverMysql:
return "`" + identifier + "`"
case DriverPostgres, DriverSqlite:
return "\"" + identifier + "\""
default:
return identifier
}
}

0 comments on commit 08021ca

Please sign in to comment.