From eca5bf76f8f0f91eb2384191b57dd89d0b211632 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 16 Dec 2024 16:46:11 -0800 Subject: [PATCH 1/3] allow using function and table function --- enginetest/join_stats_tests.go | 6 +- memory/provider.go | 6 +- sql/analyzer/catalog.go | 13 +- sql/catalog_map.go | 6 +- sql/databases.go | 2 +- .../tablefunction/table_function.go | 125 ++++++++++++++++++ sql/planbuilder/from.go | 29 ++-- test/test_catalog.go | 2 +- 8 files changed, 154 insertions(+), 35 deletions(-) create mode 100644 sql/expression/tablefunction/table_function.go diff --git a/enginetest/join_stats_tests.go b/enginetest/join_stats_tests.go index 7df713e6fd..32e43eb7cd 100644 --- a/enginetest/join_stats_tests.go +++ b/enginetest/join_stats_tests.go @@ -360,12 +360,12 @@ func (t TestProvider) Function(ctx *sql.Context, name string) (sql.Function, boo return nil, false } -func (t TestProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) { +func (t TestProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, bool) { if tf, ok := t.tableFunctions[strings.ToLower(name)]; ok { - return tf, nil + return tf, true } - return nil, sql.ErrTableFunctionNotFound.New(name) + return nil, false } func (t TestProvider) WithTableFunctions(fns ...sql.TableFunction) (sql.TableFunctionProvider, error) { diff --git a/memory/provider.go b/memory/provider.go index 3023fc6e65..624e66b549 100644 --- a/memory/provider.go +++ b/memory/provider.go @@ -194,10 +194,10 @@ func (pro *DbProvider) ExternalStoredProcedures(_ *sql.Context, name string) ([] } // TableFunction implements sql.TableFunctionProvider -func (pro *DbProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) { +func (pro *DbProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, bool) { if tableFunction, ok := pro.tableFunctions[name]; ok { - return tableFunction, nil + return tableFunction, true } - return nil, sql.ErrTableFunctionNotFound.New(name) + return nil, false } diff --git a/sql/analyzer/catalog.go b/sql/analyzer/catalog.go index fedddae31a..167778f7dc 100644 --- a/sql/analyzer/catalog.go +++ b/sql/analyzer/catalog.go @@ -384,17 +384,14 @@ func (c *Catalog) ExternalStoredProcedures(ctx *sql.Context, name string) ([]sql } // TableFunction implements the TableFunctionProvider interface -func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, error) { +func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, bool) { if fp, ok := c.DbProvider.(sql.TableFunctionProvider); ok { - tf, err := fp.TableFunction(ctx, name) - if err != nil { - return nil, err - } else if tf != nil { - return tf, nil + tf, found := fp.TableFunction(ctx, name) + if found && tf != nil { + return tf, true } } - - return nil, sql.ErrTableFunctionNotFound.New(name) + return nil, false } func (c *Catalog) RefreshTableStats(ctx *sql.Context, table sql.Table, db string) error { diff --git a/sql/catalog_map.go b/sql/catalog_map.go index 3f23b03a6b..ecbf0f9567 100644 --- a/sql/catalog_map.go +++ b/sql/catalog_map.go @@ -25,11 +25,11 @@ func (t MapCatalog) Function(ctx *Context, name string) (Function, bool) { return nil, false } -func (t MapCatalog) TableFunction(ctx *Context, name string) (TableFunction, error) { +func (t MapCatalog) TableFunction(ctx *Context, name string) (TableFunction, bool) { if f, ok := t.tabFuncs[name]; ok { - return f, nil + return f, true } - return nil, fmt.Errorf("table func not found") + return nil, false } func (t MapCatalog) ExternalStoredProcedure(ctx *Context, name string, numOfParams int) (*ExternalStoredProcedureDetails, error) { diff --git a/sql/databases.go b/sql/databases.go index f7dfbe2f49..cefaafc221 100644 --- a/sql/databases.go +++ b/sql/databases.go @@ -58,7 +58,7 @@ type CollatedDatabaseProvider interface { // always) implemented by a DatabaseProvider. type TableFunctionProvider interface { // TableFunction returns the table function with the name provided, case-insensitive - TableFunction(ctx *Context, name string) (TableFunction, error) + TableFunction(ctx *Context, name string) (TableFunction, bool) // WithTableFunctions returns a new provider with (only) the list of table functions arguments WithTableFunctions(fns ...TableFunction) (TableFunctionProvider, error) } diff --git a/sql/expression/tablefunction/table_function.go b/sql/expression/tablefunction/table_function.go new file mode 100644 index 0000000000..7bae4c12e9 --- /dev/null +++ b/sql/expression/tablefunction/table_function.go @@ -0,0 +1,125 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dtablefunctions + +import ( + "fmt" + "strings" + + "github.com/dolthub/go-mysql-server/sql" +) + +var _ sql.TableFunction = &TableFunction{} +var _ sql.ExecSourceRel = &TableFunction{} + +type TableFunction struct { + underlyingFunc sql.Function + + args []sql.Expression + database sql.Database + funcExpr sql.Expression +} + +func NewTableFunction(f sql.Function) sql.TableFunction { + return &TableFunction{ + underlyingFunc: f, + } +} + +func (t *TableFunction) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) { + nt := *t + nt.database = db + nt.args = args + f, err := nt.underlyingFunc.NewInstance(args) + if err != nil { + return nil, err + } + nt.funcExpr = f + return &nt, nil +} + +func (t *TableFunction) Children() []sql.Node { + return nil +} + +func (t *TableFunction) Database() sql.Database { + return t.database +} + +func (t *TableFunction) Expressions() []sql.Expression { + return t.funcExpr.Children() +} + +func (t *TableFunction) IsReadOnly() bool { + return true +} + +func (t *TableFunction) Name() string { + return t.underlyingFunc.FunctionName() +} + +func (t *TableFunction) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { + v, err := t.funcExpr.Eval(ctx, r) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(sql.Row{v}), nil +} + +func (t *TableFunction) Resolved() bool { + for _, expr := range t.args { + return expr.Resolved() + } + return true +} + +func (t *TableFunction) Schema() sql.Schema { + return sql.Schema{&sql.Column{Name: t.underlyingFunc.FunctionName(), Type: t.funcExpr.Type()}} +} + +func (t *TableFunction) String() string { + var args []string + for _, expr := range t.args { + args = append(args, expr.String()) + } + return fmt.Sprintf("%s(%s)", t.underlyingFunc.FunctionName(), strings.Join(args, ", ")) +} + +func (t *TableFunction) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, fmt.Errorf("unexpected children") + } + return t, nil +} + +func (t *TableFunction) WithDatabase(database sql.Database) (sql.Node, error) { + nt := *t + nt.database = database + return &nt, nil +} + +func (t *TableFunction) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + l := len(t.funcExpr.Children()) + if len(exprs) != l { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(exprs), l) + } + nt := *t + nf, err := nt.funcExpr.WithChildren(exprs...) + if err != nil { + return nil, err + } + nt.funcExpr = nf + return &nt, nil +} diff --git a/sql/planbuilder/from.go b/sql/planbuilder/from.go index 51283365b3..df57f0801f 100644 --- a/sql/planbuilder/from.go +++ b/sql/planbuilder/from.go @@ -16,6 +16,7 @@ package planbuilder import ( "fmt" + dtablefunctions "github.com/dolthub/go-mysql-server/sql/expression/tablefunction" "strings" ast "github.com/dolthub/vitess/go/vt/sqlparser" @@ -447,20 +448,11 @@ func (b *Builder) resolveTable(tab, db string, asOf interface{}) *plan.ResolvedT func (b *Builder) buildTableFunc(inScope *scope, t *ast.TableFuncExpr) (outScope *scope) { //TODO what are valid mysql table arguments args := make([]sql.Expression, 0, len(t.Exprs)) - for _, e := range t.Exprs { - switch e := e.(type) { + for _, expr := range t.Exprs { + switch e := expr.(type) { case *ast.AliasedExpr: - expr := b.buildScalar(inScope, e.Expr) - - if !e.As.IsEmpty() { - b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e))) - } - - if selectExprNeedsAlias(e, expr) { - b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e))) - } - - args = append(args, expr) + scalarExpr := b.buildScalar(inScope, e.Expr) + args = append(args, scalarExpr) default: b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(e))) } @@ -468,9 +460,14 @@ func (b *Builder) buildTableFunc(inScope *scope, t *ast.TableFuncExpr) (outScope utf := expression.NewUnresolvedTableFunction(t.Name, args) - tableFunction, err := b.cat.TableFunction(b.ctx, utf.Name()) - if err != nil { - b.handleErr(err) + tableFunction, found := b.cat.TableFunction(b.ctx, utf.Name()) + if !found { + // try getting regular function + f, funcFound := b.cat.Function(b.ctx, utf.Name()) + if !funcFound { + b.handleErr(sql.ErrTableFunctionNotFound.New(utf.Name())) + } + tableFunction = dtablefunctions.NewTableFunction(f) } database := b.currentDb() diff --git a/test/test_catalog.go b/test/test_catalog.go index 1f94f439f1..9500271e04 100644 --- a/test/test_catalog.go +++ b/test/test_catalog.go @@ -159,7 +159,7 @@ func (c *Catalog) UnlockTables(ctx *sql.Context, id uint32) error { return nil } -func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, error) { +func (c *Catalog) TableFunction(ctx *sql.Context, name string) (sql.TableFunction, bool) { //TODO implement me panic("implement me") } From 4ee456d36059768b799f1679be4c632bfc032ba1 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 17 Dec 2024 01:08:56 +0000 Subject: [PATCH 2/3] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/planbuilder/from.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/planbuilder/from.go b/sql/planbuilder/from.go index df57f0801f..2aeecb2c65 100644 --- a/sql/planbuilder/from.go +++ b/sql/planbuilder/from.go @@ -16,13 +16,13 @@ package planbuilder import ( "fmt" - dtablefunctions "github.com/dolthub/go-mysql-server/sql/expression/tablefunction" "strings" ast "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + dtablefunctions "github.com/dolthub/go-mysql-server/sql/expression/tablefunction" "github.com/dolthub/go-mysql-server/sql/mysql_db" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/transform" From dda20025ef3dfe33089f456083e5541134603caf Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 17 Dec 2024 12:48:29 -0800 Subject: [PATCH 3/3] fixes --- sql/databases.go | 3 +- .../tablefunction/table_function.go | 54 ++++++++++++------- sql/planbuilder/from.go | 2 +- 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/sql/databases.go b/sql/databases.go index cefaafc221..e1ec634ace 100644 --- a/sql/databases.go +++ b/sql/databases.go @@ -57,7 +57,8 @@ type CollatedDatabaseProvider interface { // TableFunctionProvider is an interface that allows custom table functions to be provided. It's usually (but not // always) implemented by a DatabaseProvider. type TableFunctionProvider interface { - // TableFunction returns the table function with the name provided, case-insensitive + // TableFunction returns the table function with the name provided, case-insensitive. + // It also returns boolean param for whether the table function was found. TableFunction(ctx *Context, name string) (TableFunction, bool) // WithTableFunctions returns a new provider with (only) the list of table functions arguments WithTableFunctions(fns ...TableFunction) (TableFunctionProvider, error) diff --git a/sql/expression/tablefunction/table_function.go b/sql/expression/tablefunction/table_function.go index 7bae4c12e9..4491b1f8d9 100644 --- a/sql/expression/tablefunction/table_function.go +++ b/sql/expression/tablefunction/table_function.go @@ -21,10 +21,12 @@ import ( "github.com/dolthub/go-mysql-server/sql" ) -var _ sql.TableFunction = &TableFunction{} -var _ sql.ExecSourceRel = &TableFunction{} +var _ sql.TableFunction = &TableFunctionWrapper{} +var _ sql.ExecSourceRel = &TableFunctionWrapper{} -type TableFunction struct { +// TableFunctionWrapper represents a table function with underlying +// regular function. It allows using regular function as table function. +type TableFunctionWrapper struct { underlyingFunc sql.Function args []sql.Expression @@ -32,13 +34,15 @@ type TableFunction struct { funcExpr sql.Expression } -func NewTableFunction(f sql.Function) sql.TableFunction { - return &TableFunction{ +// NewTableFunctionWrapper creates new TableFunction +// with given Function as underlying function. +func NewTableFunctionWrapper(f sql.Function) sql.TableFunction { + return &TableFunctionWrapper{ underlyingFunc: f, } } -func (t *TableFunction) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) { +func (t *TableFunctionWrapper) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) { nt := *t nt.database = db nt.args = args @@ -50,27 +54,30 @@ func (t *TableFunction) NewInstance(ctx *sql.Context, db sql.Database, args []sq return &nt, nil } -func (t *TableFunction) Children() []sql.Node { +func (t *TableFunctionWrapper) Children() []sql.Node { return nil } -func (t *TableFunction) Database() sql.Database { +func (t *TableFunctionWrapper) Database() sql.Database { return t.database } -func (t *TableFunction) Expressions() []sql.Expression { +func (t *TableFunctionWrapper) Expressions() []sql.Expression { + if t.funcExpr == nil { + return nil + } return t.funcExpr.Children() } -func (t *TableFunction) IsReadOnly() bool { +func (t *TableFunctionWrapper) IsReadOnly() bool { return true } -func (t *TableFunction) Name() string { +func (t *TableFunctionWrapper) Name() string { return t.underlyingFunc.FunctionName() } -func (t *TableFunction) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { +func (t *TableFunctionWrapper) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { v, err := t.funcExpr.Eval(ctx, r) if err != nil { return nil, err @@ -78,18 +85,20 @@ func (t *TableFunction) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error return sql.RowsToRowIter(sql.Row{v}), nil } -func (t *TableFunction) Resolved() bool { +func (t *TableFunctionWrapper) Resolved() bool { for _, expr := range t.args { - return expr.Resolved() + if !expr.Resolved() { + return false + } } return true } -func (t *TableFunction) Schema() sql.Schema { +func (t *TableFunctionWrapper) Schema() sql.Schema { return sql.Schema{&sql.Column{Name: t.underlyingFunc.FunctionName(), Type: t.funcExpr.Type()}} } -func (t *TableFunction) String() string { +func (t *TableFunctionWrapper) String() string { var args []string for _, expr := range t.args { args = append(args, expr.String()) @@ -97,20 +106,25 @@ func (t *TableFunction) String() string { return fmt.Sprintf("%s(%s)", t.underlyingFunc.FunctionName(), strings.Join(args, ", ")) } -func (t *TableFunction) WithChildren(children ...sql.Node) (sql.Node, error) { +func (t *TableFunctionWrapper) WithChildren(children ...sql.Node) (sql.Node, error) { if len(children) != 0 { - return nil, fmt.Errorf("unexpected children") + return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0) } return t, nil } -func (t *TableFunction) WithDatabase(database sql.Database) (sql.Node, error) { +func (t *TableFunctionWrapper) WithDatabase(database sql.Database) (sql.Node, error) { nt := *t nt.database = database return &nt, nil } -func (t *TableFunction) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { +func (t *TableFunctionWrapper) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if t.funcExpr == nil { + if len(exprs) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(t, len(exprs), 0) + } + } l := len(t.funcExpr.Children()) if len(exprs) != l { return nil, sql.ErrInvalidChildrenNumber.New(t, len(exprs), l) diff --git a/sql/planbuilder/from.go b/sql/planbuilder/from.go index 2aeecb2c65..c7070fe69f 100644 --- a/sql/planbuilder/from.go +++ b/sql/planbuilder/from.go @@ -467,7 +467,7 @@ func (b *Builder) buildTableFunc(inScope *scope, t *ast.TableFuncExpr) (outScope if !funcFound { b.handleErr(sql.ErrTableFunctionNotFound.New(utf.Name())) } - tableFunction = dtablefunctions.NewTableFunction(f) + tableFunction = dtablefunctions.NewTableFunctionWrapper(f) } database := b.currentDb()