diff --git a/enginetest/queries/trigger_queries.go b/enginetest/queries/trigger_queries.go index 505c204d41..2b5bb50c13 100644 --- a/enginetest/queries/trigger_queries.go +++ b/enginetest/queries/trigger_queries.go @@ -2206,18 +2206,244 @@ INSERT INTO t0 (v1, v2) VALUES (i, s); END;`, }, { - Name: "delete me", + Name: "triggers with nested begin-end blocks", SetUpScript: []string{ "create table t (i int primary key);", + ` + create trigger trig + before insert on t + for each row + begin + declare x int; + set x = new.i * 10; + begin + declare y int; + set y = new.i + 10; + set new.i = x + y; + end; + end; + `, }, Assertions: []ScriptTestAssertion{ { - Query: "create trigger trig1 before insert on t for each row begin declare x int; select new.i + 10 into x; set new.i = x; end;", - ExpectedErr: sql.ErrUnsupportedFeature, + Query: "insert into t values (1), (2), (3);", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, }, { - Query: "create trigger trig2 before insert on t for each row begin declare x int; set x = new.i * 10; set new.i = x; end;", - ExpectedErr: sql.ErrUnsupportedFeature, + Query: "select * from t;", + Expected: []sql.Row{ + {21}, + {32}, + {43}, + }, + }, + }, + }, + { + Name: "triggers with declare statements and select into", + SetUpScript: []string{ + "create table t (i int primary key);", + "create trigger trig before insert on t for each row begin declare x int; select new.i + 10 into x; set new.i = x; end;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (1), (2), (3);", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {11}, + {12}, + {13}, + }, + }, + }, + }, + { + Name: "triggers with declare statements and set", + SetUpScript: []string{ + "create table t (i int primary key);", + "create trigger trig before insert on t for each row begin declare x int; set x = new.i + 10; set new.i = x; end;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (1), (2), (3);", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {11}, + {12}, + {13}, + }, + }, + }, + }, + { + Name: "triggers with declare statements and insert", + SetUpScript: []string{ + "create table t (i int primary key);", + "create table t2 (i int primary key);", + ` +create trigger trig before +insert on t for each row begin + declare x int; + set x = new.i * 10; + insert into t2 values (x); +end; +`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (1), (2), (3);", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1}, + {2}, + {3}, + }, + }, + { + Query: "select * from t2;", + Expected: []sql.Row{ + {10}, + {20}, + {30}, + }, + }, + }, + }, + { + Name: "triggers with declare statements and update", + SetUpScript: []string{ + "create table t (i int primary key);", + "create table t2 (i int primary key);", + "insert into t2 values (1), (2), (3);", + ` +create trigger trig before +insert on t for each row begin + declare x int; + set x = new.i * 10; + update t2 set i = x where i = new.i; +end; +`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (1), (2), (3);", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1}, + {2}, + {3}, + }, + }, + { + Query: "select * from t2;", + Expected: []sql.Row{ + {10}, + {20}, + {30}, + }, + }, + }, + }, + { + Name: "triggers with declare statements and delete", + SetUpScript: []string{ + "create table t (i int primary key);", + "create table t2 (i int primary key);", + "insert into t2 values (1), (2), (3);", + ` +create trigger trig before +insert on t for each row begin + declare x int; + set x = new.i; + delete from t2 where i = x; +end; +`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (1), (2), (3);", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1}, + {2}, + {3}, + }, + }, + { + Query: "select * from t2;", + Expected: []sql.Row{}, + }, + }, + }, + { + Name: "triggers with declare statements and stored procedure", + SetUpScript: []string{ + "create table t (i int primary key);", + "create table t2 (i int primary key);", + ` +create procedure proc(in i int) +begin + insert into t2 values (i); +end; +`, + ` +create trigger trig before +insert on t for each row begin + declare x int; + set x = new.i + 10; + call proc(x); +end; +`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (1), (2), (3);", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1}, + {2}, + {3}, + }, + }, + { + Query: "select * from t2;", + Expected: []sql.Row{ + {11}, + {12}, + {13}, + }, }, }, }, diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index a07c220928..bfacd631d0 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -445,6 +445,7 @@ func applyProceduresCall(ctx *sql.Context, a *Analyzer, call *plan.Call, scope * } return n, transform.SameTree, nil case expression.ProcedureReferencable: + // BeginEndBlocks need to reference the same ParameterReference as the Call return n.WithParamReference(pRef), transform.NewTree, nil default: return transform.NodeExprsWithOpaque(n, procParamTransformFunc) diff --git a/sql/analyzer/triggers.go b/sql/analyzer/triggers.go index e5baaa4eb1..aca2d0254d 100644 --- a/sql/analyzer/triggers.go +++ b/sql/analyzer/triggers.go @@ -249,6 +249,60 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope return nil, transform.SameTree, err } + if _, ok := triggerLogic.(*plan.TriggerBeginEndBlock); ok { + pRef := expression.NewProcedureReference() + // assignProcParam transforms any ProcedureParams to reference the ProcedureReference + assignProcParam := func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + switch e := expr.(type) { + case *expression.ProcedureParam: + return e.WithParamReference(pRef), transform.NewTree, nil + default: + return expr, transform.SameTree, nil + } + } + // assignProcRef calls assignProcParam on all nodes that sql.Expressioner + assignProcRef := func(node sql.Node) (sql.Node, transform.TreeIdentity, error) { + switch n := node.(type) { + case sql.Expressioner: + newExprs, same, err := transform.Exprs(n.Expressions(), assignProcParam) + if err != nil { + return nil, transform.SameTree, err + } + if same { + return node, transform.SameTree, nil + } + newNode, err := n.WithExpressions(newExprs...) + if err != nil { + return nil, transform.SameTree, err + } + return newNode, transform.NewTree, nil + default: + return node, transform.SameTree, nil + } + } + assignProcs := func(node sql.Node) (sql.Node, transform.TreeIdentity, error) { + switch n := node.(type) { + case *plan.InsertInto: + newSource, same, err := transform.NodeWithOpaque(n.Source, assignProcRef) + if err != nil { + return nil, transform.SameTree, err + } + if same { + return node, transform.SameTree, nil + } + return n.WithSource(newSource), transform.NewTree, nil + case expression.ProcedureReferencable: + return n.WithParamReference(pRef), transform.NewTree, nil + default: + return assignProcRef(node) + } + } + triggerLogic, _, err = transform.NodeWithOpaque(triggerLogic, assignProcs) + if err != nil { + return nil, transform.SameTree, err + } + } + return transform.NodeWithCtx(n, nil, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) { // Don't double-apply trigger executors to the bodies of triggers. To avoid this, don't apply the trigger if the // parent is a trigger body. diff --git a/sql/expression/procedurereference.go b/sql/expression/procedurereference.go index 496a5b025c..7baf8d8f66 100644 --- a/sql/expression/procedurereference.go +++ b/sql/expression/procedurereference.go @@ -78,27 +78,38 @@ func (ppr *ProcedureReference) InitializeVariable(name string, sqlType sql.Type, } // InitializeCursor sets the initial state for the cursor. -func (ppr *ProcedureReference) InitializeCursor(name string, selectStmt sql.Node) { +func (ppr *ProcedureReference) InitializeCursor(name string, selectStmt sql.Node) error { + if ppr == nil || ppr.InnermostScope == nil { + return fmt.Errorf("cannot initialize cursor `%s` in an empty procedure reference", name) + } lowerName := strings.ToLower(name) ppr.InnermostScope.Cursors[lowerName] = &procedureCursorReferenceValue{ Name: lowerName, SelectStmt: selectStmt, RowIter: nil, } + return nil } // InitializeHandler sets the given handler's statement. -func (ppr *ProcedureReference) InitializeHandler(stmt sql.Node, action DeclareHandlerAction, cond HandlerCondition) { +func (ppr *ProcedureReference) InitializeHandler(stmt sql.Node, action DeclareHandlerAction, cond HandlerCondition) error { + if ppr == nil || ppr.InnermostScope == nil { + return fmt.Errorf("cannot initialize handler in an empty procedure reference") + } ppr.InnermostScope.Handlers = append(ppr.InnermostScope.Handlers, &procedureHandlerReferenceValue{ Stmt: stmt, Cond: cond, Action: action, ScopeHeight: ppr.height, }) + return nil } // GetVariableValue returns the value of the given parameter. func (ppr *ProcedureReference) GetVariableValue(name string) (interface{}, error) { + if ppr == nil { + return nil, fmt.Errorf("cannot find value for parameter `%s`", name) + } lowerName := strings.ToLower(name) scope := ppr.InnermostScope for scope != nil { @@ -128,6 +139,9 @@ func (ppr *ProcedureReference) GetVariableType(name string) sql.Type { // SetVariable updates the value of the given parameter. func (ppr *ProcedureReference) SetVariable(name string, val interface{}, valType sql.Type) error { + if ppr == nil { + return fmt.Errorf("cannot find value for parameter `%s`", name) + } lowerName := strings.ToLower(name) scope := ppr.InnermostScope for scope != nil { @@ -148,6 +162,9 @@ func (ppr *ProcedureReference) SetVariable(name string, val interface{}, valType // VariableHasBeenSet returns whether the parameter has had its value altered from the initial value. func (ppr *ProcedureReference) VariableHasBeenSet(name string) bool { + if ppr == nil { + return false + } lowerName := strings.ToLower(name) scope := ppr.InnermostScope for scope != nil { @@ -161,6 +178,9 @@ func (ppr *ProcedureReference) VariableHasBeenSet(name string) bool { // CloseCursor closes the designated cursor. func (ppr *ProcedureReference) CloseCursor(ctx *sql.Context, name string) error { + if ppr == nil { + return nil + } lowerName := strings.ToLower(name) scope := ppr.InnermostScope for scope != nil { @@ -179,6 +199,9 @@ func (ppr *ProcedureReference) CloseCursor(ctx *sql.Context, name string) error // FetchCursor returns the next row from the designated cursor. func (ppr *ProcedureReference) FetchCursor(ctx *sql.Context, name string) (sql.Row, sql.Schema, error) { + if ppr == nil || ppr.InnermostScope == nil { + return nil, nil, fmt.Errorf("cannot find cursor `%s`", name) + } lowerName := strings.ToLower(name) scope := ppr.InnermostScope for scope != nil { @@ -196,6 +219,9 @@ func (ppr *ProcedureReference) FetchCursor(ctx *sql.Context, name string) (sql.R // PushScope creates a new scope inside the current one. func (ppr *ProcedureReference) PushScope() { + if ppr == nil { + return + } ppr.InnermostScope = &procedureScope{ Parent: ppr.InnermostScope, variables: make(map[string]*procedureVariableReferenceValue), @@ -208,7 +234,7 @@ func (ppr *ProcedureReference) PushScope() { // PopScope removes the innermost scope, returning to its parent. Also closes all open cursors. func (ppr *ProcedureReference) PopScope(ctx *sql.Context) error { var err error - if ppr.InnermostScope == nil { + if ppr == nil || ppr.InnermostScope == nil { return fmt.Errorf("attempted to pop an empty scope") } for _, cursorRefVal := range ppr.InnermostScope.Cursors { @@ -227,6 +253,9 @@ func (ppr *ProcedureReference) PopScope(ctx *sql.Context) error { // CloseAllCursors closes all cursors that are still open. func (ppr *ProcedureReference) CloseAllCursors(ctx *sql.Context) error { + if ppr == nil { + return nil + } var err error scope := ppr.InnermostScope for scope != nil { @@ -234,7 +263,7 @@ func (ppr *ProcedureReference) CloseAllCursors(ctx *sql.Context) error { if cursorRefVal.RowIter != nil { nErr := cursorRefVal.RowIter.Close(ctx) cursorRefVal.RowIter = nil - if err == nil { + if nErr != nil { err = nErr } } @@ -246,6 +275,9 @@ func (ppr *ProcedureReference) CloseAllCursors(ctx *sql.Context) error { // CurrentHeight returns the current height of the scope stack. func (ppr *ProcedureReference) CurrentHeight() int { + if ppr == nil { + return 0 + } return ppr.height } diff --git a/sql/plan/ddl_event.go b/sql/plan/ddl_event.go index ed307d9c55..1ea874cb25 100644 --- a/sql/plan/ddl_event.go +++ b/sql/plan/ddl_event.go @@ -339,6 +339,7 @@ func prepareCreateEventDefinitionNode(definition sql.Node) sql.Node { // analyzer for ProcedureCalls, so we initialize it here. // TODO: How does this work for triggers, which would have the same issue; seems like there // should be a cleaner way to handle this + // TODO: treat this the same way we treat triggers and stored procedures beginEndBlock.Pref = expression.NewProcedureReference() newChildren := make([]sql.Node, len(beginEndBlock.Children())) diff --git a/sql/plan/trigger_begin_end_block.go b/sql/plan/trigger_begin_end_block.go index 93001f1f35..2d83a24554 100644 --- a/sql/plan/trigger_begin_end_block.go +++ b/sql/plan/trigger_begin_end_block.go @@ -16,6 +16,7 @@ package plan import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" ) // TriggerBeginEndBlock represents a BEGIN/END block specific to TRIGGER execution, which has special considerations @@ -46,3 +47,10 @@ func (b *TriggerBeginEndBlock) WithChildren(children ...sql.Node) (sql.Node, err func (b *TriggerBeginEndBlock) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { return b.Block.CheckPrivileges(ctx, opChecker) } + +// WithParamReference implements the interface expression.ProcedureReferencable. +func (b *TriggerBeginEndBlock) WithParamReference(pRef *expression.ProcedureReference) sql.Node { + nb := *b + nb.BeginEndBlock = b.BeginEndBlock.WithParamReference(pRef).(*BeginEndBlock) + return &nb +} diff --git a/sql/planbuilder/create_ddl.go b/sql/planbuilder/create_ddl.go index 38d0634415..a3190f839b 100644 --- a/sql/planbuilder/create_ddl.go +++ b/sql/planbuilder/create_ddl.go @@ -108,16 +108,6 @@ func (b *Builder) buildCreateTrigger(inScope *scope, query string, c *ast.DDL) ( } } - // TODO: /~https://github.com/dolthub/dolt/issues/7720 - if block, isBEBlock := bodyScope.node.(*plan.BeginEndBlock); isBEBlock { - for _, child := range block.Children() { - if _, ok := child.(*plan.DeclareVariables); ok { - err := sql.ErrUnsupportedFeature.New("DECLARE in BEGIN END block in TRIGGER") - b.handleErr(err) - } - } - } - outScope.node = plan.NewCreateTrigger( db, c.TriggerSpec.TrigName.Name.String(), diff --git a/sql/rowexec/dml_iters.go b/sql/rowexec/dml_iters.go index b1de809c9e..7a2d1db595 100644 --- a/sql/rowexec/dml_iters.go +++ b/sql/rowexec/dml_iters.go @@ -170,7 +170,7 @@ func prependRowInPlanForTriggerExecution(row sql.Row) func(c transform.Context) case *plan.Project: // Only prepend rows for projects that aren't the input to inserts and other triggers switch c.Parent.(type) { - case *plan.InsertInto, *plan.TriggerExecutor: + case *plan.InsertInto, *plan.Into, *plan.TriggerExecutor: return n, transform.SameTree, nil default: return plan.NewPrependNode(n, row), transform.NewTree, nil diff --git a/sql/rowexec/other_iters.go b/sql/rowexec/other_iters.go index 64a507a13d..09f0b5961b 100644 --- a/sql/rowexec/other_iters.go +++ b/sql/rowexec/other_iters.go @@ -263,7 +263,9 @@ var _ sql.RowIter = (*declareCursorIter)(nil) // Next implements the interface sql.RowIter. func (d *declareCursorIter) Next(ctx *sql.Context) (sql.Row, error) { - d.Pref.InitializeCursor(d.Name, d.Select) + if err := d.Pref.InitializeCursor(d.Name, d.Select); err != nil { + return nil, err + } return nil, io.EOF } diff --git a/sql/rowexec/rel_iters.go b/sql/rowexec/rel_iters.go index f8cb338c77..56a6eddddb 100644 --- a/sql/rowexec/rel_iters.go +++ b/sql/rowexec/rel_iters.go @@ -754,7 +754,9 @@ var _ sql.RowIter = (*declareHandlerIter)(nil) // Next implements the interface sql.RowIter. func (d *declareHandlerIter) Next(ctx *sql.Context) (sql.Row, error) { - d.Pref.InitializeHandler(d.Statement, d.Action, d.Condition) + if err := d.Pref.InitializeHandler(d.Statement, d.Action, d.Condition); err != nil { + return nil, err + } return nil, io.EOF }