Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple statements per ALTER #68

Merged
merged 1 commit into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 84 additions & 47 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ func (*Delete) iStatement() {}
func (*Set) iStatement() {}
func (*DBDDL) iStatement() {}
func (*DDL) iStatement() {}
func (*MultiAlterDDL) iStatement() {}
func (*Explain) iStatement() {}
func (*Show) iStatement() {}
func (*Use) iStatement() {}
Expand Down Expand Up @@ -1432,6 +1433,36 @@ func (c Characteristic) String() string {
return string(c.Type)
}

// MultiAlterDDL represents multiple ALTER statements on a single table.
type MultiAlterDDL struct {
Table TableName
Statements []*DDL
}

var _ SQLNode = (*MultiAlterDDL)(nil)

// Format implements SQLNode.
func (m *MultiAlterDDL) Format(buf *TrackedBuffer) {
buf.Myprintf("alter table %v", m.Table)
for i, ddl := range m.Statements {
if i > 0 {
buf.Myprintf(",")
}
ddl.alterFormat(buf)
}
}

// walkSubtree implements SQLNode.
func (m *MultiAlterDDL) walkSubtree(visit Visit) error {
for _, ddl := range m.Statements {
err := ddl.walkSubtree(visit)
if err != nil {
return err
}
}
return nil
}

// DDL represents a CREATE, ALTER, DROP, RENAME, TRUNCATE or ANALYZE statement.
type DDL struct {
Action string
Expand Down Expand Up @@ -1608,53 +1639,8 @@ func (node *DDL) Format(buf *TrackedBuffer) {
buf.Myprintf(", %v to %v", node.FromTables[i], node.ToTables[i])
}
case AlterStr:
if node.PartitionSpec != nil {
buf.Myprintf("%s table %v %v", node.Action, node.Table, node.PartitionSpec)
} else if node.ColumnAction == AddStr {
after := ""
if node.ColumnOrder != nil {
if node.ColumnOrder.First {
after = " first"
} else {
after = " after " + node.ColumnOrder.AfterColumn.String()
}
}
buf.Myprintf("%s table %v %s column %v%s", node.Action, node.Table, node.ColumnAction, node.TableSpec, after)
} else if node.ColumnAction == ModifyStr || node.ColumnAction == ChangeStr {
after := ""
if node.ColumnOrder != nil {
if node.ColumnOrder.First {
after = " first"
} else {
after = " after " + node.ColumnOrder.AfterColumn.String()
}
}
buf.Myprintf("%s table %v %s column %v %v%s", node.Action, node.Table, node.ColumnAction, node.Column, node.TableSpec, after)
} else if node.ColumnAction == DropStr {
buf.Myprintf("%s table %v %s column %v", node.Action, node.Table, node.ColumnAction, node.Column)
} else if node.ColumnAction == RenameStr {
buf.Myprintf("%s table %v %s column %v to %v", node.Action, node.Table, node.ColumnAction, node.Column, node.ToColumn)
} else if node.IndexSpec != nil {
buf.Myprintf("%s table %v %v", node.Action, node.Table, node.IndexSpec)
} else if node.ConstraintAction == AddStr && node.TableSpec != nil && len(node.TableSpec.Constraints) == 1 {
switch node.TableSpec.Constraints[0].Details.(type) {
case *ForeignKeyDefinition, *CheckConstraintDefinition:
buf.Myprintf("%s table %v add %v", node.Action, node.Table, node.TableSpec.Constraints[0])
default:
buf.Myprintf("%s table %v", node.Action, node.Table)
}
} else if node.ConstraintAction == DropStr && node.TableSpec != nil && len(node.TableSpec.Constraints) == 1 {
switch node.TableSpec.Constraints[0].Details.(type) {
case *ForeignKeyDefinition:
buf.Myprintf("%s table %v drop foreign key %s", node.Action, node.Table, node.TableSpec.Constraints[0].Name)
case *CheckConstraintDefinition:
buf.Myprintf("%s table %v drop check %s", node.Action, node.Table, node.TableSpec.Constraints[0].Name)
default:
buf.Myprintf("%s table %v drop constraint %s", node.Action, node.Table, node.TableSpec.Constraints[0].Name)
}
} else {
buf.Myprintf("%s table %v", node.Action, node.Table)
}
buf.Myprintf("%s table %v", node.Action, node.Table)
node.alterFormat(buf)
case FlushStr:
buf.Myprintf("%s", node.Action)
case AddAutoIncStr:
Expand All @@ -1676,6 +1662,57 @@ func (node *DDL) walkSubtree(visit Visit) error {
return nil
}

func (node *DDL) alterFormat(buf *TrackedBuffer) {
if node.Action == RenameStr {
buf.Myprintf(" %s to %v", node.Action, node.ToTables[0])
for i := 1; i < len(node.FromTables); i++ {
buf.Myprintf(", %v to %v", node.FromTables[i], node.ToTables[i])
}
} else if node.PartitionSpec != nil {
buf.Myprintf(" %v", node.PartitionSpec)
} else if node.ColumnAction == AddStr {
after := ""
if node.ColumnOrder != nil {
if node.ColumnOrder.First {
after = " first"
} else {
after = " after " + node.ColumnOrder.AfterColumn.String()
}
}
buf.Myprintf(" %s column %v%s", node.ColumnAction, node.TableSpec, after)
} else if node.ColumnAction == ModifyStr || node.ColumnAction == ChangeStr {
after := ""
if node.ColumnOrder != nil {
if node.ColumnOrder.First {
after = " first"
} else {
after = " after " + node.ColumnOrder.AfterColumn.String()
}
}
buf.Myprintf(" %s column %v %v%s",node.ColumnAction, node.Column, node.TableSpec, after)
} else if node.ColumnAction == DropStr {
buf.Myprintf(" %s column %v", node.ColumnAction, node.Column)
} else if node.ColumnAction == RenameStr {
buf.Myprintf(" %s column %v to %v", node.ColumnAction, node.Column, node.ToColumn)
} else if node.IndexSpec != nil {
buf.Myprintf(" %v", node.IndexSpec)
} else if node.ConstraintAction == AddStr && node.TableSpec != nil && len(node.TableSpec.Constraints) == 1 {
switch node.TableSpec.Constraints[0].Details.(type) {
case *ForeignKeyDefinition, *CheckConstraintDefinition:
buf.Myprintf(" add %v", node.TableSpec.Constraints[0])
}
} else if node.ConstraintAction == DropStr && node.TableSpec != nil && len(node.TableSpec.Constraints) == 1 {
switch node.TableSpec.Constraints[0].Details.(type) {
case *ForeignKeyDefinition:
buf.Myprintf(" drop foreign key %s", node.TableSpec.Constraints[0].Name)
case *CheckConstraintDefinition:
buf.Myprintf(" drop check %s", node.TableSpec.Constraints[0].Name)
default:
buf.Myprintf(" drop constraint %s", node.TableSpec.Constraints[0].Name)
}
}
}

// AffectedTables returns the list table names affected by the DDL.
func (node *DDL) AffectedTables() TableNames {
if node.Action == RenameStr || node.Action == DropStr {
Expand Down
30 changes: 19 additions & 11 deletions go/vt/sqlparser/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func TestSetLimit(t *testing.T) {
func TestDDL(t *testing.T) {
testcases := []struct {
query string
output *DDL
output Statement
affected []string
}{{
query: "create table a",
Expand Down Expand Up @@ -250,18 +250,24 @@ func TestDDL(t *testing.T) {
affected: []string{"a", "b"},
}, {
query: "alter table a auto_increment 19",
output: &DDL{
Action: AlterStr,
Table: TableName{Name: NewTableIdent("a")},
AutoIncSpec: &AutoIncSpec{Value: newIntVal("19")},
output: &MultiAlterDDL{
Table: TableName{Name: NewTableIdent("a")},
Statements: []*DDL{{
Action: AlterStr,
Table: TableName{Name: NewTableIdent("a")},
AutoIncSpec: &AutoIncSpec{Value: newIntVal("19")},
}},
},
affected: []string{"a"},
}, {
query: "alter table a auto_increment 19.9",
output: &DDL{
Action: AlterStr,
Table: TableName{Name: NewTableIdent("a")},
AutoIncSpec: &AutoIncSpec{Value: newFloatVal("19.9")},
output: &MultiAlterDDL{
Table: TableName{Name: NewTableIdent("a")},
Statements: []*DDL{{
Action: AlterStr,
Table: TableName{Name: NewTableIdent("a")},
AutoIncSpec: &AutoIncSpec{Value: newFloatVal("19.9")},
}},
},
affected: []string{"a"},
}}
Expand All @@ -277,8 +283,10 @@ func TestDDL(t *testing.T) {
for _, t := range tcase.affected {
want = append(want, TableName{Name: NewTableIdent(t)})
}
if affected := got.(*DDL).AffectedTables(); !reflect.DeepEqual(affected, want) {
t.Errorf("Affected(%s): %v, want %v", tcase.query, affected, want)
if ddl, ok := got.(*DDL); ok {
if affected := ddl.AffectedTables(); !reflect.DeepEqual(affected, want) {
t.Errorf("Affected(%s): %v, want %v", tcase.query, affected, want)
}
}
}
}
Expand Down
18 changes: 13 additions & 5 deletions go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1066,16 +1066,15 @@ var (
output: "alter table a",
}, {
input: "alter table a rename b",
output: "rename table a to b",
output: "alter table a rename to b",
}, {
input: "alter table `By` rename `bY`",
output: "rename table `By` to `bY`",
output: "alter table `By` rename to `bY`",
}, {
input: "alter table a rename to b",
output: "rename table a to b",
}, {
input: "alter table a rename as b",
output: "rename table a to b",
output: "alter table a rename to b",
}, {
input: "alter table a rename index foo to bar",
output: "alter table a rename index foo to bar",
Expand Down Expand Up @@ -1905,6 +1904,16 @@ var (
}, {
input: "alter table a modify foo int unique comment 'a comment here' auto_increment on update current_timestamp() default 0 not null after bar",
output: "alter table a modify column foo (\n\tfoo int not null default 0 on update current_timestamp() auto_increment comment 'a comment here' unique\n) after bar",
}, {
input: "alter table t add column c int unique comment 'a comment here' auto_increment on update current_timestamp() default 0 not null," +
" change foo bar int not null auto_increment first," +
" reorganize partition b into (partition c values less than (:v1), partition d values less than (maxvalue))," +
" add spatial index idx (id)",
output: `alter table t add column (
c int not null default 0 on update current_timestamp() auto_increment comment 'a comment here' unique
), change column foo (
bar int not null auto_increment
) first, reorganize partition b into (partition c values less than (:v1), partition d values less than (maxvalue)), add spatial index idx (id)`,
}, {
input: "delete a.*, b.* from tbl_a a, tbl_b b where a.id = b.id and b.name = 'test'",
output: "delete a, b from tbl_a as a, tbl_b as b where a.id = b.id and b.name = 'test'",
Expand Down Expand Up @@ -2365,7 +2374,6 @@ func TestCaseSensitivity(t *testing.T) {
output: "alter table a",
}, {
input: "alter table A rename to B",
output: "rename table A to B",
}, {
input: "rename table A to B",
}, {
Expand Down
Loading