Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jennifersp committed Dec 17, 2024
1 parent 4ee456d commit dda2002
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 22 deletions.
3 changes: 2 additions & 1 deletion sql/databases.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ type CollatedDatabaseProvider interface {
// TableFunctionProvider is an interface that allows custom table functions to be provided. It's usually (but not
// always) implemented by a DatabaseProvider.
type TableFunctionProvider interface {
// TableFunction returns the table function with the name provided, case-insensitive
// TableFunction returns the table function with the name provided, case-insensitive.
// It also returns boolean param for whether the table function was found.
TableFunction(ctx *Context, name string) (TableFunction, bool)
// WithTableFunctions returns a new provider with (only) the list of table functions arguments
WithTableFunctions(fns ...TableFunction) (TableFunctionProvider, error)
Expand Down
54 changes: 34 additions & 20 deletions sql/expression/tablefunction/table_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,28 @@ import (
"github.com/dolthub/go-mysql-server/sql"
)

var _ sql.TableFunction = &TableFunction{}
var _ sql.ExecSourceRel = &TableFunction{}
var _ sql.TableFunction = &TableFunctionWrapper{}
var _ sql.ExecSourceRel = &TableFunctionWrapper{}

type TableFunction struct {
// TableFunctionWrapper represents a table function with underlying
// regular function. It allows using regular function as table function.
type TableFunctionWrapper struct {
underlyingFunc sql.Function

args []sql.Expression
database sql.Database
funcExpr sql.Expression
}

func NewTableFunction(f sql.Function) sql.TableFunction {
return &TableFunction{
// NewTableFunctionWrapper creates new TableFunction
// with given Function as underlying function.
func NewTableFunctionWrapper(f sql.Function) sql.TableFunction {
return &TableFunctionWrapper{
underlyingFunc: f,
}
}

func (t *TableFunction) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
func (t *TableFunctionWrapper) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
nt := *t
nt.database = db
nt.args = args
Expand All @@ -50,67 +54,77 @@ func (t *TableFunction) NewInstance(ctx *sql.Context, db sql.Database, args []sq
return &nt, nil
}

func (t *TableFunction) Children() []sql.Node {
func (t *TableFunctionWrapper) Children() []sql.Node {
return nil
}

func (t *TableFunction) Database() sql.Database {
func (t *TableFunctionWrapper) Database() sql.Database {
return t.database
}

func (t *TableFunction) Expressions() []sql.Expression {
func (t *TableFunctionWrapper) Expressions() []sql.Expression {
if t.funcExpr == nil {
return nil
}
return t.funcExpr.Children()
}

func (t *TableFunction) IsReadOnly() bool {
func (t *TableFunctionWrapper) IsReadOnly() bool {
return true
}

func (t *TableFunction) Name() string {
func (t *TableFunctionWrapper) Name() string {
return t.underlyingFunc.FunctionName()
}

func (t *TableFunction) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) {
func (t *TableFunctionWrapper) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) {
v, err := t.funcExpr.Eval(ctx, r)
if err != nil {
return nil, err
}
return sql.RowsToRowIter(sql.Row{v}), nil
}

func (t *TableFunction) Resolved() bool {
func (t *TableFunctionWrapper) Resolved() bool {
for _, expr := range t.args {
return expr.Resolved()
if !expr.Resolved() {
return false
}
}
return true
}

func (t *TableFunction) Schema() sql.Schema {
func (t *TableFunctionWrapper) Schema() sql.Schema {
return sql.Schema{&sql.Column{Name: t.underlyingFunc.FunctionName(), Type: t.funcExpr.Type()}}
}

func (t *TableFunction) String() string {
func (t *TableFunctionWrapper) String() string {
var args []string
for _, expr := range t.args {
args = append(args, expr.String())
}
return fmt.Sprintf("%s(%s)", t.underlyingFunc.FunctionName(), strings.Join(args, ", "))
}

func (t *TableFunction) WithChildren(children ...sql.Node) (sql.Node, error) {
func (t *TableFunctionWrapper) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 0 {
return nil, fmt.Errorf("unexpected children")
return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 0)
}
return t, nil
}

func (t *TableFunction) WithDatabase(database sql.Database) (sql.Node, error) {
func (t *TableFunctionWrapper) WithDatabase(database sql.Database) (sql.Node, error) {
nt := *t
nt.database = database
return &nt, nil
}

func (t *TableFunction) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
func (t *TableFunctionWrapper) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
if t.funcExpr == nil {
if len(exprs) != 0 {
return nil, sql.ErrInvalidChildrenNumber.New(t, len(exprs), 0)
}
}
l := len(t.funcExpr.Children())
if len(exprs) != l {
return nil, sql.ErrInvalidChildrenNumber.New(t, len(exprs), l)
Expand Down
2 changes: 1 addition & 1 deletion sql/planbuilder/from.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ func (b *Builder) buildTableFunc(inScope *scope, t *ast.TableFuncExpr) (outScope
if !funcFound {
b.handleErr(sql.ErrTableFunctionNotFound.New(utf.Name()))
}
tableFunction = dtablefunctions.NewTableFunction(f)
tableFunction = dtablefunctions.NewTableFunctionWrapper(f)
}

database := b.currentDb()
Expand Down

0 comments on commit dda2002

Please sign in to comment.