Skip to content

Commit

Permalink
Cleanup orchestrator proto (#112)
Browse files Browse the repository at this point in the history
* Cleanup orchestrator proto

* Update JetStream based on proto cleanup
  • Loading branch information
JoeZijunZhou authored Jul 16, 2024
1 parent 196beda commit 46c152f
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 76 deletions.
30 changes: 0 additions & 30 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,18 +426,14 @@ async def send_request(
tokenizer: Any,
input_request: InputRequest,
pbar: tqdm,
session_cache: str,
priority: int,
) -> RequestFuncOutput:
"""Send the request to JetStream server."""
# Tokenization on client side following MLPerf standard.
token_ids = tokenizer.encode(input_request.prompt)
request = jetstream_pb2.DecodeRequest(
session_cache=session_cache,
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
priority=priority,
max_tokens=input_request.output_len,
)
output = RequestFuncOutput()
Expand All @@ -463,8 +459,6 @@ async def benchmark(
input_requests: list[InputRequest],
request_rate: float,
disable_tqdm: bool,
session_cache: str,
priority: int,
):
"""Benchmark the online serving performance."""
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
Expand All @@ -481,8 +475,6 @@ async def benchmark(
tokenizer=tokenizer,
input_request=request,
pbar=pbar,
session_cache=session_cache,
priority=priority,
)
)
)
Expand Down Expand Up @@ -614,8 +606,6 @@ def main(args: argparse.Namespace):
input_requests=warmup_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
session_cache=args.session_cache,
priority=args.priority,
)
)
print(f"{args.warmup_mode} warmup completed.")
Expand All @@ -631,8 +621,6 @@ def main(args: argparse.Namespace):
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
session_cache=args.session_cache,
priority=args.priority,
)
)

Expand Down Expand Up @@ -790,24 +778,6 @@ def main(args: argparse.Namespace):
" the form of a string."
),
)
parser.add_argument(
"--priority",
type=int,
default=0,
help=(
"Message priority. (currently no business logic implemented, use"
" default 0)"
),
)
parser.add_argument(
"--session-cache",
type=str,
default="",
help=(
"Location of any pre-cached results. (currently _load_cache_history"
" not implemented, use default empty str)"
),
)
parser.add_argument(
"--save-request-outputs",
action="store_true",
Expand Down
7 changes: 2 additions & 5 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ class ActiveRequest:
complete: Optional[np.ndarray] = None
prefill_result: Any = None
#################### Information relevant for prefill ########################
history_path: Optional[str] = None
prefill_content: Optional[str | list[int]] = None
padded_token_length: Optional[int] = None
################## Information relevant for detokenization ###################
Expand Down Expand Up @@ -491,14 +490,13 @@ def _prefill_thread(self, idx: int):

if request is None:
break
is_bos = not bool(request.history_path)
is_bos = True
logging.info(
"Prefilling on prefill engine %d : prefill queue size, %d,"
" is_bos: %s, history: %s",
" is_bos: %s",
idx,
self._prefill_backlog.qsize(),
is_bos,
request.history_path,
)
# Tokenize and padding the text or token input.
padded_tokens, true_length = self._process_prefill_content(
Expand Down Expand Up @@ -895,7 +893,6 @@ async def Decode( # pylint: disable=invalid-overridden-method
# Wrap request as an ActiveRequest.
active_request = ActiveRequest(
max_tokens=request.max_tokens,
history_path=request.session_cache,
prefill_content=prefill_content,
is_client_side_tokenization=is_client_side_tokenization,
return_channel=return_channel,
Expand Down
5 changes: 1 addition & 4 deletions jetstream/core/proto/jetstream.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ service Orchestrator {
}

message DecodeRequest {
// Where to load any pre-existing kv cache from.
string session_cache = 1;
int32 priority = 3;
// The maximum output length of a sequence. It's used in JetStream to control
// the output/decode length of a sequence. It would not be used in the engine.
// We should always set max_tokens <= (max_target_length -
Expand All @@ -51,7 +48,7 @@ message DecodeRequest {
TextContent text_content = 5;
TokenContent token_content = 6;
}
reserved 2;
reserved 1, 2, 3;
// Next ID: 7
}

Expand Down
40 changes: 20 additions & 20 deletions jetstream/core/proto/jetstream_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3'
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\x8a\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3'
)

_globals = globals()
Expand All @@ -39,23 +39,23 @@
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_DECODEREQUEST"]._serialized_start = 58
_globals["_DECODEREQUEST"]._serialized_end = 353
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 274
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 301
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 303
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 336
_globals["_DECODERESPONSE"]._serialized_start = 356
_globals["_DECODERESPONSE"]._serialized_end = 687
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 522
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 538
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 541
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 670
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 629
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 670
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 689
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 709
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 711
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 749
_globals["_ORCHESTRATOR"]._serialized_start = 752
_globals["_ORCHESTRATOR"]._serialized_end = 937
_globals["_DECODEREQUEST"]._serialized_end = 324
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 233
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 260
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 262
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 295
_globals["_DECODERESPONSE"]._serialized_start = 327
_globals["_DECODERESPONSE"]._serialized_end = 658
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 493
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 509
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 512
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 641
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 600
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 641
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 660
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 680
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 682
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 720
_globals["_ORCHESTRATOR"]._serialized_start = 723
_globals["_ORCHESTRATOR"]._serialized_end = 908
# @@protoc_insertion_point(module_scope)
4 changes: 0 additions & 4 deletions jetstream/tests/core/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ async def test_orchestrator_interleaved_mode(self):
text = "AB"

request = jetstream_pb2.DecodeRequest(
session_cache="",
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
priority=1,
max_tokens=3,
)
iterator = client.Decode(request)
Expand Down Expand Up @@ -109,11 +107,9 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self):
token_ids = [65, 66]

request = jetstream_pb2.DecodeRequest(
session_cache="",
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
priority=1,
max_tokens=3,
)
iterator = client.Decode(request)
Expand Down
2 changes: 0 additions & 2 deletions jetstream/tests/core/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ async def test_server(
# as BOS
text = "AB"
request = jetstream_pb2.DecodeRequest(
session_cache="",
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
priority=1,
max_tokens=3,
)
iterator = stub.Decode(request)
Expand Down
3 changes: 0 additions & 3 deletions jetstream/tools/load_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,11 @@ def api_call(
stub: jetstream_pb2_grpc.OrchestratorStub,
text: str,
max_tokens: int,
session_cache: str = "",
print_interim: bool = True,
) -> str:
"""Sends a request to server and returns text."""
request = jetstream_pb2.DecodeRequest(
session_cache=session_cache,
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
priority=1,
max_tokens=max_tokens,
)
response = stub.Decode(request)
Expand Down
8 changes: 0 additions & 8 deletions jetstream/tools/requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@

_SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address")
_PORT = flags.DEFINE_string("port", "9000", "port to ping")
_SESSION_CACHE = flags.DEFINE_string(
"session_cache", "", "Location of any pre-cached results"
)
_TEXT = flags.DEFINE_string("text", "Today is a good day", "The message")
_PRIORITY = flags.DEFINE_integer("priority", 0, "Message priority")
_MAX_TOKENS = flags.DEFINE_integer(
"max_tokens", 3, "Maximum number of output/decode tokens of a sequence"
)
Expand Down Expand Up @@ -82,20 +78,16 @@ def main(argv: Sequence[str]) -> None:
vocab = load_vocab(_TOKENIZER.value)
token_ids = vocab.tokenizer.encode(_TEXT.value)
request = jetstream_pb2.DecodeRequest(
session_cache=_SESSION_CACHE.value,
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
priority=_PRIORITY.value,
max_tokens=_MAX_TOKENS.value,
)
else:
request = jetstream_pb2.DecodeRequest(
session_cache=_SESSION_CACHE.value,
text_content=jetstream_pb2.DecodeRequest.TextContent(
text=_TEXT.value
),
priority=_PRIORITY.value,
max_tokens=_MAX_TOKENS.value,
)
return _GetResponseAsync(stub, request)
Expand Down

0 comments on commit 46c152f

Please sign in to comment.