Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for the differentiable head #112

Merged
merged 31 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6f13f03
add SetFitHead
blakechi Oct 14, 2022
dcffe0b
made SetFitHead inherit from Sentence Transformers's Dense for consis…
blakechi Oct 17, 2022
0b7b230
integrate SetFitHead to SetFitModel's fit method
blakechi Oct 17, 2022
eb9c389
updated Trainer for the head
blakechi Oct 17, 2022
c4e0670
updated test_modeling.py and added in SetFitModel's to test SetFitHead
blakechi Oct 17, 2022
358d4a9
added keep_body_frozen in SetFitTrainer's unfreeze and refine code
blakechi Oct 17, 2022
559c3cf
added new tests in test_modeling.py and refine code
blakechi Oct 19, 2022
91d03e4
aligned max sequence length with the model body when tokenize
blakechi Oct 19, 2022
f9e9dca
added notes for SetFitTrainer's freeze and unfreeze
blakechi Oct 19, 2022
af14c5e
Update src/setfit/modeling.py - adding types
blakechi Oct 22, 2022
50320b6
Update src/setfit/modeling.py - doctoring
blakechi Oct 22, 2022
3cb9a73
Update src/setfit/modeling.py - doc string
blakechi Oct 22, 2022
a6c5c17
merged SetFitHead's forward and _forward into one
blakechi Oct 22, 2022
f09c5a3
added types in modeling.py
blakechi Oct 22, 2022
8557b86
added checks for and in SetFitTrainer
blakechi Oct 22, 2022
b6de5bd
switched from SGD to AdamW for SetFitModel's fit method and refined code
blakechi Oct 23, 2022
2fdc51b
grouped differentiable head related tests
blakechi Oct 23, 2022
3c5159e
updated run_fewshow.py and modeling.py to benchmark the differential …
blakechi Oct 29, 2022
83db1d7
Merge branch 'main' into differentiable-head
blakechi Oct 29, 2022
3fbe320
made SetFitModel to train on GPU
blakechi Oct 30, 2022
92a8cdc
add a flag: keep_body_frozen to control whether to train SetFitModel …
blakechi Oct 30, 2022
81e503f
Merge branch 'differentiable-head' of /~https://github.com/blakechi/set…
blakechi Oct 30, 2022
6a0e039
assigned SetFit body's device to SetFitHead when initializing SetFitH…
blakechi Oct 30, 2022
f97dd13
updated docstring for SetFitHead
blakechi Oct 30, 2022
11a67ce
Update scripts/setfit/run_fewshot.py
blakechi Nov 1, 2022
4b1ae43
Update src/setfit/modeling.py
blakechi Nov 1, 2022
49268f6
Update src/setfit/modeling.py
blakechi Nov 1, 2022
c10011c
Update src/setfit/modeling.py
blakechi Nov 1, 2022
86aa7e7
removed cal_score.py
blakechi Nov 1, 2022
2c9192c
moved SetFitDataset to data.py
blakechi Nov 1, 2022
dee87fc
Merge branch 'main' into differentiable-head
blakechi Nov 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions scripts/setfit/cal_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious what you need this script for? We have a create_summary_table.py script for aggregating scores across runs, but maybe you needed something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh haha, I didn't see that script 😂
ok, I will remove cal_score.py since it's duplicated to create_summary_table.py, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

import json
from os import listdir
from os.path import isdir, isfile, join

import numpy as np


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--exp_folder",
"-e",
required=True,
type=str,
help="The folder path of the experiment created by `run_fewshot.py`.",
)

args = parser.parse_args()

return args


def get_folders(folder):
return [join(folder, f) for f in listdir(folder) if isdir(join(folder, f))]


if __name__ == "__main__":

args = parse_args()

dataset_folders = get_folders(args.exp_folder)
for dataset_folder in dataset_folders:
run_folders = get_folders(dataset_folder)

scores = []
for run_folder in run_folders:
with open(join(run_folder, "results.json"), "r") as f:
score = json.load(f)["score"]
scores.append(score)

scores = np.array(scores)
with open(join(dataset_folder, "results.json"), "w") as f:
json.dump(
{
"mean": np.mean(scores).item(),
"std": np.std(scores).item(),
},
f,
)
24 changes: 22 additions & 2 deletions scripts/setfit/run_fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def parse_args():
parser.add_argument("--is_dev_set", type=bool, default=False)
parser.add_argument("--is_test_set", type=bool, default=False)
parser.add_argument("--override_results", default=False, action="store_true")
parser.add_argument("--keep_body_frozen", default=False, action="store_true")
parser.add_argument("--add_data_augmentation", default=False)

args = parser.parse_args()
Expand Down Expand Up @@ -105,7 +106,14 @@ def main():
continue

# Load model
model = SetFitModel.from_pretrained(args.model)
if args.classifier == "pytorch":
model = SetFitModel.from_pretrained(
args.model,
use_differentiable_head=True,
head_params={"out_features": len(set(train_data["label"]))},
)
else:
model = SetFitModel.from_pretrained(args.model)
model.model_body.max_seq_length = args.max_seq_length
if args.add_normalization_layer:
model.model_body._modules["2"] = models.Normalize()
Expand All @@ -121,7 +129,19 @@ def main():
num_epochs=args.num_epochs,
num_iterations=args.num_iterations,
)
trainer.train()
if args.classifier == "pytorch":
trainer.freeze()
trainer.train()
trainer.unfreeze(keep_body_frozen=args.keep_body_frozen)
trainer.train(
num_epochs=25,
body_learning_rate=1e-5,
learning_rate=args.lr, # recommand: 1e-2
blakechi marked this conversation as resolved.
Show resolved Hide resolved
l2_weight=0.0,
batch_size=args.batch_size,
)
else:
trainer.train()

# Evaluate the model on the test data
metrics = trainer.evaluate()
Expand Down
2 changes: 1 addition & 1 deletion src/setfit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.0.dev0"

from .modeling import SetFitModel
from .modeling import SetFitHead, SetFitModel
from .trainer import SetFitTrainer
Loading