Skip to content

Commit

Permalink
planner/core: fix a bug that check update privilege use wrong `AsName…
Browse files Browse the repository at this point in the history
…` and DBName (#9003) (#10157)
  • Loading branch information
bb7133 authored and zz-jason committed Apr 17, 2019
1 parent 33f1f79 commit 8295bc0
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 7 deletions.
25 changes: 18 additions & 7 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) {

func tblInfoFromCol(from ast.ResultSetNode, col *expression.Column) *model.TableInfo {
var tableList []*ast.TableName
tableList = extractTableList(from, tableList)
tableList = extractTableList(from, tableList, true)
for _, field := range tableList {
if field.Name.L == col.TblName.L {
return field.TableInfo
Expand Down Expand Up @@ -2144,7 +2144,7 @@ func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) {
}

var tableList []*ast.TableName
tableList = extractTableList(sel.From.TableRefs, tableList)
tableList = extractTableList(sel.From.TableRefs, tableList, false)
for _, t := range tableList {
dbName := t.Schema.L
if dbName == "" {
Expand Down Expand Up @@ -2262,6 +2262,15 @@ func (b *planBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A
p = np
newList = append(newList, &expression.Assignment{Col: col, Expr: newExpr})
}
for _, assign := range newList {
col := assign.Col

dbName := col.DBName.L
if dbName == "" {
dbName = b.ctx.GetSessionVars().CurrentDB
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, dbName, col.OrigTblName.L, "")
}
return newList, p, nil
}

Expand Down Expand Up @@ -2363,7 +2372,7 @@ func (b *planBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) {
del.SetSchema(expression.NewSchema())

var tableList []*ast.TableName
tableList = extractTableList(delete.TableRefs.TableRefs, tableList)
tableList = extractTableList(delete.TableRefs.TableRefs, tableList, true)

// Collect visitInfo.
if delete.Tables != nil {
Expand Down Expand Up @@ -2416,14 +2425,16 @@ func (b *planBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) {
}

// extractTableList extracts all the TableNames from node.
func extractTableList(node ast.ResultSetNode, input []*ast.TableName) []*ast.TableName {
// If asName is true, extract AsName prior to OrigName.
// Privilege check should use OrigName, while expression may use AsName.
func extractTableList(node ast.ResultSetNode, input []*ast.TableName, asName bool) []*ast.TableName {
switch x := node.(type) {
case *ast.Join:
input = extractTableList(x.Left, input)
input = extractTableList(x.Right, input)
input = extractTableList(x.Left, input, asName)
input = extractTableList(x.Right, input, asName)
case *ast.TableSource:
if s, ok := x.Source.(*ast.TableName); ok {
if x.AsName.L != "" {
if x.AsName.L != "" && asName {
newTableName := *s
newTableName.Name = x.AsName
input = append(input, &newTableName)
Expand Down
7 changes: 7 additions & 0 deletions planner/core/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,13 @@ func (s *testPlanSuite) TestVisitInfo(c *C) {
{mysql.SelectPriv, "test", "t", ""},
},
},
{
sql: "update t a1 set a1.a = a1.a + 1",
ans: []visitInfo{
{mysql.UpdatePriv, "test", "t", ""},
{mysql.SelectPriv, "test", "t", ""},
},
},
{
sql: "select a, sum(e) from t group by a",
ans: []visitInfo{
Expand Down
18 changes: 18 additions & 0 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2348,6 +2348,24 @@ func (s *testSessionSuite) TestSetGroupConcatMaxLen(c *C) {
c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue, Commentf("err %v", err))
}

func (s *testSessionSuite) TestUpdatePrivilege(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)

// Fix issue 8911
tk.MustExec("create database weperk")
tk.MustExec("use weperk")
tk.MustExec("create table tb_wehub_server (id int, active_count int, used_count int)")
tk.MustExec("create user 'weperk'")
tk.MustExec("grant all privileges on weperk.* to 'weperk'@'%'")
tk.MustExec("flush privileges;")

tk1 := testkit.NewTestKitWithInit(c, s.store)
c.Assert(tk1.Se.Auth(&auth.UserIdentity{Username: "weperk", Hostname: "%"},
[]byte(""), []byte("")), IsTrue)
tk1.MustExec("use weperk")
tk1.MustExec("update tb_wehub_server a set a.active_count=a.active_count+1,a.used_count=a.used_count+1 where id=1")
}

func (s *testSessionSuite) TestTxnGoString(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("drop table if exists gostr;")
Expand Down

0 comments on commit 8295bc0

Please sign in to comment.