diff --git a/master/internal/configpolicy/utils.go b/master/internal/configpolicy/utils.go index d7b645d07f9..e1b28a211a1 100644 --- a/master/internal/configpolicy/utils.go +++ b/master/internal/configpolicy/utils.go @@ -159,9 +159,9 @@ func checkConstraintConflicts(constraints *model.Constraints, maxSlots, slots, p if maxSlots != nil && *constraints.ResourceConstraints.MaxSlots != *maxSlots { return fmt.Errorf("invariant config & constraints are trying to set the max slots") } - if slots != nil && *constraints.ResourceConstraints.MaxSlots > *slots { - return fmt.Errorf("invariant config & constraints are attempting to set an invalid max slot123: %v vs %v", - *constraints.ResourceConstraints.MaxSlots, *slots) + if slots != nil && *constraints.ResourceConstraints.MaxSlots < *slots { + return fmt.Errorf("invariant config has %v slots per trial. violates constraints max slots of %v", + *slots, *constraints.ResourceConstraints.MaxSlots) } return nil diff --git a/master/internal/configpolicy/utils_test.go b/master/internal/configpolicy/utils_test.go index 3649fa4c68d..e6d68ea53b1 100644 --- a/master/internal/configpolicy/utils_test.go +++ b/master/internal/configpolicy/utils_test.go @@ -7,6 +7,7 @@ import ( "gotest.tools/assert" "github.com/determined-ai/determined/master/pkg/model" + "github.com/determined-ai/determined/master/pkg/ptrs" "github.com/determined-ai/determined/master/pkg/schemas/expconf" ) @@ -309,3 +310,45 @@ func TestUnmarshalJSONNTSC(t *testing.T) { }) } } + +func TestCheckConstraintsConflicts(t *testing.T) { + constraints := &model.Constraints{ + ResourceConstraints: &model.ResourceConstraints{ + MaxSlots: ptrs.Ptr(10), + }, + PriorityLimit: ptrs.Ptr(50), + } + t.Run("max_slots differs to high", func(t *testing.T) { + err := checkConstraintConflicts(constraints, ptrs.Ptr(11), ptrs.Ptr(5), nil) + require.Error(t, err) + }) + t.Run("max_slots differs to low", func(t *testing.T) { + err := checkConstraintConflicts(constraints, ptrs.Ptr(9), ptrs.Ptr(5), nil) + require.Error(t, err) + }) + + t.Run("slots_per_trial too high", func(t *testing.T) { + err := checkConstraintConflicts(constraints, ptrs.Ptr(5), ptrs.Ptr(11), nil) + require.Error(t, err) + }) + + t.Run("slots_per_trial within range", func(t *testing.T) { + err := checkConstraintConflicts(constraints, ptrs.Ptr(10), ptrs.Ptr(8), nil) + require.NoError(t, err) + }) + + t.Run("priority differs too high", func(t *testing.T) { + err := checkConstraintConflicts(constraints, nil, nil, ptrs.Ptr(100)) + require.Error(t, err) + }) + + t.Run("priority differs too low", func(t *testing.T) { + err := checkConstraintConflicts(constraints, nil, nil, ptrs.Ptr(10)) + require.Error(t, err) + }) + + t.Run("all comply", func(t *testing.T) { + err := checkConstraintConflicts(constraints, ptrs.Ptr(10), ptrs.Ptr(10), ptrs.Ptr(50)) + require.NoError(t, err) + }) +}