Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement custom_vjp.optimize_remat using custom_dce #26814

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dfm
Copy link
Collaborator

@dfm dfm commented Feb 27, 2025

The "optimize remat" feature of jax.custom_vjp was just a special case of custom_dce, so we can delete a lot of duplicated code by reimplementing that feature using custom_dce. I don't anticipate that this will change any behavior.

@dfm dfm self-assigned this Feb 27, 2025
@dfm dfm added the pull ready Ready for copybara import and testing label Feb 27, 2025
@dfm dfm force-pushed the remat-opt-custom-dce branch from 8a67fbe to 6c4cb0b Compare February 27, 2025 18:55
@dfm dfm force-pushed the remat-opt-custom-dce branch from 6c4cb0b to 10756cb Compare February 27, 2025 20:06
@dfm dfm requested a review from froystig February 27, 2025 20:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants