Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream update #755

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,6 @@ def tokenize_and_concatenate(

Returns:
Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"

Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why
"""
dataset = keep_single_column(dataset, column_name)
if tokenizer.pad_token is None:
Expand All @@ -329,6 +327,11 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
text = examples[column_name]
# Concatenate it all into an enormous string, separated by eos_tokens
full_text = tokenizer.eos_token.join(text)

# Handle the case when full_text is empty
if not full_text.strip():
return {"tokens": np.array([], dtype=np.int64)}

# Divide into 20 chunks of ~ equal length
num_chunks = 20
chunk_length = (len(full_text) - 1) // num_chunks + 1
Expand All @@ -338,9 +341,21 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
# Drop padding tokens
tokens = tokens[tokens != tokenizer.pad_token_id]
num_tokens = len(tokens)
num_batches = num_tokens // (seq_len)
# Drop the final tokens if not enough to make a full sequence
tokens = tokens[: seq_len * num_batches]

# Handle cases where num_tokens is less than seq_len
if num_tokens < seq_len:
num_batches = 1
# Pad tokens if necessary
tokens = tokens[:seq_len]
if len(tokens) < seq_len:
padding_length = seq_len - len(tokens)
padding = np.full(padding_length, tokenizer.pad_token_id)
tokens = np.concatenate([tokens, padding], axis=0)
else:
num_batches = num_tokens // seq_len
# Drop the final tokens if not enough to make a full sequence
tokens = tokens[: seq_len * num_batches]

tokens = einops.rearrange(
tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
)
Expand Down
Loading