Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Initializer protocol #26765

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

fix Initializer protocol #26765

wants to merge 1 commit into from

Conversation

Qwlouse
Copy link

@Qwlouse Qwlouse commented Feb 26, 2025

The initializer protocol is currently defined in terms of a static __call__ method.
However that definition does not interact correctly with callable classes. For example this raises a mypy error (but no pytype error):

import jax

class MyInit:
  def __call__(self, key, shape, dtype=None):
    pass

def foo(init: jax.nn.Initializer):
  return

foo(MyInit())

After this PR the above code passes the mypy check, and passing plain functions still works as before:

import jax

def my_init(key, shape, dtype=None):
  return

def foo(init: jax.nn.Initializer):
  return

foo(my_init)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant