Skip to content

Commit

Permalink
Allow setting and querying timezone
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Jan 9, 2025
1 parent ce7a315 commit dce453e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 14 deletions.
18 changes: 10 additions & 8 deletions src/query_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type QueryHandler struct {

type PreparedStatement struct {
Name string
OriginalQuery string
Query string
Statement *sql.Stmt
ParameterOIDs []uint32
Expand Down Expand Up @@ -199,7 +200,7 @@ func (queryHandler *QueryHandler) HandleQuery(originalQuery string) ([]pgproto3.
return nil, err
}
messages = append(messages, descriptionMessages...)
dataMessages, err := queryHandler.rowsToDataMessages(rows, query)
dataMessages, err := queryHandler.rowsToDataMessages(rows, originalQuery)
if err != nil {
return nil, err
}
Expand All @@ -224,6 +225,7 @@ func (queryHandler *QueryHandler) HandleParseQuery(message *pgproto3.Parse) ([]p

preparedStatement := &PreparedStatement{
Name: message.Name,
OriginalQuery: originalQuery,
Query: query,
Statement: statement,
ParameterOIDs: message.ParameterOIDs,
Expand Down Expand Up @@ -320,7 +322,7 @@ func (queryHandler *QueryHandler) HandleExecuteQuery(message *pgproto3.Execute,

defer preparedStatement.Rows.Close()

return queryHandler.rowsToDataMessages(preparedStatement.Rows, preparedStatement.Query)
return queryHandler.rowsToDataMessages(preparedStatement.Rows, preparedStatement.OriginalQuery)
}

func (queryHandler *QueryHandler) createSchemas() {
Expand Down Expand Up @@ -355,30 +357,30 @@ func (queryHandler *QueryHandler) rowsToDescriptionMessages(rows *sql.Rows, quer
return messages, nil
}

func (queryHandler *QueryHandler) rowsToDataMessages(rows *sql.Rows, query string) ([]pgproto3.Message, error) {
func (queryHandler *QueryHandler) rowsToDataMessages(rows *sql.Rows, originalQuery string) ([]pgproto3.Message, error) {
cols, err := rows.ColumnTypes()
if err != nil {
LogError(queryHandler.config, "Couldn't get column types", query+"\n"+err.Error())
LogError(queryHandler.config, "Couldn't get column types", originalQuery+"\n"+err.Error())
return nil, err
}

var messages []pgproto3.Message
for rows.Next() {
dataRow, err := queryHandler.generateDataRow(rows, cols)
if err != nil {
LogError(queryHandler.config, "Couldn't get data row", query+"\n"+err.Error())
LogError(queryHandler.config, "Couldn't get data row", originalQuery+"\n"+err.Error())
return nil, err
}
messages = append(messages, dataRow)
}

commandTag := FALLBACK_SQL_QUERY
switch {
case strings.HasPrefix(query, "SET "):
case strings.HasPrefix(originalQuery, "SET "):
commandTag = "SET"
case strings.HasPrefix(query, "SHOW "):
case strings.HasPrefix(originalQuery, "SHOW "):
commandTag = "SHOW"
case strings.HasPrefix(query, "DISCARD ALL"):
case strings.HasPrefix(originalQuery, "DISCARD ALL"):
commandTag = "DISCARD ALL"
}

Expand Down
29 changes: 25 additions & 4 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -897,10 +897,24 @@ func TestHandleQuery(t *testing.T) {
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.CommandComplete{},
})
commandComplete := messages[0].(*pgproto3.CommandComplete)
if string(commandComplete.CommandTag) != "SET" {
t.Errorf("Expected the command tag to be 'SET', got %v", string(commandComplete.CommandTag))
}
testCommandCompleteTag(t, messages[0], "SET")
})

t.Run("Allows setting and querying timezone", func(t *testing.T) {
queryHandler := initQueryHandler()
queryHandler.HandleQuery("SET timezone = 'UTC'")

messages, err := queryHandler.HandleQuery("SHOW timezone")

testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.RowDescription{},
&pgproto3.DataRow{},
&pgproto3.CommandComplete{},
})
testRowDescription(t, messages[0], []string{"timezone"}, []string{Uint32ToString(pgtype.TextOID)})
testDataRowValues(t, messages[1], []string{"UTC"})
testCommandCompleteTag(t, messages[2], "SHOW")
})
}

Expand Down Expand Up @@ -1169,6 +1183,13 @@ func testDataRowValues(t *testing.T, dataRowMessage pgproto3.Message, expectedVa
}
}

func testCommandCompleteTag(t *testing.T, message pgproto3.Message, expectedTag string) {
commandComplete := message.(*pgproto3.CommandComplete)
if string(commandComplete.CommandTag) != expectedTag {
t.Errorf("Expected the command tag to be %v, got %v", expectedTag, string(commandComplete.CommandTag))
}
}

func Uint32ToString(i uint32) string {
return strconv.FormatUint(uint64(i), 10)
}
11 changes: 9 additions & 2 deletions src/query_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ import (
pgQuery "github.com/pganalyze/pg_query_go/v5"
)

var SUPPORTED_SET_STATEMENTS = NewSet([]string{
"timezone", // SET SESSION timezone TO 'UTC'
})

var KNOWN_SET_STATEMENTS = NewSet([]string{
"client_encoding", // SET client_encoding TO 'UTF8'
"client_min_messages", // SET client_min_messages TO 'warning'
"standard_conforming_strings", // SET standard_conforming_strings = on
"intervalstyle", // SET intervalstyle = iso_8601
"timezone", // SET SESSION timezone TO 'UTC'
"extra_float_digits", // SET extra_float_digits = 3
"application_name", // SET application_name = 'psql'
"datestyle", // SET datestyle TO 'ISO'
Expand Down Expand Up @@ -95,8 +98,12 @@ func (remapper *QueryRemapper) RemapStatements(statements []*pgQuery.RawStmt) ([
func (remapper *QueryRemapper) remapSetStatement(stmt *pgQuery.RawStmt) *pgQuery.RawStmt {
setStatement := stmt.Stmt.GetVariableSetStmt()

if SUPPORTED_SET_STATEMENTS.Contains(strings.ToLower(setStatement.Name)) {
return stmt
}

if !KNOWN_SET_STATEMENTS.Contains(strings.ToLower(setStatement.Name)) {
LogWarn(remapper.config, "Unsupported SET ", setStatement.Name, ":", setStatement)
LogWarn(remapper.config, "Unknown SET ", setStatement.Name, ":", setStatement)
}

return FALLBACK_SET_QUERY_TREE.Stmts[0]
Expand Down

0 comments on commit dce453e

Please sign in to comment.