diff --git a/Lib/test/test_capi/test_opt.py b/Lib/test/test_capi/test_opt.py index 4cf9b66170c055..d84702411afe41 100644 --- a/Lib/test/test_capi/test_opt.py +++ b/Lib/test/test_capi/test_opt.py @@ -1,4 +1,5 @@ import contextlib +import itertools import sys import textwrap import unittest @@ -1511,6 +1512,49 @@ def test_jit_error_pops(self): with self.assertRaises(TypeError): {item for item in items} + def test_power_type_depends_on_input_values(self): + template = textwrap.dedent(""" + import _testinternalcapi + + L, R, X, Y = {l}, {r}, {x}, {y} + + def check(actual: complex, expected: complex) -> None: + assert actual == expected, (actual, expected) + assert type(actual) is type(expected), (actual, expected) + + def f(l: complex, r: complex) -> None: + expected_local_local = pow(l, r) + pow(l, r) + expected_const_local = pow(L, r) + pow(L, r) + expected_local_const = pow(l, R) + pow(l, R) + expected_const_const = pow(L, R) + pow(L, R) + for _ in range(_testinternalcapi.TIER2_THRESHOLD): + # Narrow types: + l + l, r + r + # The powers produce results, and the addition is unguarded: + check(l ** r + l ** r, expected_local_local) + check(L ** r + L ** r, expected_const_local) + check(l ** R + l ** R, expected_local_const) + check(L ** R + L ** R, expected_const_const) + + # JIT for one pair of values... + f(L, R) + # ...then run with another: + f(X, Y) + """) + interesting = [ + (1, 1), # int ** int -> int + (1, -1), # int ** int -> float + (1.0, 1), # float ** int -> float + (1, 1.0), # int ** float -> float + (-1, 0.5), # int ** float -> complex + (1.0, 1.0), # float ** float -> float + (-1.0, 0.5), # float ** float -> complex + ] + for (l, r), (x, y) in itertools.product(interesting, repeat=2): + s = template.format(l=l, r=r, x=x, y=y) + with self.subTest(l=l, r=r, x=x, y=y): + script_helper.assert_python_ok("-c", s) + def global_identity(x): return x diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2024-12-11-14-32-22.gh-issue-127809.0W8khe.rst b/Misc/NEWS.d/next/Core_and_Builtins/2024-12-11-14-32-22.gh-issue-127809.0W8khe.rst new file mode 100644 index 00000000000000..19c8cc6e99c8c5 --- /dev/null +++ b/Misc/NEWS.d/next/Core_and_Builtins/2024-12-11-14-32-22.gh-issue-127809.0W8khe.rst @@ -0,0 +1,2 @@ +Fix an issue where the experimental JIT may infer an incorrect result type +for exponentiation (``**`` and ``**=``), leading to bugs or crashes. diff --git a/Python/bytecodes.c b/Python/bytecodes.c index ec1cd00962ac0a..8bab4ea16b629b 100644 --- a/Python/bytecodes.c +++ b/Python/bytecodes.c @@ -530,6 +530,8 @@ dummy_func( pure op(_BINARY_OP_MULTIPLY_INT, (left, right -- res)) { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyLong_CheckExact(left_o)); + assert(PyLong_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o); @@ -543,6 +545,8 @@ dummy_func( pure op(_BINARY_OP_ADD_INT, (left, right -- res)) { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyLong_CheckExact(left_o)); + assert(PyLong_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o); @@ -556,6 +560,8 @@ dummy_func( pure op(_BINARY_OP_SUBTRACT_INT, (left, right -- res)) { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyLong_CheckExact(left_o)); + assert(PyLong_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o); @@ -593,6 +599,8 @@ dummy_func( pure op(_BINARY_OP_MULTIPLY_FLOAT, (left, right -- res)) { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyFloat_CheckExact(left_o)); + assert(PyFloat_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); double dres = @@ -607,6 +615,8 @@ dummy_func( pure op(_BINARY_OP_ADD_FLOAT, (left, right -- res)) { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyFloat_CheckExact(left_o)); + assert(PyFloat_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); double dres = @@ -621,6 +631,8 @@ dummy_func( pure op(_BINARY_OP_SUBTRACT_FLOAT, (left, right -- res)) { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyFloat_CheckExact(left_o)); + assert(PyFloat_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); double dres = @@ -650,6 +662,8 @@ dummy_func( pure op(_BINARY_OP_ADD_UNICODE, (left, right -- res)) { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyUnicode_CheckExact(left_o)); + assert(PyUnicode_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = PyUnicode_Concat(left_o, right_o); @@ -672,6 +686,8 @@ dummy_func( op(_BINARY_OP_INPLACE_ADD_UNICODE, (left, right --)) { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyUnicode_CheckExact(left_o)); + assert(PyUnicode_CheckExact(right_o)); int next_oparg; #if TIER_ONE diff --git a/Python/executor_cases.c.h b/Python/executor_cases.c.h index ac2f69b7e98dc3..e40fa88be89172 100644 --- a/Python/executor_cases.c.h +++ b/Python/executor_cases.c.h @@ -638,6 +638,8 @@ left = stack_pointer[-2]; PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyLong_CheckExact(left_o)); + assert(PyLong_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o); PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc); @@ -658,6 +660,8 @@ left = stack_pointer[-2]; PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyLong_CheckExact(left_o)); + assert(PyLong_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o); PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc); @@ -678,6 +682,8 @@ left = stack_pointer[-2]; PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyLong_CheckExact(left_o)); + assert(PyLong_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o); PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc); @@ -738,6 +744,8 @@ left = stack_pointer[-2]; PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyFloat_CheckExact(left_o)); + assert(PyFloat_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); double dres = ((PyFloatObject *)left_o)->ob_fval * @@ -759,6 +767,8 @@ left = stack_pointer[-2]; PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyFloat_CheckExact(left_o)); + assert(PyFloat_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); double dres = ((PyFloatObject *)left_o)->ob_fval + @@ -780,6 +790,8 @@ left = stack_pointer[-2]; PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyFloat_CheckExact(left_o)); + assert(PyFloat_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); double dres = ((PyFloatObject *)left_o)->ob_fval - @@ -819,6 +831,8 @@ left = stack_pointer[-2]; PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyUnicode_CheckExact(left_o)); + assert(PyUnicode_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = PyUnicode_Concat(left_o, right_o); PyStackRef_CLOSE_SPECIALIZED(left, _PyUnicode_ExactDealloc); @@ -838,6 +852,8 @@ left = stack_pointer[-2]; PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyUnicode_CheckExact(left_o)); + assert(PyUnicode_CheckExact(right_o)); int next_oparg; #if TIER_ONE assert(next_instr->op.code == STORE_FAST); diff --git a/Python/generated_cases.c.h b/Python/generated_cases.c.h index eaa8a563464068..7028ba52faae96 100644 --- a/Python/generated_cases.c.h +++ b/Python/generated_cases.c.h @@ -80,6 +80,8 @@ { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyFloat_CheckExact(left_o)); + assert(PyFloat_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); double dres = ((PyFloatObject *)left_o)->ob_fval + @@ -116,6 +118,8 @@ { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyLong_CheckExact(left_o)); + assert(PyLong_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o); PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc); @@ -151,6 +155,8 @@ { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyUnicode_CheckExact(left_o)); + assert(PyUnicode_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = PyUnicode_Concat(left_o, right_o); PyStackRef_CLOSE_SPECIALIZED(left, _PyUnicode_ExactDealloc); @@ -185,6 +191,8 @@ { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyUnicode_CheckExact(left_o)); + assert(PyUnicode_CheckExact(right_o)); int next_oparg; #if TIER_ONE assert(next_instr->op.code == STORE_FAST); @@ -247,6 +255,8 @@ { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyFloat_CheckExact(left_o)); + assert(PyFloat_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); double dres = ((PyFloatObject *)left_o)->ob_fval * @@ -283,6 +293,8 @@ { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyLong_CheckExact(left_o)); + assert(PyLong_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o); PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc); @@ -318,6 +330,8 @@ { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyFloat_CheckExact(left_o)); + assert(PyFloat_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); double dres = ((PyFloatObject *)left_o)->ob_fval - @@ -354,6 +368,8 @@ { PyObject *left_o = PyStackRef_AsPyObjectBorrow(left); PyObject *right_o = PyStackRef_AsPyObjectBorrow(right); + assert(PyLong_CheckExact(left_o)); + assert(PyLong_CheckExact(right_o)); STAT_INC(BINARY_OP, hit); PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o); PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc); diff --git a/Python/optimizer_bytecodes.c b/Python/optimizer_bytecodes.c index a14d119b7a1dec..86394480f76bb8 100644 --- a/Python/optimizer_bytecodes.c +++ b/Python/optimizer_bytecodes.c @@ -167,23 +167,56 @@ dummy_func(void) { } op(_BINARY_OP, (left, right -- res)) { - PyTypeObject *ltype = sym_get_type(left); - PyTypeObject *rtype = sym_get_type(right); - if (ltype != NULL && (ltype == &PyLong_Type || ltype == &PyFloat_Type) && - rtype != NULL && (rtype == &PyLong_Type || rtype == &PyFloat_Type)) - { - if (oparg != NB_TRUE_DIVIDE && oparg != NB_INPLACE_TRUE_DIVIDE && - ltype == &PyLong_Type && rtype == &PyLong_Type) { - /* If both inputs are ints and the op is not division the result is an int */ - res = sym_new_type(ctx, &PyLong_Type); + bool lhs_int = sym_matches_type(left, &PyLong_Type); + bool rhs_int = sym_matches_type(right, &PyLong_Type); + bool lhs_float = sym_matches_type(left, &PyFloat_Type); + bool rhs_float = sym_matches_type(right, &PyFloat_Type); + if (!((lhs_int || lhs_float) && (rhs_int || rhs_float))) { + // There's something other than an int or float involved: + res = sym_new_unknown(ctx); + } + else if (oparg == NB_POWER || oparg == NB_INPLACE_POWER) { + // This one's fun... the *type* of the result depends on the + // *values* being exponentiated. However, exponents with one + // constant part are reasonably common, so it's probably worth + // trying to infer some simple cases: + // - A: 1 ** 1 -> 1 (int ** int -> int) + // - B: 1 ** -1 -> 1.0 (int ** int -> float) + // - C: 1.0 ** 1 -> 1.0 (float ** int -> float) + // - D: 1 ** 1.0 -> 1.0 (int ** float -> float) + // - E: -1 ** 0.5 ~> 1j (int ** float -> complex) + // - F: 1.0 ** 1.0 -> 1.0 (float ** float -> float) + // - G: -1.0 ** 0.5 ~> 1j (float ** float -> complex) + if (rhs_float) { + // Case D, E, F, or G... can't know without the sign of the LHS + // or whether the RHS is whole, which isn't worth the effort: + res = sym_new_unknown(ctx); } - else { - /* For any other op combining ints/floats the result is a float */ + else if (lhs_float) { + // Case C: res = sym_new_type(ctx, &PyFloat_Type); } + else if (!sym_is_const(right)) { + // Case A or B... can't know without the sign of the RHS: + res = sym_new_unknown(ctx); + } + else if (_PyLong_IsNegative((PyLongObject *)sym_get_const(right))) { + // Case B: + res = sym_new_type(ctx, &PyFloat_Type); + } + else { + // Case A: + res = sym_new_type(ctx, &PyLong_Type); + } + } + else if (oparg == NB_TRUE_DIVIDE || oparg == NB_INPLACE_TRUE_DIVIDE) { + res = sym_new_type(ctx, &PyFloat_Type); + } + else if (lhs_int && rhs_int) { + res = sym_new_type(ctx, &PyLong_Type); } else { - res = sym_new_unknown(ctx); + res = sym_new_type(ctx, &PyFloat_Type); } } diff --git a/Python/optimizer_cases.c.h b/Python/optimizer_cases.c.h index be3e06108aec92..c72ae7b6281e80 100644 --- a/Python/optimizer_cases.c.h +++ b/Python/optimizer_cases.c.h @@ -2307,24 +2307,69 @@ _Py_UopsSymbol *res; right = stack_pointer[-1]; left = stack_pointer[-2]; - PyTypeObject *ltype = sym_get_type(left); - PyTypeObject *rtype = sym_get_type(right); - if (ltype != NULL && (ltype == &PyLong_Type || ltype == &PyFloat_Type) && - rtype != NULL && (rtype == &PyLong_Type || rtype == &PyFloat_Type)) - { - if (oparg != NB_TRUE_DIVIDE && oparg != NB_INPLACE_TRUE_DIVIDE && - ltype == &PyLong_Type && rtype == &PyLong_Type) { - /* If both inputs are ints and the op is not division the result is an int */ - res = sym_new_type(ctx, &PyLong_Type); + bool lhs_int = sym_matches_type(left, &PyLong_Type); + bool rhs_int = sym_matches_type(right, &PyLong_Type); + bool lhs_float = sym_matches_type(left, &PyFloat_Type); + bool rhs_float = sym_matches_type(right, &PyFloat_Type); + if (!((lhs_int || lhs_float) && (rhs_int || rhs_float))) { + // There's something other than an int or float involved: + res = sym_new_unknown(ctx); + } + else { + if (oparg == NB_POWER || oparg == NB_INPLACE_POWER) { + // This one's fun... the *type* of the result depends on the + // *values* being exponentiated. However, exponents with one + // constant part are reasonably common, so it's probably worth + // trying to infer some simple cases: + // - A: 1 ** 1 -> 1 (int ** int -> int) + // - B: 1 ** -1 -> 1.0 (int ** int -> float) + // - C: 1.0 ** 1 -> 1.0 (float ** int -> float) + // - D: 1 ** 1.0 -> 1.0 (int ** float -> float) + // - E: -1 ** 0.5 ~> 1j (int ** float -> complex) + // - F: 1.0 ** 1.0 -> 1.0 (float ** float -> float) + // - G: -1.0 ** 0.5 ~> 1j (float ** float -> complex) + if (rhs_float) { + // Case D, E, F, or G... can't know without the sign of the LHS + // or whether the RHS is whole, which isn't worth the effort: + res = sym_new_unknown(ctx); + } + else { + if (lhs_float) { + // Case C: + res = sym_new_type(ctx, &PyFloat_Type); + } + else { + if (!sym_is_const(right)) { + // Case A or B... can't know without the sign of the RHS: + res = sym_new_unknown(ctx); + } + else { + if (_PyLong_IsNegative((PyLongObject *)sym_get_const(right))) { + // Case B: + res = sym_new_type(ctx, &PyFloat_Type); + } + else { + // Case A: + res = sym_new_type(ctx, &PyLong_Type); + } + } + } + } } else { - /* For any other op combining ints/floats the result is a float */ - res = sym_new_type(ctx, &PyFloat_Type); + if (oparg == NB_TRUE_DIVIDE || oparg == NB_INPLACE_TRUE_DIVIDE) { + res = sym_new_type(ctx, &PyFloat_Type); + } + else { + if (lhs_int && rhs_int) { + res = sym_new_type(ctx, &PyLong_Type); + } + else { + res = sym_new_type(ctx, &PyFloat_Type); + } + } } } - else { - res = sym_new_unknown(ctx); - } stack_pointer[-2] = res; stack_pointer += -1; assert(WITHIN_STACK_BOUNDS()); diff --git a/Tools/cases_generator/analyzer.py b/Tools/cases_generator/analyzer.py index c0a370a936aa94..679beca3ec3a9d 100644 --- a/Tools/cases_generator/analyzer.py +++ b/Tools/cases_generator/analyzer.py @@ -599,6 +599,7 @@ def has_error_without_pop(op: parser.InstDef) -> bool: "_PyLong_CompactValue", "_PyLong_DigitCount", "_PyLong_IsCompact", + "_PyLong_IsNegative", "_PyLong_IsNonNegativeCompact", "_PyLong_IsZero", "_PyLong_Multiply",