Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for the differentiable head #112

Merged
merged 31 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6f13f03
add SetFitHead
blakechi Oct 14, 2022
dcffe0b
made SetFitHead inherit from Sentence Transformers's Dense for consis…
blakechi Oct 17, 2022
0b7b230
integrate SetFitHead to SetFitModel's fit method
blakechi Oct 17, 2022
eb9c389
updated Trainer for the head
blakechi Oct 17, 2022
c4e0670
updated test_modeling.py and added in SetFitModel's to test SetFitHead
blakechi Oct 17, 2022
358d4a9
added keep_body_frozen in SetFitTrainer's unfreeze and refine code
blakechi Oct 17, 2022
559c3cf
added new tests in test_modeling.py and refine code
blakechi Oct 19, 2022
91d03e4
aligned max sequence length with the model body when tokenize
blakechi Oct 19, 2022
f9e9dca
added notes for SetFitTrainer's freeze and unfreeze
blakechi Oct 19, 2022
af14c5e
Update src/setfit/modeling.py - adding types
blakechi Oct 22, 2022
50320b6
Update src/setfit/modeling.py - doctoring
blakechi Oct 22, 2022
3cb9a73
Update src/setfit/modeling.py - doc string
blakechi Oct 22, 2022
a6c5c17
merged SetFitHead's forward and _forward into one
blakechi Oct 22, 2022
f09c5a3
added types in modeling.py
blakechi Oct 22, 2022
8557b86
added checks for and in SetFitTrainer
blakechi Oct 22, 2022
b6de5bd
switched from SGD to AdamW for SetFitModel's fit method and refined code
blakechi Oct 23, 2022
2fdc51b
grouped differentiable head related tests
blakechi Oct 23, 2022
3c5159e
updated run_fewshow.py and modeling.py to benchmark the differential …
blakechi Oct 29, 2022
83db1d7
Merge branch 'main' into differentiable-head
blakechi Oct 29, 2022
3fbe320
made SetFitModel to train on GPU
blakechi Oct 30, 2022
92a8cdc
add a flag: keep_body_frozen to control whether to train SetFitModel …
blakechi Oct 30, 2022
81e503f
Merge branch 'differentiable-head' of /~https://github.com/blakechi/set…
blakechi Oct 30, 2022
6a0e039
assigned SetFit body's device to SetFitHead when initializing SetFitH…
blakechi Oct 30, 2022
f97dd13
updated docstring for SetFitHead
blakechi Oct 30, 2022
11a67ce
Update scripts/setfit/run_fewshot.py
blakechi Nov 1, 2022
4b1ae43
Update src/setfit/modeling.py
blakechi Nov 1, 2022
49268f6
Update src/setfit/modeling.py
blakechi Nov 1, 2022
c10011c
Update src/setfit/modeling.py
blakechi Nov 1, 2022
86aa7e7
removed cal_score.py
blakechi Nov 1, 2022
2c9192c
moved SetFitDataset to data.py
blakechi Nov 1, 2022
dee87fc
Merge branch 'main' into differentiable-head
blakechi Nov 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions scripts/setfit/run_fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def parse_args():
parser.add_argument("--is_dev_set", type=bool, default=False)
parser.add_argument("--is_test_set", type=bool, default=False)
parser.add_argument("--override_results", default=False, action="store_true")
parser.add_argument("--keep_body_frozen", default=False, action="store_true")
parser.add_argument("--add_data_augmentation", default=False)

args = parser.parse_args()
Expand Down Expand Up @@ -105,7 +106,14 @@ def main():
continue

# Load model
model = SetFitModel.from_pretrained(args.model)
if args.classifier == "pytorch":
model = SetFitModel.from_pretrained(
args.model,
use_differentiable_head=True,
head_params={"out_features": len(set(train_data["label"]))},
)
else:
model = SetFitModel.from_pretrained(args.model)
model.model_body.max_seq_length = args.max_seq_length
if args.add_normalization_layer:
model.model_body._modules["2"] = models.Normalize()
Expand All @@ -121,7 +129,19 @@ def main():
num_epochs=args.num_epochs,
num_iterations=args.num_iterations,
)
trainer.train()
if args.classifier == "pytorch":
trainer.freeze()
trainer.train()
trainer.unfreeze(keep_body_frozen=args.keep_body_frozen)
trainer.train(
num_epochs=25,
body_learning_rate=1e-5,
learning_rate=args.lr, # recommend: 1e-2
l2_weight=0.0,
batch_size=args.batch_size,
)
else:
trainer.train()

# Evaluate the model on the test data
metrics = trainer.evaluate()
Expand Down
2 changes: 1 addition & 1 deletion src/setfit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.0.dev0"

from .modeling import SetFitModel
from .modeling import SetFitHead, SetFitModel
from .trainer import SetFitTrainer
77 changes: 76 additions & 1 deletion src/setfit/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import Dict, List
from typing import TYPE_CHECKING, Dict, List, Tuple

import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from torch.utils.data import Dataset as TorchDataset


if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase


TokenizerOutput = Dict[str, List[int]]
SEEDS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
SAMPLE_SIZES = [2, 4, 8, 16, 32, 64]

Expand Down Expand Up @@ -189,3 +196,71 @@ def add_templated_examples(
dataset = dataset.add_item(example)

return dataset


class SetFitDataset(TorchDataset):
"""SetFitDataset

A dataset for training the differentiable head on text classification.

Args:
x (`List[str]`):
A list of input data as texts that will be fed into `SetFitModel`.
y (`List[int]`):
A list of input data's labels.
tokenizer (`PreTrainedTokenizerBase`):
The tokenizer from `SetFitModel`'s body.
max_length (`int`, defaults to `32`):
The maximum token length a tokenizer can generate.
Will pad or truncate tokens when the number of tokens for a text is either smaller or larger than this value.
"""

def __init__(
self,
x: List[str],
y: List[int],
tokenizer: "PreTrainedTokenizerBase",
max_length: int = 32,
) -> None:
assert len(x) == len(y)

self.x = x
self.y = y
self.tokenizer = tokenizer
self.max_length = max_length

def __len__(self) -> int:
return len(self.x)

def __getitem__(self, idx: int) -> Tuple[TokenizerOutput, int]:
feature = self.tokenizer(
self.x[idx],
max_length=self.max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_token_type_ids=True,
)
label = self.y[idx]

return feature, label

@staticmethod
def collate_fn(batch):
features = {
"input_ids": [],
"attention_mask": [],
"token_type_ids": [],
}
labels = []
for feature, label in batch:
features["input_ids"].append(feature["input_ids"])
features["attention_mask"].append(feature["attention_mask"])
features["token_type_ids"].append(feature["token_type_ids"])
labels.append(label)

# convert to tensors
features = {k: torch.Tensor(v).int() for k, v in features.items()}
labels = torch.Tensor(labels).long()

return features, labels
Loading