Add stack overflow checking to Python

This commit is contained in:
Justine Tunney 2021-10-02 10:31:24 -07:00
parent 47a53e143b
commit 7521bf9e73
16 changed files with 47 additions and 46 deletions

View file

@ -1,6 +1,8 @@
#ifndef Py_CEVAL_H #ifndef Py_CEVAL_H
#define Py_CEVAL_H #define Py_CEVAL_H
#include "libc/bits/likely.h"
#include "third_party/python/Include/object.h" #include "third_party/python/Include/object.h"
#include "third_party/python/Include/pyerrors.h"
#include "third_party/python/Include/pystate.h" #include "third_party/python/Include/pystate.h"
#include "third_party/python/Include/pythonrun.h" #include "third_party/python/Include/pythonrun.h"
COSMOPOLITAN_C_START_ COSMOPOLITAN_C_START_
@ -77,18 +79,24 @@ int Py_MakePendingCalls(void);
void Py_SetRecursionLimit(int); void Py_SetRecursionLimit(int);
int Py_GetRecursionLimit(void); int Py_GetRecursionLimit(void);
#if IsModeDbg() forceinline int Py_EnterRecursiveCall(const char *where) {
#define Py_EnterRecursiveCall(where) \ const char *rsp, *bot;
(_Py_MakeRecCheck(PyThreadState_GET()->recursion_depth) && \ extern char ape_stack_vaddr[] __attribute__((__weak__));
_Py_CheckRecursiveCall(where)) if (!IsTiny()) {
#define Py_LeaveRecursiveCall() \ rsp = __builtin_frame_address(0);
do{ if(_Py_MakeEndRecCheck(PyThreadState_GET()->recursion_depth)) \ asm(".weak\tape_stack_vaddr\n\t"
PyThreadState_GET()->overflowed = 0; \ "movabs\t$ape_stack_vaddr+12288,%0" : "=r"(bot));
} while(0) if (UNLIKELY(rsp < bot)) {
#else PyErr_Format(PyExc_MemoryError, "Stack overflow%s", where);
#define Py_EnterRecursiveCall(where) (0) return -1;
#define Py_LeaveRecursiveCall(where) ((void)0) }
#endif }
return 0;
}
forceinline void Py_LeaveRecursiveCall(void) {
}
int _Py_CheckRecursiveCall(const char *where); int _Py_CheckRecursiveCall(const char *where);
extern int _Py_CheckRecursionLimit; extern int _Py_CheckRecursionLimit;

View file

@ -490,8 +490,6 @@ class ClassTests(unittest.TestCase):
self.assertRaises(TypeError, hash, C2()) self.assertRaises(TypeError, hash, C2())
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def testSFBug532646(self): def testSFBug532646(self):
# Test for SF bug 532646 # Test for SF bug 532646
@ -502,7 +500,7 @@ class ClassTests(unittest.TestCase):
try: try:
a() # This should not segfault a() # This should not segfault
except RecursionError: except (RecursionError, MemoryError):
pass pass
else: else:
self.fail("Failed to raise RecursionError") self.fail("Failed to raise RecursionError")

View file

@ -3470,16 +3470,14 @@ order (MRO) for bases """
list.__init__(a, sequence=[0, 1, 2]) list.__init__(a, sequence=[0, 1, 2])
self.assertEqual(a, [0, 1, 2]) self.assertEqual(a, [0, 1, 2])
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_recursive_call(self): def test_recursive_call(self):
# Testing recursive __call__() by setting to instance of class... # Testing recursive __call__() by setting to instance of class...
class A(object): class A(object):
pass pass
A.__call__ = A() A.__call__ = A()
try: try:
A()() A()()
except RecursionError: except (RecursionError, MemoryError):
pass pass
else: else:
self.fail("Recursion limit should have been reached for __call__()") self.fail("Recursion limit should have been reached for __call__()")
@ -4496,7 +4494,6 @@ order (MRO) for bases """
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
str.__add__(fake_str, "abc") str.__add__(fake_str, "abc")
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_repr_as_str(self): def test_repr_as_str(self):
# Issue #11603: crash or infinite loop when rebinding __str__ as # Issue #11603: crash or infinite loop when rebinding __str__ as
# __repr__. # __repr__.
@ -4504,8 +4501,12 @@ order (MRO) for bases """
pass pass
Foo.__repr__ = Foo.__str__ Foo.__repr__ = Foo.__str__
foo = Foo() foo = Foo()
if cosmo.MODE == 'dbg':
self.assertRaises(RecursionError, str, foo) self.assertRaises(RecursionError, str, foo)
self.assertRaises(RecursionError, repr, foo) self.assertRaises(RecursionError, repr, foo)
else:
self.assertRaises(MemoryError, str, foo)
self.assertRaises(MemoryError, repr, foo)
def test_mixing_slot_wrappers(self): def test_mixing_slot_wrappers(self):
class X(dict): class X(dict):

View file

@ -219,7 +219,12 @@ class DictSetTest(unittest.TestCase):
d = {} d = {}
for i in range(sys.getrecursionlimit() + 100): for i in range(sys.getrecursionlimit() + 100):
d = {42: d.values()} d = {42: d.values()}
self.assertRaises(RecursionError, repr, d) try:
repr(d)
except (RecursionError, MemoryError):
pass
else:
assert False
def test_copy(self): def test_copy(self):
d = {1: 10, "a": "ABC"} d = {1: 10, "a": "ABC"}

View file

@ -515,12 +515,11 @@ class ExceptionTests(unittest.TestCase):
self.assertEqual(x.fancy_arg, 42) self.assertEqual(x.fancy_arg, 42)
@no_tracing @no_tracing
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking") @unittest.skipUnless(cosmo.MODE == "dbg", "TODO: disabled recursion checking")
def testInfiniteRecursion(self): def testInfiniteRecursion(self):
def f(): def f():
return f() return f()
self.assertRaises(RecursionError, f) self.assertRaises(RecursionError, f)
def g(): def g():
try: try:
return g() return g()
@ -939,8 +938,8 @@ class ExceptionTests(unittest.TestCase):
self.assertTrue(isinstance(v, RecursionError), type(v)) self.assertTrue(isinstance(v, RecursionError), type(v))
self.assertIn("maximum recursion depth exceeded", str(v)) self.assertIn("maximum recursion depth exceeded", str(v))
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
@cpython_only @cpython_only
@unittest.skipUnless(cosmo.MODE == "dbg", "TODO: disabled recursion checking")
def test_recursion_normalizing_exception(self): def test_recursion_normalizing_exception(self):
# Issue #22898. # Issue #22898.
# Test that a RecursionError is raised when tstate->recursion_depth is # Test that a RecursionError is raised when tstate->recursion_depth is
@ -1017,8 +1016,8 @@ class ExceptionTests(unittest.TestCase):
b'while normalizing an exception', err) b'while normalizing an exception', err)
self.assertIn(b'Done.', out) self.assertIn(b'Done.', out)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
@cpython_only @cpython_only
@unittest.skipUnless(cosmo.MODE == "dbg", "TODO: disabled recursion checking")
def test_recursion_normalizing_with_no_memory(self): def test_recursion_normalizing_with_no_memory(self):
# Issue #30697. Test that in the abort that occurs when there is no # Issue #30697. Test that in the abort that occurs when there is no
# memory left and the size of the Python frames stack is greater than # memory left and the size of the Python frames stack is greater than
@ -1123,7 +1122,8 @@ class ExceptionTests(unittest.TestCase):
self.assertEqual(wr(), None) self.assertEqual(wr(), None)
@no_tracing @no_tracing
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking") @unittest.skipIf(cosmo.MODE != 'dbg',
"WUT: Fatal Python error: Cannot recover from MemoryErrors while normalizing exceptions.")
def test_recursion_error_cleanup(self): def test_recursion_error_cleanup(self):
# Same test as above, but with "recursion exceeded" errors # Same test as above, but with "recursion exceeded" errors
class C: class C:

View file

@ -177,12 +177,16 @@ class AutoFileTests:
finally: finally:
os.close(fd) os.close(fd)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking") # @unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def testRecursiveRepr(self): def testRecursiveRepr(self):
# Issue #25455 # Issue #25455
with swap_attr(self.f, 'name', self.f): with swap_attr(self.f, 'name', self.f):
with self.assertRaises(RuntimeError): try:
repr(self.f) # Should not crash repr(self.f) # Should not crash
except (RuntimeError, MemoryError):
pass
else:
assert False
def testErrors(self): def testErrors(self):
f = self.f f = self.f

View file

@ -217,7 +217,6 @@ class TestPartial:
[f'{name}({capture!r}, {args_repr}, {kwargs_repr})' [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
for kwargs_repr in kwargs_reprs]) for kwargs_repr in kwargs_reprs])
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_recursive_repr(self): def test_recursive_repr(self):
if self.partial in (c_functools.partial, py_functools.partial): if self.partial in (c_functools.partial, py_functools.partial):
name = 'functools.partial' name = 'functools.partial'

View file

@ -1100,7 +1100,6 @@ class CommonBufferedTests:
raw.name = b"dummy" raw.name = b"dummy"
self.assertEqual(repr(b), "<%s name=b'dummy'>" % clsname) self.assertEqual(repr(b), "<%s name=b'dummy'>" % clsname)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_recursive_repr(self): def test_recursive_repr(self):
# Issue #25455 # Issue #25455
raw = self.MockRawIO() raw = self.MockRawIO()
@ -2542,7 +2541,6 @@ class TextIOWrapperTest(unittest.TestCase):
t.buffer.detach() t.buffer.detach()
repr(t) # Should not raise an exception repr(t) # Should not raise an exception
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_recursive_repr(self): def test_recursive_repr(self):
# Issue #25455 # Issue #25455
raw = self.BytesIO() raw = self.BytesIO()

View file

@ -67,7 +67,6 @@ class TestRecursion:
self.fail("didn't raise ValueError on default recursion") self.fail("didn't raise ValueError on default recursion")
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_highly_nested_objects_decoding(self): def test_highly_nested_objects_decoding(self):
# test that loading highly-nested objects doesn't segfault when C # test that loading highly-nested objects doesn't segfault when C
# accelerations are used. See #12017 # accelerations are used. See #12017
@ -78,7 +77,6 @@ class TestRecursion:
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
self.loads('[' * 100000 + '1' + ']' * 100000) self.loads('[' * 100000 + '1' + ']' * 100000)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_highly_nested_objects_encoding(self): def test_highly_nested_objects_encoding(self):
# See #12051 # See #12051
l, d = [], {} l, d = [], {}
@ -89,7 +87,6 @@ class TestRecursion:
with self.assertRaises(RecursionError): with self.assertRaises(RecursionError):
self.dumps(d) self.dumps(d)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_endless_recursion(self): def test_endless_recursion(self):
# See #12051 # See #12051
class EndlessJSONEncoder(self.json.JSONEncoder): class EndlessJSONEncoder(self.json.JSONEncoder):

View file

@ -814,13 +814,12 @@ class TestBinaryPlistlib(unittest.TestCase):
b = plistlib.loads(plistlib.dumps(a, fmt=plistlib.FMT_BINARY)) b = plistlib.loads(plistlib.dumps(a, fmt=plistlib.FMT_BINARY))
self.assertIs(b['x'], b) self.assertIs(b['x'], b)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_deep_nesting(self): def test_deep_nesting(self):
for N in [300, 100000]: for N in [300, 100000]:
chunks = [b'\xa1' + (i + 1).to_bytes(4, 'big') for i in range(N)] chunks = [b'\xa1' + (i + 1).to_bytes(4, 'big') for i in range(N)]
try: try:
result = self.decode(*chunks, b'\x54seed', offset_size=4, ref_size=4) result = self.decode(*chunks, b'\x54seed', offset_size=4, ref_size=4)
except RecursionError: except (RecursionError, MemoryError):
pass pass
else: else:
for i in range(N): for i in range(N):

View file

@ -221,7 +221,7 @@ class MiscTest(unittest.TestCase):
for func in (do, operator.not_): for func in (do, operator.not_):
self.assertRaises(Exc, func, Bad()) self.assertRaises(Exc, func, Bad())
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking") @unittest.skipUnless(cosmo.MODE == "dbg", "TODO: disabled recursion checking")
@support.no_tracing @support.no_tracing
def test_recursion(self): def test_recursion(self):
# Check that comparison for recursive objects fails gracefully # Check that comparison for recursive objects fails gracefully

View file

@ -724,7 +724,6 @@ class RunPathTestCase(unittest.TestCase, CodeExecutionMixin):
self._check_import_error(zip_name, msg) self._check_import_error(zip_name, msg)
@no_tracing @no_tracing
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_main_recursion_error(self): def test_main_recursion_error(self):
with temp_dir() as script_dir, temp_dir() as dummy_dir: with temp_dir() as script_dir, temp_dir() as dummy_dir:
mod_name = '__main__' mod_name = '__main__'

View file

@ -191,7 +191,6 @@ class SysModuleTest(unittest.TestCase):
finally: finally:
sys.setswitchinterval(orig) sys.setswitchinterval(orig)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_recursionlimit(self): def test_recursionlimit(self):
self.assertRaises(TypeError, sys.getrecursionlimit, 42) self.assertRaises(TypeError, sys.getrecursionlimit, 42)
oldlimit = sys.getrecursionlimit() oldlimit = sys.getrecursionlimit()
@ -201,7 +200,6 @@ class SysModuleTest(unittest.TestCase):
self.assertEqual(sys.getrecursionlimit(), 10000) self.assertEqual(sys.getrecursionlimit(), 10000)
sys.setrecursionlimit(oldlimit) sys.setrecursionlimit(oldlimit)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_recursionlimit_recovery(self): def test_recursionlimit_recovery(self):
if hasattr(sys, 'gettrace') and sys.gettrace(): if hasattr(sys, 'gettrace') and sys.gettrace():
self.skipTest('fatal error if run with a trace function') self.skipTest('fatal error if run with a trace function')
@ -225,7 +223,6 @@ class SysModuleTest(unittest.TestCase):
finally: finally:
sys.setrecursionlimit(oldlimit) sys.setrecursionlimit(oldlimit)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
@test.support.cpython_only @test.support.cpython_only
def test_setrecursionlimit_recursion_depth(self): def test_setrecursionlimit_recursion_depth(self):
# Issue #25274: Setting a low recursion limit must be blocked if the # Issue #25274: Setting a low recursion limit must be blocked if the
@ -261,7 +258,6 @@ class SysModuleTest(unittest.TestCase):
finally: finally:
sys.setrecursionlimit(oldlimit) sys.setrecursionlimit(oldlimit)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_recursionlimit_fatalerror(self): def test_recursionlimit_fatalerror(self):
# A fatal error occurs if a second recursion limit is hit when recovering # A fatal error occurs if a second recursion limit is hit when recovering
# from a first one. # from a first one.

View file

@ -883,7 +883,6 @@ class ThreadingExceptionTests(BaseTestCase):
lock = threading.Lock() lock = threading.Lock()
self.assertRaises(RuntimeError, lock.release) self.assertRaises(RuntimeError, lock.release)
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
@unittest.skipUnless(sys.platform == 'darwin' and test.support.python_is_optimized(), @unittest.skipUnless(sys.platform == 'darwin' and test.support.python_is_optimized(),
'test macosx problem') 'test macosx problem')
def test_recursion_limit(self): def test_recursion_limit(self):

View file

@ -301,7 +301,6 @@ class TracebackFormatTests(unittest.TestCase):
]) ])
# issue 26823 - Shrink recursive tracebacks # issue 26823 - Shrink recursive tracebacks
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def _check_recursive_traceback_display(self, render_exc): def _check_recursive_traceback_display(self, render_exc):
# Always show full diffs when this test fails # Always show full diffs when this test fails
# Note that rearranging things may require adjusting # Note that rearranging things may require adjusting

View file

@ -1921,7 +1921,6 @@ class BadElementTest(ElementTestCase, unittest.TestCase):
e.extend([ET.Element('bar')]) e.extend([ET.Element('bar')])
self.assertRaises(ValueError, e.remove, X('baz')) self.assertRaises(ValueError, e.remove, X('baz'))
@unittest.skipUnless(cosmo.MODE == "dbg", "disabled recursion checking")
def test_recursive_repr(self): def test_recursive_repr(self):
# Issue #25455 # Issue #25455
e = ET.Element('foo') e = ET.Element('foo')