Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Bert support #829

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from

Conversation

degenfabian
Copy link
Contributor

@degenfabian degenfabian commented Jan 6, 2025

Description

The current implementation of BERT only implements MaskedLanguageModelling and has a certain number of other limitations, like only being able to take tokens as input or only supporting the model "bert-base-cased". This PR intends to enhance the BERT support of TransformerLens by addressing these limitations.

Features added include:

  • Next Sentence Prediction
  • Accepting strings and lists of strings as input and not only tokens
  • Directly return human-readable predictions instead of only logits
  • Support for Bert models "bert-base-uncased", "bert-large-uncased", "bert-large-cased"

I also extensively added and edited documentation and the existing BERT notebook.

There is no specific issue attached to this PR.

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@degenfabian degenfabian changed the base branch from main to dev January 6, 2025 23:32
@degenfabian degenfabian marked this pull request as ready for review January 7, 2025 22:02
Comment on lines +143 to +150
) -> Optional[
Union[
Float[torch.Tensor, "batch pos d_vocab"],
Float[torch.Tensor, "batch 2"],
str,
List[str],
]
]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything that can be done to simplify this return type? This function could return 5 different states, and the worst part of that is that two of the potential states are the exact same type of data, but in different shapes. I understand that they will be used in completely different contexts, but it's for that reason that I am kind of wondering why we don't do a bit more with this, and create two distinct modules for the two different use cases. All of the shared functionality could be put into components that are used in both, and it may simply be easier to communicate the usage to end users if they are distinct. The real question is, is there any use case where someone will be using both use cases with the same instance of the module? If there is a valid use case for when that may happen, then leaving it as is is probably fine. However, if it's probably not going to happen, I think I would prefer to find a way to make these more simple and single purpose.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants