-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[typing] Use ParamSpec in JIT annotation #14688
base: main
Are you sure you want to change the base?
Conversation
Not sure what's going on with the type checking, but perhaps it's not running MyPy 1.0.1? Also, I'm still working on fixing this for applying |
ba45b45
to
353bba1
Compare
setup.py
Outdated
@@ -67,6 +67,7 @@ def generate_proto(source): | |||
'numpy>=1.20', | |||
'opt_einsum', | |||
'scipy>=1.5', | |||
'typing_extensions>=4.5.0', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for adding this as a compulsory dependency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could make it a dev-dependency. Would that be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it a dev-dependency in a subsequence change (within this PR) so that you can assess whether you like that better.
Normally, typing-extensions is seen as a light dependency since it only contains typing annotations, so it won't realistically break any runtime code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to add as a dependency, but I'm also curious why it's necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c877b16
to
ee60835
Compare
After testing this a lot with my code, I think MyPy is still not ready to check code with complex If we check something like this in, it may trip up MyPy quite a bit. On the other hand, it has a the amazing effect of exposing plenty Jax type annotations that were previously hidden. Plenty of functions in Also, there are some limitations in Python typing wrt method decorators. I've posted about this on python/typing. If people are using What I will do is add typed decorators to my tjax library. Please let me know if you'd like to get this in to Jax sooner or if we should wait on it. |
Thanks for looking at this! ParamSpec seems like it could be interesting; that said it still looks to be relatively unstable and probably the remarks here still apply. |
Hasn't this already been discussed here: #10311? |
@JesseFarebro My mistake, I didn't find that when I searched! MyPy has progressed since then, but maybe still not enough. @jakevdp Yes, fair enough. |
Great news: Thanks to the recently merged python/mypy#15837, it appears that the main MyPy error with this pull request may have been solved: python/mypy#12169. Also, many of the MyPy errors that may have affected its usage may have been solved: python/mypy#11846, python/mypy#12986, python/mypy#14802. |
We've updated mypy to 1.4.1 in the meantime, and #17147 bumps it to 1.5.0 – can you sync your PR to the current main branch? |
aaec678
to
0ffe6c8
Compare
@jakevdp Done. FYI the MyPy pull I linked is merged, but it's not in any released MyPy yet. |
ed78ba5
to
0e7d881
Compare
As soon as the new mypy version is mirrored at /~https://github.com/pre-commit/mirrors-mypy, we can bump the version in the pre-commit configuration here: /~https://github.com/google/jax/blob/5ed692809da29ff4fd9a89faee9af6d4b41c7056/.pre-commit-config.yaml#L30 It looks like the mirror update will happen automatically in about 12.5 hours: /~https://github.com/pre-commit/mirrors-mypy/blob/08cbc46b6e135adec84911b20e98e5bc52032152/.github/workflows/main.yml#L6 |
#18066 updates mypy to v1.6.0 |
ee38fd6
to
534cae8
Compare
Okay, I've rebased this to take advantage of the new MyPy version |
jax/_src/numpy/lax_numpy.py
Outdated
@overload | ||
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: | ||
... | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split out these changes in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, okay, I will do that when we get close to having this checked in. This PR will expose a lot of typing errors as code that's jitted becomes typed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Factored out as requested: #18395
That makes sense. I'll look at it as soon as I have more time. |
Filed a MyPy bug python/mypy#16404 |
@jakevdp I think we should wait one more release of MyPy, which should fix the above errors. After that, although this change will not help MyPy users (the inferred type of a jit-decorated function will just be |
9816dad
to
9f838ab
Compare
Sounds good |
9f838ab
to
91b4f1a
Compare
621d9c1
to
bdc0d85
Compare
@jakevdp I've rebased this now. Is there a way to get the tests to run? python/mypy#1484 is still unsolved, so there may only be modest gains in MyPy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! I'm going to approve and then pull it in to run internal pytype tests
Pytype is failing with this message:
Seems this is a known issue (google/pytype#1471). Is there any way to do this kind of improvement without using covariant TypeVars? |
Thanks for taking the time!
The return type does need to be covariant (MyPy gives an error without it). One option is to do this in two steps. Just keep the parameter specification annotation and remove the return type annotation. It's funny because in Python 3.12, we won't need these markers at all as the type checker can infer them, although I think we'd have to use the new PEP 695 syntax. Should I remove the annotations on the return type? |
I'm honestly not sure what the best fix is here. |
I think maybe the discussion in https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html#avoid-unstable-typing-mechanisms is still applicable. |
@jakevdp No worries. I thought we were just waiting for MyPy. I didn't realize that we were waiting for Pytype as well! |
c065282
to
fcfd29c
Compare
fcfd29c
to
b743a69
Compare
b743a69
to
617850d
Compare
This pull request would be a huge improvement for Jax users who use type checkers like MyPy or Pyright, which now support
ParamSpec
.Consider:
Previous to this PR, users who decorate a function with
jit
lose all annotations of the method:After this PR, we get:
This unblocks a lot of type checking.
Part of #12049 cc: @jakevdp