Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache applied with torch.compile #453

Merged
merged 1 commit into from
Feb 25, 2025
Merged

Conversation

Binary2355
Copy link
Contributor

@Binary2355 Binary2355 commented Feb 24, 2025

This PR:

  • Add the Performance of Teacache and FBCache on 1xH20
  • Add the Performance of Teacache and FBCache on 1xH20 and 4xH20 with torch.compile
  • refracture the code to make torch.compile as efficient as possible
  • Fix the bug, postpone the logic to update modulated_inputs of FBCache (which is equivalent to the paraAttn)

The performance table is listed below:

Method Latency (s)
without torch.compile with torch.compile
4xH20 1xH20 4xH20 1xH20
Baseline 2.02s 6.10s 1.81s 5.02s
use_teacache 1.60s 4.67s 1.50s 3.92s
use_fbcache 0.93s 2.51s 0.85s 2.09s

@Binary2355 Binary2355 force-pushed the main branch 2 times, most recently from 526c180 to a74e42a Compare February 24, 2025 15:15
@Binary2355 Binary2355 changed the title add 1xH20 performance add 1xH20 performance and 1xH20 performance, 4xH20 performance with torch.compile Feb 24, 2025
Comment on lines 74 to 75
if engine_config.runtime_config.use_torch_compile:
pipe.transformer = torch.compile(apply_cache_on_transformer(pipe.original_transformer, **cache_args))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你现在遇到的问题是在_convert_transformer_backbone 先 compile 的 transformers formward。然后你apply_cache_on_transformer又改了逻辑,所以 compile 失效。

apply_cache_on_transformer 放在 _convert_transformer_backbone 函数中吧。
别在 example 中操作了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我在_convert_transformer_backbone里进行torch compile之前保存了original_transformer, 所以apply_cache_on_transformer之后的torch compile是针对原始的transformer做的操作

@feifeibear feifeibear changed the title add 1xH20 performance and 1xH20 performance, 4xH20 performance with torch.compile Cache applied with torch.compile Feb 25, 2025
@Binary2355 Binary2355 force-pushed the main branch 6 times, most recently from 5bf79c9 to 2c75c4e Compare February 25, 2025 05:25
- add 4xH20 performance and 1xH20 performance with torch.compile
@feifeibear feifeibear merged commit 248abec into xdit-project:main Feb 25, 2025
1 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants