Skip to content

Commit

Permalink
Merge pull request #573 from amoffat/559-execution-context
Browse files Browse the repository at this point in the history
Let Command know about baked call args when wrapping the module
  • Loading branch information
ecederstrand authored Jun 6, 2021
2 parents 59c1bda + 06eaa91 commit 309b5cc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Changelog
## 1.14.3
* bugfix where `Command` was not aware of default call args when wrapping the module [#559](/~https://github.com/amoffat/sh/pull/573)

## 1.14.1 - 10/24/20
* bugfix where setting `_ok_code` to not include 0, but 0 was the exit code [#545](/~https://github.com/amoffat/sh/pull/545)

Expand Down
13 changes: 12 additions & 1 deletion sh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3512,7 +3512,18 @@ def __init__(self, self_module, baked_args=None):
# if we set this to None. and 3.3 needs a value for __path__
self.__path__ = []
self.__self_module = self_module
self.__env = Environment(globals(), baked_args=baked_args)

# Copy the Command class and add any baked call kwargs to it
cls_attrs = Command.__dict__.copy()
if baked_args:
call_args, _ = Command._extract_call_args(baked_args)
cls_attrs['_call_args'] = cls_attrs['_call_args'].copy()
cls_attrs['_call_args'].update(call_args)
command_cls = type(Command.__name__, Command.__bases__, cls_attrs)
globs = globals().copy()
globs[Command.__name__] = command_cls

self.__env = Environment(globs, baked_args=baked_args)

def __getattr__(self, name):
return self.__env[name]
Expand Down
7 changes: 7 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3138,6 +3138,13 @@ def test_reimport_no_interfere(self):
_sh.echo("-n", "TEST")
self.assertEqual("TEST", out.getvalue())

def test_command_with_baked_call_args(self):
# Test that sh.Command() knows about baked call args
import sh
_sh = sh(_ok_code=1)
self.assertEqual(sh.Command._call_args['ok_code'], 0)
self.assertEqual(_sh.Command._call_args['ok_code'], 1)

def test_importer_detects_module_name(self):
import sh
_sh = sh()
Expand Down

0 comments on commit 309b5cc

Please sign in to comment.