Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Float8] Unable to run asyncTP + Float8 row with 'full' AC active, leading dims mismatch #864

Open
lessw2020 opened this issue Feb 20, 2025 · 34 comments
Assignees
Labels
bug Something isn't working module: float8

Comments

@lessw2020
Copy link
Contributor

lessw2020 commented Feb 20, 2025

Bug description

With the latest PR for Float8 rowwise to support TP, I hit the following error when full checkpointing is active. It is not an issue using no checkpointing.

_fused_scaled_matmul_reduce_scatter_fallback
[rank5]:[E0216](https://www.internalfb.com/servicelab/experiment/0216) 16:37:13.835000 2683263 site-packages/torch/_subclasses/fake_tensor.py:2391] [0/0]     raise ValueError(
[rank5]:[E0216](https://l.workplace.com/l.php?u=https%3A%2F%2Fwww.internalfb.com%2Fservicelab%2Fexperiment%2F0216&h=AT1TX4lglt8Gt6VweeYU3P0oFRLUvsJsvML5Y_yZJMPKCdVzFRdqqmTdvdtqmpXcIGhySXy-kct0PRVIYHpBY3OOFCLjLhFoaBNkUI3dt8eeFEvhGwmkZIm8ArpSAEFcFvqXXHIAHklO_ZuYQ9dt40L_8qfAErr7nWWJgg) 16:37:13.835000 2683263 site-packages/torch/_subclasses/fake_tensor.py:2391] [0/0] ValueError: For row-wise scaling, the leading dims of A_scale must match the leading dims of A (A shape: torch.Size([5, 8192, 4096]), A_scale shape: torch.Size([40960, 1]))
....
rank5]: ValueError: For row-wise scaling, the leading dims of A_scale must match the leading dims of A (A shape: torch.Size([5, 8192, 4096]), A_scale shape: torch.Size([40960, 1]))
[P1734043454]

Versions

Latest TorchAO nightly + TP PR + PyTorch nightly (Feb 10)
TP PR = /~https://github.com/pytorch/torchtitan/pull/808/files +
pytorch/ao#1718
This was at 256 scale, with full AC
TP=2 + FP8 rowwise

@lessw2020 lessw2020 added the bug Something isn't working label Feb 20, 2025
@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 20, 2025

@lessw2020 quick question, by "latest PR for rowwise TP" are you referring to the torchtitan PR here #808 or the torchao PR here pytorch/ao#1718? Were you using both?

@danielvegamyhre
Copy link
Contributor

Also feel free to assign this to me

@lessw2020
Copy link
Contributor Author

Hi @danielvegamyhre - yes, good questions - I was using both and have updated the original issue.
Assigning over to you, thanks for taking a look!

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 20, 2025

Different error occurs at pytorch/ao@32a51ec:

aten.view.default with axiswise scaling and t.shape torch.Size([1, 8192, 4096]) t._scale.shape torch.Size([1, 8192, 1]) t._axiswise_dim -1 new_shape [1, 8192, 4096] is not supported yet.

I believe this was fixed later in pytorch/ao@988c5c9

@danielvegamyhre
Copy link
Contributor

Testing at pytorch/ao@988c5c9 the scale dim mismatch issue repros:

  ValueError: For row-wise scaling, the leading dims of A_scale must match the leading dims of A (A shape: torch.Size([1, 8192, 2048]), A_scale shape: torch.Size([8192, 1]))

@danielvegamyhre
Copy link
Contributor

Tried disabling power of 2 scale factors, issue still occurs.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 20, 2025

Error: torch._inductor.exc.InductorError: ValueError: For row-wise scaling, the leading dims of A_scale must match the leading dims of A (A shape: torch.Size([1, 8192, 2048]), A_scale shape: torch.Size([8192, 1]))

Error is thrown here when A shape is (1,8192,2048) and A_scale shape is (8192, 1): /~https://github.com/pytorch/pytorch/blob/fb1f7f6a09576e833647a19b0328d599c754af2a/torch/distributed/_symmetric_memory/__init__.py#L1151

Dumping logs from torchao scale calculations: https://www.internalfb.com/phabricator/paste/view/P1736882650

We can see that we have a tensor shape (1,8192,4096) and scale shape of (1,8192,1). It follows naturally that with async TP degree = 2, this tensor is split along the final dim into 2x of (1,8192,2048), and we keep the same scale of (1,8192,1).

So we would expect to see an A shape of (1,8192,2048) and A_scale shape of (1,8192,1). In the error, we see the A shape is correct but the A scale shape is not.

Looking around the code where the error is thrown, it seems inductor inserts a reduce scatter op (needed for backward presumably) here, the A_scale has had the leading dim squeezed off, and it is (8192,1), which is incorrect, leading to the mismatch between the leading dims of A and A_scale.

I believe at this point in the code the node has the wrong shape for the A_scale: /~https://github.com/pytorch/pytorch/blob/863ac20659adf63c998a272f7e32e33a38dcea91/torch/_inductor/fx_passes/micro_pipeline_tp.py#L393

Need to install pytorch from source with print statements to debug further.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 20, 2025

In torch/_inductor/fx_passes i see squeezing going on in a couple places:

If "split_cat" has to do with concatenation, this seems like it could be related since presumably in TP we need to all-gather and concat at certain points. Wondering if the scaling factor tensor could be affected by this.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 20, 2025

I think this is where the graph is parsed to find reduce scatters and creates a ReduceScatterMatch for reduce scatters with non-zero dim tensors: /~https://github.com/pytorch/pytorch/blob/863ac20659adf63c998a272f7e32e33a38dcea91/torch/_inductor/fx_passes/micro_pipeline_tp.py#L265

Presumably we would go down this code path, since the scales tensor is not 0 dim.

I also see in the non zero reduce scatter pattern they're looking for in the graph, the aten.cat.default op is what it starts with, which potentially relates to the squeeze() found in split_cat in #864 (comment)

I'm beginning to wonder if async TP is splitting the scales tensor instead of broadcasting it?

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 20, 2025

Dumped TORCH_LOGS=+aot,+inductor,+dynamo logs (link) and found some relevant lines:

[rank0]:I0220 12:15:11.288000 3845728 site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:402] [0/0] [__aot_joint_graph]         exp2: "f32[1, 8192, 1][8192, 1, 1]cuda:0" = torch.ops.aten.exp2.default(floor);  floor = None
...
[rank0]:I0220 12:15:11.288000 3845728 site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:402] [0/0] [__aot_joint_graph]         view_4: "f32[8192, 1][1, 1]cuda:0" = torch.ops.aten.view.default(exp2, [-1, 1]);  exp2 = None

