-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Document Python -> Rust Model Translation Best Practices #549
Comments
It's hard to call this "best practices" but here is how I would think about it. First, using a jit module saved from Python (your (1.2)) makes things straightforward as there is no need to write model specific Rust code in this case (see this tutorial if not already familiar with how to use this). The only reasons I can see not to use a jit model would be customizing the model, running training rather than inference (though I think training can even be done with jit to some extent), or wanting to learn about the model by porting it to Rust. I don't think I ever tried to reverse engineer a model just by looking at the weight files or by printing the model itself, this would be quite tricky as you noticed and certainly error prone. I always used the Python implementation as a guide on how to implement the Rust version. An important point is that tch tries to mimic all the default behavior of PyTorch so as to make porting models easier, this is the case for variable initializations for example, or optional arguments in function calls. The When it comes to tips and tricks, lots of issues are detected by shape mismatches when loading the weights, I use I'm not sure to remember how I ended up with this yolo-v3 version, my guess is that it was the first self-contained and reasonably simple implementation that I came along back then. |
Thanks a lot for outlining your process, it helped me to confirm I'm vaguely on the right track, and Not sure if I should close this, it might be worth adding your answer to a FAQ? |
Ah that's a good point, I've added a small FAQ section to the main readme. |
Thanks! |
One follow up to #543, which also touches #545 and and answer you've given here (#174).
While you provide instructions how to re-create the weights from a specific Python version, would it be possible to provide a guide how to best replicate their appropriate architectures?
Background
When I tried to follow the Python instructions they worked alright to get the
.ot
, but then (taking the yolo example) you also need to "magically know" amongst others:nn::func_t
behaviorVarStore
/Path
labels for each weight setcoco_classes.rs
).For simple networks this seems mildly guessable, but when I tried to re-create yolo3 I already ran into these issues, let alone when I was looking into yolo5, yolo6 or yolo7.
I tried to Python
print()
the models and using their outputs for guidance for a coarse outline, as well as inspecting the model blobs in a Python debugger, but for the more intricate parts I hit a wall pretty quickly having to step through the actual model source in minute detail, and still not being sure if I end up with the right thing.Question
So, tl;dr, would you mind sharing or documenting your "best practice" how to convert "arbitrary" models to tch? In particular ...
1.1 When is it feasible to recreate a model with existing weights?
1.2. If recreation doesn't work, any thoughts on JIT?
3.1 How did you determine the right layers and params (e.g., use a debugger, Python source line-by-line, ...)?
3.2 Same for custom functions? I assume these you have to get from source in any case.
3.3 Where do you get the
VarStore
labels from?3.4 Do you have any debugging / QA tips (e.g., to actually verify all parameters / weights are correct w.r.t Python)?
I don't think this has to be overly long, but a few lines might help people to ensure they're on the right track and follow "best practices".
The text was updated successfully, but these errors were encountered: