-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for the differentiable head #112
Merged
Merged
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
6f13f03
add SetFitHead
blakechi dcffe0b
made SetFitHead inherit from Sentence Transformers's Dense for consis…
blakechi 0b7b230
integrate SetFitHead to SetFitModel's fit method
blakechi eb9c389
updated Trainer for the head
blakechi c4e0670
updated test_modeling.py and added in SetFitModel's to test SetFitHead
blakechi 358d4a9
added keep_body_frozen in SetFitTrainer's unfreeze and refine code
blakechi 559c3cf
added new tests in test_modeling.py and refine code
blakechi 91d03e4
aligned max sequence length with the model body when tokenize
blakechi f9e9dca
added notes for SetFitTrainer's freeze and unfreeze
blakechi af14c5e
Update src/setfit/modeling.py - adding types
blakechi 50320b6
Update src/setfit/modeling.py - doctoring
blakechi 3cb9a73
Update src/setfit/modeling.py - doc string
blakechi a6c5c17
merged SetFitHead's forward and _forward into one
blakechi f09c5a3
added types in modeling.py
blakechi 8557b86
added checks for and in SetFitTrainer
blakechi b6de5bd
switched from SGD to AdamW for SetFitModel's fit method and refined code
blakechi 2fdc51b
grouped differentiable head related tests
blakechi 3c5159e
updated run_fewshow.py and modeling.py to benchmark the differential …
blakechi 83db1d7
Merge branch 'main' into differentiable-head
blakechi 3fbe320
made SetFitModel to train on GPU
blakechi 92a8cdc
add a flag: keep_body_frozen to control whether to train SetFitModel …
blakechi 81e503f
Merge branch 'differentiable-head' of /~https://github.com/blakechi/set…
blakechi 6a0e039
assigned SetFit body's device to SetFitHead when initializing SetFitH…
blakechi f97dd13
updated docstring for SetFitHead
blakechi 11a67ce
Update scripts/setfit/run_fewshot.py
blakechi 4b1ae43
Update src/setfit/modeling.py
blakechi 49268f6
Update src/setfit/modeling.py
blakechi c10011c
Update src/setfit/modeling.py
blakechi 86aa7e7
removed cal_score.py
blakechi 2c9192c
moved SetFitDataset to data.py
blakechi dee87fc
Merge branch 'main' into differentiable-head
blakechi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import argparse | ||
import json | ||
from os import listdir | ||
from os.path import isdir, isfile, join | ||
|
||
import numpy as np | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--exp_folder", | ||
"-e", | ||
required=True, | ||
type=str, | ||
help="The folder path of the experiment created by `run_fewshot.py`.", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def get_folders(folder): | ||
return [join(folder, f) for f in listdir(folder) if isdir(join(folder, f))] | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
args = parse_args() | ||
|
||
dataset_folders = get_folders(args.exp_folder) | ||
for dataset_folder in dataset_folders: | ||
run_folders = get_folders(dataset_folder) | ||
|
||
scores = [] | ||
for run_folder in run_folders: | ||
with open(join(run_folder, "results.json"), "r") as f: | ||
score = json.load(f)["score"] | ||
scores.append(score) | ||
|
||
scores = np.array(scores) | ||
with open(join(dataset_folder, "results.json"), "w") as f: | ||
json.dump( | ||
{ | ||
"mean": np.mean(scores).item(), | ||
"std": np.std(scores).item(), | ||
}, | ||
f, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__version__ = "0.4.0.dev0" | ||
|
||
from .modeling import SetFitModel | ||
from .modeling import SetFitHead, SetFitModel | ||
from .trainer import SetFitTrainer |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious what you need this script for? We have a
create_summary_table.py
script for aggregating scores across runs, but maybe you needed something else?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh haha, I didn't see that script 😂
ok, I will remove
cal_score.py
since it's duplicated tocreate_summary_table.py
, thanks!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!