diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index c86bd7be9f..8a63c6964b 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -326,6 +326,10 @@ def generate( ) elif stop_token_ids == "auto": stop_token_ids = [self.preprocessor.tokenizer.end_token_id] + # Some models like Llama3 use two end tokens: <|eot_id|> in + # "instruct" versions and <|end_of_text|> in others. + if hasattr(self.preprocessor.tokenizer, "end_token2_id"): + stop_token_ids.append(self.preprocessor.tokenizer.end_token2_id) def preprocess(x): return self.preprocessor.generate_preprocess( diff --git a/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py b/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py index b8b45d8fd6..f79be674fb 100644 --- a/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py +++ b/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py @@ -11,6 +11,8 @@ class Llama3CausalLMPreprocessorTest(TestCase): def setUp(self): self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"] + self.vocab += ["<|start_header_id|>", "<|end_header_id|>"] + self.vocab += ["<|eot_id|>"] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] diff --git a/keras_hub/src/models/llama3/llama3_causal_lm_test.py b/keras_hub/src/models/llama3/llama3_causal_lm_test.py index 995c1a00e1..7af0a18f77 100644 --- a/keras_hub/src/models/llama3/llama3_causal_lm_test.py +++ b/keras_hub/src/models/llama3/llama3_causal_lm_test.py @@ -16,6 +16,8 @@ class Llama3CausalLMTest(TestCase): def setUp(self): self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"] + self.vocab += ["<|start_header_id|>", "<|end_header_id|>"] + self.vocab += ["<|eot_id|>"] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] @@ -44,7 +46,7 @@ def test_causal_lm_basics(self): cls=Llama3CausalLM, init_kwargs=self.init_kwargs, train_data=self.train_data, - expected_output_shape=(2, 7, 8), + expected_output_shape=(2, 7, 11), ) def test_generate(self): diff --git a/keras_hub/src/models/llama3/llama3_tokenizer.py b/keras_hub/src/models/llama3/llama3_tokenizer.py index 397b5e1923..ee3037e854 100644 --- a/keras_hub/src/models/llama3/llama3_tokenizer.py +++ b/keras_hub/src/models/llama3/llama3_tokenizer.py @@ -16,10 +16,33 @@ def __init__( self, vocabulary=None, merges=None, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>", + misc_special_tokens={"<|start_header_id|>", "<|end_header_id|>"}, **kwargs, ): - self._add_special_token("<|begin_of_text|>", "start_token") - self._add_special_token("<|end_of_text|>", "end_token") + # Note: all special tokens must also appear in "vocabulary" + + self._add_special_token(bos_token, "start_token") + misc_special_tokens -= {bos_token} + self._add_special_token(eos_token, "end_token") + misc_special_tokens -= {eos_token} + for i, token in enumerate(misc_special_tokens): + self._add_special_token(token, f"special_token_{i:03d}") + + # Hack: + # Llama models use the <|end_of_text|> or the <|eot_id|> as the stop + # token. This info can be read from config when loading a Hugging Face + # checkpoint but no such config exists for Keras checkpoints. + # Setting both probable end tokens when no config is availble will + # make text generation work in all cases as it will stop + # on both end tokens. However, the packer will always use + # "<|end_of_text|>" , which will be the wrong eos_token for "instruct" + # variants of Llama3. + # TODO: load this correctly from a Keras tokenizer config. + if eos_token == "<|end_of_text|>": + self._add_special_token("<|eot_id|>", "end_token2") + self.pad_token_id = 0 super().__init__( vocabulary=vocabulary, diff --git a/keras_hub/src/models/llama3/llama3_tokenizer_test.py b/keras_hub/src/models/llama3/llama3_tokenizer_test.py index 8440d8ebb2..aff591de04 100644 --- a/keras_hub/src/models/llama3/llama3_tokenizer_test.py +++ b/keras_hub/src/models/llama3/llama3_tokenizer_test.py @@ -8,6 +8,8 @@ class Llama3TokenizerTest(TestCase): def setUp(self): self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"] + self.vocab += ["<|start_header_id|>", "<|end_header_id|>"] + self.vocab += ["<|eot_id|>"] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] diff --git a/keras_hub/src/tokenizers/byte_pair_tokenizer.py b/keras_hub/src/tokenizers/byte_pair_tokenizer.py index 41cef2b652..a7447c562e 100644 --- a/keras_hub/src/tokenizers/byte_pair_tokenizer.py +++ b/keras_hub/src/tokenizers/byte_pair_tokenizer.py @@ -43,7 +43,11 @@ SPLIT_PATTERN_1 = SPLIT_PATTERN_1.replace( "{special_spaces}", SPECIAL_WHITESPACES ) -SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$""" + +# The pattern " \t\r\f\v" is the same as \s "all spaces" but without the \n. +# Multiple \n\n\n in sequence must not be split for Llama3. +# SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$""" +SPLIT_PATTERN_2 = rf"""[ \t\r\f\v६{SPECIAL_WHITESPACES}]$""" def create_alts_for_unsplittable_tokens(unsplittable_tokens): diff --git a/keras_hub/src/utils/transformers/convert_llama3.py b/keras_hub/src/utils/transformers/convert_llama3.py index 08e982e862..75c7eb801c 100644 --- a/keras_hub/src/utils/transformers/convert_llama3.py +++ b/keras_hub/src/utils/transformers/convert_llama3.py @@ -107,10 +107,26 @@ def convert_tokenizer(cls, preset, **kwargs): vocab = tokenizer_config["model"]["vocab"] merges = tokenizer_config["model"]["merges"] - bot = tokenizer_config["added_tokens"][0] # begin of text - eot = tokenizer_config["added_tokens"][1] # end of text - - vocab[bot["content"]] = bot["id"] - vocab[eot["content"]] = eot["id"] + # Load all special tokens with the exception of "reserved" ones. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + # Load text start and stop tokens from the config. + # Llama3 uses the <|end_of_text|> end token for regular models + # but uses <|eot_id|> for instruction-tuned variants. + tokenizer_config2 = load_json(preset, "tokenizer_config.json") + bos_token = tokenizer_config2["bos_token"] + eos_token = tokenizer_config2["eos_token"] + + kwargs.update( + { + "bos_token": bos_token, + "eos_token": eos_token, + "misc_special_tokens": special_tokens, + } + ) return cls(vocabulary=vocab, merges=merges, **kwargs)