Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BytePairTokenizer must not split sequences of \n #1910

Merged
merged 9 commits into from
Oct 7, 2024
4 changes: 4 additions & 0 deletions keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ def generate(
)
elif stop_token_ids == "auto":
stop_token_ids = [self.preprocessor.tokenizer.end_token_id]
# Some models like Llama3 use two end tokens: <|eot_id|> in
# "instruct" versions and <|end_of_text|> in others.
if hasattr(self.preprocessor.tokenizer, "end_token2_id"):
stop_token_ids.append(self.preprocessor.tokenizer.end_token2_id)

def preprocess(x):
return self.preprocessor.generate_preprocess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class Llama3CausalLMPreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
self.vocab += ["<|eot_id|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/llama3/llama3_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Llama3CausalLMTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
self.vocab += ["<|eot_id|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
Expand Down Expand Up @@ -44,7 +46,7 @@ def test_causal_lm_basics(self):
cls=Llama3CausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
expected_output_shape=(2, 7, 8),
expected_output_shape=(2, 7, 11),
)

def test_generate(self):
Expand Down
27 changes: 25 additions & 2 deletions keras_hub/src/models/llama3/llama3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,33 @@ def __init__(
self,
vocabulary=None,
merges=None,
bos_token="<|begin_of_text|>",
eos_token="<|end_of_text|>",
misc_special_tokens={"<|start_header_id|>", "<|end_header_id|>"},
**kwargs,
):
self._add_special_token("<|begin_of_text|>", "start_token")
self._add_special_token("<|end_of_text|>", "end_token")
# Note: all special tokens must also appear in "vocabulary"

self._add_special_token(bos_token, "start_token")
misc_special_tokens -= {bos_token}
self._add_special_token(eos_token, "end_token")
misc_special_tokens -= {eos_token}
for i, token in enumerate(misc_special_tokens):
self._add_special_token(token, f"special_token_{i:03d}")

# Hack:
# Llama models use the <|end_of_text|> or the <|eot_id|> as the stop
# token. This info can be read from config when loading a Hugging Face
# checkpoint but no such config exists for Keras checkpoints.
# Setting both probable end tokens when no config is availble will
# make text generation work in all cases as it will stop
# on both end tokens. However, the packer will always use
# "<|end_of_text|>" , which will be the wrong eos_token for "instruct"
# variants of Llama3.
# TODO: load this correctly from a Keras tokenizer config.
if eos_token == "<|end_of_text|>":
self._add_special_token("<|eot_id|>", "end_token2")

self.pad_token_id = 0
super().__init__(
vocabulary=vocabulary,
Expand Down
2 changes: 2 additions & 0 deletions keras_hub/src/models/llama3/llama3_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class Llama3TokenizerTest(TestCase):
def setUp(self):
self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"]
self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
self.vocab += ["<|eot_id|>"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
Expand Down
6 changes: 5 additions & 1 deletion keras_hub/src/tokenizers/byte_pair_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
SPLIT_PATTERN_1 = SPLIT_PATTERN_1.replace(
"{special_spaces}", SPECIAL_WHITESPACES
)
SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""

# The pattern " \t\r\f\v" is the same as \s "all spaces" but without the \n.
# Multiple \n\n\n in sequence must not be split for Llama3.
# SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""
SPLIT_PATTERN_2 = rf"""[ \t\r\f\v६{SPECIAL_WHITESPACES}]$"""


def create_alts_for_unsplittable_tokens(unsplittable_tokens):
Expand Down
26 changes: 21 additions & 5 deletions keras_hub/src/utils/transformers/convert_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,26 @@ def convert_tokenizer(cls, preset, **kwargs):
vocab = tokenizer_config["model"]["vocab"]
merges = tokenizer_config["model"]["merges"]

bot = tokenizer_config["added_tokens"][0] # begin of text
eot = tokenizer_config["added_tokens"][1] # end of text

vocab[bot["content"]] = bot["id"]
vocab[eot["content"]] = eot["id"]
# Load all special tokens with the exception of "reserved" ones.
special_tokens = set()
for token in tokenizer_config["added_tokens"]:
if not token["content"].startswith("<|reserved_special_token_"):
vocab[token["content"]] = token["id"]
special_tokens.add(token["content"])

# Load text start and stop tokens from the config.
# Llama3 uses the <|end_of_text|> end token for regular models
# but uses <|eot_id|> for instruction-tuned variants.
tokenizer_config2 = load_json(preset, "tokenizer_config.json")
bos_token = tokenizer_config2["bos_token"]
eos_token = tokenizer_config2["eos_token"]

kwargs.update(
{
"bos_token": bos_token,
"eos_token": eos_token,
"misc_special_tokens": special_tokens,
}
)

return cls(vocabulary=vocab, merges=merges, **kwargs)
Loading