Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Fix detect non-divisible iteration form like (x % 255) // 16 #15665

Merged

Conversation

wrongtest-intellif
Copy link
Contributor

Previously, the optimization from floordiv(floormod(..)) to floormod(floordiv(..)) do not check the divisibility. Which may create wrong iteration form. The change add a default handling and let it just fail in followup validity checks.

For example, (x % 255) // 16 should not get transformed to (x // 16) % 15. The change make it as (x % 255) // 16 % ceildiv(255, 16).

It currently could not pass the surjective check though the mapping actually is surjective. We could improve it in the future.

@wrongtest-intellif
Copy link
Contributor Author

Hi, could you kindly take a look at the change? Thank you! cc @vinx13 @Lunderberg

@Hzfengsy
Copy link
Member

Thanks for the fix!

@Hzfengsy Hzfengsy merged commit 44cc2b9 into apache:main Sep 18, 2023
@vinx13
Copy link
Member

vinx13 commented Oct 4, 2023

do we need another IterMark? I think in this case we can rewrite to (x // 16) % ceildiv(255, 16)

@Lunderberg
Copy link
Contributor

Lunderberg commented Oct 4, 2023

@vinx13 I don't think that simplification would hold. If I use x==-16, (x % 255) // 16 evaluates to 14, while (x // 16) % ceildiv(255, 16) evaluates to 15.

That said, if we know that x is on the range 0 <= x < 255, such that x % 255 == x, then the simplification would be valid. Since this simplification is in the CanonicalSimplifier where the range is frequently known, that might be an easy check to add.

@junrushao
Copy link
Member

This seems to cause a break in MLC LLM q3f16_1, which is a pretty commonly used format. @vinx13 mind sharing what a proper fix should look like?

@vinx13
Copy link
Member

vinx13 commented Oct 5, 2023

The simplification @Lunderberg mentioned should work. Actually we can get the extent from the source IterMark of the split.
Say we want to compute floordiv(lhs, rhs) where lhs == floormod(floordiv(x, c1), ext), x == IterMark(extent=ext_x), and ext % rhs != 0.

if we can prove ext_x // c1 \in [0, ext), it can be simplified as floordiv(lhs, rhs) == floordiv(x, c1 * rhs) == IterSplit(x, lower_factor=c1*rhs, extent=ceildiv(ext, rhs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants