-
Notifications
You must be signed in to change notification settings - Fork 296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Float8] Unable to run asyncTP + Float8 row with 'full' AC active, leading dims mismatch #864
Comments
@lessw2020 quick question, by "latest PR for rowwise TP" are you referring to the torchtitan PR here #808 or the torchao PR here pytorch/ao#1718? Were you using both? |
Also feel free to assign this to me |
Hi @danielvegamyhre - yes, good questions - I was using both and have updated the original issue. |
Different error occurs at pytorch/ao@32a51ec:
I believe this was fixed later in pytorch/ao@988c5c9 |
Testing at pytorch/ao@988c5c9 the scale dim mismatch issue repros:
|
Tried disabling power of 2 scale factors, issue still occurs. |
Error: Error is thrown here when A shape is (1,8192,2048) and A_scale shape is (8192, 1): /~https://github.com/pytorch/pytorch/blob/fb1f7f6a09576e833647a19b0328d599c754af2a/torch/distributed/_symmetric_memory/__init__.py#L1151 Dumping logs from torchao scale calculations: https://www.internalfb.com/phabricator/paste/view/P1736882650 We can see that we have a tensor shape (1,8192,4096) and scale shape of (1,8192,1). It follows naturally that with async TP degree = 2, this tensor is split along the final dim into 2x of (1,8192,2048), and we keep the same scale of (1,8192,1). So we would expect to see an A shape of (1,8192,2048) and A_scale shape of (1,8192,1). In the error, we see the A shape is correct but the A scale shape is not. Looking around the code where the error is thrown, it seems inductor inserts a reduce scatter op (needed for backward presumably) here, the A_scale has had the leading dim squeezed off, and it is (8192,1), which is incorrect, leading to the mismatch between the leading dims of A and A_scale. I believe at this point in the code the node has the wrong shape for the A_scale: /~https://github.com/pytorch/pytorch/blob/863ac20659adf63c998a272f7e32e33a38dcea91/torch/_inductor/fx_passes/micro_pipeline_tp.py#L393 Need to install pytorch from source with print statements to debug further. |
In torch/_inductor/fx_passes i see squeezing going on in a couple places:
If "split_cat" has to do with concatenation, this seems like it could be related since presumably in TP we need to all-gather and concat at certain points. Wondering if the scaling factor tensor could be affected by this. |
I think this is where the graph is parsed to find reduce scatters and creates a Presumably we would go down this code path, since the scales tensor is not 0 dim. I also see in the I'm beginning to wonder if async TP is splitting the scales tensor instead of broadcasting it? |
Dumped
Looks like there is a view op which changes the shape of exp2 output from (1,8192,1) to (8192,1)? |
Confirmed issue still occurs with no AC. |
This line of code will result in the A_scale shape going from (1,8192,1) -> (8192,1), which would explain the error. However, the error is actually raised on the line just above it, so it would seem this isn't actually reached? |
Looks like this line could also rearrange the dims of the scale: /~https://github.com/pytorch/pytorch/blob/be0df96b50490b03c9037e4dc3d1529b0471c533/torch/distributed/_symmetric_memory/__init__.py#L1240 |
Eager mode rowwise with vanilla TP throws a different error:
Note this does NOT occur for tensorwise + eager + vanilla TP, only rowwise. |
Looking for where sharding strategies are registered, I found the Dtensor op dispatcher and the util function for registering op strategies. I think we need to register |
Checked with tianyu-l@ and he agreed this does seem to be the case. He suggested I submit a PR for this, using this WIP PR as a starting point: pytorch/pytorch#130887. To be clear, this is to fix eager mode + float8 rowwise + vanilla TP. |
I reinstalled pytorch using the nightly build and am now getting a different error for eager mode + float8 rowwise + vanilla TP.
Added some logging and found the output dtype in question is fp32. I also logged the shapes of A, B tensors and their scales:
The good news is the scales are the right shape:
The bad news is I'm not sure why these tensors are suddenly float32 and not bf16.. |
I managed to get the tensors into bf16 again by using HSDP2, which automatically casts to bf16 in torchtitan. In fact, by combining pytorch nightly build + HSDP + vanilla TP, it actually works for both eager and compile ("no sharding strategy registered" error is gone). |
Update: I tried building pytorch from source and using
Since the scale is already wrong here, the rest of the execution matters less, but I mapped out the following for my own understanding:
|
I traced further back to the beginning of the micropipeline TP pass here and confirmed that the Viewing the part of the graph I know will be used as scale from the execution tracing in #864 (comment): _get_tensor(reduce_scatters[1].input_node.args[0].args[2])
[rank0]:(Pdb) [rank0]:FakeTensor(..., device='cuda:0', size=(8192, 1)) We can see the raw To summarize, at this point I've verified:
My conclusion is there is a problem in the graph tracing code for async TP, but am not sure what specifically yet. cc @yifuwang who worked on async TP - can you take a look at this comment and let me know if you agree with the root cause identified here? And who on the inductor side might be a good point of contact to take a look at this? |
Adding some more logging in torchao for additional data:
We can observe that the raw scales computed have the correct shape, but when that tensor is later used in an |
Added logging in the float8 view op, since as noted in my previous comment, the originally computed scale has the correct shape, but by the time the matmul is performed, the scale shape has changed, so we need to find out where in between that happens. Previously I noted there was a view op changing the scale shape from (1,8192,1) to (8192,1) here. Added logging to float8 view op and we can see the problematic op come through:
Since the data is being reshaped from [1, 8192, 4096] to [8192,4096], the scale shape is also adjusted from [1,8192,1] to [8192,1] here. Now the question is, why is pytorch performing a view to cut off leading dim of the data shape? |
Set a breakpoint at the start of post_grad_passes before any post grad passes have run, and have confirmed the scale is wrong already:
This would seem to eliminate the micropipeline/async TP pass as well as the other post grad passes. However, if that's the case i'd expect the issue to appear for vanilla TP too, and it doesn't.. |
I set a breakpoint in the float8 view() op implementation, and when the problematic view op described here was intercepted, I examined the stack trace and found the source of the call, which is this line. As shown here the reshape of the tensor causes the scale to be reshaped. However,at this point in the async TP code when the error is thrown, the tensor is back in it's original shape, but the scale is still in reshaped form, causing the mismatch. |
cc: @vkuzo |
This is occurring downstream of the attention layer Q/K/V projections here. These projections are using colwise_parallel here. |
I think maybe the problem isn't that the scale is the wrong shape, but rather async TP is using the wrong node with the A tensor, a node before the reshape here. |
I visualized the fx graph before any post grad passes have been applied, and confirmed my hunch in my prior comment was correct: the graph is actually correct, but the graph manipulation in async TP is referencing the wrong "A tensor" node somehow (specifically, the "A tensor" prior to the reshape here). The TL;DR is for a To be more specific, the ![]() To fix this, I tried changing this line to reference
|
Side note for later investigation: this seems to restride the "A tensor" to be column major, which does not seem correct/ideal for fp8 gemms using _scaled_mm, we should use a row-major tensor since the cublas api requires this, so we avoid unnecessary conversions: /~https://github.com/pytorch/pytorch/blob/55bf3ff3a5bd18ab1805d0d52f7a723d097294d4/torch/_inductor/fx_passes/micro_pipeline_tp.py#L567 |
So problem 1 (the original error) is fixed by changing this line to param to
I visualized the graph after the async TP post grad pass has been applied, and located the area of the graph where I believe this error is occurring. The
The output of this mm will be shape (8192,4096), but it's being reduce scattered, which I think will result in it sharding along the final dim into 2 shards of shape (8192,2048). Now the output of this is used by the
This results in the broadcast error. |
I'm now trying to understand why we are adding the result of the TL;DR is it seems the input to the attention layer is being sharded across the sequence dim, but the output of the attention layer is being reduce_scattered across the model dim, so there is a shape mismatch in the residual layer. Details: I assume the addition is coming from a residual layer in the mode like this one. Through some additional testing and analysis, I confirmed the tensor of shape (1,4096,4096) is the input to the transformer block, which has a shape (batch, seq_len // 2, model_dim) = (1, 8192//2, 4096) = (1,4096,4096). The output of the attention layer is (batch, seq_len, model_dim // 2) = (1,8192,4096//2) = (1,8192,2048). Therefore these tensors cannot be added. It looks to me like the input is being sharded across a different dimension than the attention outputs, thus the shape incompatibility. Specifically, I'm wondering if these input sharding specs in torchtitan here are the problem. |
I think I see the problem causing the broadcast error, the pattern is pulling the scatter_dim from the original tensor with 3 dims, where scatter_dim=1 is the seq dim. However, after the reshape in torchao here, reshaping the tensor have 2 dims, the scatter dim of 1 is now outdated and incorrectly applied along the hidden dim - which does not match the sharding spec of the tensor it is added with in the residual here, which is sharded along the seq dim. |
After discovering how my original solution made the So I went back to the drawing board and found a different solution: rather than changing what node we use for the A_tensor so that it's dims match the scale dims, we can use an ancestor node of the scale tensor from before it was reshaped, so it's dims match the A_tensor dims. I confirmed via manual testing with torch titan that this approach works and float8 rowwise + async TP is no longer crashing. The change adds special handling for reshape -> mm -> reshape pattern, as is done in torchao here and causes the scale to change shape as well: |
After trying multiple different solutions, I've finally managed to find one which fixes the crash AND doesn't affect numerics, while passing pytorch test cases: pytorch/pytorch#148001 |
…hape pattern" in async TP with rowwise scales (#148001) Fixes pytorch/torchtitan#864 ## Summary While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to. My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used. ## TL;DR of root cause - When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned. - In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op. ## Example - Concrete example: - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1). - During post grad pass in async TP: - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122)) - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)). ## Solution **Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics. - Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape. - reshape is just a view, so there should be no impact on performance ``` Before: reshape (a,bc,) to (a*b,c) -> reciprocal After: reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c) ``` - Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor` ## Test plan - Added unit test which exercises this new path - Manually tested with torchtitan with float8 rowwise + async TP Pull Request resolved: #148001 Approved by: /~https://github.com/yifuwang
…hape pattern" in async TP with rowwise scales (#148001) Fixes pytorch/torchtitan#864 ## Summary While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to. My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used. ## TL;DR of root cause - When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned. - In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op. ## Example - Concrete example: - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1). - During post grad pass in async TP: - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122)) - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)). ## Solution **Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics. - Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape. - reshape is just a view, so there should be no impact on performance ``` Before: reshape (a,bc,) to (a*b,c) -> reciprocal After: reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c) ``` - Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor` ## Test plan - Added unit test which exercises this new path - Manually tested with torchtitan with float8 rowwise + async TP Pull Request resolved: #148001 Approved by: /~https://github.com/yifuwang
…hape pattern" in async TP with rowwise scales (#148001) Fixes pytorch/torchtitan#864 ## Summary While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to. My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used. ## TL;DR of root cause - When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned. - In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op. ## Example - Concrete example: - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1). - During post grad pass in async TP: - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122)) - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)). ## Solution **Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics. - Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape. - reshape is just a view, so there should be no impact on performance ``` Before: reshape (a,bc,) to (a*b,c) -> reciprocal After: reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c) ``` - Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor` ## Test plan - Added unit test which exercises this new path - Manually tested with torchtitan with float8 rowwise + async TP Pull Request resolved: #148001 Approved by: /~https://github.com/yifuwang
Bug description
With the latest PR for Float8 rowwise to support TP, I hit the following error when full checkpointing is active. It is not an issue using no checkpointing.
Versions
Latest TorchAO nightly + TP PR + PyTorch nightly (Feb 10)
TP PR = /~https://github.com/pytorch/torchtitan/pull/808/files +
pytorch/ao#1718
This was at 256 scale, with full AC
TP=2 + FP8 rowwise
The text was updated successfully, but these errors were encountered: