Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for the differentiable head #112

Merged
merged 31 commits into from
Nov 1, 2022

Conversation

blakechi
Copy link
Contributor

This pull request add supports as mentioned in the issue #8.

  • added SetFitHead, which inherited from Sentence Transformers' Dense to make APIs consistent
  • integrated SetFitHead to SetFitModel
  • integrated SetFitHead to SetFitTrainer (sklearn-based head still works and usage remains the same)
  • added new tests to test SetFitHead (tested initialization for single/multiple targets, forward, backward)
  • added new APIs for SetFitTrainer: trainer.freeze() and trainer.unfreeze(keep_body_frozen)

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely incredible work on implementing a pure torch model @blakechi 🔥 !

I've left a few comments / questions, but overall this is looking really good.

Would you mind sharing a code snippet in the PR description so that others can understand how the new API should work?

I'm also curious if you tested this implementation with some of the test datasets in our paper? It would be cool to know if (a) this implementation does better than our original one and (b) there are no subtle regressions with the sklearn version.

@@ -23,6 +25,43 @@
MODEL_HEAD_NAME = "model_head.pkl"


class SetFitDataset(Dataset):
def __init__(self, x, y, tokenizer, max_length=32):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: could we add some types here? For things like tokenizer, we can import them with typing.TYPE_CHECKING and then wrap the type in quotes " ... "

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I will add them in the next commit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added it!

src/setfit/modeling.py Outdated Show resolved Hide resolved
src/setfit/modeling.py Outdated Show resolved Hide resolved
src/setfit/modeling.py Outdated Show resolved Hide resolved
"bias": self.bias,
}

def __repr__(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

return outputs

def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
features.update({"prediction": self._forward(features["sentence_embedding"])})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about the need for separate _forward() and forward() methods - is there a reason we can't wrap the logic in a single method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original intend is to separate the usages. So forward() is for being compatible with Sentence-Transformer, and on the other hand, _forward() can be used when users feed embeddings only.

But ya, I think I can merge them together and act correspondly by detecting what format users feed into it. Will fix it in the next commit!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

optimizer.zero_grad()

outputs = self.model_body(features)
predictions = self.model_head._forward(outputs["sentence_embedding"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the answer to my question about two forward() methods, it might also be worth implementing a __call__ method in SetFitHead to simplify this step a bit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, I think by merging forward() and _forward() together as I answered in your question, I can simplify this step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

l2_weight: float,
):
body_learning_rate = body_learning_rate or learning_rate
optimizer = torch.optim.SGD(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to use AdamW here or did you find SGD works better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I used SGD for simplicity. I think maybe we can decide which one to use after running experiments?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will report the results using SGD later, now the first table (please check below) uses AdamW :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reported the results using SGD, Adam, and AdamW, please see below :)

def freeze(self):
"""
Freeze SetFitModel's differentiable head.
Note: call this function only when using the differentiable head.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could do a small check on the SetFitModel and raise a ValueError if a user tries to execute this with the wrong model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great idea. Let me add it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added it

@@ -72,3 +109,49 @@ def test_setfit_multilabel_classifier_chain_classifier_model_head():
)

assert type(model.model_head) is ClassifierChain


def test_setfit_single_target_differentiable_head():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recently refactored the test_trainer.py tests into groups by using unittest.TestCase. This has the advantage of allowing us to download groups of models just once vs multiple times per test run.

Would you be happy to group your tests in a similar way?

(I realise the existing tests in test_modeling.py also need to be grouped - I'll do that later :))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure no problem 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@blakechi
Copy link
Contributor Author

blakechi commented Oct 23, 2022

Absolutely incredible work on implementing a pure torch model @blakechi 🔥 !

I've left a few comments / questions, but overall this is looking really good.

Would you mind sharing a code snippet in the PR description so that others can understand how the new API should work?

I'm also curious if you tested this implementation with some of the test datasets in our paper? It would be cool to know if (a) this implementation does better than our original one and (b) there are no subtle regressions with the sklearn version.

Thanks! Glad you like it!

Sure, will provide a snippet in the next comment. :)

Sorry I might rush a bit. I wanted to share with you the implementation to make sure the APIs are correct, so I only tested the head by running one step (forward and backward) to check its gradients.
I will test this implementation on the test datasets. If you have some suggested scripts for me to run, I'm happy to know!

@blakechi
Copy link
Contributor Author

Here is the snippet for using the differentiable head (partially copied from README.md):

from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

# Simulate the few-shot regime by sampling 8 examples per class
num_classes = 2
train_dataset = dataset["train"].shuffle(seed=42).select(range(8 * num_classes))

# Initialize `SetFitModel` with the differentiable head
model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-albert-small-v2",
    use_differentiable_head=True,
    head_params={"out_features": num_classes},
)

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_class=CosineSimilarityLoss,
    metric="accuracy",
    batch_size=16,
    num_iterations=20, # The number of text pairs to generate for contrastive learning
    num_epochs=1, # The number of epochs to use for constrastive learning
    column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)

# Freeze head
trainer.freeze()

# Do contrastive training
trainer.train(num_epochs=1)

# Unfreeze head
trainer.unfreeze()

# Unfreeze head and freeze body
# trainer.unfreeze(keep_body_frozen=True)

# Train end-to-end
trainer.train(num_epochs=1)

@lewtun
Copy link
Member

lewtun commented Oct 28, 2022

Hey @blakechi sorry for the delay in reviewing your latest changes! They're looking really good and there's just a few things we need to do before merging:

  • Resolve the merge conflicts (from some unrelated PRs)
  • Verify that running scripts/setfit/run_fewshot.py reproduces the results from our paper when using the sklearn head (I can do this)
  • [Optional] Check if this PR can produce similar results as the sklearn head. We could just run a few experiments in Colab for e.g. the emotion dataset to see if we're in the same ballpark

How does that sound?

@blakechi
Copy link
Contributor Author

blakechi commented Oct 29, 2022

Hi @lewtun,

Sorry for the late. I was packed by other stuff. Ya, that sounds great to me. :)

  • Sure, just resolved it!

  • Thanks! I also ran experiments using sklearn head so that we can double check it with yours and mine. :)

  • I ran some experiments on the test set using pytorch head with different epochs and they have similar performance on the paper. Please see the table below.

  • Updated run_fewshot.py for using the differentiable head by --classifier=pytorch

Results that mimics Table 2 - N=8 in the paper (all pytorch heads use batch size = 16, optimizer = AdamW, L2 weight (weight decay) = 0, head learning rate = 1e-2, body learning rate = 1e-5):

Head SST-5 Amazon-CF CR Emotion EnronSpam AGNews
sklearn 43.9 (2.8) 40.2 (9.3) 88.2 (2.5) 48.4 (4.6) 89.6 (4.0) 82.8 (3.0)
pytorch (freeze body, 25 epochs) 43.9 (3.0) 40.7 (12.6) 88.8 (1.2) 48.6 (4.0) 88.6 (4.7) 82.7 (2.8)
pytorch (freeze body, 50 epochs) 44.4 (2.9) 39.9 (11.7) 89.0 (1.0) 48.4 (5.2) 89.1 (4.1) 83.3 (2.9)
pytorch (end to end, 25 epochs) 43.6 (2.2) 40.6 (12.2) 88.6 (1.4) 46.3 (4.7) 89.9 (3.6) 83.0 (2.9)
pytorch (end to end, 50 epochs) 43.0 (2.8) 39.1 (11.6) 88.8 (1.3) 47.0 (3.1) 89.3 (3.8) 83.3 (2.7)

I also tried on SGD and Adam, and kept other parameters as same as the above ones. For SGD, the performance dropped significantly (e.g. accuracy dropped to 1X.XX for emotion), so I excluded it here. The below is the results using Adam:

Head SST-5 Amazon-CF CR Emotion EnronSpam AGNews
pytorch (freeze body, 25 epochs) 42.9 (3.0) 41.9 (12.4) 88.5 (1.2) 48.8 (5.4) 89.9 (3.7) 83.2 (2.8)
pytorch (end to end, 25 epochs) 43.1 (3.1) 41.6 (11.8) 88.6 (1.4) 48.7 (4.3) 89.4 (4.5) 83.4 (2.4)

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for sharing the experimental results with the new API - incredible work 🔥 !

This PR is looking really close to being done - I've just left a few small nits and then we can merge 🥳

I'm excited to see this in action!

Edit: you also need to run make style && make quality to get the CI to pass

@@ -0,0 +1,50 @@
import argparse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious what you need this script for? We have a create_summary_table.py script for aggregating scores across runs, but maybe you needed something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh haha, I didn't see that script 😂
ok, I will remove cal_score.py since it's duplicated to create_summary_table.py, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

scripts/setfit/run_fewshot.py Outdated Show resolved Hide resolved
MODEL_HEAD_NAME = "model_head.pkl"


class SetFitDataset(Dataset):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would belong better in setfit.data - WDYT?

Could we also add a one-line docstring to summarise what it's for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea! Will also add doctoring as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

src/setfit/modeling.py Outdated Show resolved Hide resolved
src/setfit/modeling.py Outdated Show resolved Hide resolved
src/setfit/modeling.py Outdated Show resolved Hide resolved
blakechi and others added 7 commits November 1, 2022 00:15
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the final iteration - this looks great so I'm going to merge it now 🔥 !

Amazing contribution and thank you for working on it @blakechi 🤗

@lewtun lewtun merged commit 43dbaf1 into huggingface:main Nov 1, 2022
@PhilipMay
Copy link
Contributor

Good job @blakechi ! Many thanks.

@blakechi
Copy link
Contributor Author

blakechi commented Nov 2, 2022

Thanks for the final iteration - this looks great so I'm going to merge it now 🔥 !

Amazing contribution and thank you for working on it @blakechi 🤗

I want to thank you for your advice and review as well, those make this PR wonderful! Really like this collaboration! 🤗 @lewtun

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants