-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement end-to-end differentiable model #8
Comments
Hey @blakechi if you have bandwidth, I would certainly welcome a PR for this 😍 ! For the API, I've been wondering whether we need something like trainer = SetFitTrainer(...)
# Freeze head
trainer.freeze()
# Do contrastive training
trainer.train(num_epochs=1)
# Unfreeze head
trainer.unfreeze()
# Train end-to-end
trainer.train(num_epochs=1) The alternative would be to do the freeze/unfreeze automatically for the user and have a single Another thing to keep in mind is that we'll probably need different learning rates for the head vs body, so the And yes, it's probably best to wait until #69 is merged - let's see what their timeline is :) |
Hey @lewtun, # freeze/unfreeze body only
trainer.freeze_body()
trainer.unfreeze_body()
# freeze/unfreeze head only
trainer.freeze_head()
trainer.unfreeze_head()
# freeze/unfreeze both body and head
trainer.freeze()
trainer.unfreeze() Okay, I will work on it after #69 merged. ;) Thanks! |
Sure, we can have methods like you suggest if we find it's necessary to freeze/unfreeze the head and the body. I'm not 100% sure yet if we can get by with just freezing the body - it will probably take some experimentation to find out what works best :) |
Hi @lewtun, good to hear that! I think I will begin with implementing a new class |
That sounds like a good plan to start with. If we can consistently match the performance of the scikit-learn approach, we can eventually deprecate that part and have a pure PyTorch model (my preference since it enables other features like ONNX export in future) |
Ok, I will start to implement it! |
Hi @lewtun, I think I'm in the half way (or more) of the implementation and a question raised in my mind. From the API you suggested as below, after we unfreeze the head and fire up the training again, the body will be trained since it's in the end-to-end fashion. Do you think we should also enable users to freeze the body and train the head only as usual?
|
This great news and good question @blakechi ! I think for the first version, enabling the body to be frozen makes sense - especially so we can check that we can reproduce the results from the original logistic regression head :) Maybe this could be supported with a simple boolean arg like |
Sounds great! Just added that. Start to work on testing. |
Hi @lewtun, Just opened a pull request for this issue! Please have a check. I haven't trained using the differentiable head, could you suggest a script for training? Thanks! |
Can this be closed since the PR was merged? |
Yes, I think we can! |
The current SetFit implementation combines the embeddings of a frozen body, with the learnable parameters of a logistic regression classifier.
It would be desirable to have a PyTorch implementation that is end-to-end differentiable. This isn't entirely trivial as we'll likely need different learning rates for the body vs the head.
The text was updated successfully, but these errors were encountered: