Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup orchestrator proto #112

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,18 +426,14 @@ async def send_request(
tokenizer: Any,
input_request: InputRequest,
pbar: tqdm,
session_cache: str,
priority: int,
) -> RequestFuncOutput:
"""Send the request to JetStream server."""
# Tokenization on client side following MLPerf standard.
token_ids = tokenizer.encode(input_request.prompt)
request = jetstream_pb2.DecodeRequest(
session_cache=session_cache,
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
priority=priority,
max_tokens=input_request.output_len,
)
output = RequestFuncOutput()
Expand All @@ -463,8 +459,6 @@ async def benchmark(
input_requests: list[InputRequest],
request_rate: float,
disable_tqdm: bool,
session_cache: str,
priority: int,
):
"""Benchmark the online serving performance."""
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
Expand All @@ -481,8 +475,6 @@ async def benchmark(
tokenizer=tokenizer,
input_request=request,
pbar=pbar,
session_cache=session_cache,
priority=priority,
)
)
)
Expand Down Expand Up @@ -614,8 +606,6 @@ def main(args: argparse.Namespace):
input_requests=warmup_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
session_cache=args.session_cache,
priority=args.priority,
)
)
print(f"{args.warmup_mode} warmup completed.")
Expand All @@ -631,8 +621,6 @@ def main(args: argparse.Namespace):
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
session_cache=args.session_cache,
priority=args.priority,
)
)

