Skip to content

Commit

Permalink
chore: implement Delete config policies API handlers (#9927)
Browse files Browse the repository at this point in the history
  • Loading branch information
amandavialva01 authored Sep 16, 2024
1 parent 2d12be1 commit 26b0954
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 93 deletions.
69 changes: 65 additions & 4 deletions master/internal/api_config_policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,24 @@ import (
"context"
"fmt"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/structpb"
"gopkg.in/yaml.v3"

"github.com/determined-ai/determined/master/internal/cluster"
"github.com/determined-ai/determined/master/internal/configpolicy"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/license"
"github.com/determined-ai/determined/master/internal/workspace"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

const (
noWorkloadErr = "no workload type specified."
)

func stubData() (*structpb.Struct, error) {
const yamlString = `
invariant_config:
Expand Down Expand Up @@ -74,21 +85,71 @@ func (*apiServer) GetGlobalConfigPolicies(
}

// Delete workspace task config policies.
func (*apiServer) DeleteWorkspaceConfigPolicies(
func (a *apiServer) DeleteWorkspaceConfigPolicies(
ctx context.Context, req *apiv1.DeleteWorkspaceConfigPoliciesRequest,
) (*apiv1.DeleteWorkspaceConfigPoliciesResponse, error) {
license.RequireLicense("manage config policies")

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
}

w, err := a.GetWorkspaceByID(ctx, req.WorkspaceId, *curUser, false)
if err != nil {
return nil, err
}

err = workspace.AuthZProvider.Get().CanModifyWorkspaceConfigPolicies(ctx, *curUser, w)
if err != nil {
return nil, err
}

if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
errMessage := fmt.Sprintf("invalid workload type: %s.", req.WorkloadType)
if len(req.WorkloadType) == 0 {
errMessage = noWorkloadErr
}
return nil, status.Errorf(codes.InvalidArgument, errMessage)
}

err = configpolicy.DeleteConfigPolicies(ctx, ptrs.Ptr(int(req.WorkspaceId)),
req.WorkloadType)
if err != nil {
return nil, err
}
return &apiv1.DeleteWorkspaceConfigPoliciesResponse{}, nil
}

// Delete global task config policies.
func (*apiServer) DeleteGlobalConfigPolicies(
func (a *apiServer) DeleteGlobalConfigPolicies(
ctx context.Context, req *apiv1.DeleteGlobalConfigPoliciesRequest,
) (*apiv1.DeleteGlobalConfigPoliciesResponse, error) {
license.RequireLicense("manage config policies")

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
}

permErr, err := cluster.AuthZProvider.Get().CanModifyGlobalConfigPolicies(ctx, curUser)
if err != nil {
return nil, err
} else if permErr != nil {
return nil, permErr
}

if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
errMessage := fmt.Sprintf("invalid workload type: %s.", req.WorkloadType)
if len(req.WorkloadType) == 0 {
errMessage = noWorkloadErr
}
return nil, status.Errorf(codes.InvalidArgument, errMessage)
}

err = configpolicy.DeleteConfigPolicies(ctx, nil, req.WorkloadType)
if err != nil {
return nil, err
}
return &apiv1.DeleteGlobalConfigPoliciesResponse{}, nil
}
214 changes: 214 additions & 0 deletions master/internal/api_config_policies_intg_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package internal

import (
"database/sql"
"fmt"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/determined-ai/determined/master/internal/configpolicy"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/user"
"github.com/determined-ai/determined/master/internal/workspace"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/test/testutils"
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

func TestDeleteWorkspaceConfigPolicies(t *testing.T) {
// TODO (CM-520): Make test cases for experiment config policies.

// Create one workspace and continuously set and delete config policies from there
api, curUser, ctx := setupAPITest(t, nil)
testutils.MustLoadLicenseAndKeyFromFilesystem("../../")

wkspResp, err := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()})
require.NoError(t, err)
workspaceID := wkspResp.Workspace.Id
cases := []struct {
name string
req *apiv1.DeleteWorkspaceConfigPoliciesRequest
err error
}{
{
"invalid workload type",
&apiv1.DeleteWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID,
WorkloadType: "bad workload type",
},
fmt.Errorf("invalid workload type"),
},
{
"empty workload type",
&apiv1.DeleteWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID,
WorkloadType: "",
},
fmt.Errorf(noWorkloadErr),
},
{
"valid request",
&apiv1.DeleteWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID,
WorkloadType: model.NTSCType,
},
nil,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
ntscPolicies := &model.NTSCTaskConfigPolicies{
WorkspaceID: ptrs.Ptr(int(test.req.WorkspaceId)),
WorkloadType: model.NTSCType,
LastUpdatedBy: curUser.ID,
}
err = configpolicy.SetNTSCConfigPolicies(ctx, ntscPolicies)
require.NoError(t, err)

resp, err := api.DeleteWorkspaceConfigPolicies(ctx, test.req)
if test.err != nil {
require.ErrorContains(t, err, test.err.Error())
return
}
// Delete successful?
require.NoError(t, err)
require.NotNil(t, resp)

// Policies removed?
policies, err := configpolicy.GetNTSCConfigPolicies(ctx, ptrs.Ptr(int(workspaceID)))
require.Nil(t, policies)
require.ErrorIs(t, err, sql.ErrNoRows)
})
}

// Test invalid workspace ID.
resp, err := api.DeleteWorkspaceConfigPolicies(ctx, &apiv1.DeleteWorkspaceConfigPoliciesRequest{
WorkspaceId: -1,
WorkloadType: model.NTSCType,
})
require.Nil(t, resp)
require.ErrorContains(t, err, "not found")
}

func TestDeleteGlobalConfigPolicies(t *testing.T) {
// TODO (CM-520): Make test cases for experiment config policies.

api, curUser, ctx := setupAPITest(t, nil)
testutils.MustLoadLicenseAndKeyFromFilesystem("../../")

cases := []struct {
name string
req *apiv1.DeleteGlobalConfigPoliciesRequest
err error
}{
{
"invalid workload type",
&apiv1.DeleteGlobalConfigPoliciesRequest{
WorkloadType: "invalid workload type",
},
fmt.Errorf("invalid workload type"),
},
{
"empty workload type",
&apiv1.DeleteGlobalConfigPoliciesRequest{
WorkloadType: "",
},
fmt.Errorf(noWorkloadErr),
},
{
"valid request",
&apiv1.DeleteGlobalConfigPoliciesRequest{
WorkloadType: model.NTSCType,
},
nil,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
err := configpolicy.SetNTSCConfigPolicies(ctx, &model.NTSCTaskConfigPolicies{
WorkloadType: model.NTSCType,
LastUpdatedBy: curUser.ID,
})
require.NoError(t, err)

resp, err := api.DeleteGlobalConfigPolicies(ctx, test.req)
if test.err != nil {
require.ErrorContains(t, err, test.err.Error())
return
}
// Delete successful?
require.NoError(t, err)
require.NotNil(t, resp)

// Policies removed?
policies, err := configpolicy.GetNTSCConfigPolicies(ctx, nil)
require.Nil(t, policies)
require.ErrorIs(t, err, sql.ErrNoRows)
})
}
}

func TestBasicRBACConfigPolicyPerms(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
curUser.Admin = false
err := user.Update(ctx, &curUser, []string{"admin"}, nil)
require.NoError(t, err)

testutils.MustLoadLicenseAndKeyFromFilesystem("../../")

resp, err := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()})
require.NoError(t, err)
wkspID := resp.Workspace.Id

wksp, err := workspace.WorkspaceByName(ctx, resp.Workspace.Name)
require.NoError(t, err)
newUser, err := db.HackAddUser(ctx, &model.User{Username: uuid.NewString()})
require.NoError(t, err)

wksp.UserID = newUser
_, err = db.Bun().NewUpdate().Model(wksp).Where("id = ?", wksp.ID).Exec(ctx)
require.NoError(t, err)

cases := []struct {
name string
req func() error
err error
}{
{
"delete workspace config policies",
func() error {
_, err := api.DeleteWorkspaceConfigPolicies(ctx,
&apiv1.DeleteWorkspaceConfigPoliciesRequest{
WorkspaceId: wkspID,
WorkloadType: model.NTSCType,
},
)
return err
},
fmt.Errorf("only admins may set config policies for workspaces"),
},
{
"delete global config policies",
func() error {
_, err := api.DeleteGlobalConfigPolicies(ctx,
&apiv1.DeleteGlobalConfigPoliciesRequest{
WorkloadType: model.NTSCType,
},
)
return err
},
fmt.Errorf("PermissionDenied"),
},
}
for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
err := test.req()
require.ErrorContains(t, err, test.err.Error())
})
}
}
20 changes: 8 additions & 12 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ const (
// SetNTSCConfigPolicies adds the NTSC invariant config and constraints config policies to
// the database.
func SetNTSCConfigPolicies(ctx context.Context,
experimentTCP *model.NTSCTaskConfigPolicies,
ntscTCPs *model.NTSCTaskConfigPolicies,
) error {
return db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return SetNTSCConfigPoliciesTx(ctx, &tx, experimentTCP)
return SetNTSCConfigPoliciesTx(ctx, &tx, ntscTCPs)
})
}

Expand All @@ -35,7 +35,7 @@ func SetNTSCConfigPoliciesTx(ctx context.Context, tx *bun.Tx,
) error {
if ntscTCPs.WorkloadType != model.NTSCType {
return status.Error(codes.InvalidArgument,
"invalid workload type for config policies: "+ntscTCPs.WorkloadType.String())
"invalid workload type for config policies: "+ntscTCPs.WorkloadType)
}

q := `
Expand All @@ -54,7 +54,7 @@ func SetNTSCConfigPoliciesTx(ctx context.Context, tx *bun.Tx,
constraints = ?
`
}
_, err := db.Bun().NewRaw(q, ntscTCPs.WorkspaceID, model.NTSCType.String(),
_, err := db.Bun().NewRaw(q, ntscTCPs.WorkspaceID, model.NTSCType,
ntscTCPs.LastUpdatedBy, ntscTCPs.LastUpdatedTime, ntscTCPs.InvariantConfig,
ntscTCPs.Constraints, ntscTCPs.LastUpdatedBy, ntscTCPs.LastUpdatedTime,
ntscTCPs.InvariantConfig, ntscTCPs.Constraints).
Expand Down Expand Up @@ -91,12 +91,8 @@ func GetNTSCConfigPolicies(ctx context.Context,
// DeleteConfigPolicies deletes the invariant experiment config and constraints for the
// given scope (global or workspace-level) and workload type.
func DeleteConfigPolicies(ctx context.Context,
scope *int, workloadType model.WorkloadType,
scope *int, workloadType string,
) error {
if workloadType == model.UnknownType {
return status.Error(codes.InvalidArgument,
"invalid workload type for config policies: "+workloadType.String())
}
wkspQuery := wkspIDQuery
if scope == nil {
wkspQuery = wkspIDGlobalQuery
Expand All @@ -105,15 +101,15 @@ func DeleteConfigPolicies(ctx context.Context,
_, err := db.Bun().NewDelete().
Table("task_config_policies").
Where(wkspQuery, scope).
Where("workload_type = ?", workloadType.String()).
Where("workload_type = ?", workloadType).
Exec(ctx)
if err != nil {
if scope == nil {
return fmt.Errorf("error deleting global %s config policies:%w",
strings.ToLower(workloadType.String()), err)
strings.ToLower(workloadType), err)
}
return fmt.Errorf("error deleting %s config policies for workspace with ID %d: %w",
strings.ToLower(workloadType.String()), scope, err)
strings.ToLower(workloadType), scope, err)
}
return nil
}
Loading

0 comments on commit 26b0954

Please sign in to comment.