From a0df320cb4a720bf123e2579973034628dbe3344 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Wed, 24 Apr 2024 15:25:48 -0700 Subject: [PATCH] Align Tokenizer in JetStream (#40) * Align Tokenizer in JetStream * Update requirements with pytest dep * Remove mix_decode unit test --- .github/workflows/unit_tests.yaml | 3 +- README.md | 8 +-- benchmarks/benchmark_serving.py | 65 +++++++++++----------- benchmarks/requirements.in | 3 +- jetstream/core/orchestrator.py | 11 +++- jetstream/core/proto/jetstream.proto | 8 ++- jetstream/core/proto/jetstream_pb2.py | 10 ++-- jetstream/engine/mock_utils.py | 8 ++- jetstream/engine/token_utils.py | 34 ++--------- jetstream/tests/core/test_orchestrator.py | 10 ++-- jetstream/tests/core/test_server.py | 54 +++++++++--------- jetstream/tests/engine/test_token_utils.py | 24 -------- jetstream/tests/engine/test_utils.py | 10 ++-- jetstream/tools/requester.py | 17 ++++-- requirements.in | 3 +- requirements.txt | 11 ++++ 16 files changed, 136 insertions(+), 143 deletions(-) diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 5fa0a2b1..12bbb4f0 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -47,9 +47,10 @@ jobs: pip install pylint pip install pyink pip install -r requirements.txt + pip install -r benchmarks/requirements.in - name: Typecheck the code with pytype run: | - pytype --jobs auto --disable import-error --disable module-attr jetstream/ + pytype --jobs auto --disable import-error --disable module-attr jetstream/ benchmarks/ - name: Analysing the code with pylint run: | pylint jetstream/ benchmarks/ diff --git a/README.md b/README.md index b92a4fea..7bf28610 100644 --- a/README.md +++ b/README.md @@ -57,15 +57,15 @@ python -m jetstream.tools.load_tester ### Test core modules ``` # Test JetStream core orchestrator -python -m jetstream.core.orchestrator_test +python -m jetstream.tests.core.test_orchestrator # Test JetStream core server library -python -m jetstream.core.server_test +python -m jetstream.tests.core.test_server # Test mock JetStream engine implementation -python -m jetstream.engine.mock_engine_test +python -m jetstream.tests.engine.test_mock_engine # Test mock JetStream token utils -python -m jetstream.engine.utils_test +python -m jetstream.tests.engine.test_utils ``` diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index a09cc66c..46fd05e4 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -65,15 +65,14 @@ import json import random import time -from typing import Any, AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, Optional import grpc from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine.token_utils import load_vocab import numpy as np -import tensorflow as tf -import tensorflow_text as tftxt -from tqdm.asyncio import tqdm +from tqdm.asyncio import tqdm # pytype: disable=pyi-error from eval_accuracy import eval_accuracy @@ -106,9 +105,9 @@ class InputRequest: @dataclass class RequestFuncOutput: - input_request: InputRequest = None - generated_token_list: list[str] = None - generated_text: str = None + input_request: Optional[InputRequest] = None + generated_token_list: list[str] = [] + generated_text: str = "" success: bool = False latency: float = 0 ttft: float = 0 @@ -132,18 +131,16 @@ def get_tokenizer(tokenizer_name: str) -> Any: if tokenizer_name == "test": return "test" else: - with tf.io.gfile.GFile(tokenizer_name, "rb") as model_fp: - sp_model = model_fp.read() - sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=True, add_eos=False, reverse=False - ) - return sp_tokenizer + # Use JetStream tokenizer util. It's using the sentencepiece wrapper in + # seqio library. + vocab = load_vocab(tokenizer_name) + return vocab.tokenizer def load_sharegpt_dataset( dataset_path: str, conversation_starter: str, -) -> List[tuple[str]]: +) -> list[tuple[Any, Any]]: # Load the dataset. with open(dataset_path, "r", encoding="utf-8") as f: dataset = json.load(f) @@ -166,7 +163,7 @@ def load_sharegpt_dataset( return dataset -def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]: +def load_openorca_dataset(dataset_path: str) -> list[tuple[Any, Any]]: # Load the dataset. with open(dataset_path, "r", encoding="utf-8") as f: dataset = json.load(f) @@ -179,9 +176,9 @@ def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]: def tokenize_dataset( - dataset: List[tuple[str]], + dataset: list[tuple[Any, Any, Any]], tokenizer: Any, -) -> List[tuple[Any]]: +) -> list[tuple[str, Any, str, int, int, int]]: n = len(dataset) @@ -194,10 +191,10 @@ def tokenize_dataset( outputs.append(output) indices.append(idx) - prompt_token_ids = tokenizer.tokenize( + prompt_token_ids = tokenizer.encode( prompts ) # adjust this code based on tokenizer method - outputs_token_ids = tokenizer.tokenize( + outputs_token_ids = tokenizer.encode( outputs ) # adjust this code based on tokenizer method @@ -218,8 +215,9 @@ def tokenize_dataset( def filter_dataset( - tokenized_dataset: List[tuple[Any]], max_output_length: Optional[int] = None -) -> List[InputRequest]: + tokenized_dataset: list[tuple[str, Any, str, int, int, int]], + max_output_length: Optional[int] = None, +) -> list[InputRequest]: if max_output_length is None: print("In InputRequest, pass in actual output_length for each sample") else: @@ -229,7 +227,7 @@ def filter_dataset( ) # Filter out too long sequences. - filtered_dataset: List[InputRequest] = [] + filtered_dataset: list[InputRequest] = [] for ( prompt, _, @@ -258,12 +256,12 @@ def filter_dataset( def sample_requests( - dataset: List[tuple[str]], + dataset: list[tuple[Any, Any]], tokenizer: Any, num_requests: int, max_output_length: Optional[int] = None, oversample_multiplier: float = 1.2, -) -> List[InputRequest]: +) -> list[InputRequest]: # Original dataset size n = len(dataset) @@ -304,7 +302,7 @@ def sample_requests( async def get_request( - input_requests: List[InputRequest], + input_requests: list[InputRequest], request_rate: float, ) -> AsyncGenerator[InputRequest, None]: input_requests = iter(input_requests) @@ -321,8 +319,8 @@ async def get_request( def calculate_metrics( - input_requests: List[InputRequest], - outputs: List[RequestFuncOutput], + input_requests: list[InputRequest], + outputs: list[RequestFuncOutput], dur_s: float, tokenizer: Any, ) -> BenchmarkMetrics: @@ -374,16 +372,17 @@ async def grpc_async_request( token_list = [] request_start_time = time.perf_counter() response = stub.Decode(request) - async for token in response: + async for sample_list in response: if ttft == 0: ttft = time.perf_counter() - request_start_time - token_list.append(token.response[0]) + token_list.extend(sample_list.response[0].token_ids) latency = time.perf_counter() - request_start_time return token_list, ttft, latency async def send_request( api_url: str, + tokenizer: Any, input_request: InputRequest, pbar: tqdm, session_cache: str, @@ -405,7 +404,8 @@ async def send_request( output.ttft = ttft output.latency = latency output.generated_token_list = generated_token_list - output.generated_text = "".join(generated_token_list) + # generated_token_list is a list of token ids, decode it to generated_text. + output.generated_text = tokenizer.decode(generated_token_list) output.success = True if pbar: pbar.update(1) @@ -415,7 +415,7 @@ async def send_request( async def benchmark( api_url: str, tokenizer: Any, - input_requests: List[InputRequest], + input_requests: list[InputRequest], request_rate: float, disable_tqdm: bool, session_cache: str, @@ -433,6 +433,7 @@ async def benchmark( asyncio.create_task( send_request( api_url=api_url, + tokenizer=tokenizer, input_request=request, pbar=pbar, session_cache=session_cache, @@ -442,7 +443,7 @@ async def benchmark( ) outputs = await asyncio.gather(*tasks) - if not disable_tqdm: + if not disable_tqdm and pbar: pbar.close() benchmark_duration = time.perf_counter() - benchmark_start_time diff --git a/benchmarks/requirements.in b/benchmarks/requirements.in index 7e19e557..b49edb00 100644 --- a/benchmarks/requirements.in +++ b/benchmarks/requirements.in @@ -1,3 +1,4 @@ nltk evaluate -rouge-score \ No newline at end of file +rouge-score +tqdm \ No newline at end of file diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index f2b8797f..6e5dbf55 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -127,7 +127,7 @@ class ActiveRequest: # We keep prefill and decode information together in the same object so that # there is less indirection about where this return channel is. # The return channel returns a list of strings, one per sample for that query. - return_channel: async_multifuture.AsyncMultifuture[list[str]] + return_channel: async_multifuture.AsyncMultifuture[list[list[int]]] # [num_samples,] which corresponds to whether each sample is complete for the # requests. complete: Optional[np.ndarray] = None @@ -139,7 +139,7 @@ class ActiveRequest: # Which generate step this was added at. generate_timestep_added: Optional[int] = None - def enqueue_tokens(self, generated_tokens: list[str]): + def enqueue_tokens(self, generated_tokens: list[list[int]]): """Records information about the step. Args: @@ -662,4 +662,9 @@ async def Decode( # pylint: disable=invalid-overridden-method # The DecodeResponse stream should consume all generated tokens in # return_channel when complete signal is received. It should check if # return_channel is empty to decide if it should exit the while loop. - yield jetstream_pb2.DecodeResponse(response=response) + repeated_token_ids = [] + for token_ids in response: + repeated_token_ids.append( + jetstream_pb2.RepeatedTokenIds(token_ids=token_ids) + ) + yield jetstream_pb2.DecodeResponse(response=repeated_token_ids) diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 228ddd05..78ab88ee 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -37,6 +37,10 @@ message DecodeRequest { int32 max_tokens = 4; } message DecodeResponse { - // List of responses, one per sample. - repeated string response = 1; + // List of responses, one per sample. The list size depends on text generation strategy the engine used. + repeated RepeatedTokenIds response = 1; } +message RepeatedTokenIds { + // List of token ids, one list per sample. When speculative decoding is disabled, the list size should be 1; When speculative decoding is enabled, the list size should be >= 1. + repeated int32 token_ids = 1; +} \ No newline at end of file diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index 4d01a307..a799ce8e 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -28,7 +28,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"e\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x17\n\x0f\x61\x64\x64itional_text\x18\x02 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05""\n\x0e\x44\x65\x63odeResponse\x12\x10\n\x08response\x18\x01 \x03(\t2]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3' + b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"e\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x17\n\x0f\x61\x64\x64itional_text\x18\x02 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05"E\n\x0e\x44\x65\x63odeResponse\x12\x33\n\x08response\x18\x01 \x03(\x0b\x32!.jetstream_proto.RepeatedTokenIds"%\n\x10RepeatedTokenIds\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x32]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3' ) _globals = globals() @@ -41,7 +41,9 @@ _globals["_DECODEREQUEST"]._serialized_start = 57 _globals["_DECODEREQUEST"]._serialized_end = 158 _globals["_DECODERESPONSE"]._serialized_start = 160 - _globals["_DECODERESPONSE"]._serialized_end = 194 - _globals["_ORCHESTRATOR"]._serialized_start = 196 - _globals["_ORCHESTRATOR"]._serialized_end = 289 + _globals["_DECODERESPONSE"]._serialized_end = 229 + _globals["_REPEATEDTOKENIDS"]._serialized_start = 231 + _globals["_REPEATEDTOKENIDS"]._serialized_end = 268 + _globals["_ORCHESTRATOR"]._serialized_start = 270 + _globals["_ORCHESTRATOR"]._serialized_end = 363 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/engine/mock_utils.py b/jetstream/engine/mock_utils.py index d52d7895..5ccc3b45 100644 --- a/jetstream/engine/mock_utils.py +++ b/jetstream/engine/mock_utils.py @@ -62,9 +62,7 @@ def _encode(self, s: str) -> Sequence[int]: def _decode(self, ids: np.ndarray): """Converts a numpy array into a string.""" - # 'We use array methods, not python iterables so we don't - # implement this method in the mock vocab. - raise NotImplementedError + return "".join([chr(r) for r in list(ids)]) def _encode_tf(self, s: str) -> np.ndarray: """Converts a string into a numpy array.""" @@ -78,6 +76,10 @@ def _decode_tf(self, ids: np.ndarray) -> List[str]: results = np.split(ids, ids.shape[0]) return ["".join([chr(r) for r in list(line[0])]) for line in results] + def decode(self, ids: np.ndarray): + """Converts a numpy array into a string.""" + return self._decode(ids) + def encode_tf(self, s: str) -> np.ndarray: """Converts a string into a numpy array.""" return self._encode_tf(s) diff --git a/jetstream/engine/token_utils.py b/jetstream/engine/token_utils.py index 4a096d94..f7d4b1eb 100644 --- a/jetstream/engine/token_utils.py +++ b/jetstream/engine/token_utils.py @@ -28,21 +28,6 @@ from jetstream.engine import mock_utils -def mix_decode(vocab: Vocabulary, tok_id: int): - """ - The IdToPiece and decode results differ for 344 tokens in Llama2. - Use the decode function to generate the correct strings for these 344 tokens. - If IdToPiece returns a hex string (e.g., '<0x0A>') for a token within these - 344, utilize IdToPiece to convert it into a string, likely with a space - placeholder (' ') for the corresponding tokens. - """ - p_token = vocab.tokenizer.IdToPiece(tok_id) - # SentencePiece escapes the whitespace with a meta symbol "▁" (U+2581) - p_token = p_token.replace("▁", " ") - d_token = vocab.tokenizer.decode([tok_id]) - return p_token if p_token.lstrip() == d_token else d_token - - def take_nearest_length(lengths: list[int], length: int) -> int: """Gets the nearest length to the right in a set of lengths.""" pos = bisect_left(lengths, length) @@ -131,7 +116,7 @@ def process_result_tokens( vocab: Vocabulary, complete: np.ndarray, debug: bool = False, -) -> Tuple[List[str], np.ndarray]: +) -> Tuple[List[List[int]], np.ndarray]: """Processes a result tokens into a list of strings, handling multiple samples. @@ -145,7 +130,7 @@ def process_result_tokens( debug: Whether to log step by step detokenisation. Returns: - sample_return: List of strings, one per sample. + sample_return: List of tok_id list, one list per sample. complete: Updated complete. """ # tokens: [samples, speculations] @@ -166,7 +151,7 @@ def process_result_tokens( ) sample_return = [] for idx in range(samples): - string_so_far = "" + tok_id_so_far = [] if not complete[idx].item(): for spec_idx in range(speculations): tok_id = slot_tokens[idx, spec_idx].item() @@ -182,17 +167,8 @@ def process_result_tokens( complete[idx] = True break else: - try: - # pytype: disable=attribute-error - token = mix_decode(vocab, tok_id) - except ValueError: - # This error only occurs when using tests where the vocab range is - # computed via addition and int->char is computed using chr(). Real - # models have vocab logits which are at max the size of the vocab. - logging.warning("%d exceeded vocab range", tok_id) - token = "" - string_so_far += token - sample_return.append(string_so_far) + tok_id_so_far.append(tok_id) + sample_return.append(tok_id_so_far) if debug: logging.info("Sampled return %s", str(sample_return)) return sample_return, complete diff --git a/jetstream/tests/core/test_orchestrator.py b/jetstream/tests/core/test_orchestrator.py index 48b21c0a..d899bd05 100644 --- a/jetstream/tests/core/test_orchestrator.py +++ b/jetstream/tests/core/test_orchestrator.py @@ -41,6 +41,7 @@ tokenizer returns). """ +import pytest from absl.testing import absltest from jetstream.core import orchestrator from jetstream.core.proto import jetstream_pb2 @@ -66,6 +67,7 @@ def _setup_driver(self): ) return driver + @pytest.mark.asyncio async def test_orchestrator(self): """Test the multithreaded orchestration.""" driver = self._setup_driver() @@ -87,12 +89,10 @@ async def test_orchestrator(self): counter = 0 async for token in iterator: # Tokens come through as bytes. - print( - "actual output: " - + bytes(token.response[0], encoding="utf-8").decode() - ) + output_token = await token.response[0].token_id[0] + print("actual output: " + bytes(output_token, encoding="utf-8").decode()) assert ( - bytes(token.response[0], encoding="utf-8").decode() + bytes(output_token, encoding="utf-8").decode() == expected_tokens[counter] ) counter += 1 diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 6996f4e8..3e9cf782 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -22,6 +22,7 @@ from absl.testing import absltest, parameterized import grpc +import pytest from jetstream.core import config_lib from jetstream.core import server_lib from jetstream.core.proto import jetstream_pb2 @@ -45,7 +46,8 @@ class ServerTest(parameterized.TestCase): [None], ), ) - def test_server( + @pytest.mark.asyncio + async def test_server( self, config: Type[config_lib.ServerConfig], expected_tokens: list[str], @@ -64,34 +66,34 @@ def test_server( credentials=credentials, ) ###################### Requester side ###################################### - channel = grpc.secure_channel( + async with grpc.aio.secure_channel( f"localhost:{port}", grpc.local_channel_credentials() - ) - stub = jetstream_pb2_grpc.OrchestratorStub(channel) + ) as channel: + stub = jetstream_pb2_grpc.OrchestratorStub(channel) - # The string representation of np.array([[65, 66]]), [2] will be prependd - # as BOS - text = "AB" - request = jetstream_pb2.DecodeRequest( - session_cache="", - additional_text=text, - priority=1, - max_tokens=3, - ) - iterator = stub.Decode(request) - counter = 0 - for token in iterator: - # Tokens come through as bytes - print( - "actual output: " - + bytes(token.response[0], encoding="utf-8").decode() - ) - assert ( - bytes(token.response[0], encoding="utf-8").decode() - == expected_tokens[counter] + # The string representation of np.array([[65, 66]]), [2] will be prependd + # as BOS + text = "AB" + request = jetstream_pb2.DecodeRequest( + session_cache="", + additional_text=text, + priority=1, + max_tokens=3, ) - counter += 1 - server.stop() + iterator = stub.Decode(request) + counter = 0 + async for token in iterator: + # Tokens come through as bytes + output_token = await token.response[0].token_id[0] + print( + "actual output: " + bytes(output_token, encoding="utf-8").decode() + ) + assert ( + bytes(output_token, encoding="utf-8").decode() + == expected_tokens[counter] + ) + counter += 1 + server.stop() if __name__ == "__main__": diff --git a/jetstream/tests/engine/test_token_utils.py b/jetstream/tests/engine/test_token_utils.py index 419e4e06..96cad029 100644 --- a/jetstream/tests/engine/test_token_utils.py +++ b/jetstream/tests/engine/test_token_utils.py @@ -75,22 +75,6 @@ def test_decode_vs_piece(self): self.assertNotEqual(jt_output, expeted_sp_output) - def test_mix_decode(self): - self.setup() - for n in range(0, self.sp_tokenizer.tokenizer.vocab_size()): - # From decode function - decode_output = self.sp_tokenizer.decode([n]) - # From IdToPiece function - piece_output = self.jt_tokenizer.decode(n) - # Mix output from decode and IdToPiece - mix_output = token_utils.mix_decode( - vocab=self.jt_tokenizer.vocab, tok_id=n - ) - if piece_output.lstrip() == decode_output: - self.assertEqual(mix_output, piece_output) - else: - self.assertEqual(mix_output, decode_output) - def test_sp_vs_seqio(self): self.setup() for n in range(0, self.sp_tokenizer.tokenizer.vocab_size()): @@ -98,14 +82,6 @@ def test_sp_vs_seqio(self): seqio_t = self.jt_tokenizer.vocab.tokenizer.decode([n]) self.assertEqual(sp_t, seqio_t) - def test_underscore_in_output(self): - self.setup() - n = 21326 - mix_output = token_utils.mix_decode(vocab=self.jt_tokenizer.vocab, tok_id=n) - decode_output = self.sp_tokenizer.decode([n]) - self.assertEqual(mix_output, " `__") - self.assertEqual(mix_output.lstrip(), decode_output) - def test_tokenize_and_pad_jax(self): jax.config.update("jax_platform_name", "cpu") self.setup() diff --git a/jetstream/tests/engine/test_utils.py b/jetstream/tests/engine/test_utils.py index 74819f51..e4cc761f 100644 --- a/jetstream/tests/engine/test_utils.py +++ b/jetstream/tests/engine/test_utils.py @@ -64,8 +64,9 @@ def test_speculations_with_multi_sample_slots(self, samples_per_slot=2): ) np.testing.assert_equal(complete, np.array([1, 0])) - assert not per_channel[0] # i.e. == '', because of the pad. - assert per_channel[1] == "AD" + text_output = [mock_utils.TestVocab().decode(row) for row in per_channel] + assert not text_output[0] # i.e. == '', because of the pad. + assert text_output[1] == "AD" mock_complete = np.zeros( (mock_tokens.shape[0] // samples_per_slot), dtype=np.int32 ) @@ -76,8 +77,9 @@ def test_speculations_with_multi_sample_slots(self, samples_per_slot=2): vocab=mock_utils.TestVocab(), complete=mock_complete, ) - assert per_channel[0] == "T3" - assert per_channel[1] == "A" # second token is padded. + text_output = [mock_utils.TestVocab().decode(row) for row in per_channel] + assert text_output[0] == "T3" + assert text_output[1] == "A" # second token is padded. np.testing.assert_equal(complete, np.array([0, 1])) diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index fd3cc633..5c189aba 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -21,6 +21,7 @@ import grpc from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine.token_utils import load_vocab _SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address") @@ -33,6 +34,12 @@ _MAX_TOKENS = flags.DEFINE_integer( "max_tokens", 3, "Maximum number of output/decode tokens of a sequence" ) +_TOKENIZER = flags.DEFINE_string( + "tokenizer", + "", + "Name or path of the tokenizer (matched to the model)", + required=True, +) def _GetResponseAsync( @@ -42,11 +49,13 @@ def _GetResponseAsync( """Gets an async response.""" response = stub.Decode(request) - output = "" - for token_list in response: - output += token_list.response[0] + output = [] + for sample_list in response: + output.extend(sample_list.response[0].token_ids) + vocab = load_vocab(_TOKENIZER.value) + text_output = vocab.tokenizer.decode(output) print(f"Prompt: {_TEXT.value}") - print(f"Response: {output}") + print(f"Response: {text_output}") def main(argv: Sequence[str]) -> None: diff --git a/requirements.in b/requirements.in index ae59e1b8..f500da6b 100644 --- a/requirements.in +++ b/requirements.in @@ -4,6 +4,7 @@ flax grpcio jax jaxlib -portpicker numpy +portpicker +pytest seqio \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 54a2dcfb..92217cfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,6 +53,8 @@ etils[array-types,enp,epath,epy,etqdm,etree]==1.6.0 # clu # orbax-checkpoint # tfds-nightly +exceptiongroup==1.2.0 + # via pytest flatbuffers==23.5.26 # via tensorflow flax==0.8.0 @@ -84,6 +86,8 @@ idna==3.7 # via requests importlib-resources==6.1.1 # via etils +iniconfig==2.0.0 + # via pytest jax==0.4.23 # via # -r requirements.in @@ -160,8 +164,11 @@ orbax-checkpoint==0.5.2 packaging==23.2 # via # clu + # pytest # seqio # tensorflow +pluggy==1.4.0 + # via pytest portpicker==1.6.0 # via -r requirements.in promise==2.3 @@ -190,6 +197,8 @@ pyglove==0.4.4 # via seqio pygments==2.17.2 # via rich +pytest==8.1.1 + # via -r requirements.in pyyaml==6.0.1 # via # flax @@ -251,6 +260,8 @@ tfds-nightly==4.9.2.dev202308090034 # via seqio toml==0.10.2 # via tfds-nightly +tomli==2.0.1 + # via pytest toolz==0.12.1 # via chex tqdm==4.66.1