Looks like there is a view op which changes the shape of exp2 output from (1,8192,1) to (8192,1)?

@danielvegamyhre
Copy link
Contributor

Confirmed issue still occurs with no AC.

@danielvegamyhre
Copy link
Contributor

This line of code will result in the A_scale shape going from (1,8192,1) -> (8192,1), which would explain the error.

However, the error is actually raised on the line just above it, so it would seem this isn't actually reached?

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 20, 2025

Eager mode rowwise with vanilla TP throws a different error:

Operator aten.amax.default does not have a sharding strategy registered.

Note this does NOT occur for tensorwise + eager + vanilla TP, only rowwise.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 21, 2025

Eager mode rowwise with vanilla TP throws a different error:

Operator aten.amax.default does not have a sharding strategy registered.

Looking for where sharding strategies are registered, I found the Dtensor op dispatcher and the util function for registering op strategies. I think we need to register aten.amax.default.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 21, 2025

Eager mode rowwise with vanilla TP throws a different error:
Operator aten.amax.default does not have a sharding strategy registered.

Looking for where sharding strategies are registered, I found the Dtensor op dispatcher and the util function for registering op strategies. I think we need to register aten.amax.default.

Checked with tianyu-l@ and he agreed this does seem to be the case. He suggested I submit a PR for this, using this WIP PR as a starting point: pytorch/pytorch#130887.

To be clear, this is to fix eager mode + float8 rowwise + vanilla TP.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 21, 2025

I reinstalled pytorch using the nightly build and am now getting a different error for eager mode + float8 rowwise + vanilla TP.

RuntimeError: Only bf16 high precision output types are supported for row-wise scaling.

Added some logging and found the output dtype in question is fp32.

I also logged the shapes of A, B tensors and their scales:

[rank0]:A_shape: torch.Size([256, 4096]), B_shape: torch.Size([4096, 2048]), orig dtype: torch.float32
[rank0]:A_scale: torch.Size([256, 1]), B_scale: torch.Size([1, 2048]), orig dtype: torch.float32
[rank0]:$$$$$$$$$$$ output dtype: torch.float32

The good news is the scales are the right shape:

  • The A tensor should be sharded row-wise for TP (256,4096) -> scales for each row, col-vector of shape (256,1)
  • The B tensor should be sharded col-wise for TP (4096,2048) -> scales for each col, row-vector of shape (1,2048)

The bad news is I'm not sure why these tensors are suddenly float32 and not bf16..

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 21, 2025

I reinstalled pytorch using the nightly build and am now getting a different error for eager mode + float8 rowwise + vanilla TP.

RuntimeError: Only bf16 high precision output types are supported for row-wise scaling.

Added some logging and found the output dtype in question is fp32.

I also logged the shapes of A, B tensors and their scales:

[rank0]:A_shape: torch.Size([256, 4096]), B_shape: torch.Size([4096, 2048]), orig dtype: torch.float32
[rank0]:A_scale: torch.Size([256, 1]), B_scale: torch.Size([1, 2048]), orig dtype: torch.float32
[rank0]:$$$$$$$$$$$ output dtype: torch.float32

The good news is the scales are the right shape:

  • The A tensor should be sharded row-wise for TP (256,4096) -> scales for each row, col-vector of shape (256,1)
  • The B tensor should be sharded col-wise for TP (4096,2048) -> scales for each col, row-vector of shape (1,2048)

The bad news is I'm not sure why these tensors are suddenly float32 and not bf16..

I managed to get the tensors into bf16 again by using HSDP2, which automatically casts to bf16 in torchtitan.

In fact, by combining pytorch nightly build + HSDP + vanilla TP, it actually works for both eager and compile ("no sharding strategy registered" error is gone).

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 21, 2025

Update: I tried building pytorch from source and using torch.distributed.breakpoint() to understand the execution flow of async TP and found the following:

  • During tracing we first hit _ScaledMatmul.from_match
    • A_node: _get_tensor(match[0].args[0]) will evalute to a fake tensor of shape (1,8192,2048)
    • B_node: _get_tensor(mm_node.args[1]) will evaluate to fake tensor of shape (2048, 4096)
    • A_scale: _get_tensor(mm_node.args[2]) evaluates to a fake tensor of shape (8192,1). So at this point, the graph trace holds a scale that is already wrong.

