Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support DECLARE in BEGIN...END BLOCK in TRIGGER #2446

Merged
merged 14 commits into from
Apr 12, 2024
42 changes: 16 additions & 26 deletions enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,44 +206,34 @@ func newUpdateResult(matched, updated int) types.OkResult {

// Convenience test for debugging a single query. Unskip and set to the desired query.
func TestSingleScript(t *testing.T) {
t.Skip()
//t.Skip()
var scripts = []queries.ScriptTest{
{
Name: "physical columns added after virtual one",
Name: "trigger before inserts, use updated reference to other table",
SetUpScript: []string{
"create table t (pk int primary key, col1 int as (pk + 1));",
"insert into t (pk) values (1), (3)",
"alter table t add index idx1 (col1, pk);",
"alter table t add index idx2 (col1);",
"alter table t add column col2 int;",
"alter table t add column col3 int;",
"insert into t (pk, col2, col3) values (2, 4, 5);",
"create table a (i int primary key, j int)",
"create table b (x int primary key)",
"create trigger trig before insert on a for each row begin set new.j = (select coalesce(max(x),1) from b); update b set x = x + 1; end;",
"insert into b values (1)",
"insert into a values (1,0), (2,0), (3,0)",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "select * from t order by pk",
Expected: []sql.Row{
{1, 2, nil, nil},
{2, 3, 4, 5},
{3, 4, nil, nil},
},
},
{
Query: "select * from t where col1 = 2",
Query: "select * from a order by i",
Expected: []sql.Row{
{1, 2, nil, nil},
{1, 1}, {2, 2}, {3, 3},
},
},
{
Query: "select * from t where col1 = 3 and pk = 2",
Query: "select x from b",
Expected: []sql.Row{
{2, 3, 4, 5},
{4},
},
},
{
Query: "select * from t where pk = 2",
Query: "insert into a values (4,0), (5,0)",
Expected: []sql.Row{
{2, 3, 4, 5},
{types.OkResult{RowsAffected: 2}},
},
},
},
Expand All @@ -252,13 +242,13 @@ func TestSingleScript(t *testing.T) {

for _, test := range scripts {
harness := enginetest.NewMemoryHarness("", 1, testNumPartitions, true, nil)
harness.Setup(setup.MydbData, setup.Parent_childData)
//harness.Setup(setup.MydbData, setup.Parent_childData)
engine, err := harness.NewEngine(t)
if err != nil {
panic(err)
}
engine.EngineAnalyzer().Debug = true
engine.EngineAnalyzer().Verbose = true
//engine.EngineAnalyzer().Debug = true
//engine.EngineAnalyzer().Verbose = true

enginetest.TestScriptWithEngine(t, engine, harness, test)
}
Expand Down
76 changes: 71 additions & 5 deletions enginetest/queries/trigger_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2206,18 +2206,84 @@ INSERT INTO t0 (v1, v2) VALUES (i, s); END;`,
},

{
Name: "delete me",
Name: "triggers with declare statements and select into",
SetUpScript: []string{
"create table t (i int primary key);",
"create trigger trig before insert on t for each row begin declare x int; select new.i + 10 into x; set new.i = x; end;",
},
Assertions: []ScriptTestAssertion{
{
Query: "create trigger trig1 before insert on t for each row begin declare x int; select new.i + 10 into x; set new.i = x; end;",
ExpectedErr: sql.ErrUnsupportedFeature,
Query: "insert into t values (1), (2), (3);",
Expected: []sql.Row{
{types.NewOkResult(3)},
},
},
{
Query: "select * from t;",
Expected: []sql.Row{
{11},
{12},
{13},
},
},
},
},
{
Name: "triggers with declare statements and set",
SetUpScript: []string{
"create table t (i int primary key);",
"create trigger trig before insert on t for each row begin declare x int; set x = new.i + 10; set new.i = x; end;",
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into t values (1), (2), (3);",
Expected: []sql.Row{
{types.NewOkResult(3)},
},
},
{
Query: "select * from t;",
Expected: []sql.Row{
{11},
{12},
{13},
},
},
},
},
{
Name: "triggers with nested begin-end blocks",
SetUpScript: []string{
"create table t (i int primary key);",
`
create trigger trig
before insert on t
for each row
begin
declare x int;
set x = new.i * 10;
begin
declare y int;
set y = new.i + 10;
set new.i = x + y;
end;
end;
`,
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into t values (1), (2), (3);",
Expected: []sql.Row{
{types.NewOkResult(3)},
},
},
{
Query: "create trigger trig2 before insert on t for each row begin declare x int; set x = new.i * 10; set new.i = x; end;",
ExpectedErr: sql.ErrUnsupportedFeature,
Query: "select * from t;",
Expected: []sql.Row{
{21},
{32},
{43},
},
},
},
},
Expand Down
1 change: 1 addition & 0 deletions sql/analyzer/stored_procedures.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ func applyProceduresCall(ctx *sql.Context, a *Analyzer, call *plan.Call, scope *
}
return n, transform.SameTree, nil
case expression.ProcedureReferencable:
// BeginEndBlocks need to reference the same ParameterReference as the Call
return n.WithParamReference(pRef), transform.NewTree, nil
default:
return transform.NodeExprsWithOpaque(n, procParamTransformFunc)
Expand Down
35 changes: 35 additions & 0 deletions sql/analyzer/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,41 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
return nil, transform.SameTree, err
}

if _, ok := triggerLogic.(*plan.TriggerBeginEndBlock); ok {
pRef := expression.NewProcedureReference()
triggerLogic, _, err = transform.NodeWithOpaque(triggerLogic, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) {
switch n := node.(type) {
case expression.ProcedureReferencable:
return n.WithParamReference(pRef), transform.NewTree, nil
case sql.Expressioner:
newExprs, same, err := transform.Exprs(n.Expressions(), func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
switch e := expr.(type) {
case *expression.ProcedureParam:
return e.WithParamReference(pRef), transform.NewTree, nil
default:
return expr, transform.SameTree, nil
}
})
if err != nil {
return nil, transform.SameTree, err
}
if !same {
newNode, err := n.WithExpressions(newExprs...)
if err != nil {
return nil, transform.SameTree, err
}
return newNode, transform.NewTree, nil
}
return node, transform.SameTree, nil
default:
return node, transform.SameTree, nil
}
})
if err != nil {
return nil, transform.SameTree, err
}
}

return transform.NodeWithCtx(n, nil, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) {
// Don't double-apply trigger executors to the bodies of triggers. To avoid this, don't apply the trigger if the
// parent is a trigger body.
Expand Down
40 changes: 36 additions & 4 deletions sql/expression/procedurereference.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,38 @@ func (ppr *ProcedureReference) InitializeVariable(name string, sqlType sql.Type,
}

// InitializeCursor sets the initial state for the cursor.
func (ppr *ProcedureReference) InitializeCursor(name string, selectStmt sql.Node) {
func (ppr *ProcedureReference) InitializeCursor(name string, selectStmt sql.Node) error {
if ppr == nil || ppr.InnermostScope == nil {
return fmt.Errorf("cannot initialize cursor `%s` in an empty procedure reference", name)
}
lowerName := strings.ToLower(name)
ppr.InnermostScope.Cursors[lowerName] = &procedureCursorReferenceValue{
Name: lowerName,
SelectStmt: selectStmt,
RowIter: nil,
}
return nil
}

// InitializeHandler sets the given handler's statement.
func (ppr *ProcedureReference) InitializeHandler(stmt sql.Node, action DeclareHandlerAction, cond HandlerCondition) {
func (ppr *ProcedureReference) InitializeHandler(stmt sql.Node, action DeclareHandlerAction, cond HandlerCondition) error {
if ppr == nil || ppr.InnermostScope == nil {
return fmt.Errorf("cannot initialize handler in an empty procedure reference")
}
ppr.InnermostScope.Handlers = append(ppr.InnermostScope.Handlers, &procedureHandlerReferenceValue{
Stmt: stmt,
Cond: cond,
Action: action,
ScopeHeight: ppr.height,
})
return nil
}

// GetVariableValue returns the value of the given parameter.
func (ppr *ProcedureReference) GetVariableValue(name string) (interface{}, error) {
if ppr == nil {
return nil, fmt.Errorf("cannot find value for parameter `%s`", name)
}
lowerName := strings.ToLower(name)
scope := ppr.InnermostScope
for scope != nil {
Expand Down Expand Up @@ -128,6 +139,9 @@ func (ppr *ProcedureReference) GetVariableType(name string) sql.Type {

// SetVariable updates the value of the given parameter.
func (ppr *ProcedureReference) SetVariable(name string, val interface{}, valType sql.Type) error {
if ppr == nil {
return fmt.Errorf("cannot find value for parameter `%s`", name)
}
lowerName := strings.ToLower(name)
scope := ppr.InnermostScope
for scope != nil {
Expand All @@ -148,6 +162,9 @@ func (ppr *ProcedureReference) SetVariable(name string, val interface{}, valType

// VariableHasBeenSet returns whether the parameter has had its value altered from the initial value.
func (ppr *ProcedureReference) VariableHasBeenSet(name string) bool {
if ppr == nil {
return false
}
lowerName := strings.ToLower(name)
scope := ppr.InnermostScope
for scope != nil {
Expand All @@ -161,6 +178,9 @@ func (ppr *ProcedureReference) VariableHasBeenSet(name string) bool {

// CloseCursor closes the designated cursor.
func (ppr *ProcedureReference) CloseCursor(ctx *sql.Context, name string) error {
if ppr == nil {
return nil
}
lowerName := strings.ToLower(name)
scope := ppr.InnermostScope
for scope != nil {
Expand All @@ -179,6 +199,9 @@ func (ppr *ProcedureReference) CloseCursor(ctx *sql.Context, name string) error

// FetchCursor returns the next row from the designated cursor.
func (ppr *ProcedureReference) FetchCursor(ctx *sql.Context, name string) (sql.Row, sql.Schema, error) {
if ppr == nil || ppr.InnermostScope == nil {
return nil, nil, fmt.Errorf("cannot find cursor `%s`", name)
}
lowerName := strings.ToLower(name)
scope := ppr.InnermostScope
for scope != nil {
Expand All @@ -196,6 +219,9 @@ func (ppr *ProcedureReference) FetchCursor(ctx *sql.Context, name string) (sql.R

// PushScope creates a new scope inside the current one.
func (ppr *ProcedureReference) PushScope() {
if ppr == nil {
return
}
ppr.InnermostScope = &procedureScope{
Parent: ppr.InnermostScope,
variables: make(map[string]*procedureVariableReferenceValue),
Expand All @@ -208,7 +234,7 @@ func (ppr *ProcedureReference) PushScope() {
// PopScope removes the innermost scope, returning to its parent. Also closes all open cursors.
func (ppr *ProcedureReference) PopScope(ctx *sql.Context) error {
var err error
if ppr.InnermostScope == nil {
if ppr == nil || ppr.InnermostScope == nil {
return fmt.Errorf("attempted to pop an empty scope")
}
for _, cursorRefVal := range ppr.InnermostScope.Cursors {
Expand All @@ -227,14 +253,17 @@ func (ppr *ProcedureReference) PopScope(ctx *sql.Context) error {

// CloseAllCursors closes all cursors that are still open.
func (ppr *ProcedureReference) CloseAllCursors(ctx *sql.Context) error {
if ppr == nil {
return nil
}
var err error
scope := ppr.InnermostScope
for scope != nil {
for _, cursorRefVal := range scope.Cursors {
if cursorRefVal.RowIter != nil {
nErr := cursorRefVal.RowIter.Close(ctx)
cursorRefVal.RowIter = nil
if err == nil {
if nErr != nil {
err = nErr
}
}
Expand All @@ -246,6 +275,9 @@ func (ppr *ProcedureReference) CloseAllCursors(ctx *sql.Context) error {

// CurrentHeight returns the current height of the scope stack.
func (ppr *ProcedureReference) CurrentHeight() int {
if ppr == nil {
return 0
}
return ppr.height
}

Expand Down
1 change: 1 addition & 0 deletions sql/plan/ddl_event.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ func prepareCreateEventDefinitionNode(definition sql.Node) sql.Node {
// analyzer for ProcedureCalls, so we initialize it here.
// TODO: How does this work for triggers, which would have the same issue; seems like there
// should be a cleaner way to handle this
// TODO: treat this the same way we treat triggers and stored procedures
beginEndBlock.Pref = expression.NewProcedureReference()

newChildren := make([]sql.Node, len(beginEndBlock.Children()))
Expand Down
8 changes: 8 additions & 0 deletions sql/plan/trigger_begin_end_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package plan

import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
)

// TriggerBeginEndBlock represents a BEGIN/END block specific to TRIGGER execution, which has special considerations
Expand Down Expand Up @@ -46,3 +47,10 @@ func (b *TriggerBeginEndBlock) WithChildren(children ...sql.Node) (sql.Node, err
func (b *TriggerBeginEndBlock) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return b.Block.CheckPrivileges(ctx, opChecker)
}

// WithParamReference implements the interface expression.ProcedureReferencable.
func (b *TriggerBeginEndBlock) WithParamReference(pRef *expression.ProcedureReference) sql.Node {
nb := *b
nb.BeginEndBlock = b.BeginEndBlock.WithParamReference(pRef).(*BeginEndBlock)
return &nb
}
10 changes: 0 additions & 10 deletions sql/planbuilder/create_ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,6 @@ func (b *Builder) buildCreateTrigger(inScope *scope, query string, c *ast.DDL) (
}
}

// TODO: /~https://github.com/dolthub/dolt/issues/7720
if block, isBEBlock := bodyScope.node.(*plan.BeginEndBlock); isBEBlock {
for _, child := range block.Children() {
if _, ok := child.(*plan.DeclareVariables); ok {
err := sql.ErrUnsupportedFeature.New("DECLARE in BEGIN END block in TRIGGER")
b.handleErr(err)
}
}
}

outScope.node = plan.NewCreateTrigger(
db,
c.TriggerSpec.TrigName.Name.String(),
Expand Down
2 changes: 1 addition & 1 deletion sql/rowexec/dml_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func prependRowInPlanForTriggerExecution(row sql.Row) func(c transform.Context)
case *plan.Project:
// Only prepend rows for projects that aren't the input to inserts and other triggers
switch c.Parent.(type) {
case *plan.InsertInto, *plan.TriggerExecutor:
case *plan.InsertInto, *plan.Into, *plan.TriggerExecutor:
return n, transform.SameTree, nil
default:
return plan.NewPrependNode(n, row), transform.NewTree, nil
Expand Down
Loading
Loading