diff --git a/sql/analyzer/apply_indexes_from_outer_scope.go b/sql/analyzer/apply_indexes_from_outer_scope.go index 62415c33bf..d457e9c138 100644 --- a/sql/analyzer/apply_indexes_from_outer_scope.go +++ b/sql/analyzer/apply_indexes_from_outer_scope.go @@ -18,8 +18,6 @@ import ( "fmt" "strings" - "github.com/dolthub/go-mysql-server/sql/fixidx" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" @@ -282,7 +280,7 @@ func getSubqueryIndexes( func tablesInScope(scope *plan.Scope) []string { tables := make(map[string]bool) for _, node := range scope.InnerToOuter() { - for _, col := range fixidx.Schemas(node.Children()) { + for _, col := range Schemas(node.Children()) { tables[col.Source] = true } } @@ -292,3 +290,12 @@ func tablesInScope(scope *plan.Scope) []string { } return tableSlice } + +// Schemas returns the Schemas for the nodes given appended in to a single one +func Schemas(nodes []sql.Node) sql.Schema { + var schema sql.Schema + for _, n := range nodes { + schema = append(schema, n.Schema()...) + } + return schema +} diff --git a/sql/analyzer/inserts.go b/sql/analyzer/inserts.go index 1581132c42..5fc21f3b82 100644 --- a/sql/analyzer/inserts.go +++ b/sql/analyzer/inserts.go @@ -15,11 +15,11 @@ package analyzer import ( + "fmt" "strings" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" - "github.com/dolthub/go-mysql-server/sql/fixidx" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" @@ -149,10 +149,18 @@ func wrapRowSource(ctx *sql.Context, scope *plan.Scope, logFn func(string, ...an } var err error + colIdx := make(map[string]int) + for i, c := range schema { + colIdx[fmt.Sprintf("%s.%s", strings.ToLower(c.Source), strings.ToLower(c.Name))] = i + } def, _, err := transform.Expr(defaultExpr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { switch e := e.(type) { case *expression.GetField: - return fixidx.FixFieldIndexes(scope, logFn, schema, e.WithTable(destTbl.Name())) + idx, ok := colIdx[strings.ToLower(e.WithTable(destTbl.Name()).String())] + if !ok { + return nil, transform.SameTree, fmt.Errorf("field not found: %s", e.String()) + } + return e.WithIndex(idx), transform.NewTree, nil default: return e, transform.SameTree, nil } diff --git a/sql/analyzer/resolve_create_select.go b/sql/analyzer/resolve_create_select.go index 8f3b0ab285..deb5249581 100644 --- a/sql/analyzer/resolve_create_select.go +++ b/sql/analyzer/resolve_create_select.go @@ -6,6 +6,10 @@ import ( "github.com/dolthub/go-mysql-server/sql/transform" ) +// todo this should be split into two rules. The first should be in +// planbuilder and only bind the child select, strip/merge schemas. +// a second rule should finalize analysis of the source/dest nodes +// (skipping passthrough rule). func resolveCreateSelect(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { ct, ok := n.(*plan.CreateTable) if !ok || ct.Select() == nil { diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index baf616b213..0af483789d 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -19,8 +19,6 @@ import ( "reflect" "strings" - "github.com/dolthub/go-mysql-server/sql/fixidx" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors" "github.com/dolthub/go-mysql-server/sql/expression" @@ -619,7 +617,7 @@ func validateSubqueryColumns(ctx *sql.Context, a *Analyzer, n sql.Node, scope *p return true } - outerScopeRowLen := len(scope.Schema()) + len(fixidx.Schemas(n.Children())) + outerScopeRowLen := len(scope.Schema()) + len(Schemas(n.Children())) transform.Inspect(s.Query, func(n sql.Node) bool { if n == nil { return true @@ -634,7 +632,7 @@ func validateSubqueryColumns(ctx *sql.Context, a *Analyzer, n sql.Node, scope *p default: } if es, ok := n.(sql.Expressioner); ok { - childSchemaLen := len(fixidx.Schemas(n.Children())) + childSchemaLen := len(Schemas(n.Children())) for _, e := range es.Expressions() { sql.Inspect(e, func(e sql.Expression) bool { if gf, ok := e.(*expression.GetField); ok { diff --git a/sql/fixidx/fix_field_indexes.go b/sql/fixidx/fix_field_indexes.go deleted file mode 100644 index 11f4820fee..0000000000 --- a/sql/fixidx/fix_field_indexes.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2020-2021 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fixidx - -import ( - "strings" - - "gopkg.in/src-d/go-errors.v1" - - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" - "github.com/dolthub/go-mysql-server/sql/plan" - "github.com/dolthub/go-mysql-server/sql/transform" -) - -// ErrFieldMissing is returned when the field is not on the schema. -var ErrFieldMissing = errors.NewKind("field %q is not on schema") - -// FixFieldIndexes transforms the given expression by correcting the indexes of columns in GetField expressions, -// according to the schema given. Used when combining multiple tables together into a single join result, or when -// otherwise changing / combining schemas in the node tree. -func FixFieldIndexes(scope *plan.Scope, logFn func(string, ...any), schema sql.Schema, exp sql.Expression) (sql.Expression, transform.TreeIdentity, error) { - scopeLen := len(scope.Schema()) - - return transform.Expr(exp, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { - switch e := e.(type) { - // For each GetField expression, re-index it with the appropriate index from the schema. - case *expression.GetField: - partial := -1 - for i, col := range schema { - newIndex := scopeLen + i - if strings.EqualFold(e.Name(), col.Name) && strings.EqualFold(e.Table(), col.Source) { - if e.Table() == "" && e.Name() != col.Name { - // aliases with same lowered representation need to case-sensitive match - partial = newIndex - continue - } - if newIndex != e.Index() { - if logFn != nil { - logFn("Rewriting field %s.%s from index %d to %d", e.Table(), e.Name(), e.Index(), newIndex) - } - return e.WithIndex(newIndex), transform.NewTree, nil - } - return e, transform.SameTree, nil - } - } - if partial >= 0 { - if partial != e.Index() { - if logFn != nil { - logFn("Rewriting field %s.%s from index %d to %d", e.Table(), e.Name(), e.Index(), partial) - } - return e.WithIndex(partial), transform.NewTree, nil - } - return e, transform.SameTree, nil - } - - // If we didn't find the column in the schema of the node itself, look outward in surrounding scopes. Work - // inner-to-outer, in accordance with MySQL scope naming precedence rules. - offset := 0 - for _, n := range scope.InnerToOuter() { - schema := Schemas(n.Children()) - offset += len(schema) - for i, col := range schema { - if strings.EqualFold(e.Name(), col.Name) && strings.EqualFold(e.Table(), col.Source) { - newIndex := scopeLen - offset + i - if e.Index() != newIndex { - if logFn != nil { - logFn("Rewriting field %s.%s from index %d to %d", e.Table(), e.Name(), e.Index(), newIndex) - } - return e.WithIndex(newIndex), transform.NewTree, nil - } - return e, transform.SameTree, nil - } - } - } - - return nil, transform.SameTree, ErrFieldMissing.New(e.Name()) - } - - return e, transform.SameTree, nil - }) -} - -// Schemas returns the Schemas for the nodes given appended in to a single one -func Schemas(nodes []sql.Node) sql.Schema { - var schema sql.Schema - for _, n := range nodes { - schema = append(schema, n.Schema()...) - } - return schema -} diff --git a/sql/func_deps.go b/sql/func_deps.go index 166d9ced78..81b884401c 100644 --- a/sql/func_deps.go +++ b/sql/func_deps.go @@ -82,6 +82,9 @@ func (k *Key) implies(other Key) bool { // - Do a set of grouping columns constitute a strict key // (only_full_group_by) // +// The docs here provide a summary of how functional dependencies work: +// - /~https://github.com/cockroachdb/cockroach/blob/5a6aa768cd945118e795d1086ba6f6365f6d1284/pkg/sql/opt/props/func_dep.go#L420 +// // This object expects fields to be set in the following order: // - notNull: what columns are non-nullable? // - consts: what columns are constant? diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index a01bd2fed8..7bc79e547f 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -18,7 +18,6 @@ import ( "bufio" "fmt" "io" - "log" "strings" "sync" "time" @@ -31,7 +30,6 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" - "github.com/dolthub/go-mysql-server/sql/fixidx" "github.com/dolthub/go-mysql-server/sql/fulltext" "github.com/dolthub/go-mysql-server/sql/mysql_db" "github.com/dolthub/go-mysql-server/sql/plan" @@ -197,10 +195,18 @@ func (l loadDataIter) parseFields(ctx *sql.Context, line string) ([]sql.Expressi } var def sql.Expression = f.Default var err error + colIdx := make(map[string]int) + for i, c := range l.destSch { + colIdx[fmt.Sprintf("%s.%s", strings.ToLower(c.Source), strings.ToLower(c.Name))] = i + } def, _, err = transform.Expr(f.Default, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { switch e := e.(type) { case *expression.GetField: - return fixidx.FixFieldIndexes(nil, log.Printf, l.destSch, e.WithTable(l.destSch[0].Source)) + idx, ok := colIdx[strings.ToLower(e.String())] + if !ok { + return nil, transform.SameTree, fmt.Errorf("field not found: %s", e.String()) + } + return e.WithIndex(idx), transform.NewTree, nil default: return e, transform.SameTree, nil }