Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Multiple Devices and Parallelism #1495

Open
wants to merge 95 commits into
base: master
Choose a base branch
from
Open

Conversation

YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented Dec 25, 2024

Initial support for multi-GPU optimization.

  • Adds a convenience function for getting ObjectiveFunction with multiple ForceBalance objectives that are distributed to multiple GPUs
  • Adds jit_with_device decorator to jit a function a specific device (without this everything runs on GPU id:0)
  • Adds pconcat to concatenate, hstack or vstack a list of arrays that lives on different devices to a single device. If the resultant arrays fits to GPU id:0, puts them to GPU, otherwise puts them to CPU. This function is used for compute and jac methods. If CPU is used, QR and other types of linear algebra on Jacobian get slower, but there is no memory restriction.
  • Adds _device_id and _device to _Objective class (defaults to 0) for making parallelization work with other objectives
  • Uses mpi4py for compute and jvp methods of parallel objective functions (there is no separate class but if user passes objectives with different devices and an mpi communicator, parallel strategy will be used.) To make this work a user has to use context manager properly (which I think is not to much of a work)

We won't see any speed improvement for the trust_region_subproblem solvers, because JAX doesn't support distributed linear algebra yet.

TODO:

  • Proximal part still uses for loop, to make that part parallel too, we need to add a blocked option to _jvp of ProximalProjection.
  • Remove the redundant lines of code
  • if we add a new GitHub action we can test it with virtual devices where you can make JAX see different cores as different devices

Things to consider,

  • Maybe implement a new optimizer that uses distributed matrix operations instead of QR and SVD. Probably future PR
  • Make the default grid of get_forcabalance_parallel smaller.

Resolves #1071 (but with mpi4py)
Resolves #1601

Copy link

codecov bot commented Dec 26, 2024

Codecov Report

Attention: Patch coverage is 44.44444% with 145 lines in your changes missing coverage. Please review.

Project coverage is 95.26%. Comparing base (c1217bf) to head (6587728).

Files with missing lines Patch % Lines
desc/objectives/objective_funs.py 40.90% 91 Missing ⚠️
desc/objectives/getters.py 4.16% 23 Missing ⚠️
desc/optimize/_constraint_wrappers.py 64.61% 23 Missing ⚠️
desc/backend.py 55.55% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1495      +/-   ##
==========================================
- Coverage   95.70%   95.26%   -0.44%     
==========================================
  Files         101      100       -1     
  Lines       26349    26478     +129     
==========================================
+ Hits        25216    25225       +9     
- Misses       1133     1253     +120     
Files with missing lines Coverage Δ
desc/objectives/_bootstrap.py 97.18% <ø> (ø)
desc/objectives/_coils.py 99.38% <ø> (ø)
desc/objectives/_equilibrium.py 95.08% <ø> (ø)
desc/objectives/_fast_ion.py 98.78% <ø> (ø)
desc/objectives/_free_boundary.py 96.17% <ø> (ø)
desc/objectives/_generic.py 99.40% <ø> (ø)
desc/objectives/_geometry.py 96.79% <ø> (ø)
desc/objectives/_neoclassical.py 98.75% <ø> (ø)
desc/objectives/_omnigenity.py 97.06% <ø> (ø)
desc/objectives/_power_balance.py 91.83% <ø> (ø)
... and 7 more

... and 3 files with indirect coverage changes

Copy link
Contributor

github-actions bot commented Dec 26, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +3.41 +/- 7.21     | +2.12e-02 +/- 4.48e-02 |  6.43e-01 +/- 1.8e-02  |  6.22e-01 +/- 4.1e-02  |
 test_equilibrium_init_medres            |     -0.92 +/- 3.14     | -4.11e-02 +/- 1.40e-01 |  4.41e+00 +/- 4.8e-02  |  4.46e+00 +/- 1.3e-01  |
 test_equilibrium_init_highres           |     -0.09 +/- 3.08     | -5.03e-03 +/- 1.68e-01 |  5.45e+00 +/- 1.4e-01  |  5.45e+00 +/- 9.9e-02  |
 test_objective_compile_dshape_current   |     +1.50 +/- 5.69     | +6.41e-02 +/- 2.43e-01 |  4.33e+00 +/- 1.8e-01  |  4.27e+00 +/- 1.6e-01  |
 test_objective_compute_dshape_current   |     +1.24 +/- 4.11     | +6.76e-05 +/- 2.24e-04 |  5.53e-03 +/- 2.1e-04  |  5.47e-03 +/- 7.3e-05  |
 test_objective_jac_dshape_current       |     -2.01 +/- 6.56     | -8.97e-04 +/- 2.93e-03 |  4.38e-02 +/- 2.0e-03  |  4.47e-02 +/- 2.1e-03  |
 test_perturb_2                          |     -0.01 +/- 3.22     | -2.51e-03 +/- 6.74e-01 |  2.10e+01 +/- 5.7e-01  |  2.10e+01 +/- 3.6e-01  |
 test_proximal_jac_atf_with_eq_update    |     -0.20 +/- 1.57     | -3.36e-02 +/- 2.69e-01 |  1.71e+01 +/- 2.2e-01  |  1.71e+01 +/- 1.5e-01  |
 test_proximal_freeb_jac                 |     +0.25 +/- 2.38     | +1.84e-02 +/- 1.72e-01 |  7.24e+00 +/- 5.1e-02  |  7.23e+00 +/- 1.6e-01  |
 test_solve_fixed_iter_compiled          |     -0.65 +/- 1.22     | -1.39e-01 +/- 2.60e-01 |  2.11e+01 +/- 1.9e-01  |  2.13e+01 +/- 1.7e-01  |
 test_LinearConstraintProjection_build   |     -1.65 +/- 3.78     | -1.91e-01 +/- 4.39e-01 |  1.14e+01 +/- 2.7e-01  |  1.16e+01 +/- 3.4e-01  |
 test_objective_compute_ripple_spline    |     +0.08 +/- 3.10     | +2.82e-04 +/- 1.08e-02 |  3.49e-01 +/- 7.7e-03  |  3.49e-01 +/- 7.6e-03  |
 test_objective_grad_ripple_spline       |     -0.16 +/- 1.08     | -2.25e-03 +/- 1.53e-02 |  1.42e+00 +/- 1.2e-02  |  1.42e+00 +/- 9.9e-03  |
 test_build_transform_fft_midres         |     -0.99 +/- 2.44     | -6.36e-03 +/- 1.57e-02 |  6.35e-01 +/- 9.6e-03  |  6.41e-01 +/- 1.2e-02  |
 test_build_transform_fft_highres        |     -1.47 +/- 2.33     | -1.38e-02 +/- 2.18e-02 |  9.24e-01 +/- 1.6e-02  |  9.38e-01 +/- 1.5e-02  |
 test_equilibrium_init_lowres            |     -0.49 +/- 1.47     | -1.94e-02 +/- 5.85e-02 |  3.97e+00 +/- 3.9e-02  |  3.99e+00 +/- 4.3e-02  |
 test_objective_compile_atf              |     +0.28 +/- 2.88     | +2.32e-02 +/- 2.39e-01 |  8.31e+00 +/- 1.9e-01  |  8.28e+00 +/- 1.4e-01  |
 test_objective_compute_atf              |     +0.78 +/- 2.60     | +1.27e-04 +/- 4.22e-04 |  1.64e-02 +/- 3.7e-04  |  1.62e-02 +/- 2.0e-04  |
 test_objective_jac_atf                  |     -0.59 +/- 1.97     | -1.18e-02 +/- 3.95e-02 |  1.99e+00 +/- 2.8e-02  |  2.00e+00 +/- 2.8e-02  |
 test_perturb_1                          |     +0.29 +/- 2.12     | +4.46e-02 +/- 3.22e-01 |  1.52e+01 +/- 2.5e-01  |  1.52e+01 +/- 2.0e-01  |
 test_proximal_jac_atf                   |     -0.70 +/- 0.79     | -5.69e-02 +/- 6.43e-02 |  8.04e+00 +/- 4.5e-02  |  8.09e+00 +/- 4.6e-02  |
 test_proximal_freeb_compute             |     +1.16 +/- 0.96     | +2.49e-03 +/- 2.05e-03 |  2.18e-01 +/- 1.4e-03  |  2.15e-01 +/- 1.5e-03  |
 test_solve_fixed_iter                   |     +0.42 +/- 2.62     | +1.36e-01 +/- 8.55e-01 |  3.28e+01 +/- 6.4e-01  |  3.27e+01 +/- 5.6e-01  |
 test_objective_compute_ripple           |     +0.52 +/- 2.04     | +3.63e-03 +/- 1.43e-02 |  7.02e-01 +/- 9.8e-03  |  6.98e-01 +/- 1.0e-02  |
 test_objective_grad_ripple              |     +1.42 +/- 1.34     | +3.77e-02 +/- 3.56e-02 |  2.70e+00 +/- 8.4e-03  |  2.66e+00 +/- 3.5e-02  |

@dpanici
Copy link
Collaborator

dpanici commented Jan 6, 2025

#763
check for overlap with this one

@YigitElma YigitElma requested review from a team, rahulgaur104, f0uriest, dpanici, sinaatalay and unalmis and removed request for a team February 14, 2025 05:43
@YigitElma YigitElma self-assigned this Feb 14, 2025
@YigitElma YigitElma added performance New feature or request to make the code faster gpu Issues related to the GPU backend labels Feb 14, 2025
@YigitElma YigitElma requested a review from a team February 25, 2025 21:08
…d batched and blocked options, also implement mpi for proximal
@@ -139,6 +139,7 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa
xp = put(xp, unfixed_idx, A_inv @ b)
xp = put(xp, fixed_idx, ((1 / D) * xp)[fixed_idx])
# cast to jnp arrays
# TODO: might consider sharding these
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean for multiple GPU case, if we use sharding for Z, then project and recover (which are called a lot) will be faster. It is just a matrix vector product but still...

@YigitElma YigitElma changed the title Support for Multiple GPUs Support for Multiple Devices and Parallelism Feb 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
gpu Issues related to the GPU backend performance New feature or request to make the code faster
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Parallelization options Parallelize across multiple GPUs with MPI4Jax
3 participants