Skip to content

Commit

Permalink
ddl: don't rely on expression.Column.ColName (#11255)
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros authored Jul 23, 2019
1 parent 3a1ba35 commit d47b655
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 20 deletions.
5 changes: 2 additions & 3 deletions ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/ddl/util"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/meta"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -596,9 +595,9 @@ func generateOriginDefaultValue(col *model.ColumnInfo) (interface{}, error) {
return odValue, nil
}

func findColumnInIndexCols(c *expression.Column, cols []*ast.IndexColName) bool {
func findColumnInIndexCols(c *model.ColumnInfo, cols []*ast.IndexColName) bool {
for _, c1 := range cols {
if c.ColName.L == c1.Column.Name.L {
if c.Name.L == c1.Column.Name.L {
return true
}
}
Expand Down
68 changes: 51 additions & 17 deletions ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
Expand Down Expand Up @@ -179,13 +180,13 @@ func checkPartitionNameUnique(tbInfo *model.TableInfo, pi *model.PartitionInfo)

// See /~https://github.com/mysql/mysql-server/blob/5.7/sql/item_func.h#L387
func hasTimestampField(ctx sessionctx.Context, tblInfo *model.TableInfo, expr ast.ExprNode) (bool, error) {
partCols, err := partitionColumns(ctx, tblInfo, expr)
partCols, err := checkPartitionColumns(tblInfo, expr)
if err != nil {
return false, err
}

for _, c := range partCols {
if c.GetType().Tp == mysql.TypeTimestamp {
if c.FieldType.Tp == mysql.TypeTimestamp {
return true, nil
}
}
Expand All @@ -195,13 +196,13 @@ func hasTimestampField(ctx sessionctx.Context, tblInfo *model.TableInfo, expr as

// See /~https://github.com/mysql/mysql-server/blob/5.7/sql/item_func.h#L399
func hasDateField(ctx sessionctx.Context, tblInfo *model.TableInfo, expr ast.ExprNode) (bool, error) {
partCols, err := partitionColumns(ctx, tblInfo, expr)
partCols, err := checkPartitionColumns(tblInfo, expr)
if err != nil {
return false, err
}

for _, c := range partCols {
if c.GetType().Tp == mysql.TypeDate || c.GetType().Tp == mysql.TypeDatetime {
if c.FieldType.Tp == mysql.TypeDate || c.FieldType.Tp == mysql.TypeDatetime {
return true, nil
}
}
Expand All @@ -211,13 +212,13 @@ func hasDateField(ctx sessionctx.Context, tblInfo *model.TableInfo, expr ast.Exp

// See /~https://github.com/mysql/mysql-server/blob/5.7/sql/item_func.h#L412
func hasTimeField(ctx sessionctx.Context, tblInfo *model.TableInfo, expr ast.ExprNode) (bool, error) {
partCols, err := partitionColumns(ctx, tblInfo, expr)
partCols, err := checkPartitionColumns(tblInfo, expr)
if err != nil {
return false, err
}

for _, c := range partCols {
if c.GetType().Tp == mysql.TypeDatetime || c.GetType().Tp == mysql.TypeDuration {
if c.FieldType.Tp == mysql.TypeDatetime || c.FieldType.Tp == mysql.TypeDuration {
return true, nil
}
}
Expand Down Expand Up @@ -283,7 +284,7 @@ func checkPartitionFuncValid(ctx sessionctx.Context, tblInfo *model.TableInfo, e
}

// check constant.
_, err := partitionColumns(ctx, tblInfo, expr)
_, err := checkPartitionColumns(tblInfo, expr)
return err
}

Expand All @@ -302,12 +303,12 @@ func checkPartitionFunc(isTimezoneDependent bool, err error) error {
return nil
}

func partitionColumns(ctx sessionctx.Context, tblInfo *model.TableInfo, expr ast.ExprNode) ([]*expression.Column, error) {
func checkPartitionColumns(tblInfo *model.TableInfo, expr ast.ExprNode) ([]*model.ColumnInfo, error) {
buf := new(bytes.Buffer)
expr.Format(buf)
partCols, err := extractPartitionColumns(ctx, buf.String(), tblInfo)
partCols, err := extractPartitionColumns(buf.String(), tblInfo)
if err != nil {
return nil, errors.Trace(err)
return nil, err
}

if len(partCols) == 0 {
Expand Down Expand Up @@ -597,7 +598,7 @@ func checkRangePartitioningKeysConstraints(sctx sessionctx.Context, s *ast.Creat
// Parse partitioning key, extract the column names in the partitioning key to slice.
buf := new(bytes.Buffer)
s.Partition.Expr.Format(buf)
partCols, err := extractPartitionColumns(sctx, buf.String(), tblInfo)
partCols, err := extractPartitionColumns(buf.String(), tblInfo)
if err != nil {
return err
}
Expand All @@ -621,7 +622,7 @@ func checkRangePartitioningKeysConstraints(sctx sessionctx.Context, s *ast.Creat

func checkPartitionKeysConstraint(sctx sessionctx.Context, partExpr string, idxColNames []*ast.IndexColName, tblInfo *model.TableInfo) error {
// Parse partitioning key, extract the column names in the partitioning key to slice.
partCols, err := extractPartitionColumns(sctx, partExpr, tblInfo)
partCols, err := extractPartitionColumns(partExpr, tblInfo)
if err != nil {
return err
}
Expand All @@ -634,16 +635,49 @@ func checkPartitionKeysConstraint(sctx sessionctx.Context, partExpr string, idxC
return nil
}

func extractPartitionColumns(sctx sessionctx.Context, partExpr string, tblInfo *model.TableInfo) ([]*expression.Column, error) {
e, err := expression.ParseSimpleExprWithTableInfo(sctx, partExpr, tblInfo)
type columnNameExtractor struct {
extractedColumns []*model.ColumnInfo
tblInfo *model.TableInfo
err error
}

func (cne *columnNameExtractor) Enter(node ast.Node) (ast.Node, bool) {
return node, false
}

func (cne *columnNameExtractor) Leave(node ast.Node) (ast.Node, bool) {
if c, ok := node.(*ast.ColumnNameExpr); ok {
for _, info := range cne.tblInfo.Columns {
if info.Name.L == c.Name.Name.L {
cne.extractedColumns = append(cne.extractedColumns, info)
return node, true
}
}
cne.err = ErrBadField.GenWithStackByArgs(c.Name.Name.O, "expression")
return nil, false
}
return node, true
}

func extractPartitionColumns(partExpr string, tblInfo *model.TableInfo) ([]*model.ColumnInfo, error) {
partExpr = "select " + partExpr
stmts, _, err := parser.New().Parse(partExpr, "", "")
if err != nil {
return nil, errors.Trace(err)
return nil, err
}
extractor := &columnNameExtractor{
tblInfo: tblInfo,
extractedColumns: make([]*model.ColumnInfo, 0),
}
stmts[0].Accept(extractor)
if extractor.err != nil {
return nil, extractor.err
}
return expression.ExtractColumns(e), nil
return extractor.extractedColumns, nil
}

// checkUniqueKeyIncludePartKey checks that the partitioning key is included in the constraint.
func checkUniqueKeyIncludePartKey(partCols []*expression.Column, idxCols []*ast.IndexColName) bool {
func checkUniqueKeyIncludePartKey(partCols []*model.ColumnInfo, idxCols []*ast.IndexColName) bool {
for _, partCol := range partCols {
if !findColumnInIndexCols(partCol, idxCols) {
return false
Expand Down

0 comments on commit d47b655

Please sign in to comment.