-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Multiple Devices and Parallelism #1495
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1495 +/- ##
==========================================
- Coverage 95.70% 95.26% -0.44%
==========================================
Files 101 100 -1
Lines 26349 26478 +129
==========================================
+ Hits 25216 25225 +9
- Misses 1133 1253 +120
|
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +3.41 +/- 7.21 | +2.12e-02 +/- 4.48e-02 | 6.43e-01 +/- 1.8e-02 | 6.22e-01 +/- 4.1e-02 |
test_equilibrium_init_medres | -0.92 +/- 3.14 | -4.11e-02 +/- 1.40e-01 | 4.41e+00 +/- 4.8e-02 | 4.46e+00 +/- 1.3e-01 |
test_equilibrium_init_highres | -0.09 +/- 3.08 | -5.03e-03 +/- 1.68e-01 | 5.45e+00 +/- 1.4e-01 | 5.45e+00 +/- 9.9e-02 |
test_objective_compile_dshape_current | +1.50 +/- 5.69 | +6.41e-02 +/- 2.43e-01 | 4.33e+00 +/- 1.8e-01 | 4.27e+00 +/- 1.6e-01 |
test_objective_compute_dshape_current | +1.24 +/- 4.11 | +6.76e-05 +/- 2.24e-04 | 5.53e-03 +/- 2.1e-04 | 5.47e-03 +/- 7.3e-05 |
test_objective_jac_dshape_current | -2.01 +/- 6.56 | -8.97e-04 +/- 2.93e-03 | 4.38e-02 +/- 2.0e-03 | 4.47e-02 +/- 2.1e-03 |
test_perturb_2 | -0.01 +/- 3.22 | -2.51e-03 +/- 6.74e-01 | 2.10e+01 +/- 5.7e-01 | 2.10e+01 +/- 3.6e-01 |
test_proximal_jac_atf_with_eq_update | -0.20 +/- 1.57 | -3.36e-02 +/- 2.69e-01 | 1.71e+01 +/- 2.2e-01 | 1.71e+01 +/- 1.5e-01 |
test_proximal_freeb_jac | +0.25 +/- 2.38 | +1.84e-02 +/- 1.72e-01 | 7.24e+00 +/- 5.1e-02 | 7.23e+00 +/- 1.6e-01 |
test_solve_fixed_iter_compiled | -0.65 +/- 1.22 | -1.39e-01 +/- 2.60e-01 | 2.11e+01 +/- 1.9e-01 | 2.13e+01 +/- 1.7e-01 |
test_LinearConstraintProjection_build | -1.65 +/- 3.78 | -1.91e-01 +/- 4.39e-01 | 1.14e+01 +/- 2.7e-01 | 1.16e+01 +/- 3.4e-01 |
test_objective_compute_ripple_spline | +0.08 +/- 3.10 | +2.82e-04 +/- 1.08e-02 | 3.49e-01 +/- 7.7e-03 | 3.49e-01 +/- 7.6e-03 |
test_objective_grad_ripple_spline | -0.16 +/- 1.08 | -2.25e-03 +/- 1.53e-02 | 1.42e+00 +/- 1.2e-02 | 1.42e+00 +/- 9.9e-03 |
test_build_transform_fft_midres | -0.99 +/- 2.44 | -6.36e-03 +/- 1.57e-02 | 6.35e-01 +/- 9.6e-03 | 6.41e-01 +/- 1.2e-02 |
test_build_transform_fft_highres | -1.47 +/- 2.33 | -1.38e-02 +/- 2.18e-02 | 9.24e-01 +/- 1.6e-02 | 9.38e-01 +/- 1.5e-02 |
test_equilibrium_init_lowres | -0.49 +/- 1.47 | -1.94e-02 +/- 5.85e-02 | 3.97e+00 +/- 3.9e-02 | 3.99e+00 +/- 4.3e-02 |
test_objective_compile_atf | +0.28 +/- 2.88 | +2.32e-02 +/- 2.39e-01 | 8.31e+00 +/- 1.9e-01 | 8.28e+00 +/- 1.4e-01 |
test_objective_compute_atf | +0.78 +/- 2.60 | +1.27e-04 +/- 4.22e-04 | 1.64e-02 +/- 3.7e-04 | 1.62e-02 +/- 2.0e-04 |
test_objective_jac_atf | -0.59 +/- 1.97 | -1.18e-02 +/- 3.95e-02 | 1.99e+00 +/- 2.8e-02 | 2.00e+00 +/- 2.8e-02 |
test_perturb_1 | +0.29 +/- 2.12 | +4.46e-02 +/- 3.22e-01 | 1.52e+01 +/- 2.5e-01 | 1.52e+01 +/- 2.0e-01 |
test_proximal_jac_atf | -0.70 +/- 0.79 | -5.69e-02 +/- 6.43e-02 | 8.04e+00 +/- 4.5e-02 | 8.09e+00 +/- 4.6e-02 |
test_proximal_freeb_compute | +1.16 +/- 0.96 | +2.49e-03 +/- 2.05e-03 | 2.18e-01 +/- 1.4e-03 | 2.15e-01 +/- 1.5e-03 |
test_solve_fixed_iter | +0.42 +/- 2.62 | +1.36e-01 +/- 8.55e-01 | 3.28e+01 +/- 6.4e-01 | 3.27e+01 +/- 5.6e-01 |
test_objective_compute_ripple | +0.52 +/- 2.04 | +3.63e-03 +/- 1.43e-02 | 7.02e-01 +/- 9.8e-03 | 6.98e-01 +/- 1.0e-02 |
test_objective_grad_ripple | +1.42 +/- 1.34 | +3.77e-02 +/- 3.56e-02 | 2.70e+00 +/- 8.4e-03 | 2.66e+00 +/- 3.5e-02 | |
#763 |
46ed909
to
3edf125
Compare
3b2f0f1
to
a08c348
Compare
… and it must be built before entering
…d batched and blocked options, also implement mpi for proximal
…e, we should still be able to parallelize
@@ -139,6 +139,7 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa | |||
xp = put(xp, unfixed_idx, A_inv @ b) | |||
xp = put(xp, fixed_idx, ((1 / D) * xp)[fixed_idx]) | |||
# cast to jnp arrays | |||
# TODO: might consider sharding these |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean for multiple GPU case, if we use sharding for Z, then project
and recover
(which are called a lot) will be faster. It is just a matrix vector product but still...
Initial support for multi-GPU optimization.
ObjectiveFunction
with multipleForceBalance
objectives that are distributed to multiple GPUsjit_with_device
decorator to jit a function a specific device (without this everything runs on GPU id:0)pconcat
to concatenate, hstack or vstack a list of arrays that lives on different devices to a single device. If the resultant arrays fits to GPU id:0, puts them to GPU, otherwise puts them to CPU. This function is used for compute and jac methods. If CPU is used, QR and other types of linear algebra on Jacobian get slower, but there is no memory restriction._device_id
and_device
to_Objective
class (defaults to 0) for making parallelization work with other objectivesmpi4py
for compute and jvp methods of parallel objective functions (there is no separate class but if user passes objectives with different devices and an mpi communicator, parallel strategy will be used.) To make this work a user has to use context manager properly (which I think is not to much of a work)We won't see any speed improvement for the trust_region_subproblem solvers, because
JAX
doesn't support distributed linear algebra yet.TODO:
blocked
option to_jvp
ofProximalProjection
.Things to consider,
get_forcabalance_parallel
smaller.Resolves #1071 (but with
mpi4py
)Resolves #1601