Skip to content

Commit

Permalink
fix coalesce type() (#2301)
Browse files Browse the repository at this point in the history
  • Loading branch information
jycor authored Jan 30, 2024
1 parent 24b7a21 commit 140a153
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 10 deletions.
2 changes: 1 addition & 1 deletion enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -3646,7 +3646,7 @@ CREATE TABLE tab3 (
{
Query: "select COLUMN_NAME, DATA_TYPE from INFORMATION_SCHEMA.COLUMNS where TABLE_NAME='c';",
Expected: []sql.Row{
{"coalesce(NULL,1)", "tinyint"},
{"coalesce(NULL,1)", "int"},
},
},
},
Expand Down
47 changes: 43 additions & 4 deletions sql/expression/function/coalesce.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import (
"fmt"
"strings"

"github.com/dolthub/go-mysql-server/sql/types"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/types"
)

// Coalesce returns the first non-NULL value in the list, or NULL if there are no non-NULL values.
Expand Down Expand Up @@ -53,17 +53,56 @@ func (c *Coalesce) Description() string {
// Type implements the sql.Expression interface.
// The return type of Type() is the aggregated type of the argument types.
func (c *Coalesce) Type() sql.Type {
typ := types.Null
for _, arg := range c.args {
if arg == nil {
continue
}
t := arg.Type()
// special case for signed and unsigned integers
if (types.IsSigned(typ) && types.IsUnsigned(t)) || (types.IsUnsigned(typ) && types.IsSigned(t)) {
typ = types.MustCreateDecimalType(20, 0)
continue
}

if t != nil && t != types.Null {
return t
convType := expression.GetConvertToType(typ, t)
switch convType {
case expression.ConvertToChar:
// Can't get any larger than this
return types.LongText
case expression.ConvertToDecimal:
if typ == types.Float64 || t == types.Float64 {
typ = types.Float64
} else if types.IsDecimal(t) {
typ = t
} else if !types.IsDecimal(typ) {
typ = types.MustCreateDecimalType(10, 0)
}
case expression.ConvertToUnsigned:
if typ == types.Uint64 || t == types.Uint64 {
typ = types.Uint64
} else {
typ = types.Uint32
}
case expression.ConvertToSigned:
if typ == types.Int64 || t == types.Int64 {
typ = types.Int64
} else {
typ = types.Int32
}
case expression.ConvertToFloat:
if typ == types.Float64 || t == types.Float64 {
typ = types.Float64
} else {
typ = types.Float32
}
default:
}
}
}

return types.Null
return typ
}

// CollationCoercibility implements the interface sql.CollationCoercible.
Expand Down
102 changes: 97 additions & 5 deletions sql/expression/function/coalesce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package function
import (
"testing"

"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"

"github.com/dolthub/go-mysql-server/sql"
Expand All @@ -37,11 +38,102 @@ func TestCoalesce(t *testing.T) {
typ sql.Type
nullable bool
}{
{"coalesce(1, 2, 3)", []sql.Expression{expression.NewLiteral(1, types.Int32), expression.NewLiteral(2, types.Int32), expression.NewLiteral(3, types.Int32)}, 1, types.Int32, false},
{"coalesce(NULL, NULL, 3)", []sql.Expression{nil, nil, expression.NewLiteral(3, types.Int32)}, 3, types.Int32, false},
{"coalesce(NULL, NULL, '3')", []sql.Expression{nil, nil, expression.NewLiteral("3", types.LongText)}, "3", types.LongText, false},
{"coalesce(NULL, '2', 3)", []sql.Expression{nil, expression.NewLiteral("2", types.LongText), expression.NewLiteral(3, types.Int32)}, "2", types.LongText, false},
{"coalesce(NULL, NULL, NULL)", []sql.Expression{nil, nil, nil}, nil, types.Null, true},
{
name: "coalesce(1, 2, 3)",
input: []sql.Expression{
expression.NewLiteral(1, types.Int32),
expression.NewLiteral(2, types.Int32),
expression.NewLiteral(3, types.Int32),
},
expected: 1,
typ: types.Int32,
nullable: false,
},
{
name: "coalesce(NULL, NULL, 3)",
input: []sql.Expression{
nil,
nil,
expression.NewLiteral(3, types.Int32),
},
expected: 3,
typ: types.Int32,
nullable: false,
},
{
name: "coalesce(NULL, NULL, '3')",
input: []sql.Expression{
nil,
nil,
expression.NewLiteral("3", types.LongText),
},
expected: "3",
typ: types.LongText,
nullable: false,
},
{
name: "coalesce(NULL, '2', 3)",
input: []sql.Expression{
nil,
expression.NewLiteral("2", types.LongText),
expression.NewLiteral(3, types.Int32),
},
expected: "2",
typ: types.LongText,
nullable: false,
},
{
name: "coalesce(NULL, NULL, NULL)",
input: []sql.Expression{
nil,
nil,
nil,
},
expected: nil,
typ: types.Null,
nullable: true,
},
{
name: "coalesce(int(1), decimal(2.0), string('3'))",
input: []sql.Expression{
expression.NewLiteral(1, types.Int32),
expression.NewLiteral(decimal.NewFromFloat(2.0), types.MustCreateDecimalType(10, 0)),
expression.NewLiteral("3", types.LongText),
},
expected: 1,
typ: types.LongText,
nullable: false,
},
{
name: "coalesce(signed(1), unsigned(2))",
input: []sql.Expression{
expression.NewLiteral(1, types.Int32),
expression.NewLiteral(2, types.Uint32),
},
expected: 1,
typ: types.MustCreateDecimalType(20, 0),
nullable: false,
},
{
name: "coalesce(signed(1), unsigned(2))",
input: []sql.Expression{
expression.NewLiteral(1, types.Int32),
expression.NewLiteral(2, types.Uint32),
},
expected: 1,
typ: types.MustCreateDecimalType(20, 0),
nullable: false,
},
{
name: "coalesce(decimal(1.0), float64(2.0))",
input: []sql.Expression{
expression.NewLiteral(1, types.MustCreateDecimalType(10, 0)),
expression.NewLiteral(2, types.Float64),
},
expected: 1,
typ: types.Float64,
nullable: false,
},
}

for _, tt := range testCases {
Expand Down

0 comments on commit 140a153

Please sign in to comment.