Since the scale is already wrong here, the rest of the execution matters less, but I mapped out the following for my own understanding:

  • Next we hit restride_A_for_fused_matmul_reduce_scatter first.
    • The fake tensor t has shape (1, 8192, 2048) with dtype torch.float8_e4m3fn, and the scatter_dim is 1.
    • perm holds [0,1,2] to start.
    • after perm.insert(0, perm.pop(scatter_dim)) it holds [1,0,2]
    • it then restrides the tensor t such such that t.permute(perm) is contiguous (but undoes the permutation, so it's back to the original shape of (1,8192,2048) (see here)
  • Next, we hit the same function again, but with t shape (1, 8192, 7168). Same steps as above happen.
  • Next, we hit _fused_scaled_matmul_reduce_scatter_fallback(code) where the A fake tensor is shape (1, 8192, 2048) and A_scale is (8192,1).

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 21, 2025

I traced further back to the beginning of the micropipeline TP pass here and confirmed that the reduce_scatters parsed from the raw fx.Graph have the incorrect scale.

Viewing the part of the graph I know will be used as scale from the execution tracing in #864 (comment):

_get_tensor(reduce_scatters[1].input_node.args[0].args[2])
[rank0]:(Pdb) [rank0]:FakeTensor(..., device='cuda:0', size=(8192, 1))

We can see the raw reduce_scatter graph parsed from the full fx.Graph contains an incorrect scale tensor (should be shape [1,8192,1]). This conflicts with the torchao logging analyzed in this comment which shows the scale has the correct shape of (1,8192,1).

To summarize, at this point I've verified:

  1. pytorch nightly build + float8 rowwise + HSDP2 + vanilla TP works for both eager and compile
  2. pytorch nightly build + float8 rowwise + HSDP2 + async TP + compile fails, because the reduce_scatters parsed from the fx.Graph here contain the incorrect scale for tensor A.

My conclusion is there is a problem in the graph tracing code for async TP, but am not sure what specifically yet.

cc @yifuwang who worked on async TP - can you take a look at this comment and let me know if you agree with the root cause identified here? And who on the inductor side might be a good point of contact to take a look at this?

@danielvegamyhre
Copy link
Contributor

Adding some more logging in torchao for additional data:

  • Logged tensor_to_scale tensor shape and scale shape, to log the raw scale shape for the given tensor.
  • Logged the tensor and scale shape at the beginning and end of preprocess_addmm, which is called when an aten.mm.default op is performed on a Float8Tensor and dispatched here. preprocess_addmm gets the scales from the tensors.
[rank0]:$$$ tensor_to_scale -  hp_tensor.shape= torch.Size([1, 8192, 4096]) amax.shape= torch.Size([1, 8192, 1]) scale.shape= torch.Size([1, 8192, 1])
[rank0]:$$$ tensor_to_scale -  hp_tensor.shape= torch.Size([4096, 4096]) amax.shape= torch.Size([1, 4096]) scale.shape= torch.Size([1, 4096])
[rank0]:$$$ START preprocess_addmm: a.shape: torch.Size([8192, 2048]), a_scale.shape: torch.Size([8192, 1])
[rank0]:$$$ END preprocess_addmm: a.shape: torch.Size([8192, 2048]), a_scale.shape: torch.Size([8192, 1])

We can observe that the raw scales computed have the correct shape, but when that tensor is later used in an aten.matmul op by pytorch, the scale has somehow changed shape.

@danielvegamyhre
Copy link
Contributor

Added logging in the float8 view op, since as noted in my previous comment, the originally computed scale has the correct shape, but by the time the matmul is performed, the scale shape has changed, so we need to find out where in between that happens.

Previously I noted there was a view op changing the scale shape from (1,8192,1) to (8192,1) here.

Added logging to float8 view op and we can see the problematic op come through:

view called: t.shape: torch.Size([1, 8192, 4096]), t._scale.shape: torch.Size([1, 8192, 1]), new_shape: [-1, 4096]

Since the data is being reshaped from [1, 8192, 4096] to [8192,4096], the scale shape is also adjusted from [1,8192,1] to [8192,1] here.

Now the question is, why is pytorch performing a view to cut off leading dim of the data shape?

@danielvegamyhre
Copy link
Contributor

Set a breakpoint at the start of post_grad_passes before any post grad passes have run, and have confirmed the scale is wrong already:

rs = find_reduce_scatter_patterns(gm.graph)
rs
[rank0]:(Pdb) [rank0]:(Pdb) [rank0]:[_ReduceScatterMatch(match=Match(..., [], {'input': view_35, 'scatter_dim': 1, 'reduce_op': 'sum', 'group_name': '9'}), input_node=view_35, rs_node=reduce_scatter_tensor, res_node=wait_tensor_4, reduce_op='sum', scatter_dim=1, group_name='9'), _ReduceScatterMatch(match=Match(..., [], {'input': view_55, 'scatter_dim': 1, 'reduce_op': 'sum', 'group_name': '9'}), input_node=view_55, rs_node=reduce_scatter_tensor_1, res_node=wait_tensor_9, reduce_op='sum', scatter_dim=1, group_name='9')]

len(rs)
[rank0]:(Pdb) [rank0]:2

rs[1].input_node.args[0].args[2].meta["val"]
[rank0]:(Pdb) [rank0]:FakeTensor(..., device='cuda:0', size=(8192, 1))

This would seem to eliminate the micropipeline/async TP pass as well as the other post grad passes. However, if that's the case i'd expect the issue to appear for vanilla TP too, and it doesn't..

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 22, 2025

I set a breakpoint in the float8 view() op implementation, and when the problematic view op described here was intercepted, I examined the stack trace and found the source of the call, which is this line.

As shown here the reshape of the tensor causes the scale to be reshaped. However,at this point in the async TP code when the error is thrown, the tensor is back in it's original shape, but the scale is still in reshaped form, causing the mismatch.

@gnadathur
Copy link
Contributor

cc: @vkuzo

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 24, 2025

I set a breakpoint in the float8 view() op implementation, and when the problematic view op described here was intercepted, I examined the stack trace and found the source of the call, which is this line.

As shown here the reshape of the tensor causes the scale to be reshaped. However,at this point in the async TP code when the error is thrown, the tensor is back in it's original shape, but the scale is still in reshaped form, causing the mismatch.

This is occurring downstream of the attention layer Q/K/V projections here. These projections are using colwise_parallel here.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 24, 2025

I think maybe the problem isn't that the scale is the wrong shape, but rather async TP is using the wrong node with the A tensor, a node before the reshape here.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 24, 2025

I visualized the fx graph before any post grad passes have been applied, and confirmed my hunch in my prior comment was correct: the graph is actually correct, but the graph manipulation in async TP is referencing the wrong "A tensor" node somehow (specifically, the "A tensor" prior to the reshape here).

The TL;DR is for a reshape -> mm -> reshape pattern which is done in torchao here, match[0].args[0] holds the original tensor shape from before the first reshape, so it does not match the scale shape in mm_node.args[1].

To be more specific, the fuse_matmul_reduce_scatter function should be referencing the nodes in the red boxes below ("A tensor" operand after it has been reshaped for compatibility with torch.mm, and it's corresponding scale). However, in practice, it is incorrectly referencing the green box for the "A tensor", which still has the non-reshaped shape, that doesn't match the scale.

Image

To fix this, I tried changing this line to reference mm_node.args[0], after confirming this held the correct shape of (8192,2048) (as opposed to the incorrect shape of (1,8192,2048) that match[0].args[0] held). This resulted in a new error, though:

  torch._inductor.exc.InductorError: RuntimeError: Attempting to broadcast a dimension of length 2048 at -1! Mismatching argument at index 1 had torch.Size([8192, 2048]); but expected shape should be broadcastable to [1, 4096, 4096]

@danielvegamyhre
Copy link
Contributor

Side note for later investigation: this seems to restride the "A tensor" to be column major, which does not seem correct/ideal for fp8 gemms using _scaled_mm, we should use a row-major tensor since the cublas api requires this, so we avoid unnecessary conversions: /~https://github.com/pytorch/pytorch/blob/55bf3ff3a5bd18ab1805d0d52f7a723d097294d4/torch/_inductor/fx_passes/micro_pipeline_tp.py#L567

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 25, 2025

So problem 1 (the original error) is fixed by changing this line to param to mm_node.args[0]. I'm now investigating the new error, which occurs later in the graph:

 torch._inductor.exc.InductorError: RuntimeError: Attempting to broadcast a dimension of length 2048 at -1! Mismatching argument at index 1 had torch.Size([8192, 2048]); but expected shape should be broadcastable to [1, 4096, 4096]

I visualized the graph after the async TP post grad pass has been applied, and located the area of the graph where I believe this error is occurring.

The fused_scaled_matmul_reduce_scatter_default (red box) is doing a _scaled_mm with args:

  • A_tensor: (8192,2048)
  • B_tensor: (2048,4096)
  • A_scale: (8192,1)
  • B_scale: (4096,1)

The output of this mm will be shape (8192,4096), but it's being reduce scattered, which I think will result in it sharding along the final dim into 2 shards of shape (8192,2048).

Now the output of this is used by the add node (green box) which is adding:

  • (8192,2048) from above
  • (1,4096,4096) from a wait_tensor way higher in the graph (second screenshot).

This results in the broadcast error.

Image

Image

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 25, 2025

I'm now trying to understand why we are adding the result of the fused_scaled_matmul_reduce_scatter (shape (1,8192,4096) to tensor shape (1,4096,4096), which is incompatible.

TL;DR is it seems the input to the attention layer is being sharded across the sequence dim, but the output of the attention layer is being reduce_scattered across the model dim, so there is a shape mismatch in the residual layer.

Details:

I assume the addition is coming from a residual layer in the mode like this one.

Through some additional testing and analysis, I confirmed the tensor of shape (1,4096,4096) is the input to the transformer block, which has a shape (batch, seq_len // 2, model_dim) = (1, 8192//2, 4096) = (1,4096,4096).

The output of the attention layer is (batch, seq_len, model_dim // 2) = (1,8192,4096//2) = (1,8192,2048). Therefore these tensors cannot be added.

It looks to me like the input is being sharded across a different dimension than the attention outputs, thus the shape incompatibility.

Specifically, I'm wondering if these input sharding specs in torchtitan here are the problem.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 25, 2025

I think I see the problem causing the broadcast error, the pattern is pulling the scatter_dim from the original tensor with 3 dims, where scatter_dim=1 is the seq dim.

However, after the reshape in torchao here, reshaping the tensor have 2 dims, the scatter dim of 1 is now outdated and incorrectly applied along the hidden dim - which does not match the sharding spec of the tensor it is added with in the residual here, which is sharded along the seq dim.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 26, 2025

I think I see the problem causing the broadcast error, the pattern is pulling the scatter_dim from the original tensor with 3 dims, where scatter_dim=1 is the seq dim.

However, after the reshape in torchao here, reshaping the tensor have 2 dims, the scatter dim of 1 is now outdated and incorrectly applied along the hidden dim - which does not match the sharding spec of the tensor it is added with in the residual here, which is sharded along the seq dim.

After discovering how my original solution made the scatter_dim outdated here, I realized finding a way to keep the scatter dim in sync with the reshapes would be complicated/not feasible.

So I went back to the drawing board and found a different solution: rather than changing what node we use for the A_tensor so that it's dims match the scale dims, we can use an ancestor node of the scale tensor from before it was reshaped, so it's dims match the A_tensor dims.

I confirmed via manual testing with torch titan that this approach works and float8 rowwise + async TP is no longer crashing.

The change adds special handling for reshape -> mm -> reshape pattern, as is done in torchao here and causes the scale to change shape as well:

pytorch/pytorch#147794

@danielvegamyhre
Copy link
Contributor

After trying multiple different solutions, I've finally managed to find one which fixes the crash AND doesn't affect numerics, while passing pytorch test cases: pytorch/pytorch#148001

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Feb 28, 2025
…hape pattern" in async TP with rowwise scales (#148001)

Fixes pytorch/torchtitan#864

## Summary
While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to.

My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used.

## TL;DR of root cause
- When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned.
- In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op.

## Example
- Concrete example:
    - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1).
    - During post grad pass in async TP:
        - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122))
        - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)).

