Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ray disaggregated serving support #87

Merged
merged 5 commits into from
May 23, 2024

Conversation

FanhaiLu1
Copy link
Contributor

This PR add ray disaggregated serving support in JetStream, the underline engine implementation will be pytorch and maxtext side.

This PR doesn't not impact any current interleave or Pathway disaggregated behavior.

@FanhaiLu1 FanhaiLu1 requested a review from JoeZijunZhou May 23, 2024 17:30
@FanhaiLu1 FanhaiLu1 requested a review from vipannalla as a code owner May 23, 2024 17:30
@FanhaiLu1 FanhaiLu1 requested a review from allenwang28 May 23, 2024 18:28
Copy link
Collaborator

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks Fanhai! Just minor nits

jetstream/core/orchestrator.py Outdated Show resolved Hide resolved
jetstream/core/orchestrator.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@JoeZijunZhou JoeZijunZhou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! May need some unit tests.

@@ -97,6 +97,7 @@ def run(
metrics_server_config: config_lib.MetricsServerConfig | None = None,
enable_jax_profiler: bool = False,
jax_profiler_port: int = 9999,
ray_multiple_host: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this flag to ServerConfig, since it's used to control server mode?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! done.

Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unitests please.

@@ -97,6 +97,7 @@ def run(
metrics_server_config: config_lib.MetricsServerConfig | None = None,
enable_jax_profiler: bool = False,
jax_profiler_port: int = 9999,
ray_multiple_host: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines 521 to 522
def _ray_transfer_prefill_result(self, new_request, target_idx):
self._generate_engines[target_idx].transfer(new_request.prefill_result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see anything ray specfic here, the transfer code is abstracted into engine.transfer(). Can we use a generic name such as "non_jax_transfer" instead of "ray_transfer" here?
Jetstream doesn't need to know if its Ray or some other mechanism used. Also, lets move this setting to server config for better control.

Copy link
Contributor Author

@FanhaiLu1 FanhaiLu1 May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done for the server config part.

The code here is interface level, the actually logic (or real implementation) is in engine side (pytorch or maxtext). The implementation logic would be:

  1. Gather prefill result from TPU chips in ray worker
  2. Transfer all gathered result from TPU to CPU ram though PCIE in ray worker
  3. Transfer prefill result from prefill server to decode server though DCN by ray head.

Right now or in near future, I don't see we will introduce another transfer mechanism beside ray or jax, I feel name it as ray_transffer is clear and straightforward.

I agree that Jetstream doesn't know it its ray or other mechanism, even future more, Jetstream doesn't need to know whether it's jax, ray or any other mechanism, the engine should handle it, but right now Pathy way hanlde transfer in orchestrator, we have to decide which method need to be called. Ideal case is that just call enginer.transfer() even with Pathyway, this need more effort to explore it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the implementation of _ray_transfer_prefill_result, could we assign the responsibility of implementing transfer to the generate engine?

What I mean by this is that within the orchestrator, we can generically use:

def _transfer_prefill_result(
      self, new_request: ActiveRequest, target_idx: int
  ):
   self._generate_engines[target_idx].transfer(new_request.prefill_result)

and so your Pathways engine needs to implement transfer (the jax.device_put logic) and your Ray engine needs to implement transfer with ray.remote primitives.

Or does this break some assumption? I think having JAX involved in the orchestrator is a somewhat leaky abstraction

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. @JoeZijunZhou May have more insights, but based on my understanding, this could break Pathway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced with @vipannalla offline, we should removed jax dependencies in JetStream in long term. Current concerns is what is the impact to Pathway, we will figure out and make sure Patyway can adopt no jax in jetstream.

@FanhaiLu1
Copy link
Contributor Author

LGTM! May need some unit tests.

Let's explore how to add disaggregated serving unit test and add disaggregated tests before enable disaggregated serving.

@JoeZijunZhou
Copy link
Collaborator

LGTM! May need some unit tests.

Let's explore how to add disaggregated serving unit test and add disaggregated tests before enable disaggregated serving.

We could add a simple unit test in test_orchestrator.py and test_server.py just to cover the flag if else logic. The e2e functional test could be added later I guess?

@FanhaiLu1 FanhaiLu1 merged commit e19a790 into AI-Hypercomputer:main May 23, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants