Skip to content

Commit

Permalink
feat: add Config Policies GET API and modify CRUD functions to accept…
Browse files Browse the repository at this point in the history
… both Workload types (#9946)
  • Loading branch information
salonig23 authored Sep 17, 2024
1 parent a73c8db commit 88a4c67
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 141 deletions.
49 changes: 45 additions & 4 deletions master/internal/api_config_policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,55 @@ func (*apiServer) PutGlobalConfigPolicies(
}

// Get workspace task config policies.
func (*apiServer) GetWorkspaceConfigPolicies(
func (a *apiServer) GetWorkspaceConfigPolicies(
ctx context.Context, req *apiv1.GetWorkspaceConfigPoliciesRequest,
) (*apiv1.GetWorkspaceConfigPoliciesResponse, error) {
license.RequireLicense("manage config policies")

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
}

w, err := a.GetWorkspaceByID(ctx, req.WorkspaceId, *curUser, false)
if err != nil {
return nil, err
}

err = workspace.AuthZProvider.Get().CanViewWorkspaceConfigPolicies(ctx, *curUser, w)
if err != nil {
return nil, err
}

if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
errMessage := fmt.Sprintf("invalid workload type: %s.", req.WorkloadType)
if len(req.WorkloadType) == 0 {
errMessage = noWorkloadErr
}
return nil, status.Errorf(codes.InvalidArgument, errMessage)
}
data, err := stubData()
return &apiv1.GetWorkspaceConfigPoliciesResponse{ConfigPolicies: data}, err

configPolicies, err := configpolicy.GetTaskConfigPolicies(
ctx, ptrs.Ptr(int(req.WorkspaceId)), req.WorkloadType)
if err != nil {
return nil, err
}
policyMap := map[string]interface{}{}
if configPolicies.InvariantConfig != nil {
var configMap map[string]interface{}
if err := yaml.Unmarshal([]byte(*configPolicies.InvariantConfig), &configMap); err != nil {
return nil, fmt.Errorf("unable to unmarshal json: %w", err)
}
policyMap["invariant_config"] = configMap
}
if configPolicies.Constraints != nil {
var constraintsMap map[string]interface{}
if err := yaml.Unmarshal([]byte(*configPolicies.Constraints), &constraintsMap); err != nil {
return nil, fmt.Errorf("unable to unmarshal json: %w", err)
}
policyMap["constraints"] = constraintsMap
}
return &apiv1.GetWorkspaceConfigPoliciesResponse{ConfigPolicies: configpolicy.MarshalConfigPolicy(policyMap)}, nil
}

// Get global task config policies.
Expand Down
136 changes: 131 additions & 5 deletions master/internal/api_config_policies_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ func TestDeleteWorkspaceConfigPolicies(t *testing.T) {

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
ntscPolicies := &model.NTSCTaskConfigPolicies{
ntscPolicies := &model.TaskConfigPolicies{
WorkspaceID: ptrs.Ptr(int(test.req.WorkspaceId)),
WorkloadType: model.NTSCType,
LastUpdatedBy: curUser.ID,
}
err = configpolicy.SetNTSCConfigPolicies(ctx, ntscPolicies)
err = configpolicy.SetTaskConfigPolicies(ctx, ntscPolicies)
require.NoError(t, err)

resp, err := api.DeleteWorkspaceConfigPolicies(ctx, test.req)
Expand All @@ -79,7 +79,7 @@ func TestDeleteWorkspaceConfigPolicies(t *testing.T) {
require.NotNil(t, resp)

// Policies removed?
policies, err := configpolicy.GetNTSCConfigPolicies(ctx, ptrs.Ptr(int(workspaceID)))
policies, err := configpolicy.GetTaskConfigPolicies(ctx, ptrs.Ptr(int(workspaceID)), test.req.WorkloadType)
require.Nil(t, policies)
require.ErrorIs(t, err, sql.ErrNoRows)
})
Expand Down Expand Up @@ -130,7 +130,7 @@ func TestDeleteGlobalConfigPolicies(t *testing.T) {

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
err := configpolicy.SetNTSCConfigPolicies(ctx, &model.NTSCTaskConfigPolicies{
err := configpolicy.SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
WorkloadType: model.NTSCType,
LastUpdatedBy: curUser.ID,
})
Expand All @@ -146,7 +146,7 @@ func TestDeleteGlobalConfigPolicies(t *testing.T) {
require.NotNil(t, resp)

// Policies removed?
policies, err := configpolicy.GetNTSCConfigPolicies(ctx, nil)
policies, err := configpolicy.GetTaskConfigPolicies(ctx, nil, test.req.WorkloadType)
require.Nil(t, policies)
require.ErrorIs(t, err, sql.ErrNoRows)
})
Expand Down Expand Up @@ -212,3 +212,129 @@ func TestBasicRBACConfigPolicyPerms(t *testing.T) {
})
}
}

func TestGetWorkspaceConfigPolicies(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
testutils.MustLoadLicenseAndKeyFromFilesystem("../../")

wkspResp, err := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()})
require.NoError(t, err)
workspaceID1 := wkspResp.Workspace.Id
wkspResp, err = api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()})
require.NoError(t, err)
workspaceID2 := wkspResp.Workspace.Id

// set only config policy
taskConfigPolicies := &model.TaskConfigPolicies{
WorkspaceID: ptrs.Ptr(int(workspaceID1)),
WorkloadType: model.NTSCType,
LastUpdatedBy: curUser.ID,
InvariantConfig: ptrs.Ptr(configpolicy.DefaultInvariantConfigStr),
}
err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies)
require.NoError(t, err)

// set only constraints policy
taskConfigPolicies = &model.TaskConfigPolicies{
WorkspaceID: ptrs.Ptr(int(workspaceID1)),
WorkloadType: model.ExperimentType,
LastUpdatedBy: curUser.ID,
Constraints: ptrs.Ptr(configpolicy.DefaultConstraintsStr),
}
err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies)
require.NoError(t, err)

// set both config and constraints policy
taskConfigPolicies = &model.TaskConfigPolicies{
WorkspaceID: ptrs.Ptr(int(workspaceID2)),
WorkloadType: model.NTSCType,
LastUpdatedBy: curUser.ID,
InvariantConfig: ptrs.Ptr(configpolicy.DefaultInvariantConfigStr),
Constraints: ptrs.Ptr(configpolicy.DefaultConstraintsStr),
}
err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies)
require.NoError(t, err)

cases := []struct {
name string
req *apiv1.GetWorkspaceConfigPoliciesRequest
err error
hasConfig bool
hasConstraints bool
}{
{
"invalid workload type",
&apiv1.GetWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID1,
WorkloadType: "bad workload type",
},
fmt.Errorf("invalid workload type"),
false,
false,
},
{
"empty workload type",
&apiv1.GetWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID1,
WorkloadType: "",
},
fmt.Errorf(noWorkloadErr),
false,
false,
},
{
"valid request only config",
&apiv1.GetWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID1,
WorkloadType: model.NTSCType,
},
nil,
true,
false,
},
{
"valid request only constraints",
&apiv1.GetWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID1,
WorkloadType: model.ExperimentType,
},
nil,
false,
true,
},
{
"valid request both configs and constraints",
&apiv1.GetWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID2,
WorkloadType: model.NTSCType,
},
nil,
true,
true,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
resp, err := api.GetWorkspaceConfigPolicies(ctx, test.req)
if test.err != nil {
require.ErrorContains(t, err, test.err.Error())
return
}
require.NoError(t, err)
require.NotNil(t, resp)

if test.hasConfig {
require.Contains(t, resp.ConfigPolicies.String(), "config")
} else {
require.NotContains(t, resp.ConfigPolicies.String(), "config")
}

if test.hasConstraints {
require.Contains(t, resp.ConfigPolicies.String(), "constraints")
} else {
require.NotContains(t, resp.ConfigPolicies.String(), "constraints")
}
})
}
}
90 changes: 45 additions & 45 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"strings"

"github.com/uptrace/bun"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/pkg/model"
Expand All @@ -16,76 +14,78 @@ import (
const (
wkspIDQuery = "workspace_id = ?"
wkspIDGlobalQuery = "workspace_id IS ?"
// DefaultInvariantConfigStr is the default invariant config val used for tests.
DefaultInvariantConfigStr = `{"description": "random description", "resources": {"slots": 4, "max_slots": 8}}`
// DefaultConstraintsStr is the default constraints val used for tests.
DefaultConstraintsStr = `{"priority_limit": 10, "resources": {"max_slots": 8}}`
)

// SetNTSCConfigPolicies adds the NTSC invariant config and constraints config policies to
// SetTaskConfigPolicies adds the task invariant config and constraints config policies to
// the database.
func SetNTSCConfigPolicies(ctx context.Context,
ntscTCPs *model.NTSCTaskConfigPolicies,
func SetTaskConfigPolicies(ctx context.Context,
tcp *model.TaskConfigPolicies,
) error {
return db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return SetNTSCConfigPoliciesTx(ctx, &tx, ntscTCPs)
return SetTaskConfigPoliciesTx(ctx, &tx, tcp)
})
}

