-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ray disaggregated serving support #87
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks Fanhai! Just minor nits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! May need some unit tests.
jetstream/core/server_lib.py
Outdated
@@ -97,6 +97,7 @@ def run( | |||
metrics_server_config: config_lib.MetricsServerConfig | None = None, | |||
enable_jax_profiler: bool = False, | |||
jax_profiler_port: int = 9999, | |||
ray_multiple_host: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move this flag to ServerConfig, since it's used to control server mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unitests please.
jetstream/core/server_lib.py
Outdated
@@ -97,6 +97,7 @@ def run( | |||
metrics_server_config: config_lib.MetricsServerConfig | None = None, | |||
enable_jax_profiler: bool = False, | |||
jax_profiler_port: int = 9999, | |||
ray_multiple_host: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
jetstream/core/orchestrator.py
Outdated
def _ray_transfer_prefill_result(self, new_request, target_idx): | ||
self._generate_engines[target_idx].transfer(new_request.prefill_result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see anything ray specfic here, the transfer code is abstracted into engine.transfer(). Can we use a generic name such as "non_jax_transfer" instead of "ray_transfer" here?
Jetstream doesn't need to know if its Ray or some other mechanism used. Also, lets move this setting to server config for better control.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done for the server config part.
The code here is interface level, the actually logic (or real implementation) is in engine side (pytorch or maxtext). The implementation logic would be:
- Gather prefill result from TPU chips in ray worker
- Transfer all gathered result from TPU to CPU ram though PCIE in ray worker
- Transfer prefill result from prefill server to decode server though DCN by ray head.
Right now or in near future, I don't see we will introduce another transfer mechanism beside ray or jax, I feel name it as ray_transffer is clear and straightforward.
I agree that Jetstream doesn't know it its ray or other mechanism, even future more, Jetstream doesn't need to know whether it's jax, ray or any other mechanism, the engine should handle it, but right now Pathy way hanlde transfer in orchestrator, we have to decide which method need to be called. Ideal case is that just call enginer.transfer() even with Pathyway, this need more effort to explore it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the implementation of _ray_transfer_prefill_result
, could we assign the responsibility of implementing transfer to the generate engine?
What I mean by this is that within the orchestrator, we can generically use:
def _transfer_prefill_result(
self, new_request: ActiveRequest, target_idx: int
):
self._generate_engines[target_idx].transfer(new_request.prefill_result)
and so your Pathways engine needs to implement transfer (the jax.device_put
logic) and your Ray engine needs to implement transfer with ray.remote primitives.
Or does this break some assumption? I think having JAX involved in the orchestrator is a somewhat leaky abstraction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. @JoeZijunZhou May have more insights, but based on my understanding, this could break Pathway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synced with @vipannalla offline, we should removed jax dependencies in JetStream in long term. Current concerns is what is the impact to Pathway, we will figure out and make sure Patyway can adopt no jax in jetstream.
Let's explore how to add disaggregated serving unit test and add disaggregated tests before enable disaggregated serving. |
We could add a simple unit test in |
This PR add ray disaggregated serving support in JetStream, the underline engine implementation will be pytorch and maxtext side.
This PR doesn't not impact any current interleave or Pathway disaggregated behavior.