Skip to content

Commit

Permalink
[unittest.mock] restore Mock._spec_asyncs and match upstream
Browse files Browse the repository at this point in the history
Summary:
D42039568 backported an early version of python/cpython#100252,
and there were some changes made before that PR was merged upstream.
Backport those changes so we match upstream behavior. Most importantly,
this restores `_spec_asyncs`, which although a private API does seem to be used.

Test Plan: Test suite

Reviewers: itamaro, #cinder

Reviewed By: itamaro

Subscribers: mpage, jackyzhang

Differential Revision: https://phabricator.intern.facebook.com/D42416767

Tasks: T141890380

Tags: commitClose, cinder-310-upstream-complete, publish_when_ready
  • Loading branch information
Carl Meyer authored and Service User committed Jan 10, 2023
1 parent 80b52b4 commit a7e7a21
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions Lib/unittest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False,
_eat_self=False):
_spec_class = None
_spec_signature = None
_spec_asyncs = []

if spec is not None and not _is_list(spec):
if isinstance(spec, type):
Expand All @@ -503,13 +504,20 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False,
_spec_as_instance, _eat_self)
_spec_signature = res and res[1]

spec = dir(spec)
spec_list = dir(spec)

for attr in spec_list:
if iscoroutinefunction(getattr(spec, attr, None)):
_spec_asyncs.append(attr)

spec = spec_list

__dict__ = self.__dict__
__dict__['_spec_class'] = _spec_class
__dict__['_spec_set'] = spec_set
__dict__['_spec_signature'] = _spec_signature
__dict__['_mock_methods'] = spec
__dict__['_spec_asyncs'] = _spec_asyncs

def __get_return_value(self):
ret = self._mock_return_value
Expand Down Expand Up @@ -998,8 +1006,7 @@ def _get_child_mock(self, /, **kw):
For non-callable mocks the callable variant will be used (rather than
any custom subclass)."""
_new_name = kw.get("_new_name")
_spec_val = getattr(self.__dict__["_spec_class"], _new_name, None)
if _spec_val is not None and asyncio.iscoroutinefunction(_spec_val):
if _new_name in self.__dict__['_spec_asyncs']:
return AsyncMock(**kw)

if self._mock_sealed:
Expand Down

0 comments on commit a7e7a21

Please sign in to comment.