-
Notifications
You must be signed in to change notification settings - Fork 675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Evaluator support to update multiple accumulators #2894
Conversation
Improve the performance of EvaluatorTrainingListener by enabling evaluators to update multiple accumulators from the same labels and predictions, rather than needing to recompute values.
Aims to fix failing test
I see from the failing test that my PR is more problematic than I realised. It may cause most existing subclasses of I've tried to fix all affected cases within DJL (there aren't very many) but I guess such a change could affect external subclasses that rely on the current behavior. I think subclasses of I haven't been able to test with CUDA yet, but on MPS I'm seeing training times reduced by 20-40%. Is there a better way to achieve the performance improvements? |
Codecov ReportAttention:
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## master #2894 +/- ##
============================================
+ Coverage 72.08% 72.27% +0.18%
- Complexity 5126 7184 +2058
============================================
Files 473 708 +235
Lines 21970 32014 +10044
Branches 2351 3337 +986
============================================
+ Hits 15838 23138 +7300
- Misses 4925 7284 +2359
- Partials 1207 1592 +385 ☔ View full report in Codecov by Sentry. |
Last one... I've tried on an old Windows computer with a GTX 1060 and PyTorch 2.0.1, and see a ~15-20% improvement in performance (although ~45% by removing logging altogether, consistent with it adding considerable overhead - as with MPS). Previous09:58:43.121 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Epoch 50 finished. With PR09:59:48.296 [main] [INFO ] a.d.t.l.LoggingTrainingListener - Epoch 50 finished. Without loggingTime: 5.894 Using the CPU of my tired old laptop, the PR reduces the training time slightly from 23.586 s to 21.54 s. |
Thanks for your contribution! This is awesome. @zachgk do you mind take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! And I appreciate the thorough work
We will have to document in the release notes that users will have to modify classes extending AbstractAccuracy, but I don't think that should be too common. One alternative might be to remove updateAccumulator()
entirely. This will mean that there won't be silent problems, but it will require all users to have to change their Evalaturo and Loss classes. Overall, I think this is slightly better
* Evaluator support to update multiple accumulators Improve the performance of EvaluatorTrainingListener by enabling evaluators to update multiple accumulators from the same labels and predictions, rather than needing to recompute values. * Fix formatting * Update AbstractCompositeLoss.java Aims to fix failing test
Description
This PR proposes adding a method to the
Evaluator
abstract classand then overriding this in subclasses to more efficiently update accumulators.
The reason is that the use of
EvaluatorTrainingListener
can dominate training time - at least when using Apple Silicon + MPS with the recent PR #2873Part of the issue seems to be that
updateAccumulator
needs to be called multiple times for different evaluators after each batch, e.g. accuracy and loss. This results in the same values being recalculated multiple times and transferred to the CPU. The recalculation itself is quite fast, but transfer to the CPU is slow.Example with MNIST
I see the following improvements when training with MNIST using MPS.
With the PR
Without the PR
Without logging
Removing
TrainingListener.Defaults.logging()
, I seeindicating that there is still a considerable overhead in the use of training listeners, but it is roughly halved with the changes here.
On the CPU
The MNIST example admittedly isn't the best, because it's much faster to use the CPU than MPS anyway. There are still modest improvements though
I see more substantial improvements with custom training for semantic segmentation, e.g. with U-Net, when MPS is much faster than using the CPU.
Without logging, the PR should have little or no effect:
Code
Adapted from the MNIST tutorial