-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for the differentiable head #112
Conversation
…tency with the body
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely incredible work on implementing a pure torch
model @blakechi 🔥 !
I've left a few comments / questions, but overall this is looking really good.
Would you mind sharing a code snippet in the PR description so that others can understand how the new API should work?
I'm also curious if you tested this implementation with some of the test datasets in our paper? It would be cool to know if (a) this implementation does better than our original one and (b) there are no subtle regressions with the sklearn
version.
src/setfit/modeling.py
Outdated
@@ -23,6 +25,43 @@ | |||
MODEL_HEAD_NAME = "model_head.pkl" | |||
|
|||
|
|||
class SetFitDataset(Dataset): | |||
def __init__(self, x, y, tokenizer, max_length=32): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nit: could we add some types here? For things like tokenizer
, we can import them with typing.TYPE_CHECKING
and then wrap the type in quotes " ... "
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, I will add them in the next commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just added it!
"bias": self.bias, | ||
} | ||
|
||
def __repr__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
src/setfit/modeling.py
Outdated
return outputs | ||
|
||
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
features.update({"prediction": self._forward(features["sentence_embedding"])}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious about the need for separate _forward()
and forward()
methods - is there a reason we can't wrap the logic in a single method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My original intend is to separate the usages. So forward()
is for being compatible with Sentence-Transformer, and on the other hand, _forward()
can be used when users feed embeddings only.
But ya, I think I can merge them together and act correspondly by detecting what format users feed into it. Will fix it in the next commit!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
src/setfit/modeling.py
Outdated
optimizer.zero_grad() | ||
|
||
outputs = self.model_body(features) | ||
predictions = self.model_head._forward(outputs["sentence_embedding"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on the answer to my question about two forward()
methods, it might also be worth implementing a __call__
method in SetFitHead
to simplify this step a bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya, I think by merging forward()
and _forward()
together as I answered in your question, I can simplify this step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
src/setfit/modeling.py
Outdated
l2_weight: float, | ||
): | ||
body_learning_rate = body_learning_rate or learning_rate | ||
optimizer = torch.optim.SGD( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to use AdamW
here or did you find SGD
works better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I used SGD for simplicity. I think maybe we can decide which one to use after running experiments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will report the results using SGD later, now the first table (please check below) uses AdamW :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reported the results using SGD
, Adam
, and AdamW
, please see below :)
def freeze(self): | ||
""" | ||
Freeze SetFitModel's differentiable head. | ||
Note: call this function only when using the differentiable head. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could do a small check on the SetFitModel
and raise a ValueError
if a user tries to execute this with the wrong model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great idea. Let me add it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just added it
tests/test_modeling.py
Outdated
@@ -72,3 +109,49 @@ def test_setfit_multilabel_classifier_chain_classifier_model_head(): | |||
) | |||
|
|||
assert type(model.model_head) is ClassifierChain | |||
|
|||
|
|||
def test_setfit_single_target_differentiable_head(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recently refactored the test_trainer.py
tests into groups by using unittest.TestCase
. This has the advantage of allowing us to download groups of models just once vs multiple times per test run.
Would you be happy to group your tests in a similar way?
(I realise the existing tests in test_modeling.py
also need to be grouped - I'll do that later :))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure no problem 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Thanks! Glad you like it! Sure, will provide a snippet in the next comment. :) Sorry I might rush a bit. I wanted to share with you the implementation to make sure the APIs are correct, so I only tested the head by running one step (forward and backward) to check its gradients. |
Here is the snippet for using the differentiable head (partially copied from from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")
# Simulate the few-shot regime by sampling 8 examples per class
num_classes = 2
train_dataset = dataset["train"].shuffle(seed=42).select(range(8 * num_classes))
# Initialize `SetFitModel` with the differentiable head
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
use_differentiable_head=True,
head_params={"out_features": num_classes},
)
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for constrastive learning
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)
# Freeze head
trainer.freeze()
# Do contrastive training
trainer.train(num_epochs=1)
# Unfreeze head
trainer.unfreeze()
# Unfreeze head and freeze body
# trainer.unfreeze(keep_body_frozen=True)
# Train end-to-end
trainer.train(num_epochs=1) |
Hey @blakechi sorry for the delay in reviewing your latest changes! They're looking really good and there's just a few things we need to do before merging:
How does that sound? |
…head, added cal_score.py for aggregating scores from each runs
Hi @lewtun, Sorry for the late. I was packed by other stuff. Ya, that sounds great to me. :)
Results that mimics Table 2 - N=8 in the paper (all pytorch heads use
I also tried on
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for sharing the experimental results with the new API - incredible work 🔥 !
This PR is looking really close to being done - I've just left a few small nits and then we can merge 🥳
I'm excited to see this in action!
Edit: you also need to run make style && make quality
to get the CI to pass
scripts/setfit/cal_score.py
Outdated
@@ -0,0 +1,50 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious what you need this script for? We have a create_summary_table.py
script for aggregating scores across runs, but maybe you needed something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh haha, I didn't see that script 😂
ok, I will remove cal_score.py
since it's duplicated to create_summary_table.py
, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
src/setfit/modeling.py
Outdated
MODEL_HEAD_NAME = "model_head.pkl" | ||
|
||
|
||
class SetFitDataset(Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would belong better in setfit.data
- WDYT?
Could we also add a one-line docstring to summarise what it's for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good idea! Will also add doctoring as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the final iteration - this looks great so I'm going to merge it now 🔥 !
Amazing contribution and thank you for working on it @blakechi 🤗
Good job @blakechi ! Many thanks. |
This pull request add supports as mentioned in the issue #8.
SetFitHead
, which inherited from Sentence Transformers' Dense to make APIs consistentSetFitHead
toSetFitModel
SetFitHead
toSetFitTrainer
(sklearn-based head still works and usage remains the same)SetFitHead
(tested initialization for single/multiple targets, forward, backward)SetFitTrainer
:trainer.freeze()
andtrainer.unfreeze(keep_body_frozen)