Skip to content

Commit

Permalink
Merge pull request #247 from kerthcet/feat/fungibility
Browse files Browse the repository at this point in the history
Support hostpath models
  • Loading branch information
InftyAI-Agent authored Jan 21, 2025
2 parents 5d7acaf + 37c40f7 commit c1479d4
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 23 deletions.
6 changes: 4 additions & 2 deletions api/core/v1alpha1/model_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ type ModelSource struct {
// ModelHub represents the model registry for model downloads.
// +optional
ModelHub *ModelHub `json:"modelHub,omitempty"`
// URI represents a various kinds of model sources following the uri protocol, e.g.:
// - OSS: oss://<bucket>.<endpoint>/<path-to-your-model>
// URI represents a various kinds of model sources following the uri protocol, protocol://<address>, e.g.
// - oss://<bucket>.<endpoint>/<path-to-your-model>
// - ollama://llama3.3
// - host://<path-to-your-model>
//
// +optional
URI *URIProtocol `json:"uri,omitempty"`
Expand Down
6 changes: 4 additions & 2 deletions config/crd/bases/llmaz.io_openmodels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ spec:
type: object
uri:
description: |-
URI represents a various kinds of model sources following the uri protocol, e.g.:
- OSS: oss://<bucket>.<endpoint>/<path-to-your-model>
URI represents a various kinds of model sources following the uri protocol, protocol://<address>, e.g.
- oss://<bucket>.<endpoint>/<path-to-your-model>
- ollama://llama3.3
- host://<path-to-your-model>
type: string
type: object
required:
Expand Down
5 changes: 5 additions & 0 deletions docs/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ We provide a set of examples to help you serve large language models, by default
- [Deploy models via ollama](#ollama)
- [Speculative Decoding with vLLM](#speculative-decoding-with-vllm)
- [Deploy multi-host inference](#multi-host-inference)
- [Deploy host models](#deploy-host-models)

### Deploy models from Huggingface

Expand Down Expand Up @@ -59,3 +60,7 @@ By default, we use [vLLM](/~https://github.com/vllm-project/vllm) as the inference
### Multi-Host Inference

Model size is growing bigger and bigger, Llama 3.1 405B FP16 LLM requires more than 750 GB GPU for weights only, leaving kv cache unconsidered, even with 8 x H100 Nvidia GPUs, 80 GB size of HBM each, can not fit in a single host, requires a multi-host deployment, see [example](./multi-nodes/) here.

### Deploy Host Models

Models could be loaded in prior to the hosts, especially those extremely big models, see [example](./hostpath/) to serve local models.
13 changes: 13 additions & 0 deletions docs/examples/hostpath/model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
apiVersion: llmaz.io/v1alpha1
kind: OpenModel
metadata:
name: qwen2-0--5b-instruct
spec:
familyName: qwen2
source:
uri: host:///workspace/Qwen2-0.5B-Instruct
inferenceConfig:
flavors:
- name: t4 # GPU type
requests:
nvidia.com/gpu: 1
8 changes: 8 additions & 0 deletions docs/examples/hostpath/playground.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
apiVersion: inference.llmaz.io/v1alpha1
kind: Playground
metadata:
name: qwen2-0--5b-instruct
spec:
replicas: 1
modelClaim:
modelName: qwen2-0--5b-instruct
11 changes: 7 additions & 4 deletions pkg/controller_helper/model_source/modelsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,16 @@ func NewModelSourceProvider(model *coreapi.OpenModel) ModelSourceProvider {

if model.Spec.Source.URI != nil {
// We'll validate the format in the webhook, so generally no error should happen here.
protocol, address, _ := util.ParseURI(string(*model.Spec.Source.URI))
provider := &URIProvider{modelName: model.Name, protocol: protocol, modelAddress: address}
protocol, value, _ := util.ParseURI(string(*model.Spec.Source.URI))
provider := &URIProvider{modelName: model.Name, protocol: protocol}

switch protocol {
case OSS:
provider.endpoint, provider.bucket, provider.modelPath, _ = util.ParseOSS(address)
case OLLAMA:
provider.endpoint, provider.bucket, provider.modelPath, _ = util.ParseOSS(value)
case HostPath:
provider.modelPath = value
case Ollama:
provider.modelPath = value
default:
// This should be validated at webhooks.
panic("protocol not supported")
Expand Down
59 changes: 47 additions & 12 deletions pkg/controller_helper/model_source/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,24 @@ import (
var _ ModelSourceProvider = &URIProvider{}

const (
OSS = "OSS"
OLLAMA = "OLLAMA"
OSS = "OSS"
Ollama = "OLLAMA"
HostPath = "HOST"
)

type URIProvider struct {
modelName string
protocol string
bucket string
endpoint string
modelPath string
modelAddress string
modelName string
protocol string
bucket string
endpoint string
modelPath string
}

func (p *URIProvider) ModelName() string {
if p.protocol == OLLAMA {
return p.modelAddress
if p.protocol == Ollama {
// model path stores the ollama model name,
// the model name is the name of model CRD.
return p.modelPath
}
return p.modelName
}
Expand All @@ -54,18 +56,51 @@ func (p *URIProvider) ModelName() string {
// - uri: bucket.endpoint/modelPath/model.gguf
// modelPath: /workspace/models/model.gguf
func (p *URIProvider) ModelPath() string {
if p.protocol == HostPath {
return p.modelPath
}

// protocol is oss.

splits := strings.Split(p.modelPath, "/")

if strings.Contains(p.modelPath, ".") {
if strings.Contains(p.modelPath, ".gguf") {
return CONTAINER_MODEL_PATH + splits[len(splits)-1]
}
return CONTAINER_MODEL_PATH + "models--" + splits[len(splits)-1]
}

func (p *URIProvider) InjectModelLoader(template *corev1.PodTemplateSpec, index int) {
if p.protocol == OLLAMA {
// We don't have additional operations for Ollama, just load in runtime.
if p.protocol == Ollama {
return
}

if p.protocol == HostPath {
template.Spec.Volumes = append(template.Spec.Volumes, corev1.Volume{
Name: MODEL_VOLUME_NAME,
VolumeSource: corev1.VolumeSource{
HostPath: &corev1.HostPathVolumeSource{
Path: p.modelPath,
},
},
})

for i, container := range template.Spec.Containers {
// We only consider this container.
if container.Name == MODEL_RUNNER_CONTAINER_NAME {
template.Spec.Containers[i].VolumeMounts = append(template.Spec.Containers[i].VolumeMounts, corev1.VolumeMount{
Name: MODEL_VOLUME_NAME,
MountPath: p.modelPath,
ReadOnly: true,
})
}
}
return
}

// Other protocols.

initContainerName := MODEL_LOADER_CONTAINER_NAME
if index != 0 {
initContainerName += "-" + strconv.Itoa(index)
Expand Down
5 changes: 3 additions & 2 deletions pkg/webhook/openmodel_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ func SetupOpenModelWebhook(mgr ctrl.Manager) error {
var _ webhook.CustomDefaulter = &OpenModelWebhook{}

var SUPPORTED_OBJ_STORES = map[string]struct{}{
modelSource.OSS: {},
modelSource.OLLAMA: {},
modelSource.OSS: {},
modelSource.Ollama: {},
modelSource.HostPath: {},
}

// Default implements webhook.Defaulter so a webhook will be registered for the type
Expand Down
6 changes: 6 additions & 0 deletions test/integration/webhook/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ var _ = ginkgo.Describe("model default and validation", func() {
},
failed: false,
}),
ginkgo.Entry("model creation with host protocol", &testValidatingCase{
model: func() *coreapi.OpenModel {
return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithURI("host:///models/meta-llama-3-8B").Obj()
},
failed: false,
}),
ginkgo.Entry("model creation with protocol unknown URI", &testValidatingCase{
model: func() *coreapi.OpenModel {
return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithURI("unknown://bucket.endpoint/models/meta-llama-3-8B").Obj()
Expand Down
2 changes: 1 addition & 1 deletion test/util/wrapper/playground.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (w *PlaygroundWrapper) BackendRuntimeLimit(r, v string) *PlaygroundWrapper
return w
}

func (w *PlaygroundWrapper) ElasticConfig(maxReplicas, minReplicas int32) *PlaygroundWrapper {
func (w *PlaygroundWrapper) ElasticConfig(minReplicas, maxReplicas int32) *PlaygroundWrapper {
w.Spec.ElasticConfig = &inferenceapi.ElasticConfig{
MaxReplicas: ptr.To[int32](maxReplicas),
MinReplicas: ptr.To[int32](minReplicas),
Expand Down

0 comments on commit c1479d4

Please sign in to comment.