Expand Down Expand Up @@ -790,24 +778,6 @@ def main(args: argparse.Namespace):
" the form of a string."
),
)
parser.add_argument(
"--priority",
type=int,
default=0,
help=(
"Message priority. (currently no business logic implemented, use"
" default 0)"
),
)
parser.add_argument(
"--session-cache",
type=str,
default="",
help=(
"Location of any pre-cached results. (currently _load_cache_history"
" not implemented, use default empty str)"
),
)
parser.add_argument(
"--save-request-outputs",
action="store_true",
Expand Down
7 changes: 2 additions & 5 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ class ActiveRequest:
complete: Optional[np.ndarray] = None
prefill_result: Any = None
#################### Information relevant for prefill ########################
history_path: Optional[str] = None
prefill_content: Optional[str | list[int]] = None
padded_token_length: Optional[int] = None
################## Information relevant for detokenization ###################
Expand Down Expand Up @@ -491,14 +490,13 @@ def _prefill_thread(self, idx: int):

if request is None:
break
is_bos = not bool(request.history_path)
is_bos = True
logging.info(
"Prefilling on prefill engine %d : prefill queue size, %d,"
" is_bos: %s, history: %s",
" is_bos: %s",
idx,
self._prefill_backlog.qsize(),
is_bos,
request.history_path,
)
# Tokenize and padding the text or token input.
padded_tokens, true_length = self._process_prefill_content(
Expand Down Expand Up @@ -895,7 +893,6 @@ async def Decode( # pylint: disable=invalid-overridden-method
# Wrap request as an ActiveRequest.
active_request = ActiveRequest(
max_tokens=request.max_tokens,
history_path=request.session_cache,
prefill_content=prefill_content,
is_client_side_tokenization=is_client_side_tokenization,
return_channel=return_channel,
Expand Down
5 changes: 1 addition & 4 deletions jetstream/core/proto/jetstream.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ service Orchestrator {
}

message DecodeRequest {
// Where to load any pre-existing kv cache from.
string session_cache = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In future, we will have prefill cache logic. Is this filed related with prefill cache?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for history cache originally. We could add prefill cache fields when we implement it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

int32 priority = 3;
// The maximum output length of a sequence. It's used in JetStream to control
// the output/decode length of a sequence. It would not be used in the engine.
// We should always set max_tokens <= (max_target_length -
Expand All @@ -51,7 +48,7 @@ message DecodeRequest {
TextContent text_content = 5;
TokenContent token_content = 6;
}
reserved 2;
reserved 1, 2, 3;
// Next ID: 7
}

Expand Down
40 changes: 20 additions & 20 deletions jetstream/core/proto/jetstream_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3'
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\x8a\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3'
)

_globals = globals()
Expand All @@ -39,23 +39,23 @@
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_DECODEREQUEST"]._serialized_start = 58
_globals["_DECODEREQUEST"]._serialized_end = 353
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 274
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 301
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 303
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 336
_globals["_DECODERESPONSE"]._serialized_start = 356
_globals["_DECODERESPONSE"]._serialized_end = 687
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 522
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 538
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 541
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 670
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 629
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 670
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 689
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 709
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 711
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 749
_globals["_ORCHESTRATOR"]._serialized_start = 752
_globals["_ORCHESTRATOR"]._serialized_end = 937
_globals["_DECODEREQUEST"]._serialized_end = 324
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 233
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 260
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 262
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 295
_globals["_DECODERESPONSE"]._serialized_start = 327
_globals["_DECODERESPONSE"]._serialized_end = 658
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 493
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 509
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 512
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 641
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 600
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 641
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 660
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 680
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 682
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 720
_globals["_ORCHESTRATOR"]._serialized_start = 723
_globals["_ORCHESTRATOR"]._serialized_end = 908
# @@protoc_insertion_point(module_scope)
4 changes: 0 additions & 4 deletions jetstream/tests/core/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ async def test_orchestrator_interleaved_mode(self):
text = "AB"

request = jetstream_pb2.DecodeRequest(
session_cache="",
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
priority=1,
max_tokens=3,
)
iterator = client.Decode(request)
Expand Down Expand Up @@ -109,11 +107,9 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self):
token_ids = [65, 66]

request = jetstream_pb2.DecodeRequest(
session_cache="",
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
priority=1,
max_tokens=3,
)
iterator = client.Decode(request)
Expand Down
2 changes: 0 additions & 2 deletions jetstream/tests/core/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ async def test_server(
# as BOS
text = "AB"
request = jetstream_pb2.DecodeRequest(
session_cache="",
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
priority=1,
max_tokens=3,
)
iterator = stub.Decode(request)
Expand Down
3 changes: 0 additions & 3 deletions jetstream/tools/load_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,11 @@ def api_call(
stub: jetstream_pb2_grpc.OrchestratorStub,
text: str,
max_tokens: int,
session_cache: str = "",
print_interim: bool = True,
) -> str:
"""Sends a request to server and returns text."""
request = jetstream_pb2.DecodeRequest(
session_cache=session_cache,
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
priority=1,
max_tokens=max_tokens,
)
response = stub.Decode(request)
Expand Down
8 changes: 0 additions & 8 deletions jetstream/tools/requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@

_SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address")
_PORT = flags.DEFINE_string("port", "9000", "port to ping")
_SESSION_CACHE = flags.DEFINE_string(
"session_cache", "", "Location of any pre-cached results"
)
_TEXT = flags.DEFINE_string("text", "Today is a good day", "The message")
_PRIORITY = flags.DEFINE_integer("priority", 0, "Message priority")
_MAX_TOKENS = flags.DEFINE_integer(
"max_tokens", 3, "Maximum number of output/decode tokens of a sequence"
)
Expand Down Expand Up @@ -82,20 +78,16 @@ def main(argv: Sequence[str]) -> None:
vocab = load_vocab(_TOKENIZER.value)
token_ids = vocab.tokenizer.encode(_TEXT.value)
request = jetstream_pb2.DecodeRequest(
session_cache=_SESSION_CACHE.value,
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
priority=_PRIORITY.value,
max_tokens=_MAX_TOKENS.value,
)
else:
request = jetstream_pb2.DecodeRequest(
session_cache=_SESSION_CACHE.value,
text_content=jetstream_pb2.DecodeRequest.TextContent(
text=_TEXT.value
),
priority=_PRIORITY.value,
max_tokens=_MAX_TOKENS.value,
)
return _GetResponseAsync(stub, request)
Expand Down
Loading