Skip to content

Commit

Permalink
chore: implement Delete config policies API handler for workspace and…
Browse files Browse the repository at this point in the history
… global [CM-488]
  • Loading branch information
amandavialva01 committed Sep 12, 2024
1 parent 3900742 commit adfea53
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 15 deletions.
65 changes: 61 additions & 4 deletions master/internal/api_config_policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@ import (
"context"
"fmt"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/structpb"
"gopkg.in/yaml.v3"

"github.com/determined-ai/determined/master/internal/cluster"
"github.com/determined-ai/determined/master/internal/configpolicy"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/workspace"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

Expand Down Expand Up @@ -74,21 +82,70 @@ func (*apiServer) GetGlobalConfigPolicies(
}

// Delete workspace task config policies.
func (*apiServer) DeleteWorkspaceConfigPolicies(
func (a *apiServer) DeleteWorkspaceConfigPolicies(
ctx context.Context, req *apiv1.DeleteWorkspaceConfigPoliciesRequest,
) (*apiv1.DeleteWorkspaceConfigPoliciesResponse, error) {
curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
}

err = workspace.AuthZProvider.Get().CanGetWorkspaceID(
ctx, *curUser, req.WorkspaceId,
)
if err != nil {
return nil, err
}

var w model.Workspace
err = db.Bun().NewSelect().Model(&w).Where("id = ?", req.WorkspaceId).Scan(ctx)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}

wksp, err := w.ToProto()
err = workspace.AuthZProvider.Get().CanModifyWorkspaceConfigPolicies(ctx, *curUser, wksp)
if err != nil {
return nil, err
}

if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
return nil, status.Errorf(codes.InvalidArgument, "invalid workload type :%s",
req.WorkloadType)
}

err = configpolicy.DeleteConfigPolicies(ctx, ptrs.Ptr(int(req.WorkspaceId)),
model.WorkloadType(req.WorkloadType))
if err != nil {
return nil, err
}
return &apiv1.DeleteWorkspaceConfigPoliciesResponse{}, nil
}

// Delete global task config policies.
func (*apiServer) DeleteGlobalConfigPolicies(
func (a *apiServer) DeleteGlobalConfigPolicies(
ctx context.Context, req *apiv1.DeleteGlobalConfigPoliciesRequest,
) (*apiv1.DeleteGlobalConfigPoliciesResponse, error) {
curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
}

permErr, err := cluster.AuthZProvider.Get().CanModifyGlobalConfigPolicies(ctx, curUser)
if err != nil {
return nil, err
} else if permErr != nil {
return nil, permErr
}

if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
return nil, status.Errorf(codes.InvalidArgument, "invalid workload type :%s",
req.WorkloadType)
}

err = configpolicy.DeleteConfigPolicies(ctx, nil, model.WorkloadType(req.WorkloadType))
if err != nil {
return nil, err
}
return &apiv1.DeleteGlobalConfigPoliciesResponse{}, nil
}
126 changes: 126 additions & 0 deletions master/internal/api_config_policies_intg_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package internal

import (
"database/sql"
"fmt"
"testing"

"github.com/determined-ai/determined/master/internal/configpolicy"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

func TestDeleteWorkspaceConfigPolicies(t *testing.T) {
// TODO (CM-520): Make test cases for experiment config policies.
// TODO (CM-486): Replace database function calls with actual PUT requests.

// Create one workspace and continuously set and delete config policies from there
api, _, ctx := setupAPITest(t, nil)
wkspResp, err := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()})
require.NoError(t, err)
workspaceID := wkspResp.Workspace.Id
cases := []struct {
name string
req *apiv1.DeleteWorkspaceConfigPoliciesRequest
err error
}{
{"invalid workload type",
&apiv1.DeleteWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID,
WorkloadType: "bad workload type",
},
fmt.Errorf("invalid workload type"),
},
{
"valid request",
&apiv1.DeleteWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID,
WorkloadType: model.NTSCType.String(),
},
nil,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
ntscPolicies := &model.NTSCTaskConfigPolicies{
WorkspaceID: ptrs.Ptr(int(test.req.WorkspaceId)),
WorkloadType: model.NTSCType,
}
configpolicy.SetNTSCConfigPolicies(ctx, ntscPolicies)

resp, err := api.DeleteWorkspaceConfigPolicies(ctx, test.req)
if test.err != nil {
require.ErrorContains(t, err, test.err.Error())
return
}
// Delete successful?
require.Nil(t, err)
require.NotNil(t, resp)

// Policies removed?
policies, err := configpolicy.GetNTSCConfigPolicies(ctx, ptrs.Ptr(int(workspaceID)))
require.Nil(t, policies)
require.ErrorIs(t, err, sql.ErrNoRows)
})
}

// Test invalid workspace ID.
resp, err := api.DeleteWorkspaceConfigPolicies(ctx, &apiv1.DeleteWorkspaceConfigPoliciesRequest{
WorkspaceId: -1,
WorkloadType: model.NTSCType.String(),
})
require.Nil(t, resp)
require.ErrorContains(t, err, "InvalidArgument")
}

func TestDeleteGlobalConfigPolicies(t *testing.T) {
// TODO (CM-520): Make test cases for experiment config policies.
// TODO (CM-486): Replace database function calls with actual PUT requests.

api, _, ctx := setupAPITest(t, nil)
cases := []struct {
name string
req *apiv1.DeleteGlobalConfigPoliciesRequest
err error
}{
{"invalid workload type",
&apiv1.DeleteGlobalConfigPoliciesRequest{
WorkloadType: "bad workload type",
},
fmt.Errorf("invalid workload type"),
},
{
"valid request",
&apiv1.DeleteGlobalConfigPoliciesRequest{
WorkloadType: model.NTSCType.String(),
},
nil,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
configpolicy.SetNTSCConfigPolicies(ctx, &model.NTSCTaskConfigPolicies{
WorkloadType: model.NTSCType,
})

resp, err := api.DeleteGlobalConfigPolicies(ctx, test.req)
if test.err != nil {
require.ErrorContains(t, err, test.err.Error())
return
}
// Delete successful?
require.Nil(t, err)
require.NotNil(t, resp)

// Policies removed?
policies, err := configpolicy.GetNTSCConfigPolicies(ctx, nil)
require.Nil(t, policies)
require.ErrorIs(t, err, sql.ErrNoRows)
})
}
}
8 changes: 2 additions & 6 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ const (
// SetNTSCConfigPolicies adds the NTSC invariant config and constraints config policies to
// the database.
func SetNTSCConfigPolicies(ctx context.Context,
experimentTCP *model.NTSCTaskConfigPolicies,
ntscTCPs *model.NTSCTaskConfigPolicies,
) error {
return db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return SetNTSCConfigPoliciesTx(ctx, &tx, experimentTCP)
return SetNTSCConfigPoliciesTx(ctx, &tx, ntscTCPs)
})
}

Expand Down Expand Up @@ -93,10 +93,6 @@ func GetNTSCConfigPolicies(ctx context.Context,
func DeleteConfigPolicies(ctx context.Context,
scope *int, workloadType model.WorkloadType,
) error {
if workloadType == model.UnknownType {
return status.Error(codes.InvalidArgument,
"invalid workload type for config policies: "+workloadType.String())
}
wkspQuery := wkspIDQuery
if scope == nil {
wkspQuery = wkspIDGlobalQuery
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func TestSetNTSCConfigPolicies(t *testing.T) {
nil,
},
{
"experiment workload type for NTCP policies",
"experiment workload type for NTSC policies",
&model.NTSCTaskConfigPolicies{
WorkloadType: model.ExperimentType,
LastUpdatedBy: user.ID,
Expand Down Expand Up @@ -293,10 +293,6 @@ func TestDeleteConfigPolicies(t *testing.T) {
{"ntsc no config no constraint", false, model.NTSCType, false, false, nil},
{"global ntsc config no constraint", true, model.NTSCType, true, false, nil},
{"global ntsc config and constraint", true, model.NTSCType, true, true, nil},
{
"global unspecified workload type", true, model.UnknownType, true, true,
ptrs.Ptr("invalid workload type"),
},
}

for _, test := range tests {
Expand Down

0 comments on commit adfea53

Please sign in to comment.