Skip to content

Commit

Permalink
Add webhook for RolloutStrategy
Browse files Browse the repository at this point in the history
Signed-off-by: kerthcet <kerthcet@gmail.com>
  • Loading branch information
kerthcet committed Feb 27, 2025
1 parent 01c67c5 commit 5a690ff
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 9 deletions.
3 changes: 2 additions & 1 deletion api/inference/v1alpha1/service_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ type ServiceSpec struct {
WorkloadTemplate lws.LeaderWorkerTemplate `json:"workloadTemplate"`
// RolloutStrategy defines the strategy that will be applied to update replicas
// when a revision is made to the leaderWorkerTemplate.
// +kubebuilder:default:={type: "RollingUpdate", rollingUpdateConfiguration: {"maxUnavailable": 1, "maxSurge": 0}}
// +optional
RolloutStrategy lws.RolloutStrategy `json:"rolloutStrategy,omitempty"`
RolloutStrategy *lws.RolloutStrategy `json:"rolloutStrategy,omitempty"`
}

const (
Expand Down
7 changes: 6 additions & 1 deletion api/inference/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions config/crd/bases/inference.llmaz.io_services.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ spec:
format: int32
type: integer
rolloutStrategy:
default:
rollingUpdateConfiguration:
maxSurge: 0
maxUnavailable: 1
type: RollingUpdate
description: |-
RolloutStrategy defines the strategy that will be applied to update replicas
when a revision is made to the leaderWorkerTemplate.
Expand Down
13 changes: 9 additions & 4 deletions pkg/controller/inference/service_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,15 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co
spec.WithLeaderWorkerTemplate(leaderWorkerTemplate)
spec.LeaderWorkerTemplate.WithSize(*service.Spec.WorkloadTemplate.Size)
spec.WithReplicas(*service.Spec.Replicas)
spec.WithRolloutStrategy(applyconfigurationv1.RolloutStrategy().WithType(service.Spec.RolloutStrategy.Type))
if service.Spec.RolloutStrategy.RollingUpdateConfiguration != nil {
spec.RolloutStrategy.RollingUpdateConfiguration.WithMaxSurge(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxSurge)
spec.RolloutStrategy.RollingUpdateConfiguration.WithMaxUnavailable(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxUnavailable)
if service.Spec.RolloutStrategy != nil {
spec.WithRolloutStrategy(applyconfigurationv1.RolloutStrategy().WithType(service.Spec.RolloutStrategy.Type))
if service.Spec.RolloutStrategy.RollingUpdateConfiguration != nil {
spec.RolloutStrategy.WithRollingUpdateConfiguration(
applyconfigurationv1.RollingUpdateConfiguration().
WithMaxSurge(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxSurge).
WithMaxUnavailable(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxUnavailable),
)
}
}
spec.WithStartupPolicy(lws.LeaderReadyStartupPolicy)

Expand Down
28 changes: 28 additions & 0 deletions test/integration/webhook/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ limitations under the License.
package webhook

import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/onsi/ginkgo/v2"
"github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
lws "sigs.k8s.io/lws/api/leaderworkerset/v1"

inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1"
"github.com/inftyai/llmaz/test/util"
Expand All @@ -44,6 +46,32 @@ var _ = ginkgo.Describe("service default and validation", func() {
gomega.Expect(k8sClient.Delete(ctx, ns)).To(gomega.Succeed())
})

type testDefaultingCase struct {
service func() *inferenceapi.Service
wantService func() *inferenceapi.Service
}
ginkgo.DescribeTable("Defaulting test",
func(tc *testDefaultingCase) {
svc := tc.service()
gomega.Expect(k8sClient.Create(ctx, svc)).To(gomega.Succeed())
gomega.Expect(svc).To(gomega.BeComparableTo(tc.wantService(),
cmpopts.IgnoreTypes(inferenceapi.ServiceStatus{}),
cmpopts.IgnoreFields(metav1.ObjectMeta{}, "UID", "ResourceVersion", "Generation", "CreationTimestamp", "ManagedFields")))
},
ginkgo.Entry("apply service rollingUpdate strategy", &testDefaultingCase{
service: func() *inferenceapi.Service {
return wrapper.MakeService("service-llama3-8b", ns.Name).WorkerTemplate().Obj()
},
wantService: func() *inferenceapi.Service {
return wrapper.MakeService("service-llama3-8b", ns.Name).
RolloutStrategy(string(lws.RollingUpdateStrategyType), 1, 0).
RestartPolicy("RecreateGroupOnPodRestart").
Replicas(1).Size(1).
WorkerTemplate().Obj()
},
}),
)

type testValidatingCase struct {
service func() *inferenceapi.Service
failed bool
Expand Down
32 changes: 29 additions & 3 deletions test/util/wrapper/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package wrapper
import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
lws "sigs.k8s.io/lws/api/leaderworkerset/v1"

coreapi "github.com/inftyai/llmaz/api/core/v1alpha1"
Expand Down Expand Up @@ -65,9 +67,6 @@ func (w *ServiceWrapper) ModelClaims(modelNames []string, roles []string, flavor
}

func (w *ServiceWrapper) WorkerTemplate() *ServiceWrapper {
w.Spec.RolloutStrategy = lws.RolloutStrategy{
Type: lws.RollingUpdateStrategyType,
}
w.Spec.WorkloadTemplate.WorkerTemplate = corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
Expand All @@ -90,3 +89,30 @@ func (w *ServiceWrapper) InitContainerName(name string) *ServiceWrapper {
w.Spec.WorkloadTemplate.WorkerTemplate.Spec.InitContainers[0].Name = name
return w
}

func (w *ServiceWrapper) RolloutStrategy(typ string, maxUnavailable int, maxSurge int) *ServiceWrapper {
if w.Spec.RolloutStrategy == nil {
w.Spec.RolloutStrategy = &lws.RolloutStrategy{}
}
w.Spec.RolloutStrategy.Type = lws.RolloutStrategyType(typ)
w.Spec.RolloutStrategy.RollingUpdateConfiguration = &lws.RollingUpdateConfiguration{
MaxUnavailable: intstr.FromInt(maxUnavailable),
MaxSurge: intstr.FromInt(maxSurge),
}
return w
}

func (w *ServiceWrapper) Size(size int32) *ServiceWrapper {
w.Spec.WorkloadTemplate.Size = ptr.To[int32](size)
return w
}

func (w *ServiceWrapper) Replicas(replicas int32) *ServiceWrapper {
w.Spec.Replicas = ptr.To[int32](replicas)
return w
}

func (w *ServiceWrapper) RestartPolicy(policy string) *ServiceWrapper {
w.Spec.WorkloadTemplate.RestartPolicy = lws.RestartPolicyType(policy)
return w
}

0 comments on commit 5a690ff

Please sign in to comment.