Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement end-to-end differentiable model #8

Closed
lewtun opened this issue Jul 8, 2022 · 14 comments
Closed

Implement end-to-end differentiable model #8

lewtun opened this issue Jul 8, 2022 · 14 comments

Comments

@lewtun
Copy link
Member

lewtun commented Jul 8, 2022

The current SetFit implementation combines the embeddings of a frozen body, with the learnable parameters of a logistic regression classifier.

It would be desirable to have a PyTorch implementation that is end-to-end differentiable. This isn't entirely trivial as we'll likely need different learning rates for the body vs the head.

@lewtun lewtun mentioned this issue Oct 3, 2022
@blakechi
Copy link
Contributor

blakechi commented Oct 7, 2022

Hi @lewtun,

Just want to know is this issue in progress now? If not, I have some bandwidth to work on it.
Or maybe we should wait until #69 merged?

Thanks!

@lewtun
Copy link
Member Author

lewtun commented Oct 7, 2022

Hey @blakechi if you have bandwidth, I would certainly welcome a PR for this 😍 !

For the API, I've been wondering whether we need something like fastai's freeze() and unfreeze() methods, e.g.:

trainer = SetFitTrainer(...)

# Freeze head
trainer.freeze()

# Do contrastive training
trainer.train(num_epochs=1)

# Unfreeze head
trainer.unfreeze()

# Train end-to-end
trainer.train(num_epochs=1)

The alternative would be to do the freeze/unfreeze automatically for the user and have a single trainer.train() step - I think the answer will depend a bit on how easy it is to make this work end-to-end.

Another thing to keep in mind is that we'll probably need different learning rates for the head vs body, so the fastai approach provides a potentially simpler way to decouple these steps by simply passing a different learning_rate to each trainer.train() call.

And yes, it's probably best to wait until #69 is merged - let's see what their timeline is :)

@blakechi
Copy link
Contributor

blakechi commented Oct 8, 2022

Hey @lewtun,
Great! And the API looks good to me. But just a suggestion. Since trainer.freeze is kind of saying we're going to freeze all of the model (body + head) , what do you think if we branch out freeze/unfreeze for head and body individually? And freeze/unfreeze are for both. E.g.:

# freeze/unfreeze body only
trainer.freeze_body()
trainer.unfreeze_body()

# freeze/unfreeze head only
trainer.freeze_head()
trainer.unfreeze_head()

# freeze/unfreeze both body and head
trainer.freeze()
trainer.unfreeze()

Okay, I will work on it after #69 merged. ;)

Thanks!

@lewtun
Copy link
Member Author

lewtun commented Oct 10, 2022

Sure, we can have methods like you suggest if we find it's necessary to freeze/unfreeze the head and the body.

I'm not 100% sure yet if we can get by with just freezing the body - it will probably take some experimentation to find out what works best :)

@lewtun
Copy link
Member Author

lewtun commented Oct 11, 2022

Hey @blakechi we've just merged #69 so feel free to take a stab at the pure PyTorch model whenever you want (and feel free to ping me here if you have questions!)

@blakechi
Copy link
Contributor

Hi @lewtun, good to hear that! I think I will begin with implementing a new class SetFitHead and then integrate it into SetFitModel by adding one more argument: use_differentiable_head. Does this sound good to you?

@lewtun
Copy link
Member Author

lewtun commented Oct 12, 2022

Hi @lewtun, good to hear that! I think I will begin with implementing a new class SetFitHead and then integrate it into SetFitModel by adding one more argument: use_differentiable_head. Does this sound good to you?

That sounds like a good plan to start with. If we can consistently match the performance of the scikit-learn approach, we can eventually deprecate that part and have a pure PyTorch model (my preference since it enables other features like ONNX export in future)

@blakechi
Copy link
Contributor

Ok, I will start to implement it!

@blakechi
Copy link
Contributor

Hi @lewtun,

I think I'm in the half way (or more) of the implementation and a question raised in my mind. From the API you suggested as below, after we unfreeze the head and fire up the training again, the body will be trained since it's in the end-to-end fashion. Do you think we should also enable users to freeze the body and train the head only as usual?

trainer = SetFitTrainer(...)

# Freeze head
trainer.freeze()

# Do contrastive training
trainer.train(num_epochs=1)

# Unfreeze head
trainer.unfreeze()

# Train end-to-end
trainer.train(num_epochs=1)

@lewtun
Copy link
Member Author

lewtun commented Oct 17, 2022

This great news and good question @blakechi !

I think for the first version, enabling the body to be frozen makes sense - especially so we can check that we can reproduce the results from the original logistic regression head :)

Maybe this could be supported with a simple boolean arg like keep_body_frozen in trainer.unfreeze()?

@blakechi
Copy link
Contributor

Sounds great! Just added that. Start to work on testing.

@blakechi
Copy link
Contributor

Hi @lewtun,

Just opened a pull request for this issue! Please have a check.

I haven't trained using the differentiable head, could you suggest a script for training? Thanks!

@PhilipMay
Copy link
Contributor

Can this be closed since the PR was merged?

@lewtun
Copy link
Member Author

lewtun commented Nov 1, 2022

Yes, I think we can!

@lewtun lewtun closed this as completed Nov 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants
@PhilipMay @lewtun @blakechi and others