From 331bc1be4634cf95fdb33021d4355c35723b73fe Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Tue, 30 May 2023 16:59:24 +0200 Subject: [PATCH] [7.3.x] 11028 - Fix warlus operator behavior when called by a function --- changelog/11028.bugfix.rst | 1 + src/_pytest/assertion/rewrite.py | 25 +++++++-- testing/test_assertrewrite.py | 90 ++++++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 changelog/11028.bugfix.rst diff --git a/changelog/11028.bugfix.rst b/changelog/11028.bugfix.rst new file mode 100644 index 0000000000..9efc04ba69 --- /dev/null +++ b/changelog/11028.bugfix.rst @@ -0,0 +1 @@ +Fixed bug in assertion rewriting where a variable assigned with the walrus operator could not be used later in a function call. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 8b18234705..ab8169da2e 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -996,7 +996,9 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: ] ): pytest_temp = self.variable() - self.variables_overwrite[v.left.target.id] = pytest_temp + self.variables_overwrite[ + v.left.target.id + ] = v.left # type:ignore[assignment] v.left.target.id = pytest_temp self.push_format_context() res, expl = self.visit(v) @@ -1037,10 +1039,19 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: new_args = [] new_kwargs = [] for arg in call.args: + if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite: + arg = self.variables_overwrite[arg.id] # type:ignore[assignment] res, expl = self.visit(arg) arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: + if ( + isinstance(keyword.value, ast.Name) + and keyword.value.id in self.variables_overwrite + ): + keyword.value = self.variables_overwrite[ + keyword.value.id + ] # type:ignore[assignment] res, expl = self.visit(keyword.value) new_kwargs.append(ast.keyword(keyword.arg, res)) if keyword.arg: @@ -1075,7 +1086,13 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() # We first check if we have overwritten a variable in the previous assert if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: - comp.left.id = self.variables_overwrite[comp.left.id] + comp.left = self.variables_overwrite[ + comp.left.id + ] # type:ignore[assignment] + if isinstance(comp.left, namedExpr): + self.variables_overwrite[ + comp.left.target.id + ] = comp.left # type:ignore[assignment] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, (ast.Compare, ast.BoolOp)): left_expl = f"({left_expl})" @@ -1093,7 +1110,9 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: and next_operand.target.id == left_res.id ): next_operand.target.id = self.variable() - self.variables_overwrite[left_res.id] = next_operand.target.id + self.variables_overwrite[ + left_res.id + ] = next_operand # type:ignore[assignment] next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, (ast.Compare, ast.BoolOp)): next_expl = f"({next_expl})" diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 8d94414030..245241af2b 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1436,6 +1436,96 @@ def test_walrus_operator_not_override_value(): assert result.ret == 0 +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="walrus operator not available in py<38" +) +class TestIssue11028: + def test_assertion_walrus_operator_in_operand(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def test_in_string(): + assert (obj := "foo") in obj + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_in_operand_json_dumps( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + import json + + def test_json_encoder(): + assert (obj := "foo") in json.dumps(obj) + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_equals_operand_function( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def f(a): + return a + + def test_call_other_function_arg(): + assert (obj := "foo") == f(obj) + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_equals_operand_function_keyword_arg( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def f(a='test'): + return a + + def test_call_other_function_k_arg(): + assert (obj := "foo") == f(a=obj) + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_equals_operand_function_arg_as_function( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def f(a='test'): + return a + + def test_function_of_function(): + assert (obj := "foo") == f(f(obj)) + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_gt_operand_function( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def add_one(a): + return a + 1 + + def test_gt(): + assert (obj := 4) > add_one(obj) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"]) + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" )