import gc
|
import logging
|
import os
|
import subprocess
|
from pathlib import Path
|
|
import pytest
|
|
from traceback import (
|
extract_tb,
|
print_exception,
|
format_exception,
|
)
|
from traceback import _cause_message # type: ignore
|
import sys
|
import re
|
|
from .tutil import slow
|
from .._multierror import MultiError, concat_tb, NonBaseMultiError
|
from ... import TrioDeprecationWarning
|
from ..._core import open_nursery
|
|
if sys.version_info < (3, 11):
|
from exceptiongroup import ExceptionGroup
|
|
|
class NotHashableException(Exception):
|
code = None
|
|
def __init__(self, code):
|
super().__init__()
|
self.code = code
|
|
def __eq__(self, other):
|
if not isinstance(other, NotHashableException):
|
return False
|
return self.code == other.code
|
|
|
async def raise_nothashable(code):
|
raise NotHashableException(code)
|
|
|
def raiser1():
|
raiser1_2()
|
|
|
def raiser1_2():
|
raiser1_3()
|
|
|
def raiser1_3():
|
raise ValueError("raiser1_string")
|
|
|
def raiser2():
|
raiser2_2()
|
|
|
def raiser2_2():
|
raise KeyError("raiser2_string")
|
|
|
def raiser3():
|
raise NameError
|
|
|
def get_exc(raiser):
|
try:
|
raiser()
|
except Exception as exc:
|
return exc
|
|
|
def get_tb(raiser):
|
return get_exc(raiser).__traceback__
|
|
|
def test_concat_tb():
|
|
tb1 = get_tb(raiser1)
|
tb2 = get_tb(raiser2)
|
|
# These return a list of (filename, lineno, fn name, text) tuples
|
# https://docs.python.org/3/library/traceback.html#traceback.extract_tb
|
entries1 = extract_tb(tb1)
|
entries2 = extract_tb(tb2)
|
|
tb12 = concat_tb(tb1, tb2)
|
assert extract_tb(tb12) == entries1 + entries2
|
|
tb21 = concat_tb(tb2, tb1)
|
assert extract_tb(tb21) == entries2 + entries1
|
|
# Check degenerate cases
|
assert extract_tb(concat_tb(None, tb1)) == entries1
|
assert extract_tb(concat_tb(tb1, None)) == entries1
|
assert concat_tb(None, None) is None
|
|
# Make sure the original tracebacks didn't get mutated by mistake
|
assert extract_tb(get_tb(raiser1)) == entries1
|
assert extract_tb(get_tb(raiser2)) == entries2
|
|
|
def test_MultiError():
|
exc1 = get_exc(raiser1)
|
exc2 = get_exc(raiser2)
|
|
assert MultiError([exc1]) is exc1
|
m = MultiError([exc1, exc2])
|
assert m.exceptions == (exc1, exc2)
|
assert "ValueError" in str(m)
|
assert "ValueError" in repr(m)
|
|
with pytest.raises(TypeError):
|
MultiError(object())
|
with pytest.raises(TypeError):
|
MultiError([KeyError(), ValueError])
|
|
|
def test_MultiErrorOfSingleMultiError():
|
# For MultiError([MultiError]), ensure there is no bad recursion by the
|
# constructor where __init__ is called if __new__ returns a bare MultiError.
|
exceptions = (KeyError(), ValueError())
|
a = MultiError(exceptions)
|
b = MultiError([a])
|
assert b == a
|
assert b.exceptions == exceptions
|
|
|
async def test_MultiErrorNotHashable():
|
exc1 = NotHashableException(42)
|
exc2 = NotHashableException(4242)
|
exc3 = ValueError()
|
assert exc1 != exc2
|
assert exc1 != exc3
|
|
with pytest.raises(MultiError):
|
async with open_nursery() as nursery:
|
nursery.start_soon(raise_nothashable, 42)
|
nursery.start_soon(raise_nothashable, 4242)
|
|
|
def test_MultiError_filter_NotHashable():
|
excs = MultiError([NotHashableException(42), ValueError()])
|
|
def handle_ValueError(exc):
|
if isinstance(exc, ValueError):
|
return None
|
else:
|
return exc
|
|
with pytest.warns(TrioDeprecationWarning):
|
filtered_excs = MultiError.filter(handle_ValueError, excs)
|
|
assert isinstance(filtered_excs, NotHashableException)
|
|
|
def make_tree():
|
# Returns an object like:
|
# MultiError([
|
# MultiError([
|
# ValueError,
|
# KeyError,
|
# ]),
|
# NameError,
|
# ])
|
# where all exceptions except the root have a non-trivial traceback.
|
exc1 = get_exc(raiser1)
|
exc2 = get_exc(raiser2)
|
exc3 = get_exc(raiser3)
|
|
# Give m12 a non-trivial traceback
|
try:
|
raise MultiError([exc1, exc2])
|
except BaseException as m12:
|
return MultiError([m12, exc3])
|
|
|
def assert_tree_eq(m1, m2):
|
if m1 is None or m2 is None:
|
assert m1 is m2
|
return
|
assert type(m1) is type(m2)
|
assert extract_tb(m1.__traceback__) == extract_tb(m2.__traceback__)
|
assert_tree_eq(m1.__cause__, m2.__cause__)
|
assert_tree_eq(m1.__context__, m2.__context__)
|
if isinstance(m1, MultiError):
|
assert len(m1.exceptions) == len(m2.exceptions)
|
for e1, e2 in zip(m1.exceptions, m2.exceptions):
|
assert_tree_eq(e1, e2)
|
|
|
def test_MultiError_filter():
|
def null_handler(exc):
|
return exc
|
|
m = make_tree()
|
assert_tree_eq(m, m)
|
with pytest.warns(TrioDeprecationWarning):
|
assert MultiError.filter(null_handler, m) is m
|
|
assert_tree_eq(m, make_tree())
|
|
# Make sure we don't pick up any detritus if run in a context where
|
# implicit exception chaining would like to kick in
|
m = make_tree()
|
try:
|
raise ValueError
|
except ValueError:
|
with pytest.warns(TrioDeprecationWarning):
|
assert MultiError.filter(null_handler, m) is m
|
assert_tree_eq(m, make_tree())
|
|
def simple_filter(exc):
|
if isinstance(exc, ValueError):
|
return None
|
if isinstance(exc, KeyError):
|
return RuntimeError()
|
return exc
|
|
with pytest.warns(TrioDeprecationWarning):
|
new_m = MultiError.filter(simple_filter, make_tree())
|
|
assert isinstance(new_m, MultiError)
|
assert len(new_m.exceptions) == 2
|
# was: [[ValueError, KeyError], NameError]
|
# ValueError disappeared & KeyError became RuntimeError, so now:
|
assert isinstance(new_m.exceptions[0], RuntimeError)
|
assert isinstance(new_m.exceptions[1], NameError)
|
|
# implicit chaining:
|
assert isinstance(new_m.exceptions[0].__context__, KeyError)
|
|
# also, the traceback on the KeyError incorporates what used to be the
|
# traceback on its parent MultiError
|
orig = make_tree()
|
# make sure we have the right path
|
assert isinstance(orig.exceptions[0].exceptions[1], KeyError)
|
# get original traceback summary
|
orig_extracted = (
|
extract_tb(orig.__traceback__)
|
+ extract_tb(orig.exceptions[0].__traceback__)
|
+ extract_tb(orig.exceptions[0].exceptions[1].__traceback__)
|
)
|
|
def p(exc):
|
print_exception(type(exc), exc, exc.__traceback__)
|
|
p(orig)
|
p(orig.exceptions[0])
|
p(orig.exceptions[0].exceptions[1])
|
p(new_m.exceptions[0].__context__)
|
# compare to the new path
|
assert new_m.__traceback__ is None
|
new_extracted = extract_tb(new_m.exceptions[0].__context__.__traceback__)
|
assert orig_extracted == new_extracted
|
|
# check preserving partial tree
|
def filter_NameError(exc):
|
if isinstance(exc, NameError):
|
return None
|
return exc
|
|
m = make_tree()
|
with pytest.warns(TrioDeprecationWarning):
|
new_m = MultiError.filter(filter_NameError, m)
|
# with the NameError gone, the other branch gets promoted
|
assert new_m is m.exceptions[0]
|
|
# check fully handling everything
|
def filter_all(exc):
|
return None
|
|
with pytest.warns(TrioDeprecationWarning):
|
assert MultiError.filter(filter_all, make_tree()) is None
|
|
|
def test_MultiError_catch():
|
# No exception to catch
|
|
def noop(_):
|
pass # pragma: no cover
|
|
with pytest.warns(TrioDeprecationWarning), MultiError.catch(noop):
|
pass
|
|
# Simple pass-through of all exceptions
|
m = make_tree()
|
with pytest.raises(MultiError) as excinfo:
|
with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda exc: exc):
|
raise m
|
assert excinfo.value is m
|
# Should be unchanged, except that we added a traceback frame by raising
|
# it here
|
assert m.__traceback__ is not None
|
assert m.__traceback__.tb_frame.f_code.co_name == "test_MultiError_catch"
|
assert m.__traceback__.tb_next is None
|
m.__traceback__ = None
|
assert_tree_eq(m, make_tree())
|
|
# Swallows everything
|
with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda _: None):
|
raise make_tree()
|
|
def simple_filter(exc):
|
if isinstance(exc, ValueError):
|
return None
|
if isinstance(exc, KeyError):
|
return RuntimeError()
|
return exc
|
|
with pytest.raises(MultiError) as excinfo:
|
with pytest.warns(TrioDeprecationWarning), MultiError.catch(simple_filter):
|
raise make_tree()
|
new_m = excinfo.value
|
assert isinstance(new_m, MultiError)
|
assert len(new_m.exceptions) == 2
|
# was: [[ValueError, KeyError], NameError]
|
# ValueError disappeared & KeyError became RuntimeError, so now:
|
assert isinstance(new_m.exceptions[0], RuntimeError)
|
assert isinstance(new_m.exceptions[1], NameError)
|
# Make sure that Python did not successfully attach the old MultiError to
|
# our new MultiError's __context__
|
assert not new_m.__suppress_context__
|
assert new_m.__context__ is None
|
|
# check preservation of __cause__ and __context__
|
v = ValueError()
|
v.__cause__ = KeyError()
|
with pytest.raises(ValueError) as excinfo:
|
with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda exc: exc):
|
raise v
|
assert isinstance(excinfo.value.__cause__, KeyError)
|
|
v = ValueError()
|
context = KeyError()
|
v.__context__ = context
|
with pytest.raises(ValueError) as excinfo:
|
with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda exc: exc):
|
raise v
|
assert excinfo.value.__context__ is context
|
assert not excinfo.value.__suppress_context__
|
|
for suppress_context in [True, False]:
|
v = ValueError()
|
context = KeyError()
|
v.__context__ = context
|
v.__suppress_context__ = suppress_context
|
distractor = RuntimeError()
|
with pytest.raises(ValueError) as excinfo:
|
|
def catch_RuntimeError(exc):
|
if isinstance(exc, RuntimeError):
|
return None
|
else:
|
return exc
|
|
with pytest.warns(TrioDeprecationWarning):
|
with MultiError.catch(catch_RuntimeError):
|
raise MultiError([v, distractor])
|
assert excinfo.value.__context__ is context
|
assert excinfo.value.__suppress_context__ == suppress_context
|
|
|
@pytest.mark.skipif(
|
sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC"
|
)
|
def test_MultiError_catch_doesnt_create_cyclic_garbage():
|
# https://github.com/python-trio/trio/pull/2063
|
gc.collect()
|
old_flags = gc.get_debug()
|
|
def make_multi():
|
# make_tree creates cycles itself, so a simple
|
raise MultiError([get_exc(raiser1), get_exc(raiser2)])
|
|
def simple_filter(exc):
|
if isinstance(exc, ValueError):
|
return Exception()
|
if isinstance(exc, KeyError):
|
return RuntimeError()
|
assert False, "only ValueError and KeyError should exist" # pragma: no cover
|
|
try:
|
gc.set_debug(gc.DEBUG_SAVEALL)
|
with pytest.raises(MultiError):
|
# covers MultiErrorCatcher.__exit__ and _multierror.copy_tb
|
with pytest.warns(TrioDeprecationWarning), MultiError.catch(simple_filter):
|
raise make_multi()
|
gc.collect()
|
assert not gc.garbage
|
finally:
|
gc.set_debug(old_flags)
|
gc.garbage.clear()
|
|
|
def assert_match_in_seq(pattern_list, string):
|
offset = 0
|
print("looking for pattern matches...")
|
for pattern in pattern_list:
|
print("checking pattern:", pattern)
|
reobj = re.compile(pattern)
|
match = reobj.search(string, offset)
|
assert match is not None
|
offset = match.end()
|
|
|
def test_assert_match_in_seq():
|
assert_match_in_seq(["a", "b"], "xx a xx b xx")
|
assert_match_in_seq(["b", "a"], "xx b xx a xx")
|
with pytest.raises(AssertionError):
|
assert_match_in_seq(["a", "b"], "xx b xx a xx")
|
|
|
def test_base_multierror():
|
"""
|
Test that MultiError() with at least one base exception will return a MultiError
|
object.
|
"""
|
|
exc = MultiError([ZeroDivisionError(), KeyboardInterrupt()])
|
assert type(exc) is MultiError
|
|
|
def test_non_base_multierror():
|
"""
|
Test that MultiError() without base exceptions will return a NonBaseMultiError
|
object.
|
"""
|
|
exc = MultiError([ZeroDivisionError(), ValueError()])
|
assert type(exc) is NonBaseMultiError
|
assert isinstance(exc, ExceptionGroup)
|
|
|
def run_script(name, use_ipython=False):
|
import trio
|
|
trio_path = Path(trio.__file__).parent.parent
|
script_path = Path(__file__).parent / "test_multierror_scripts" / name
|
|
env = dict(os.environ)
|
print("parent PYTHONPATH:", env.get("PYTHONPATH"))
|
if "PYTHONPATH" in env: # pragma: no cover
|
pp = env["PYTHONPATH"].split(os.pathsep)
|
else:
|
pp = []
|
pp.insert(0, str(trio_path))
|
pp.insert(0, str(script_path.parent))
|
env["PYTHONPATH"] = os.pathsep.join(pp)
|
print("subprocess PYTHONPATH:", env.get("PYTHONPATH"))
|
|
if use_ipython:
|
lines = [script_path.read_text(), "exit()"]
|
|
cmd = [
|
sys.executable,
|
"-u",
|
"-m",
|
"IPython",
|
# no startup files
|
"--quick",
|
"--TerminalIPythonApp.code_to_run=" + "\n".join(lines),
|
]
|
else:
|
cmd = [sys.executable, "-u", str(script_path)]
|
print("running:", cmd)
|
completed = subprocess.run(
|
cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
|
)
|
print("process output:")
|
print(completed.stdout.decode("utf-8"))
|
return completed
|
|
|
def check_simple_excepthook(completed):
|
assert_match_in_seq(
|
[
|
"in <module>",
|
"MultiError",
|
"--- 1 ---",
|
"in exc1_fn",
|
"ValueError",
|
"--- 2 ---",
|
"in exc2_fn",
|
"KeyError",
|
],
|
completed.stdout.decode("utf-8"),
|
)
|
|
|
try:
|
import IPython
|
except ImportError: # pragma: no cover
|
have_ipython = False
|
else:
|
have_ipython = True
|
|
need_ipython = pytest.mark.skipif(not have_ipython, reason="need IPython")
|
|
|
@slow
|
@need_ipython
|
def test_ipython_exc_handler():
|
completed = run_script("simple_excepthook.py", use_ipython=True)
|
check_simple_excepthook(completed)
|
|
|
@slow
|
@need_ipython
|
def test_ipython_imported_but_unused():
|
completed = run_script("simple_excepthook_IPython.py")
|
check_simple_excepthook(completed)
|
|
|
@slow
|
@need_ipython
|
def test_ipython_custom_exc_handler():
|
# Check we get a nice warning (but only one!) if the user is using IPython
|
# and already has some other set_custom_exc handler installed.
|
completed = run_script("ipython_custom_exc.py", use_ipython=True)
|
assert_match_in_seq(
|
[
|
# The warning
|
"RuntimeWarning",
|
"IPython detected",
|
"skip installing Trio",
|
# The MultiError
|
"MultiError",
|
"ValueError",
|
"KeyError",
|
],
|
completed.stdout.decode("utf-8"),
|
)
|
# Make sure our other warning doesn't show up
|
assert "custom sys.excepthook" not in completed.stdout.decode("utf-8")
|
|
|
@slow
|
@pytest.mark.skipif(
|
not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(),
|
reason="need Ubuntu with python3-apport installed",
|
)
|
def test_apport_excepthook_monkeypatch_interaction():
|
completed = run_script("apport_excepthook.py")
|
stdout = completed.stdout.decode("utf-8")
|
|
# No warning
|
assert "custom sys.excepthook" not in stdout
|
|
# Proper traceback
|
assert_match_in_seq(
|
["--- 1 ---", "KeyError", "--- 2 ---", "ValueError"],
|
stdout,
|
)
|