Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type coercion to top-down fix expression types #2150

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion enginetest/evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ func injectBindVarsAndPrepare(
case sqlparser.HexNum, sqlparser.HexVal:
return false, nil
}
expr := b.ConvertVal(n)
expr := b.ConvertVal(n, nil)
var val interface{}
if l, ok := expr.(*expression.Literal); ok {
val, _, err = expr.Type().Promote().Convert(l.Value())
Expand Down
102 changes: 51 additions & 51 deletions enginetest/queries/imdb_plans.go

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions enginetest/queries/tpch_plans.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions server/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
package server

import (
sqle "github.com/dolthub/go-mysql-server"
"sort"

"github.com/dolthub/vitess/go/mysql"
"github.com/dolthub/vitess/go/sqltypes"
querypb "github.com/dolthub/vitess/go/vt/proto/query"
"github.com/dolthub/vitess/go/vt/sqlparser"
ast "github.com/dolthub/vitess/go/vt/sqlparser"
"sort"

sqle "github.com/dolthub/go-mysql-server"
)

func Intercept(h Interceptor) {
Expand Down
2 changes: 0 additions & 2 deletions sql/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,6 @@ func newInsertSourceSelector(sel RuleSelector) RuleSelector {
// Analyze applies the transformation rules to the node given. In the case of an error, the last successfully
// transformed node is returned along with the error.
func (a *Analyzer) Analyze(ctx *sql.Context, n sql.Node, scope *plan.Scope) (sql.Node, error) {
//a.Verbose = true
//a.Debug = true
n, _, err := a.analyzeWithSelector(ctx, n, scope, SelectAllBatches, DefaultRuleSelector)
return n, err
}
Expand Down
2 changes: 2 additions & 0 deletions sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ func validateLimitAndOffset(ctx *sql.Context, a *Analyzer, n sql.Node, scope *pl
err = sql.ErrInvalidSyntax.New("negative limit")
return false
}
case *expression.CoerceInternal:

case *expression.BindVar:
return true
default:
Expand Down
68 changes: 68 additions & 0 deletions sql/expression/convert_internal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package expression

import (
"fmt"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

func NewCoerceInternal(e sql.Expression, typ sql.Type) *CoerceInternal {
return &CoerceInternal{e: e, typ: typ}
}

type CoerceInternal struct {
e sql.Expression
typ sql.Type
}

var _ sql.Expression = (*CoerceInternal)(nil)

func (c *CoerceInternal) Resolved() bool {
return true
}

func (c *CoerceInternal) String() string {
return fmt.Sprintf("coerce(%s->%s)", c.e, c.typ)
}

func (c *CoerceInternal) Type() sql.Type {
return c.typ
}

func (c *CoerceInternal) IsNullable() bool {
return c.e.IsNullable()
}

func (c *CoerceInternal) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
val, err := c.e.Eval(ctx, row)
if err != nil {
return nil, err
}
ret, inRange, err := c.typ.Convert(val)
if err != nil {
switch c.typ {
case types.Boolean:
return false, nil
default:
return nil, err
}
}
if !inRange {
ctx.Warn(0, "coercion %s to %s failed, out of range", val, c.typ)
}
return ret, nil
}

func (c *CoerceInternal) Children() []sql.Expression {
return []sql.Expression{c.e}
}

func (c *CoerceInternal) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1)
}
ret := *c
ret.e = children[0]
return &ret, nil
}
8 changes: 4 additions & 4 deletions sql/planbuilder/aggregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (b *Builder) buildGroupingCols(fromScope, projScope *scope, groupby ast.Gro
col = projScope.cols[intIdx-1]
}
default:
expr := b.buildScalar(fromScope, e)
expr := b.buildScalar(fromScope, e, nil)
col = scopeColumn{
tableId: sql.TableID{},
col: expr.String(),
Expand Down Expand Up @@ -592,7 +592,7 @@ func (b *Builder) buildWindowDef(fromScope *scope, def *ast.WindowDef) *sql.Wind
var sortFields sql.SortFields
for _, c := range def.OrderBy {
// resolve col in fromScope
e := b.buildScalar(fromScope, c.Expr)
e := b.buildScalar(fromScope, c.Expr, nil)
so := sql.Ascending
if c.Direction == ast.DescScr {
so = sql.Descending
Expand All @@ -606,7 +606,7 @@ func (b *Builder) buildWindowDef(fromScope *scope, def *ast.WindowDef) *sql.Wind

partitions := make([]sql.Expression, len(def.PartitionBy))
for i, expr := range def.PartitionBy {
partitions[i] = b.buildScalar(fromScope, expr)
partitions[i] = b.buildScalar(fromScope, expr, nil)
}

frame := b.NewFrame(fromScope, def.Frame)
Expand Down Expand Up @@ -774,7 +774,7 @@ func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast
}
}
havingScope.groupBy = fromScope.groupBy
h := b.buildScalar(havingScope, having.Expr)
h := b.buildScalar(havingScope, having.Expr, types.Boolean)
outScope.node = plan.NewHaving(h, outScope.node)
return
}
25 changes: 10 additions & 15 deletions sql/planbuilder/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import (
ast "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/stats"
"github.com/dolthub/go-mysql-server/sql/types"
)

func (b *Builder) buildAnalyze(inScope *scope, n *ast.Analyze, query string) (outScope *scope) {
Expand Down Expand Up @@ -119,25 +119,20 @@ func (b *Builder) buildAnalyzeTables(inScope *scope, n *ast.Analyze, query strin
return
}

func (b *Builder) buildAnalyzeUpdate(inScope *scope, n *ast.Analyze, dbName, tableName string, sch sql.Schema, columns []string, types []sql.Type) (outScope *scope) {
func (b *Builder) buildAnalyzeUpdate(inScope *scope, n *ast.Analyze, dbName, tableName string, sch sql.Schema, columns []string, typs []sql.Type) (outScope *scope) {
outScope = inScope.push()
statistic := new(stats.Statistic)
using := b.buildScalar(inScope, n.Using)
if l, ok := using.(*expression.Literal); ok {
if typ, ok := l.Type().(sql.StringType); ok {
val, _, err := typ.Convert(l.Value())
using := b.buildScalar(inScope, n.Using, types.LongText)
if lit, err := using.Eval(nil, nil); err == nil {
if str, ok := lit.(string); ok {
err := json.Unmarshal([]byte(str), statistic)
if err != nil {
err = ErrFailedToParseStats.New(err.Error(), str)
b.handleErr(err)
}
if str, ok := val.(string); ok {
err := json.Unmarshal([]byte(str), statistic)
if err != nil {
err = ErrFailedToParseStats.New(err.Error(), str)
b.handleErr(err)
}
}

}
} else {
b.handleErr(err)
}
if statistic == nil {
err := fmt.Errorf("no statistics found for update")
Expand All @@ -149,7 +144,7 @@ func (b *Builder) buildAnalyzeUpdate(inScope *scope, n *ast.Analyze, dbName, tab
}
statistic.SetQualifier(sql.NewStatQualifier(dbName, tableName, indexName))
statistic.SetColumns(columns)
statistic.SetTypes(types)
statistic.SetTypes(typs)

statCols := sql.NewFastIntSet()
for _, c := range columns {
Expand Down
2 changes: 1 addition & 1 deletion sql/planbuilder/create_ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (b *Builder) buildCreateEvent(inScope *scope, query string, c *ast.DDL) (ou
}

func (b *Builder) buildEventScheduleTimeSpec(inScope *scope, spec *ast.EventScheduleTimeSpec) (sql.Expression, []sql.Expression) {
ts := b.buildScalar(inScope, spec.EventTimestamp)
ts := b.buildScalar(inScope, spec.EventTimestamp, nil)
if len(spec.EventIntervals) == 0 {
return ts, nil
}
Expand Down
13 changes: 10 additions & 3 deletions sql/planbuilder/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/expression/function"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"
"github.com/dolthub/go-mysql-server/sql/types"
)

Expand Down Expand Up @@ -703,7 +704,13 @@ func (b *Builder) convertConstraintDefinition(inScope *scope, cd *ast.Constraint
} else if chConstraint, ok := cd.Details.(*ast.CheckConstraintDefinition); ok {
var c sql.Expression
if chConstraint.Expr != nil {
c = b.buildScalar(inScope, chConstraint.Expr)
c = b.buildScalar(inScope, chConstraint.Expr, nil)
c, _, _ = transform.Expr(c, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
if e, _ := e.(*expression.CoerceInternal); e != nil {
return e.Children()[0], transform.NewTree, nil
}
return e, transform.SameTree, nil
})
}

return &sql.CheckConstraint{
Expand Down Expand Up @@ -910,7 +917,7 @@ func (b *Builder) buildDefaultExpression(inScope *scope, defaultExpr ast.Expr) *
if defaultExpr == nil {
return nil
}
parsedExpr := b.buildScalar(inScope, defaultExpr)
parsedExpr := b.buildScalar(inScope, defaultExpr, nil)

// Function expressions must be enclosed in parentheses (except for current_timestamp() and now())
_, isParenthesized := defaultExpr.(*ast.ParenExpr)
Expand Down Expand Up @@ -1274,7 +1281,7 @@ func (b *Builder) convertDefaultExpression(inScope *scope, defaultExpr ast.Expr,
if defaultExpr == nil {
return nil
}
resExpr := b.buildScalar(inScope, defaultExpr)
resExpr := b.buildScalar(inScope, defaultExpr, nil)

// Function expressions must be enclosed in parentheses (except for current_timestamp() and now())
_, isParenthesized := defaultExpr.(*ast.ParenExpr)
Expand Down
26 changes: 5 additions & 21 deletions sql/planbuilder/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (b *Builder) buildInsertValues(inScope *scope, v ast.Values, columnNames []
// isolation (no access to the destination schema)
exprs[j] = assignColumnIndexes(exprs[j], reorderSchema(columnNames, destSchema))
default:
exprs[j] = b.buildScalar(inScope, e)
exprs[j] = b.buildScalar(inScope, e, nil)
}
}
}
Expand All @@ -193,22 +193,6 @@ func reorderSchema(names []string, schema sql.Schema) sql.Schema {
return newSch
}

func (b *Builder) buildValues(inScope *scope, v ast.Values) (outScope *scope) {
// TODO add literals to outScope?
exprTuples := make([][]sql.Expression, len(v))
for i, vt := range v {
exprs := make([]sql.Expression, len(vt))
exprTuples[i] = exprs
for j, e := range vt {
exprs[j] = b.buildScalar(inScope, e)
}
}

outScope = inScope.push()
outScope.node = plan.NewValues(exprTuples)
return
}

func (b *Builder) assignmentExprsToExpressions(inScope *scope, e ast.AssignmentExprs) []sql.Expression {
updateExprs := make([]sql.Expression, len(e))
var startAggCnt int
Expand All @@ -223,7 +207,7 @@ func (b *Builder) assignmentExprsToExpressions(inScope *scope, e ast.AssignmentE
tableSch := inScope.node.Schema()

for i, updateExpr := range e {
colName := b.buildScalar(inScope, updateExpr.Name)
colName := b.buildScalar(inScope, updateExpr.Name, nil)

// Prevent update of generated columns
if gf, ok := colName.(*expression.GetField); ok {
Expand All @@ -236,7 +220,7 @@ func (b *Builder) assignmentExprsToExpressions(inScope *scope, e ast.AssignmentE
}
}

innerExpr := b.buildScalar(inScope, updateExpr.Expr)
innerExpr := b.buildScalar(inScope, updateExpr.Expr, nil)
updateExprs[i] = expression.NewSetField(colName, innerExpr)
if inScope.groupBy != nil {
if len(inScope.groupBy.aggs) > startAggCnt {
Expand Down Expand Up @@ -282,7 +266,7 @@ func (b *Builder) buildOnDupUpdateExprs(combinedScope, destScope *scope, e ast.A
}
for i, updateExpr := range e {
colName := b.buildOnDupLeft(destScope, updateExpr.Name)
innerExpr := b.buildScalar(combinedScope, updateExpr.Expr)
innerExpr := b.buildScalar(combinedScope, updateExpr.Expr, nil)

res[i] = expression.NewSetField(colName, innerExpr)
if combinedScope.groupBy != nil {
Expand Down Expand Up @@ -613,7 +597,7 @@ func (b *Builder) buildCheckConstraint(inScope *scope, check *sql.CheckDefinitio
b.handleErr(err)
}

c := b.buildScalar(inScope, ae.Expr)
c := b.buildScalar(inScope, ae.Expr, nil)

return &sql.CheckConstraint{
Name: check.Name,
Expand Down
Loading