-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPS Load capabilities #623
Conversation
I am not one of the original contributors, but I do have an MPS device, and I confirm this fixed the issue. Specifically, I built /~https://github.com/LaurentMazare/diffusers-rs with the device manually set to MPS by changing this line to While I'm here, I note that this PR doesn't restore the original device in the two error cases (the |
Thanks for confirming, agreed that it would be nicer to set the device back to |
Thank you very much for confirming @bakkot - @LaurentMazare yes I should be able to implement a more robust setup later today |
@bakkot , @LaurentMazare I have wrapped the tensor loading in a closure to ensure the device reset code gets executed before the closure |
Looks good, thanks for the PR! |
VarStore
load method allowing loading weights on MPS-accelerated machinesThe device of the
VarStore
is temporarily set toCPU
when loading the weights and then changed back toMPS
as described by the workarounds described in Add MPS as default if available guillaume-be/rust-bert#311 and Torch("supported devices include CPU, CUDA and HPU, however got MPS") #609I do not have a
MPS
device to test the behaviour. It would be great if one of the contributors who identified the issue and the workaround could validate it solves the loading issues before merging.