-
Notifications
You must be signed in to change notification settings - Fork 296
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
**Numeric Parity** 1D FSDP - Eager: 1k steps of minipile on 8 H100 GPUs, local batch size 8, sequence length 2048, AC/SAC, bf16 mixed precision, fp32 reduce-scatter - FSDP1 (AC): 24.81% peak active, 33.82% peak reserved, 6100-6200 WPS - FSDP1 (SAC): 52.98% peak active, 67.23% peak reserved, 6500-6700 WPS - FSDP2 (AC): 23.92% peak active, 32.64% peak reserved, 6100-6300 WPS - FSDP2 (SAC): 52.13% peak active, 62.51% peak reserved, 6600-6800 WPS - Loss curves match between FSDP1 and FSDP2 - Memory numbers reported as percentage since that is how they are logged; can convert against 95.0396 GiB GPU memory - Compile: same setup as eager - FSDP2 (AC), buffer reuse disabled: 28.72 GiB (30.22%) peak reserved, 7200-7500 WPS, 33% MFU - FSDP2 (AC), buffer reuse enabled: 28.90 GiB (30.40%) peak reserved, 7200-7500 WPS, 33% MFU - FSDP2 (SAC), buffer reuse enabled: 53.83 GiB (56.64%) peak reserved, 8100-8400 WPS, 36% MFU - Loss curves slightly better than eager - For fun -- how much can we push MFU? - If we use FSDP2 (SAC) with 16 local batch size (doubled), we get 88.23 GiB (92.84%) peak reserved, 8600 WPS, 38% MFU. - If we use FSDP2 (no AC) with 8 local batch size, we get 90.28 GiB (94.99%) peak reserved, 9100-9300 WPS, 40% MFU. - Why is FSDP2 faster? (1) fp32 reduce-scatter only uses one div kernel instead of two and (2), `reshard_after_forward=False` for the last transformer block 2D FSDP - Eager (2-way SP, 4-way FSDP): 1k steps of minipile on 8 H100 GPUs, local batch size 16 (to preserve global batch size), sequence length 2048, bf16 mixed precision, fp32 reduce-scatter - FSDP2 (AC): 50.12% peak active, 60.97% peak reserved, 5800-5900 WPS - FSDP2 (SAC): 76.49% peak active, 90.14% peak reserved, 6100-6300 WPS - Loss curves match 8-way FSDP - FSDP1 + SP has incorrect numerics due to the `FSDP.clip_grad_norm_` not all-reducing over TP mesh dimension <details> <summary> Loss curves </summary> <img width="732" alt="Screenshot 2024-03-26 at 3 31 19 PM" src="/~https://github.com/pytorch/torchtrain/assets/31054793/59ec71cc-ad0a-4dd1-b5c6-a8cbf9ab5e85"> </details> **Meta-Device Initialization** - The PyTorch Core guideline is for `module.reset_parameters()` to only initialize parameters/buffers immediately owned by `module` (i.e. `module.parameters(recurse=False)` and `module.buffers(recurse=False)`). - This makes it challenging to specify custom initializations for core modules like `nn.Linear` and `nn.Embedding`. For example, in @lessw2020's depth-wise truncated normal initialization, the `trunc_normal_` standard deviation depends on the layer ID, which is a property of the `TransformerBlock` but affects the child `nn.Linear`s. - To disambiguate, I suggest avoiding the name `reset_parameters()` in the case that we violate the PyTorch Core guideline and instead use a different name (e.g. `init_weights`). **DCP & Save/Load** - Tested 1D and 2D by specifying `checkpoint_folder = "/tmp/checkpoint_andgu` in the `.toml`, training until saving a checkpoint, terminating the run, and restarting the training to load the checkpoint -- the loss after loading looks reasonable
- Loading branch information
Showing
4 changed files
with
97 additions
and
194 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.