## Solution

**Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics.

- Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape.
    - reshape is just a view, so there should be no impact on performance
```
Before:
    reshape (a,bc,) to (a*b,c) -> reciprocal

After:
    reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
```

- Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor`

## Test plan
- Added unit test which exercises this new path
- Manually tested with torchtitan with float8 rowwise + async TP

Pull Request resolved: #148001
Approved by: /~https://github.com/yifuwang
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Mar 1, 2025
…hape pattern" in async TP with rowwise scales (#148001)

Fixes pytorch/torchtitan#864

## Summary
While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to.

My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used.

## TL;DR of root cause
- When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned.
- In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op.

## Example
- Concrete example:
    - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1).
    - During post grad pass in async TP:
        - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122))
        - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)).

## Solution

**Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics.

- Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape.
    - reshape is just a view, so there should be no impact on performance
```
Before:
    reshape (a,bc,) to (a*b,c) -> reciprocal

After:
    reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
```

- Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor`

## Test plan
- Added unit test which exercises this new path
- Manually tested with torchtitan with float8 rowwise + async TP

Pull Request resolved: #148001
Approved by: /~https://github.com/yifuwang
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Mar 2, 2025
…hape pattern" in async TP with rowwise scales (#148001)

Fixes pytorch/torchtitan#864

## Summary
While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to.

My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used.

## TL;DR of root cause
- When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned.
- In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op.

## Example
- Concrete example:
    - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](/~https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1).
    - During post grad pass in async TP:
        - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](/~https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122))
        - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)).

## Solution

**Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics.

- Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape.
    - reshape is just a view, so there should be no impact on performance
```
Before:
    reshape (a,bc,) to (a*b,c) -> reciprocal

After:
    reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
```

- Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor`

## Test plan
- Added unit test which exercises this new path
- Manually tested with torchtitan with float8 rowwise + async TP

Pull Request resolved: #148001
Approved by: /~https://github.com/yifuwang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working module: float8
Projects
None yet
4 participants