Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dtensor][fix] fix _scaled_dot_product_flash_attention sharding #148125

Closed
wants to merge 2 commits into from

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Feb 27, 2025

Stack from ghstack (oldest at bottom):

Summary

#146372 changed the op signature of _scaled_dot_product_flash_attention and as a consequence DTensor needs to change its sharding defined at

def scaled_dot_product_flash_attention_strategy(

Test

pytest test/distributed/tensor/test_attention.py

Follow-up

It's still unclear why the CP unit tests were not run over the original PR which is BC-breaking.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l

Copy link

pytorch-bot bot commented Feb 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148125

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5a8391c with merge base 2978771 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Feb 27, 2025
XilunWu added a commit that referenced this pull request Feb 27, 2025
ghstack-source-id: 47c426b771a7515a3057e4b1d100dac640933265
Pull Request resolved: #148125
@XilunWu XilunWu marked this pull request as draft February 28, 2025 00:00
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Feb 28, 2025
ghstack-source-id: 408ec85127af2c09bde0248956fb6bc2456e858b
Pull Request resolved: #148125
@XilunWu XilunWu added better-engineering Relatively self-contained tasks for better engineering contributors module: dtensor distributed tensor tag module: context parallel PyTorch Context Parallel labels Feb 28, 2025
@XilunWu XilunWu changed the title [dtensor] fix scaled dot product flash attention sharding [dtensor][fix] fix _scaled_dot_product_flash_attention sharding Feb 28, 2025
@XilunWu XilunWu marked this pull request as ready for review February 28, 2025 00:46
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@XilunWu XilunWu added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 28, 2025
@XilunWu
Copy link
Contributor Author

XilunWu commented Feb 28, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

fegin pushed a commit to pytorch/torchtitan that referenced this pull request Mar 3, 2025
…as been fixed (#912)

Stack from [ghstack](/~https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #912

### Summary
This PR undo #898 and
re-enables CP tests in CI as
pytorch/pytorch#148125 fixed the DTensor sdp
flash attention op.

### Test
CI
fegin added a commit to pytorch/torchtitan that referenced this pull request Mar 3, 2025
#921)

…as been fixed (#912)

Stack from [ghstack](/~https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #912

### Summary
This PR undo #898 and
re-enables CP tests in CI as
pytorch/pytorch#148125 fixed the DTensor sdp
flash attention op.

### Test
CI

Co-authored-by: Xilun Wu <12968408+XilunWu@users.noreply.github.com>
@XilunWu XilunWu mentioned this pull request Mar 3, 2025
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…rch#148125)

### Summary
pytorch#146372 changed the op signature of `_scaled_dot_product_flash_attention` and as a consequence DTensor needs to change its sharding defined at /~https://github.com/pytorch/pytorch/blob/40ad5e01dff05c7d64e070fb01683820e678f788/torch/distributed/tensor/_ops/_matrix_ops.py#L232

### Test
`pytest test/distributed/tensor/test_attention.py`

### Follow-up
It's still unclear why the CP unit tests were not run over the original PR which is BC-breaking.

Pull Request resolved: pytorch#148125
Approved by: /~https://github.com/tianyu-l, /~https://github.com/fegin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better-engineering Relatively self-contained tasks for better engineering contributors ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: context parallel PyTorch Context Parallel module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants