From 33cd2ef75cf670fec9f2497435c6ae2b55eacd16 Mon Sep 17 00:00:00 2001 From: Maximilian Hoffman Date: Mon, 5 Feb 2024 17:23:09 -0800 Subject: [PATCH] Stored procedures can use params as LIMIT,OFFSET (#2315) * Stored procedures can use params as LIMIT,OFFSET * error tests and fix proc param types --- enginetest/queries/procedure_queries.go | 38 ++++++++++-- go.mod | 2 +- go.sum | 4 +- .../resolve_external_stored_procedures.go | 4 +- sql/analyzer/validation_rules.go | 4 +- sql/expression/procedurereference.go | 7 ++- sql/plan/procedure.go | 2 +- sql/planbuilder/create_ddl.go | 2 +- sql/planbuilder/proc.go | 2 +- sql/planbuilder/select.go | 60 +++++++++++++------ 10 files changed, 91 insertions(+), 34 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 28d469bfdb..2a933d030c 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -715,15 +715,15 @@ END;`, { Query: "CALL p1(3, 4)", Expected: []sql.Row{ - {"4", "6"}, - {"3", "4"}, + {4, 6}, + {3, 4}, }, }, { Query: "CALL p2(5, 6)", Expected: []sql.Row{ - {"6", "8"}, - {"5", "6"}, + {6, 8}, + {5, 6}, }, }, }, @@ -1153,6 +1153,36 @@ END;`, }, }, }, + { + Name: "issue 7458: proc params as limit values", + SetUpScript: []string{ + "create table t (i int primary key);", + "insert into t values (0), (1), (2), (3)", + "CREATE PROCEDURE limited(the_limit int, the_offset bigint) SELECT * FROM t LIMIT the_limit OFFSET the_offset", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "call limited(1,0)", + Expected: []sql.Row{{0}}, + }, + { + Query: "call limited(2,0)", + Expected: []sql.Row{{0}, {1}}, + }, + { + Query: "call limited(2,2)", + Expected: []sql.Row{{2}, {3}}, + }, + { + Query: "CREATE PROCEDURE limited_inv(the_limit CHAR(3), the_offset INT) SELECT * FROM t LIMIT the_limit OFFSET the_offset", + ExpectedErrStr: "the variable 'the_limit' has a non-integer based type: char(3) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_bin", + }, + { + Query: "CREATE PROCEDURE limited_inv(the_limit float, the_offset INT) SELECT * FROM t LIMIT the_limit OFFSET the_offset", + ExpectedErrStr: "the variable 'the_limit' has a non-integer based type: float", + }, + }, + }, { Name: "FETCH captures state at OPEN", SetUpScript: []string{ diff --git a/go.mod b/go.mod index 3e3cc30e60..f7cac4419c 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e github.com/dolthub/jsonpath v0.0.2-0.20240201003050-392940944c15 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20240129233432-aec9daef6af7 + github.com/dolthub/vitess v0.0.0-20240205203605-9e6c6d650813 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index dc2109814b..b24c5b1100 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240201003050-392940944c15 h1:sfTETOpsrNJP github.com/dolthub/jsonpath v0.0.2-0.20240201003050-392940944c15/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20240129233432-aec9daef6af7 h1:AhmDCMtoEh2PwYsfblCaWIVvpHgDmWhz1YNNwl67vm4= -github.com/dolthub/vitess v0.0.0-20240129233432-aec9daef6af7/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= +github.com/dolthub/vitess v0.0.0-20240205203605-9e6c6d650813 h1:tGwsoLAMFQ+7FDEyIWOIJ1Vc/nptbFi0Fh7SQahB8ro= +github.com/dolthub/vitess v0.0.0-20240205203605-9e6c6d650813/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/sql/analyzer/resolve_external_stored_procedures.go b/sql/analyzer/resolve_external_stored_procedures.go index 30e4026ba5..7062cfd4b2 100644 --- a/sql/analyzer/resolve_external_stored_procedures.go +++ b/sql/analyzer/resolve_external_stored_procedures.go @@ -131,7 +131,7 @@ func resolveExternalStoredProcedure(_ *sql.Context, externalProcedure sql.Extern Type: sqlType, Variadic: paramIsVariadic, } - paramReferences[i] = expression.NewProcedureParam(paramName) + paramReferences[i] = expression.NewProcedureParam(paramName, sqlType) } else if sqlType, ok = externalStoredProcedurePointerTypes[funcParamType]; ok { paramDefinitions[i] = plan.ProcedureParam{ Direction: plan.ProcedureParamDirection_Inout, @@ -139,7 +139,7 @@ func resolveExternalStoredProcedure(_ *sql.Context, externalProcedure sql.Extern Type: sqlType, Variadic: paramIsVariadic, } - paramReferences[i] = expression.NewProcedureParam(paramName) + paramReferences[i] = expression.NewProcedureParam(paramName, sqlType) } else { return nil, sql.ErrExternalProcedureInvalidParamType.New(funcParamType.String()) } diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index f116877e47..3fb20a67a1 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -55,7 +55,7 @@ func validateLimitAndOffset(ctx *sql.Context, a *Analyzer, n sql.Node, scope *pl err = sql.ErrInvalidSyntax.New("negative limit") return false } - case *expression.BindVar: + case *expression.BindVar, *expression.ProcedureParam: return true default: err = sql.ErrInvalidType.New(e.Type().String()) @@ -81,7 +81,7 @@ func validateLimitAndOffset(ctx *sql.Context, a *Analyzer, n sql.Node, scope *pl err = sql.ErrInvalidSyntax.New("negative offset") return false } - case *expression.BindVar: + case *expression.BindVar, *expression.ProcedureParam: return true default: err = sql.ErrInvalidType.New(e.Type().String()) diff --git a/sql/expression/procedurereference.go b/sql/expression/procedurereference.go index 258fd1257f..d86ccd8420 100644 --- a/sql/expression/procedurereference.go +++ b/sql/expression/procedurereference.go @@ -259,6 +259,7 @@ func NewProcedureReference() *ProcedureReference { type ProcedureParam struct { name string pRef *ProcedureReference + typ sql.Type hasBeenSet bool } @@ -266,8 +267,8 @@ var _ sql.Expression = (*ProcedureParam)(nil) var _ sql.CollationCoercible = (*ProcedureParam)(nil) // NewProcedureParam creates a new ProcedureParam expression. -func NewProcedureParam(name string) *ProcedureParam { - return &ProcedureParam{name: strings.ToLower(name)} +func NewProcedureParam(name string, typ sql.Type) *ProcedureParam { + return &ProcedureParam{name: strings.ToLower(name), typ: typ} } // Children implements the sql.Expression interface. @@ -287,7 +288,7 @@ func (*ProcedureParam) IsNullable() bool { // Type implements the sql.Expression interface. func (pp *ProcedureParam) Type() sql.Type { - return pp.pRef.GetVariableType(pp.name) + return pp.typ } // CollationCoercibility implements the sql.CollationCoercible interface. diff --git a/sql/plan/procedure.go b/sql/plan/procedure.go index aee48b65f7..945792a497 100644 --- a/sql/plan/procedure.go +++ b/sql/plan/procedure.go @@ -210,7 +210,7 @@ func (p *Procedure) ExtendVariadic(ctx *sql.Context, length int) *Procedure { Type: variadicParam.Type, Variadic: variadicParam.Variadic, } - newParams[i] = expression.NewProcedureParam(paramName) + newParams[i] = expression.NewProcedureParam(paramName, variadicParam.Type) } } } diff --git a/sql/planbuilder/create_ddl.go b/sql/planbuilder/create_ddl.go index 2377356e59..ee9b4a8d48 100644 --- a/sql/planbuilder/create_ddl.go +++ b/sql/planbuilder/create_ddl.go @@ -196,7 +196,7 @@ func (b *Builder) buildCreateProcedure(inScope *scope, query string, c *ast.DDL) // populate inScope with the procedure parameters. this will be // subject maybe a bug where an inner procedure has access to // outer procedure parameters. - inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name))) + inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name), p.Type)) } bodyStr := strings.TrimSpace(query[c.SubStatementPositionStart:c.SubStatementPositionEnd]) diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 0616e81ae1..520c8b47ef 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -324,7 +324,7 @@ func (b *Builder) buildDeclareVariables(inScope *scope, d *ast.Declare) (outScop for i, variable := range dVars.Names { varName := strings.ToLower(variable.String()) names[i] = varName - param := expression.NewProcedureParam(varName) + param := expression.NewProcedureParam(varName, typ) inScope.proc.AddVar(param) inScope.newColumn(scopeColumn{col: varName, typ: typ, scalar: param}) } diff --git a/sql/planbuilder/select.go b/sql/planbuilder/select.go index f58c07f0e3..ad8d864f8e 100644 --- a/sql/planbuilder/select.go +++ b/sql/planbuilder/select.go @@ -126,7 +126,49 @@ func (b *Builder) buildSelect(inScope *scope, s *ast.Select) (outScope *scope) { func (b *Builder) buildLimit(inScope *scope, limit *ast.Limit) sql.Expression { if limit != nil { - l := b.buildScalar(inScope, limit.Rowcount) + return b.buildLimitVal(inScope, limit.Rowcount) + } + return nil +} + +func (b *Builder) buildOffset(inScope *scope, limit *ast.Limit) sql.Expression { + if limit != nil && limit.Offset != nil { + e := b.buildLimitVal(inScope, limit.Offset) + if lit, ok := e.(*expression.Literal); ok { + // Check if offset starts at 0, if so, we can just remove the offset node. + // Only cast to int8, as a larger int type just means a non-zero offset. + if val, err := lit.Eval(b.ctx, nil); err == nil { + if v, ok := val.(int64); ok && v == 0 { + return nil + } + } + } + return e + } + return nil +} + +// buildLimitVal resolves a literal numeric type or a numeric +// prodecure parameter +func (b *Builder) buildLimitVal(inScope *scope, e ast.Expr) sql.Expression { + switch e := e.(type) { + case *ast.ColName: + if inScope.procActive() { + if col, ok := inScope.proc.GetVar(e.String()); ok { + // proc param is OK + if pp, ok := col.scalarGf().(*expression.ProcedureParam); ok { + if !pp.Type().Promote().Equals(types.Int64) { + err := fmt.Errorf("the variable '%s' has a non-integer based type: %s", pp.Name(), pp.Type().String()) + b.handleErr(err) + } + return pp + } + } + } + err := fmt.Errorf("limit expression expected to be numeric or prodecure parameter, found invalid column: %s", e.String()) + b.handleErr(err) + default: + l := b.buildScalar(inScope, e) return b.typeCoerceLiteral(l) } return nil @@ -150,22 +192,6 @@ func (b *Builder) typeCoerceLiteral(e sql.Expression) sql.Expression { return nil } -func (b *Builder) buildOffset(inScope *scope, limit *ast.Limit) sql.Expression { - if limit != nil && limit.Offset != nil { - rowCount := b.buildScalar(inScope, limit.Offset) - rowCount = b.typeCoerceLiteral(rowCount) - // Check if offset starts at 0, if so, we can just remove the offset node. - // Only cast to int8, as a larger int type just means a non-zero offset. - if val, err := rowCount.Eval(b.ctx, nil); err == nil { - if v, ok := val.(int64); ok && v == 0 { - return nil - } - } - return rowCount - } - return nil -} - // buildDistinct creates a new plan.Distinct node if the query has a DISTINCT option. // If the query has both DISTINCT and ALL, an error is returned. func (b *Builder) buildDistinct(inScope *scope, distinct bool) {