Skip to content

Commit

Permalink
BytePairTokenizer must not split sequences of \n (#1910)
Browse files Browse the repository at this point in the history
* fix for loading of special tokens in Llama tokenizer

* fix for Llama tokenizer which can have multiple end tokens

* bug fix

* adding some missing tokens to Llama3 tokenizer

* fixed tests and Llama3Tokenizer init.

* now loading correct eos_token config from Hugging Face checkpoint. Using hack for Keras checkpoint because it does not have this info

* fix for BytePairTokenizer to make Lllama3-instruct work in chat: \n\n sequences are significant in the chat template and must be preserved by the tokenizer

---------

Co-authored-by: Martin Görner <martin@huggingface.co>
  • Loading branch information
martin-gorner and martin-gorner authored Oct 7, 2024
1 parent f25c8ff commit ad66dc2
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 9 deletions.
4 changes: 4 additions & 0 deletions keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ def generate(
)
elif stop_token_ids == "auto":
stop_token_ids = [self.preprocessor.tokenizer.end_token_id]
# Some models like Llama3 use two end tokens: <|eot_id|> in
# "instruct" versions and <|end_of_text|> in others.
if hasattr(self.preprocessor.tokenizer, "end_token2_id"):
stop_token_ids.append(self.preprocessor.tokenizer.end_token2_id)

def preprocess(x):
return self.preprocessor.generate_preprocess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class Llama3CausalLMPreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
self.vocab += ["<|eot_id|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/llama3/llama3_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Llama3CausalLMTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
self.vocab += ["<|eot_id|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
Expand Down Expand Up @@ -44,7 +46,7 @@ def test_causal_lm_basics(self):
cls=Llama3CausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
expected_output_shape=(2, 7, 8),
expected_output_shape=(2, 7, 11),
)

def test_generate(self):
Expand Down
27 changes: 25 additions & 2 deletions keras_hub/src/models/llama3/llama3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,33 @@ def __init__(
self,
vocabulary=None,
merges=None,
bos_token="<|begin_of_text|>",
eos_token="<|end_of_text|>",
misc_special_tokens={"<|start_header_id|>", "<|end_header_id|>"},
**kwargs,
):
self._add_special_token("<|begin_of_text|>", "start_token")
self._add_special_token("<|end_of_text|>", "end_token")
# Note: all special tokens must also appear in "vocabulary"

self._add_special_token(bos_token, "start_token")
misc_special_tokens -= {bos_token}
self._add_special_token(eos_token, "end_token")
misc_special_tokens -= {eos_token}
for i, token in enumerate(misc_special_tokens):
self._add_special_token(token, f"special_token_{i:03d}")

# Hack:
# Llama models use the <|end_of_text|> or the <|eot_id|> as the stop
# token. This info can be read from config when loading a Hugging Face
# checkpoint but no such config exists for Keras checkpoints.
# Setting both probable end tokens when no config is availble will
# make text generation work in all cases as it will stop
# on both end tokens. However, the packer will always use
# "<|end_of_text|>" , which will be the wrong eos_token for "instruct"
# variants of Llama3.
# TODO: load this correctly from a Keras tokenizer config.
if eos_token == "<|end_of_text|>":
self._add_special_token("<|eot_id|>", "end_token2")

self.pad_token_id = 0
super().__init__(
vocabulary=vocabulary,
Expand Down
2 changes: 2 additions & 0 deletions keras_hub/src/models/llama3/llama3_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class Llama3TokenizerTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"]
self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
self.vocab += ["<|eot_id|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
Expand Down
6 changes: 5 additions & 1 deletion keras_hub/src/tokenizers/byte_pair_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
SPLIT_PATTERN_1 = SPLIT_PATTERN_1.replace(
"{special_spaces}", SPECIAL_WHITESPACES
)
SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""

# The pattern " \t\r\f\v" is the same as \s "all spaces" but without the \n.
# Multiple \n\n\n in sequence must not be split for Llama3.
# SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""
SPLIT_PATTERN_2 = rf"""[ \t\r\f\v६{SPECIAL_WHITESPACES}]$"""


def create_alts_for_unsplittable_tokens(unsplittable_tokens):
Expand Down
26 changes: 21 additions & 5 deletions keras_hub/src/utils/transformers/convert_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,26 @@ def convert_tokenizer(cls, preset, **kwargs):
vocab = tokenizer_config["model"]["vocab"]
merges = tokenizer_config["model"]["merges"]

bot = tokenizer_config["added_tokens"][0] # begin of text
eot = tokenizer_config["added_tokens"][1] # end of text

vocab[bot["content"]] = bot["id"]
vocab[eot["content"]] = eot["id"]
# Load all special tokens with the exception of "reserved" ones.
special_tokens = set()
for token in tokenizer_config["added_tokens"]:
if not token["content"].startswith("<|reserved_special_token_"):
vocab[token["content"]] = token["id"]
special_tokens.add(token["content"])

# Load text start and stop tokens from the config.
# Llama3 uses the <|end_of_text|> end token for regular models
# but uses <|eot_id|> for instruction-tuned variants.
tokenizer_config2 = load_json(preset, "tokenizer_config.json")
bos_token = tokenizer_config2["bos_token"]
eos_token = tokenizer_config2["eos_token"]

kwargs.update(
{
"bos_token": bos_token,
"eos_token": eos_token,
"misc_special_tokens": special_tokens,
}
)

return cls(vocabulary=vocab, merges=merges, **kwargs)

0 comments on commit ad66dc2

Please sign in to comment.