// SetNTSCConfigPoliciesTx adds the NTSC invariant config and constraints config policies to
// SetTaskConfigPoliciesTx adds the task invariant config and constraints config policies to
// the database.
func SetNTSCConfigPoliciesTx(ctx context.Context, tx *bun.Tx,
ntscTCPs *model.NTSCTaskConfigPolicies,
func SetTaskConfigPoliciesTx(ctx context.Context, tx *bun.Tx,
tcp *model.TaskConfigPolicies,
) error {
if ntscTCPs.WorkloadType != model.NTSCType {
return status.Error(codes.InvalidArgument,
"invalid workload type for config policies: "+ntscTCPs.WorkloadType)
q := db.Bun().NewInsert().
Model(tcp)

if tcp.InvariantConfig == nil {
q = q.ExcludeColumn("invariant_config")
}
if tcp.Constraints == nil {
q = q.ExcludeColumn("constraints")
}

q := `
INSERT INTO task_config_policies (workspace_id, workload_type, last_updated_by,
last_updated_time, invariant_config, constraints) VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (workspace_id, workload_type) WHERE workspace_id IS NOT NULL
DO UPDATE SET last_updated_by = ?, last_updated_time = ?, invariant_config = ?,
constraints = ?
`
if ntscTCPs.WorkspaceID == nil {
q = `
INSERT INTO task_config_policies (workspace_id, workload_type, last_updated_by,
last_updated_time, invariant_config, constraints) VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (workload_type) WHERE workspace_id IS NULL
DO UPDATE SET last_updated_by = ?, last_updated_time = ?, invariant_config = ?,
constraints = ?
`
if tcp.WorkspaceID == nil {
q = q.On("CONFLICT (workload_type) WHERE workspace_id IS NULL DO UPDATE")
} else {
q = q.On("CONFLICT (workspace_id, workload_type) WHERE workspace_id IS NOT NULL DO UPDATE")
}
_, err := db.Bun().NewRaw(q, ntscTCPs.WorkspaceID, model.NTSCType,
ntscTCPs.LastUpdatedBy, ntscTCPs.LastUpdatedTime, ntscTCPs.InvariantConfig,
ntscTCPs.Constraints, ntscTCPs.LastUpdatedBy, ntscTCPs.LastUpdatedTime,
ntscTCPs.InvariantConfig, ntscTCPs.Constraints).
Exec(ctx)
if err != nil {
return fmt.Errorf("error setting NTSC task config policies: %w", err)

q = q.Set("last_updated_by = ?, last_updated_time = ?", tcp.LastUpdatedBy, tcp.LastUpdatedTime)
if tcp.InvariantConfig != nil {
q = q.Set("invariant_config = ?", tcp.InvariantConfig)
}
if tcp.Constraints != nil {
q = q.Set("constraints = ?", tcp.Constraints)
}

_, err := q.Exec(ctx)
if err != nil {
return fmt.Errorf("error setting task config policies: %w", err)
}
return nil
}

// GetNTSCConfigPolicies retrieves the invariant NTSC config and constraints for the
// given scope (global or workspace-level).
func GetNTSCConfigPolicies(ctx context.Context,
scope *int,
) (*model.NTSCTaskConfigPolicies, error) {
var ntscTCP model.NTSCTaskConfigPolicies
// GetTaskConfigPolicies retrieves the invariant config and constraints for the
// given scope (global or workspace-level) and workload Type.
func GetTaskConfigPolicies(ctx context.Context,
scope *int, workloadType string,
) (*model.TaskConfigPolicies, error) {
var tcp model.TaskConfigPolicies
wkspQuery := wkspIDQuery
if scope == nil {
wkspQuery = wkspIDGlobalQuery
}
err := db.Bun().NewSelect().
Model(&ntscTCP).
Model(&tcp).
Where(wkspQuery, scope).
Where("workload_type = ?", model.NTSCType).
Where("workload_type = ?", workloadType).
Scan(ctx)
if err != nil {
return nil, fmt.Errorf("error retrieving NTSC task config policies for "+
"workspace with ID %d: %w", scope, err)
return nil, fmt.Errorf("error retrieving %v task config policies for "+
"workspace with ID %d: %w", workloadType, scope, err)
}
return &ntscTCP, nil
return &tcp, nil
}

// DeleteConfigPolicies deletes the invariant experiment config and constraints for the
Expand Down
Loading

0 comments on commit 88a4c67

Please sign in to comment.