Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-91079: Decouple C stack overflow checks from Python recursion checks. #96510

Merged
merged 10 commits into from
Oct 5, 2022
16 changes: 14 additions & 2 deletions Include/cpython/pystate.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ struct _ts {
/* Was this thread state statically allocated? */
int _static;

int recursion_remaining;
int recursion_limit;
int py_recursion_remaining;
int py_recursion_limit;

int c_recursion_remaining;
int recursion_headroom; /* Allow 50 more calls to handle any errors. */

/* 'tracing' keeps track of the execution depth when tracing/profiling.
Expand Down Expand Up @@ -202,6 +204,16 @@ struct _ts {
_PyCFrame root_cframe;
};

/* WASI has limited call stack. Python's recursion limit depends on code
layout, optimization, and WASI runtime. Wasmtime can handle about 700
recursions, sometimes less. 500 is a more conservative limit. */
#ifndef C_RECURSION_LIMIT
# ifdef __wasi__
# define C_RECURSION_LIMIT 500
# else
# define C_RECURSION_LIMIT 800
# endif
#endif

/* other API */

Expand Down
21 changes: 9 additions & 12 deletions Include/internal/pycore_ceval.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,8 @@ extern "C" {
struct pyruntimestate;
struct _ceval_runtime_state;

/* WASI has limited call stack. Python's recursion limit depends on code
layout, optimization, and WASI runtime. Wasmtime can handle about 700-750
recursions, sometimes less. 600 is a more conservative limit. */
#ifndef Py_DEFAULT_RECURSION_LIMIT
# ifdef __wasi__
# define Py_DEFAULT_RECURSION_LIMIT 600
# else
# define Py_DEFAULT_RECURSION_LIMIT 1000
# endif
# define Py_DEFAULT_RECURSION_LIMIT 1000
#endif

#include "pycore_interp.h" // PyInterpreterState.eval_frame
Expand Down Expand Up @@ -118,19 +111,22 @@ extern void _PyEval_DeactivateOpCache(void);
/* With USE_STACKCHECK macro defined, trigger stack checks in
_Py_CheckRecursiveCall() on every 64th call to _Py_EnterRecursiveCall. */
static inline int _Py_MakeRecCheck(PyThreadState *tstate) {
return (tstate->recursion_remaining-- <= 0
|| (tstate->recursion_remaining & 63) == 0);
return (tstate->c_recursion_remaining-- <= 0
|| (tstate->c_recursion_remaining & 63) == 0);
}
#else
static inline int _Py_MakeRecCheck(PyThreadState *tstate) {
return tstate->recursion_remaining-- <= 0;
return tstate->c_recursion_remaining-- <= 0;
}
#endif

PyAPI_FUNC(int) _Py_CheckRecursiveCall(
PyThreadState *tstate,
const char *where);

int _Py_CheckRecursiveCallPy(
PyThreadState *tstate);

static inline int _Py_EnterRecursiveCallTstate(PyThreadState *tstate,
const char *where) {
return (_Py_MakeRecCheck(tstate) && _Py_CheckRecursiveCall(tstate, where));
Expand All @@ -142,7 +138,7 @@ static inline int _Py_EnterRecursiveCall(const char *where) {
}

static inline void _Py_LeaveRecursiveCallTstate(PyThreadState *tstate) {
tstate->recursion_remaining++;
tstate->c_recursion_remaining++;
}

static inline void _Py_LeaveRecursiveCall(void) {
Expand All @@ -157,6 +153,7 @@ extern PyObject* _Py_MakeCoro(PyFunctionObject *func);
extern int _Py_HandlePending(PyThreadState *tstate);



#ifdef __cplusplus
}
#endif
Expand Down
2 changes: 1 addition & 1 deletion Include/internal/pycore_runtime_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ extern "C" {
#define _PyThreadState_INIT \
{ \
._static = 1, \
.recursion_limit = Py_DEFAULT_RECURSION_LIMIT, \
.py_recursion_limit = Py_DEFAULT_RECURSION_LIMIT, \
.context_ver = 1, \
}

Expand Down
5 changes: 4 additions & 1 deletion Lib/test/support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"run_with_tz", "PGO", "missing_compiler_executable",
"ALWAYS_EQ", "NEVER_EQ", "LARGEST", "SMALLEST",
"LOOPBACK_TIMEOUT", "INTERNET_TIMEOUT", "SHORT_TIMEOUT", "LONG_TIMEOUT",
"Py_DEBUG",
"Py_DEBUG", "EXCEEDS_RECURSION_LIMIT",
]


Expand Down Expand Up @@ -2352,3 +2352,6 @@ def adjust_int_max_str_digits(max_digits):
yield
finally:
sys.set_int_max_str_digits(current)

#For recursion tests, easily exceeds default recursion limit
EXCEEDS_RECURSION_LIMIT = 5000
6 changes: 3 additions & 3 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,9 +825,9 @@ def next(self):

@support.cpython_only
def test_ast_recursion_limit(self):
fail_depth = sys.getrecursionlimit() * 3
crash_depth = sys.getrecursionlimit() * 300
success_depth = int(fail_depth * 0.75)
fail_depth = support.EXCEEDS_RECURSION_LIMIT
crash_depth = 100_000
success_depth = 1200

def check_limit(prefix, repeated):
expect_ok = prefix + repeated * success_depth
Expand Down
38 changes: 38 additions & 0 deletions Lib/test/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,44 @@ def test_multiple_values(self):
with self.check_raises_type_error(msg):
A().method_two_args("x", "y", x="oops")

@cpython_only
class TestRecursion(unittest.TestCase):

def test_super_deep(self):

def recurse(n):
if n:
recurse(n-1)

def py_recurse(n, m):
if n:
py_recurse(n-1, m)
else:
c_py_recurse(m-1)

def c_recurse(n):
if n:
_testcapi.pyobject_fastcall(c_recurse, (n-1,))

def c_py_recurse(m):
if m:
_testcapi.pyobject_fastcall(py_recurse, (1000, m))

depth = sys.getrecursionlimit()
sys.setrecursionlimit(100_000)
try:
recurse(90_000)
with self.assertRaises(RecursionError):
recurse(101_000)
c_recurse(100)
with self.assertRaises(RecursionError):
c_recurse(90_000)
c_py_recurse(90)
with self.assertRaises(RecursionError):
c_py_recurse(100_000)
finally:
sys.setrecursionlimit(depth)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion Lib/test/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def test_odd_sizes(self):
self.assertEqual(Dot(1)._replace(d=999), (999,))
self.assertEqual(Dot(1)._fields, ('d',))

n = 5000
n = support.EXCEEDS_RECURSION_LIMIT
names = list(set(''.join([choice(string.ascii_letters)
for j in range(10)]) for i in range(n)))
n = len(names)
Expand Down
3 changes: 1 addition & 2 deletions Lib/test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def __getitem__(self, key):

@unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI")
def test_extended_arg(self):
# default: 1000 * 2.5 = 2500 repetitions
repeat = int(sys.getrecursionlimit() * 2.5)
repeat = 2000
longexpr = 'x = x or ' + '-x' * repeat
g = {}
code = '''
Expand Down
8 changes: 4 additions & 4 deletions Lib/test/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ class MyGlobals(dict):
def __missing__(self, key):
return int(key.removeprefix("_number_"))

# 1,000 on most systems
limit = sys.getrecursionlimit()
code = "lambda: " + "+".join(f"_number_{i}" for i in range(limit))
# Need more than 256 variables to use EXTENDED_ARGS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume EXTENDED_ARGS has stack implications? explaining the "why" of this here would be useful.

variables = 400
code = "lambda: " + "+".join(f"_number_{i}" for i in range(variables))
sum_func = eval(code, MyGlobals())
expected = sum(range(limit))
expected = sum(range(variables))
# Warm up the the function for quickening (PEP 659)
for _ in range(30):
self.assertEqual(sum_func(), expected)
Expand Down
8 changes: 2 additions & 6 deletions Lib/test/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,7 @@ def test_recursion_normalizing_exception(self):
code = """if 1:
import sys
from _testinternalcapi import get_recursion_depth
from test import support

class MyException(Exception): pass

Expand Down Expand Up @@ -1399,13 +1400,8 @@ def gen():
generator = gen()
next(generator)
recursionlimit = sys.getrecursionlimit()
depth = get_recursion_depth()
try:
# Upon the last recursive invocation of recurse(),
# tstate->recursion_depth is equal to (recursion_limit - 1)
# and is equal to recursion_limit when _gen_throw() calls
# PyErr_NormalizeException().
recurse(setrecursionlimit(depth + 2) - depth)
recurse(support.EXCEEDS_RECURSION_LIMIT)
finally:
sys.setrecursionlimit(recursionlimit)
print('Done.')
Expand Down
12 changes: 6 additions & 6 deletions Lib/test/test_isinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from test import support



class TestIsInstanceExceptions(unittest.TestCase):
# Test to make sure that an AttributeError when accessing the instance's
# class's bases is masked. This was actually a bug in Python 2.2 and
Expand Down Expand Up @@ -97,7 +97,7 @@ def getclass(self):
class D: pass
self.assertRaises(RuntimeError, isinstance, c, D)


# These tests are similar to above, but tickle certain code paths in
# issubclass() instead of isinstance() -- really PyObject_IsSubclass()
# vs. PyObject_IsInstance().
Expand Down Expand Up @@ -147,7 +147,7 @@ def getbases(self):
self.assertRaises(TypeError, issubclass, B, C())



# meta classes for creating abstract classes and instances
class AbstractClass(object):
def __init__(self, bases):
Expand Down Expand Up @@ -179,7 +179,7 @@ class Super:

class Child(Super):
pass

class TestIsInstanceIsSubclass(unittest.TestCase):
# Tests to ensure that isinstance and issubclass work on abstract
# classes and instances. Before the 2.2 release, TypeErrors were
Expand Down Expand Up @@ -353,10 +353,10 @@ def blowstack(fxn, arg, compare_to):
# Make sure that calling isinstance with a deeply nested tuple for its
# argument will raise RecursionError eventually.
tuple_arg = (compare_to,)
for cnt in range(sys.getrecursionlimit()+5):
for cnt in range(support.EXCEEDS_RECURSION_LIMIT):
tuple_arg = (tuple_arg,)
fxn(arg, tuple_arg)


if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion Lib/test/test_marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def test_code(self):

def test_many_codeobjects(self):
# Issue2957: bad recursion count on code objects
count = 5000 # more than MAX_MARSHAL_STACK_DEPTH
# more than MAX_MARSHAL_STACK_DEPTH
count = support.EXCEEDS_RECURSION_LIMIT
codes = (ExceptionTestCase.test_exceptions.__code__,) * count
marshal.loads(marshal.dumps(codes))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Separate Python recursion checking from C recursion checking which reduces
the chance of C stack overflow and allows the recursion limit to be
increased safely.
4 changes: 1 addition & 3 deletions Modules/_testinternalcapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ get_recursion_depth(PyObject *self, PyObject *Py_UNUSED(args))
{
PyThreadState *tstate = _PyThreadState_GET();

/* subtract one to ignore the frame of the get_recursion_depth() call */

return PyLong_FromLong(tstate->recursion_limit - tstate->recursion_remaining - 1);
return PyLong_FromLong(tstate->py_recursion_limit - tstate->py_recursion_remaining);
}


Expand Down
9 changes: 3 additions & 6 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,19 +1380,16 @@ class PartingShots(StaticVisitor):
return NULL;
}

int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;
/* Be careful here to prevent overflow. */
int COMPILER_STACK_FRAME_SCALE = 3;
PyThreadState *tstate = _PyThreadState_GET();
if (!tstate) {
return 0;
}
state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth;

PyObject *result = ast2obj_mod(state, t);
Expand Down
9 changes: 3 additions & 6 deletions Python/Python-ast.c

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 3 additions & 6 deletions Python/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,6 @@ _PyAST_Validate(mod_ty mod)
int res = -1;
struct validator state;
PyThreadState *tstate;
int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;

/* Setup recursion depth check counters */
Expand All @@ -984,12 +983,10 @@ _PyAST_Validate(mod_ty mod)
return 0;
}
/* Be careful here to prevent overflow. */
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth< INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state.recursion_depth = starting_recursion_depth;
state.recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
state.recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;

switch (mod->kind) {
case Module_kind:
Expand Down
9 changes: 3 additions & 6 deletions Python/ast_opt.c
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,6 @@ int
_PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
{
PyThreadState *tstate;
int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;

/* Setup recursion depth check counters */
Expand All @@ -1089,12 +1088,10 @@ _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
return 0;
}
/* Be careful here to prevent overflow. */
int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
starting_recursion_depth = (recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth;
state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;

int ret = astfold_mod(mod, arena, state);
assert(ret || PyErr_Occurred());
Expand Down
Loading