diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 07b36a84..f03fac91 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -426,18 +426,14 @@ async def send_request( tokenizer: Any, input_request: InputRequest, pbar: tqdm, - session_cache: str, - priority: int, ) -> RequestFuncOutput: """Send the request to JetStream server.""" # Tokenization on client side following MLPerf standard. token_ids = tokenizer.encode(input_request.prompt) request = jetstream_pb2.DecodeRequest( - session_cache=session_cache, token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), - priority=priority, max_tokens=input_request.output_len, ) output = RequestFuncOutput() @@ -463,8 +459,6 @@ async def benchmark( input_requests: list[InputRequest], request_rate: float, disable_tqdm: bool, - session_cache: str, - priority: int, ): """Benchmark the online serving performance.""" pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -481,8 +475,6 @@ async def benchmark( tokenizer=tokenizer, input_request=request, pbar=pbar, - session_cache=session_cache, - priority=priority, ) ) ) @@ -614,8 +606,6 @@ def main(args: argparse.Namespace): input_requests=warmup_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, - session_cache=args.session_cache, - priority=args.priority, ) ) print(f"{args.warmup_mode} warmup completed.") @@ -631,8 +621,6 @@ def main(args: argparse.Namespace): input_requests=input_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, - session_cache=args.session_cache, - priority=args.priority, ) ) @@ -790,24 +778,6 @@ def main(args: argparse.Namespace): " the form of a string." ), ) - parser.add_argument( - "--priority", - type=int, - default=0, - help=( - "Message priority. (currently no business logic implemented, use" - " default 0)" - ), - ) - parser.add_argument( - "--session-cache", - type=str, - default="", - help=( - "Location of any pre-cached results. (currently _load_cache_history" - " not implemented, use default empty str)" - ), - ) parser.add_argument( "--save-request-outputs", action="store_true", diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index d8fa9edd..23ca365f 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -133,7 +133,6 @@ class ActiveRequest: complete: Optional[np.ndarray] = None prefill_result: Any = None #################### Information relevant for prefill ######################## - history_path: Optional[str] = None prefill_content: Optional[str | list[int]] = None padded_token_length: Optional[int] = None ################## Information relevant for detokenization ################### @@ -491,14 +490,13 @@ def _prefill_thread(self, idx: int): if request is None: break - is_bos = not bool(request.history_path) + is_bos = True logging.info( "Prefilling on prefill engine %d : prefill queue size, %d," - " is_bos: %s, history: %s", + " is_bos: %s", idx, self._prefill_backlog.qsize(), is_bos, - request.history_path, ) # Tokenize and padding the text or token input. padded_tokens, true_length = self._process_prefill_content( @@ -895,7 +893,6 @@ async def Decode( # pylint: disable=invalid-overridden-method # Wrap request as an ActiveRequest. active_request = ActiveRequest( max_tokens=request.max_tokens, - history_path=request.session_cache, prefill_content=prefill_content, is_client_side_tokenization=is_client_side_tokenization, return_channel=return_channel, diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 5f2e8869..9fc7076f 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -26,9 +26,6 @@ service Orchestrator { } message DecodeRequest { - // Where to load any pre-existing kv cache from. - string session_cache = 1; - int32 priority = 3; // The maximum output length of a sequence. It's used in JetStream to control // the output/decode length of a sequence. It would not be used in the engine. // We should always set max_tokens <= (max_target_length - @@ -51,7 +48,7 @@ message DecodeRequest { TextContent text_content = 5; TokenContent token_content = 6; } - reserved 2; + reserved 1, 2, 3; // Next ID: 7 } diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index 3fadd54c..07a5f313 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -28,7 +28,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' + b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\x8a\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' ) _globals = globals() @@ -39,23 +39,23 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals["_DECODEREQUEST"]._serialized_start = 58 - _globals["_DECODEREQUEST"]._serialized_end = 353 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 274 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 301 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 303 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 336 - _globals["_DECODERESPONSE"]._serialized_start = 356 - _globals["_DECODERESPONSE"]._serialized_end = 687 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 522 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 538 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 541 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 670 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 629 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 670 - _globals["_HEALTHCHECKREQUEST"]._serialized_start = 689 - _globals["_HEALTHCHECKREQUEST"]._serialized_end = 709 - _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 711 - _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 749 - _globals["_ORCHESTRATOR"]._serialized_start = 752 - _globals["_ORCHESTRATOR"]._serialized_end = 937 + _globals["_DECODEREQUEST"]._serialized_end = 324 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 233 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 260 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 262 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 295 + _globals["_DECODERESPONSE"]._serialized_start = 327 + _globals["_DECODERESPONSE"]._serialized_end = 658 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 493 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 509 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 512 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 641 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 600 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 641 + _globals["_HEALTHCHECKREQUEST"]._serialized_start = 660 + _globals["_HEALTHCHECKREQUEST"]._serialized_end = 680 + _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 682 + _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 720 + _globals["_ORCHESTRATOR"]._serialized_start = 723 + _globals["_ORCHESTRATOR"]._serialized_end = 908 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/tests/core/test_orchestrator.py b/jetstream/tests/core/test_orchestrator.py index 49494bef..00e2e1c1 100644 --- a/jetstream/tests/core/test_orchestrator.py +++ b/jetstream/tests/core/test_orchestrator.py @@ -78,9 +78,7 @@ async def test_orchestrator_interleaved_mode(self): text = "AB" request = jetstream_pb2.DecodeRequest( - session_cache="", text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), - priority=1, max_tokens=3, ) iterator = client.Decode(request) @@ -109,11 +107,9 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self): token_ids = [65, 66] request = jetstream_pb2.DecodeRequest( - session_cache="", token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), - priority=1, max_tokens=3, ) iterator = client.Decode(request) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 731a72b5..9114f2fd 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -93,9 +93,7 @@ async def test_server( # as BOS text = "AB" request = jetstream_pb2.DecodeRequest( - session_cache="", text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), - priority=1, max_tokens=3, ) iterator = stub.Decode(request) diff --git a/jetstream/tools/load_tester.py b/jetstream/tools/load_tester.py index 4d6445be..5f791efd 100644 --- a/jetstream/tools/load_tester.py +++ b/jetstream/tools/load_tester.py @@ -50,14 +50,11 @@ def api_call( stub: jetstream_pb2_grpc.OrchestratorStub, text: str, max_tokens: int, - session_cache: str = "", print_interim: bool = True, ) -> str: """Sends a request to server and returns text.""" request = jetstream_pb2.DecodeRequest( - session_cache=session_cache, text_content=jetstream_pb2.DecodeRequest.TextContent(text=text), - priority=1, max_tokens=max_tokens, ) response = stub.Decode(request) diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index 8fcde556..30d7ac40 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -26,11 +26,7 @@ _SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") _PORT = flags.DEFINE_string("port", "9000", "port to ping") -_SESSION_CACHE = flags.DEFINE_string( - "session_cache", "", "Location of any pre-cached results" -) _TEXT = flags.DEFINE_string("text", "Today is a good day", "The message") -_PRIORITY = flags.DEFINE_integer("priority", 0, "Message priority") _MAX_TOKENS = flags.DEFINE_integer( "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" ) @@ -82,20 +78,16 @@ def main(argv: Sequence[str]) -> None: vocab = load_vocab(_TOKENIZER.value) token_ids = vocab.tokenizer.encode(_TEXT.value) request = jetstream_pb2.DecodeRequest( - session_cache=_SESSION_CACHE.value, token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), - priority=_PRIORITY.value, max_tokens=_MAX_TOKENS.value, ) else: request = jetstream_pb2.DecodeRequest( - session_cache=_SESSION_CACHE.value, text_content=jetstream_pb2.DecodeRequest.TextContent( text=_TEXT.value ), - priority=_PRIORITY.value, max_tokens=_MAX_TOKENS.value, ) return _GetResponseAsync(stub, request)