mirror of
https://github.com/jart/cosmopolitan.git
synced 2025-02-07 15:03:34 +00:00
The WIN32 CreateProcess() function does not require an .exe or .com suffix in order to spawn an executable. Now that we have Cosmo bash we're no longer so dependent on the cmd.exe prompt.
2105 lines
77 KiB
Python
2105 lines
77 KiB
Python
import abc
|
|
import cosmo
|
|
import builtins
|
|
import collections
|
|
import copy
|
|
from itertools import permutations
|
|
import pickle
|
|
from random import choice
|
|
import sys
|
|
from test import support
|
|
import time
|
|
import unittest
|
|
from weakref import proxy
|
|
import contextlib
|
|
try:
|
|
import _thread
|
|
import threading
|
|
except ImportError:
|
|
threading = None
|
|
|
|
import functools
|
|
if __name__ == 'PYOBJ':
|
|
import decimal
|
|
|
|
py_functools = support.import_fresh_module('functools', blocked=['_functools'])
|
|
c_functools = support.import_fresh_module('functools', fresh=['_functools'])
|
|
decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
|
|
|
|
@contextlib.contextmanager
|
|
def replaced_module(name, replacement):
|
|
original_module = sys.modules[name]
|
|
sys.modules[name] = replacement
|
|
try:
|
|
yield
|
|
finally:
|
|
sys.modules[name] = original_module
|
|
|
|
def capture(*args, **kw):
|
|
"""capture all positional and keyword arguments"""
|
|
return args, kw
|
|
|
|
|
|
def signature(part):
|
|
""" return the signature of a partial object """
|
|
return (part.func, part.args, part.keywords, part.__dict__)
|
|
|
|
class MyTuple(tuple):
|
|
pass
|
|
|
|
class BadTuple(tuple):
|
|
def __add__(self, other):
|
|
return list(self) + list(other)
|
|
|
|
class MyDict(dict):
|
|
pass
|
|
|
|
|
|
class TestPartial:
|
|
|
|
def test_basic_examples(self):
|
|
p = self.partial(capture, 1, 2, a=10, b=20)
|
|
self.assertTrue(callable(p))
|
|
self.assertEqual(p(3, 4, b=30, c=40),
|
|
((1, 2, 3, 4), dict(a=10, b=30, c=40)))
|
|
p = self.partial(map, lambda x: x*10)
|
|
self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
|
|
|
|
def test_attributes(self):
|
|
p = self.partial(capture, 1, 2, a=10, b=20)
|
|
# attributes should be readable
|
|
self.assertEqual(p.func, capture)
|
|
self.assertEqual(p.args, (1, 2))
|
|
self.assertEqual(p.keywords, dict(a=10, b=20))
|
|
|
|
def test_argument_checking(self):
|
|
self.assertRaises(TypeError, self.partial) # need at least a func arg
|
|
try:
|
|
self.partial(2)()
|
|
except TypeError:
|
|
pass
|
|
else:
|
|
self.fail('First arg not checked for callability')
|
|
|
|
def test_protection_of_callers_dict_argument(self):
|
|
# a caller's dictionary should not be altered by partial
|
|
def func(a=10, b=20):
|
|
return a
|
|
d = {'a':3}
|
|
p = self.partial(func, a=5)
|
|
self.assertEqual(p(**d), 3)
|
|
self.assertEqual(d, {'a':3})
|
|
p(b=7)
|
|
self.assertEqual(d, {'a':3})
|
|
|
|
def test_kwargs_copy(self):
|
|
# Issue #29532: Altering a kwarg dictionary passed to a constructor
|
|
# should not affect a partial object after creation
|
|
d = {'a': 3}
|
|
p = self.partial(capture, **d)
|
|
self.assertEqual(p(), ((), {'a': 3}))
|
|
d['a'] = 5
|
|
self.assertEqual(p(), ((), {'a': 3}))
|
|
|
|
def test_arg_combinations(self):
|
|
# exercise special code paths for zero args in either partial
|
|
# object or the caller
|
|
p = self.partial(capture)
|
|
self.assertEqual(p(), ((), {}))
|
|
self.assertEqual(p(1,2), ((1,2), {}))
|
|
p = self.partial(capture, 1, 2)
|
|
self.assertEqual(p(), ((1,2), {}))
|
|
self.assertEqual(p(3,4), ((1,2,3,4), {}))
|
|
|
|
def test_kw_combinations(self):
|
|
# exercise special code paths for no keyword args in
|
|
# either the partial object or the caller
|
|
p = self.partial(capture)
|
|
self.assertEqual(p.keywords, {})
|
|
self.assertEqual(p(), ((), {}))
|
|
self.assertEqual(p(a=1), ((), {'a':1}))
|
|
p = self.partial(capture, a=1)
|
|
self.assertEqual(p.keywords, {'a':1})
|
|
self.assertEqual(p(), ((), {'a':1}))
|
|
self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
|
|
# keyword args in the call override those in the partial object
|
|
self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
|
|
|
|
def test_positional(self):
|
|
# make sure positional arguments are captured correctly
|
|
for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
|
|
p = self.partial(capture, *args)
|
|
expected = args + ('x',)
|
|
got, empty = p('x')
|
|
self.assertTrue(expected == got and empty == {})
|
|
|
|
def test_keyword(self):
|
|
# make sure keyword arguments are captured correctly
|
|
for a in ['a', 0, None, 3.5]:
|
|
p = self.partial(capture, a=a)
|
|
expected = {'a':a,'x':None}
|
|
empty, got = p(x=None)
|
|
self.assertTrue(expected == got and empty == ())
|
|
|
|
def test_no_side_effects(self):
|
|
# make sure there are no side effects that affect subsequent calls
|
|
p = self.partial(capture, 0, a=1)
|
|
args1, kw1 = p(1, b=2)
|
|
self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
|
|
args2, kw2 = p()
|
|
self.assertTrue(args2 == (0,) and kw2 == {'a':1})
|
|
|
|
def test_error_propagation(self):
|
|
def f(x, y):
|
|
x / y
|
|
self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
|
|
self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
|
|
self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
|
|
self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
|
|
|
|
def test_weakref(self):
|
|
f = self.partial(int, base=16)
|
|
p = proxy(f)
|
|
self.assertEqual(f.func, p.func)
|
|
f = None
|
|
self.assertRaises(ReferenceError, getattr, p, 'func')
|
|
|
|
def test_with_bound_and_unbound_methods(self):
|
|
data = list(map(str, range(10)))
|
|
join = self.partial(str.join, '')
|
|
self.assertEqual(join(data), '0123456789')
|
|
join = self.partial(''.join)
|
|
self.assertEqual(join(data), '0123456789')
|
|
|
|
def test_nested_optimization(self):
|
|
partial = self.partial
|
|
inner = partial(signature, 'asdf')
|
|
nested = partial(inner, bar=True)
|
|
flat = partial(signature, 'asdf', bar=True)
|
|
self.assertEqual(signature(nested), signature(flat))
|
|
|
|
def test_nested_partial_with_attribute(self):
|
|
# see issue 25137
|
|
partial = self.partial
|
|
|
|
def foo(bar):
|
|
return bar
|
|
|
|
p = partial(foo, 'first')
|
|
p2 = partial(p, 'second')
|
|
p2.new_attr = 'spam'
|
|
self.assertEqual(p2.new_attr, 'spam')
|
|
|
|
def test_repr(self):
|
|
args = (object(), object())
|
|
args_repr = ', '.join(repr(a) for a in args)
|
|
kwargs = {'a': object(), 'b': object()}
|
|
kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
|
|
'b={b!r}, a={a!r}'.format_map(kwargs)]
|
|
if self.partial in (c_functools.partial, py_functools.partial):
|
|
name = 'functools.partial'
|
|
else:
|
|
name = self.partial.__name__
|
|
|
|
f = self.partial(capture)
|
|
self.assertEqual(f'{name}({capture!r})', repr(f))
|
|
|
|
f = self.partial(capture, *args)
|
|
self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
|
|
|
|
f = self.partial(capture, **kwargs)
|
|
self.assertIn(repr(f),
|
|
[f'{name}({capture!r}, {kwargs_repr})'
|
|
for kwargs_repr in kwargs_reprs])
|
|
|
|
f = self.partial(capture, *args, **kwargs)
|
|
self.assertIn(repr(f),
|
|
[f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
|
|
for kwargs_repr in kwargs_reprs])
|
|
|
|
def test_recursive_repr(self):
|
|
if self.partial in (c_functools.partial, py_functools.partial):
|
|
name = 'functools.partial'
|
|
else:
|
|
name = self.partial.__name__
|
|
|
|
f = self.partial(capture)
|
|
f.__setstate__((f, (), {}, {}))
|
|
try:
|
|
self.assertEqual(repr(f), '%s(...)' % (name,))
|
|
finally:
|
|
f.__setstate__((capture, (), {}, {}))
|
|
|
|
f = self.partial(capture)
|
|
f.__setstate__((capture, (f,), {}, {}))
|
|
try:
|
|
self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
|
|
finally:
|
|
f.__setstate__((capture, (), {}, {}))
|
|
|
|
f = self.partial(capture)
|
|
f.__setstate__((capture, (), {'a': f}, {}))
|
|
try:
|
|
self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
|
|
finally:
|
|
f.__setstate__((capture, (), {}, {}))
|
|
|
|
def test_pickle(self):
|
|
with self.AllowPickle():
|
|
f = self.partial(signature, ['asdf'], bar=[True])
|
|
f.attr = []
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
f_copy = pickle.loads(pickle.dumps(f, proto))
|
|
self.assertEqual(signature(f_copy), signature(f))
|
|
|
|
def test_copy(self):
|
|
f = self.partial(signature, ['asdf'], bar=[True])
|
|
f.attr = []
|
|
f_copy = copy.copy(f)
|
|
self.assertEqual(signature(f_copy), signature(f))
|
|
self.assertIs(f_copy.attr, f.attr)
|
|
self.assertIs(f_copy.args, f.args)
|
|
self.assertIs(f_copy.keywords, f.keywords)
|
|
|
|
def test_deepcopy(self):
|
|
f = self.partial(signature, ['asdf'], bar=[True])
|
|
f.attr = []
|
|
f_copy = copy.deepcopy(f)
|
|
self.assertEqual(signature(f_copy), signature(f))
|
|
self.assertIsNot(f_copy.attr, f.attr)
|
|
self.assertIsNot(f_copy.args, f.args)
|
|
self.assertIsNot(f_copy.args[0], f.args[0])
|
|
self.assertIsNot(f_copy.keywords, f.keywords)
|
|
self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
|
|
|
|
def test_setstate(self):
|
|
f = self.partial(signature)
|
|
f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
|
|
|
|
self.assertEqual(signature(f),
|
|
(capture, (1,), dict(a=10), dict(attr=[])))
|
|
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
|
|
|
|
f.__setstate__((capture, (1,), dict(a=10), None))
|
|
|
|
self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
|
|
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
|
|
|
|
f.__setstate__((capture, (1,), None, None))
|
|
#self.assertEqual(signature(f), (capture, (1,), {}, {}))
|
|
self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
|
|
self.assertEqual(f(2), ((1, 2), {}))
|
|
self.assertEqual(f(), ((1,), {}))
|
|
|
|
f.__setstate__((capture, (), {}, None))
|
|
self.assertEqual(signature(f), (capture, (), {}, {}))
|
|
self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
|
|
self.assertEqual(f(2), ((2,), {}))
|
|
self.assertEqual(f(), ((), {}))
|
|
|
|
def test_setstate_errors(self):
|
|
f = self.partial(signature)
|
|
self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
|
|
self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
|
|
self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
|
|
self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
|
|
self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
|
|
self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
|
|
self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
|
|
|
|
def test_setstate_subclasses(self):
|
|
f = self.partial(signature)
|
|
f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
|
|
s = signature(f)
|
|
self.assertEqual(s, (capture, (1,), dict(a=10), {}))
|
|
self.assertIs(type(s[1]), tuple)
|
|
self.assertIs(type(s[2]), dict)
|
|
r = f()
|
|
self.assertEqual(r, ((1,), {'a': 10}))
|
|
self.assertIs(type(r[0]), tuple)
|
|
self.assertIs(type(r[1]), dict)
|
|
|
|
f.__setstate__((capture, BadTuple((1,)), {}, None))
|
|
s = signature(f)
|
|
self.assertEqual(s, (capture, (1,), {}, {}))
|
|
self.assertIs(type(s[1]), tuple)
|
|
r = f(2)
|
|
self.assertEqual(r, ((1, 2), {}))
|
|
self.assertIs(type(r[0]), tuple)
|
|
|
|
def test_recursive_pickle(self):
|
|
with self.AllowPickle():
|
|
if cosmo.MODE == "dbg":
|
|
f = self.partial(capture)
|
|
f.__setstate__((f, (), {}, {}))
|
|
try:
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
with self.assertRaises(RecursionError):
|
|
pickle.dumps(f, proto)
|
|
finally:
|
|
f.__setstate__((capture, (), {}, {}))
|
|
|
|
f = self.partial(capture)
|
|
f.__setstate__((capture, (f,), {}, {}))
|
|
try:
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
f_copy = pickle.loads(pickle.dumps(f, proto))
|
|
try:
|
|
self.assertIs(f_copy.args[0], f_copy)
|
|
finally:
|
|
f_copy.__setstate__((capture, (), {}, {}))
|
|
finally:
|
|
f.__setstate__((capture, (), {}, {}))
|
|
|
|
f = self.partial(capture)
|
|
f.__setstate__((capture, (), {'a': f}, {}))
|
|
try:
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
f_copy = pickle.loads(pickle.dumps(f, proto))
|
|
try:
|
|
self.assertIs(f_copy.keywords['a'], f_copy)
|
|
finally:
|
|
f_copy.__setstate__((capture, (), {}, {}))
|
|
finally:
|
|
f.__setstate__((capture, (), {}, {}))
|
|
|
|
# Issue 6083: Reference counting bug
|
|
def test_setstate_refcount(self):
|
|
class BadSequence:
|
|
def __len__(self):
|
|
return 4
|
|
def __getitem__(self, key):
|
|
if key == 0:
|
|
return max
|
|
elif key == 1:
|
|
return tuple(range(1000000))
|
|
elif key in (2, 3):
|
|
return {}
|
|
raise IndexError
|
|
|
|
f = self.partial(object)
|
|
self.assertRaises(TypeError, f.__setstate__, BadSequence())
|
|
|
|
@unittest.skipUnless(c_functools, 'requires the C _functools module')
|
|
class TestPartialC(TestPartial, unittest.TestCase):
|
|
if c_functools:
|
|
partial = c_functools.partial
|
|
|
|
class AllowPickle:
|
|
def __enter__(self):
|
|
return self
|
|
def __exit__(self, type, value, tb):
|
|
return False
|
|
|
|
def test_attributes_unwritable(self):
|
|
# attributes should not be writable
|
|
p = self.partial(capture, 1, 2, a=10, b=20)
|
|
self.assertRaises(AttributeError, setattr, p, 'func', map)
|
|
self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
|
|
self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
|
|
|
|
p = self.partial(hex)
|
|
try:
|
|
del p.__dict__
|
|
except TypeError:
|
|
pass
|
|
else:
|
|
self.fail('partial object allowed __dict__ to be deleted')
|
|
|
|
def test_manually_adding_non_string_keyword(self):
|
|
p = self.partial(capture)
|
|
# Adding a non-string/unicode keyword to partial kwargs
|
|
p.keywords[1234] = 'value'
|
|
r = repr(p)
|
|
self.assertIn('1234', r)
|
|
self.assertIn("'value'", r)
|
|
with self.assertRaises(TypeError):
|
|
p()
|
|
|
|
def test_keystr_replaces_value(self):
|
|
p = self.partial(capture)
|
|
|
|
class MutatesYourDict(object):
|
|
def __str__(self):
|
|
p.keywords[self] = ['sth2']
|
|
return 'astr'
|
|
|
|
# Replacing the value during key formatting should keep the original
|
|
# value alive (at least long enough).
|
|
p.keywords[MutatesYourDict()] = ['sth']
|
|
r = repr(p)
|
|
self.assertIn('astr', r)
|
|
self.assertIn("['sth']", r)
|
|
|
|
@unittest.skipIf(c_functools, "skip pure-python test if C impl is present")
|
|
class TestPartialPy(TestPartial, unittest.TestCase):
|
|
partial = py_functools.partial
|
|
|
|
class AllowPickle:
|
|
def __init__(self):
|
|
self._cm = replaced_module("functools", py_functools)
|
|
def __enter__(self):
|
|
return self._cm.__enter__()
|
|
def __exit__(self, type, value, tb):
|
|
return self._cm.__exit__(type, value, tb)
|
|
|
|
if c_functools:
|
|
class CPartialSubclass(c_functools.partial):
|
|
pass
|
|
|
|
class PyPartialSubclass(py_functools.partial):
|
|
pass
|
|
|
|
@unittest.skipUnless(c_functools, 'requires the C _functools module')
|
|
class TestPartialCSubclass(TestPartialC):
|
|
if c_functools:
|
|
partial = CPartialSubclass
|
|
|
|
# partial subclasses are not optimized for nested calls
|
|
test_nested_optimization = None
|
|
|
|
@unittest.skipIf(c_functools, "skip pure-python test if C impl is present")
|
|
class TestPartialPySubclass(TestPartialPy):
|
|
partial = PyPartialSubclass
|
|
|
|
class TestPartialMethod(unittest.TestCase):
|
|
|
|
class A(object):
|
|
nothing = functools.partialmethod(capture)
|
|
positional = functools.partialmethod(capture, 1)
|
|
keywords = functools.partialmethod(capture, a=2)
|
|
both = functools.partialmethod(capture, 3, b=4)
|
|
|
|
nested = functools.partialmethod(positional, 5)
|
|
|
|
over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
|
|
|
|
static = functools.partialmethod(staticmethod(capture), 8)
|
|
cls = functools.partialmethod(classmethod(capture), d=9)
|
|
|
|
a = A()
|
|
|
|
def test_arg_combinations(self):
|
|
self.assertEqual(self.a.nothing(), ((self.a,), {}))
|
|
self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
|
|
self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
|
|
self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
|
|
|
|
self.assertEqual(self.a.positional(), ((self.a, 1), {}))
|
|
self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
|
|
self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
|
|
self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
|
|
|
|
self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
|
|
self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
|
|
self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
|
|
self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
|
|
|
|
self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
|
|
self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
|
|
self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
|
|
self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
|
|
|
|
self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
|
|
|
|
def test_nested(self):
|
|
self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
|
|
self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
|
|
self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
|
|
self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
|
|
|
|
self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
|
|
|
|
def test_over_partial(self):
|
|
self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
|
|
self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
|
|
self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
|
|
self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
|
|
|
|
self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
|
|
|
|
def test_bound_method_introspection(self):
|
|
obj = self.a
|
|
self.assertIs(obj.both.__self__, obj)
|
|
self.assertIs(obj.nested.__self__, obj)
|
|
self.assertIs(obj.over_partial.__self__, obj)
|
|
self.assertIs(obj.cls.__self__, self.A)
|
|
self.assertIs(self.A.cls.__self__, self.A)
|
|
|
|
def test_unbound_method_retrieval(self):
|
|
obj = self.A
|
|
self.assertFalse(hasattr(obj.both, "__self__"))
|
|
self.assertFalse(hasattr(obj.nested, "__self__"))
|
|
self.assertFalse(hasattr(obj.over_partial, "__self__"))
|
|
self.assertFalse(hasattr(obj.static, "__self__"))
|
|
self.assertFalse(hasattr(self.a.static, "__self__"))
|
|
|
|
def test_descriptors(self):
|
|
for obj in [self.A, self.a]:
|
|
with self.subTest(obj=obj):
|
|
self.assertEqual(obj.static(), ((8,), {}))
|
|
self.assertEqual(obj.static(5), ((8, 5), {}))
|
|
self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
|
|
self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
|
|
|
|
self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
|
|
self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
|
|
self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
|
|
self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
|
|
|
|
def test_overriding_keywords(self):
|
|
self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
|
|
self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
|
|
|
|
def test_invalid_args(self):
|
|
with self.assertRaises(TypeError):
|
|
class B(object):
|
|
method = functools.partialmethod(None, 1)
|
|
|
|
def test_repr(self):
|
|
self.assertEqual(repr(vars(self.A)['both']),
|
|
'functools.partialmethod({}, 3, b=4)'.format(capture))
|
|
|
|
def test_abstract(self):
|
|
class Abstract(abc.ABCMeta):
|
|
|
|
@abc.abstractmethod
|
|
def add(self, x, y):
|
|
pass
|
|
|
|
add5 = functools.partialmethod(add, 5)
|
|
|
|
self.assertTrue(Abstract.add.__isabstractmethod__)
|
|
self.assertTrue(Abstract.add5.__isabstractmethod__)
|
|
|
|
for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
|
|
self.assertFalse(getattr(func, '__isabstractmethod__', False))
|
|
|
|
|
|
class TestUpdateWrapper(unittest.TestCase):
|
|
|
|
def check_wrapper(self, wrapper, wrapped,
|
|
assigned=functools.WRAPPER_ASSIGNMENTS,
|
|
updated=functools.WRAPPER_UPDATES):
|
|
# Check attributes were assigned
|
|
for name in assigned:
|
|
self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
|
|
# Check attributes were updated
|
|
for name in updated:
|
|
wrapper_attr = getattr(wrapper, name)
|
|
wrapped_attr = getattr(wrapped, name)
|
|
for key in wrapped_attr:
|
|
if name == "__dict__" and key == "__wrapped__":
|
|
# __wrapped__ is overwritten by the update code
|
|
continue
|
|
self.assertIs(wrapped_attr[key], wrapper_attr[key])
|
|
# Check __wrapped__
|
|
self.assertIs(wrapper.__wrapped__, wrapped)
|
|
|
|
|
|
def _default_update(self):
|
|
def f(a:'This is a new annotation'):
|
|
"""This is a test"""
|
|
pass
|
|
f.attr = 'This is also a test'
|
|
f.__wrapped__ = "This is a bald faced lie"
|
|
def wrapper(b:'This is the prior annotation'):
|
|
pass
|
|
functools.update_wrapper(wrapper, f)
|
|
return wrapper, f
|
|
|
|
def test_default_update(self):
|
|
wrapper, f = self._default_update()
|
|
self.check_wrapper(wrapper, f)
|
|
self.assertIs(wrapper.__wrapped__, f)
|
|
self.assertEqual(wrapper.__name__, 'f')
|
|
self.assertEqual(wrapper.__qualname__, f.__qualname__)
|
|
self.assertEqual(wrapper.attr, 'This is also a test')
|
|
self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
|
|
self.assertNotIn('b', wrapper.__annotations__)
|
|
|
|
@unittest.skipIf(sys.flags.optimize >= 2,
|
|
"Docstrings are omitted with -O2 and above")
|
|
@unittest.skipIf(cosmo.MODE.startswith("tiny"),
|
|
"No .py files available in Cosmo MODE=tiny")
|
|
def test_default_update_doc(self):
|
|
wrapper, f = self._default_update()
|
|
self.assertEqual(wrapper.__doc__, 'This is a test')
|
|
|
|
def test_no_update(self):
|
|
def f():
|
|
"""This is a test"""
|
|
pass
|
|
f.attr = 'This is also a test'
|
|
def wrapper():
|
|
pass
|
|
functools.update_wrapper(wrapper, f, (), ())
|
|
self.check_wrapper(wrapper, f, (), ())
|
|
self.assertEqual(wrapper.__name__, 'wrapper')
|
|
self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
|
|
self.assertEqual(wrapper.__doc__, None)
|
|
self.assertEqual(wrapper.__annotations__, {})
|
|
self.assertFalse(hasattr(wrapper, 'attr'))
|
|
|
|
def test_selective_update(self):
|
|
def f():
|
|
pass
|
|
f.attr = 'This is a different test'
|
|
f.dict_attr = dict(a=1, b=2, c=3)
|
|
def wrapper():
|
|
pass
|
|
wrapper.dict_attr = {}
|
|
assign = ('attr',)
|
|
update = ('dict_attr',)
|
|
functools.update_wrapper(wrapper, f, assign, update)
|
|
self.check_wrapper(wrapper, f, assign, update)
|
|
self.assertEqual(wrapper.__name__, 'wrapper')
|
|
self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
|
|
self.assertEqual(wrapper.__doc__, None)
|
|
self.assertEqual(wrapper.attr, 'This is a different test')
|
|
self.assertEqual(wrapper.dict_attr, f.dict_attr)
|
|
|
|
def test_missing_attributes(self):
|
|
def f():
|
|
pass
|
|
def wrapper():
|
|
pass
|
|
wrapper.dict_attr = {}
|
|
assign = ('attr',)
|
|
update = ('dict_attr',)
|
|
# Missing attributes on wrapped object are ignored
|
|
functools.update_wrapper(wrapper, f, assign, update)
|
|
self.assertNotIn('attr', wrapper.__dict__)
|
|
self.assertEqual(wrapper.dict_attr, {})
|
|
# Wrapper must have expected attributes for updating
|
|
del wrapper.dict_attr
|
|
with self.assertRaises(AttributeError):
|
|
functools.update_wrapper(wrapper, f, assign, update)
|
|
wrapper.dict_attr = 1
|
|
with self.assertRaises(AttributeError):
|
|
functools.update_wrapper(wrapper, f, assign, update)
|
|
|
|
@support.requires_docstrings
|
|
@unittest.skipIf(sys.flags.optimize >= 2,
|
|
"Docstrings are omitted with -O2 and above")
|
|
@unittest.skipIf(cosmo.MODE .startswith("tiny"),
|
|
"No .py files available in Cosmo MODE=tiny")
|
|
def test_builtin_update(self):
|
|
# Test for bug #1576241
|
|
def wrapper():
|
|
pass
|
|
functools.update_wrapper(wrapper, max)
|
|
self.assertEqual(wrapper.__name__, 'max')
|
|
self.assertTrue(wrapper.__doc__.startswith('max('))
|
|
self.assertEqual(wrapper.__annotations__, {})
|
|
|
|
|
|
class TestWraps(TestUpdateWrapper):
|
|
|
|
def _default_update(self):
|
|
def f():
|
|
"""This is a test"""
|
|
pass
|
|
f.attr = 'This is also a test'
|
|
f.__wrapped__ = "This is still a bald faced lie"
|
|
@functools.wraps(f)
|
|
def wrapper():
|
|
pass
|
|
return wrapper, f
|
|
|
|
def test_default_update(self):
|
|
wrapper, f = self._default_update()
|
|
self.check_wrapper(wrapper, f)
|
|
self.assertEqual(wrapper.__name__, 'f')
|
|
self.assertEqual(wrapper.__qualname__, f.__qualname__)
|
|
self.assertEqual(wrapper.attr, 'This is also a test')
|
|
|
|
@unittest.skipIf(sys.flags.optimize >= 2,
|
|
"Docstrings are omitted with -O2 and above")
|
|
@unittest.skipIf(cosmo.MODE .startswith("tiny"),
|
|
"No .py files available in Cosmo MODE=tiny")
|
|
def test_default_update_doc(self):
|
|
wrapper, _ = self._default_update()
|
|
self.assertEqual(wrapper.__doc__, 'This is a test')
|
|
|
|
def test_no_update(self):
|
|
def f():
|
|
"""This is a test"""
|
|
pass
|
|
f.attr = 'This is also a test'
|
|
@functools.wraps(f, (), ())
|
|
def wrapper():
|
|
pass
|
|
self.check_wrapper(wrapper, f, (), ())
|
|
self.assertEqual(wrapper.__name__, 'wrapper')
|
|
self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
|
|
self.assertEqual(wrapper.__doc__, None)
|
|
self.assertFalse(hasattr(wrapper, 'attr'))
|
|
|
|
def test_selective_update(self):
|
|
def f():
|
|
pass
|
|
f.attr = 'This is a different test'
|
|
f.dict_attr = dict(a=1, b=2, c=3)
|
|
def add_dict_attr(f):
|
|
f.dict_attr = {}
|
|
return f
|
|
assign = ('attr',)
|
|
update = ('dict_attr',)
|
|
@functools.wraps(f, assign, update)
|
|
@add_dict_attr
|
|
def wrapper():
|
|
pass
|
|
self.check_wrapper(wrapper, f, assign, update)
|
|
self.assertEqual(wrapper.__name__, 'wrapper')
|
|
self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
|
|
self.assertEqual(wrapper.__doc__, None)
|
|
self.assertEqual(wrapper.attr, 'This is a different test')
|
|
self.assertEqual(wrapper.dict_attr, f.dict_attr)
|
|
|
|
@unittest.skipUnless(c_functools, 'requires the C _functools module')
|
|
class TestReduce(unittest.TestCase):
|
|
if c_functools:
|
|
func = c_functools.reduce
|
|
|
|
def test_reduce(self):
|
|
class Squares:
|
|
def __init__(self, max):
|
|
self.max = max
|
|
self.sofar = []
|
|
|
|
def __len__(self):
|
|
return len(self.sofar)
|
|
|
|
def __getitem__(self, i):
|
|
if not 0 <= i < self.max: raise IndexError
|
|
n = len(self.sofar)
|
|
while n <= i:
|
|
self.sofar.append(n*n)
|
|
n += 1
|
|
return self.sofar[i]
|
|
def add(x, y):
|
|
return x + y
|
|
self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
|
|
self.assertEqual(
|
|
self.func(add, [['a', 'c'], [], ['d', 'w']], []),
|
|
['a','c','d','w']
|
|
)
|
|
self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
|
|
self.assertEqual(
|
|
self.func(lambda x, y: x*y, range(2,21), 1),
|
|
2432902008176640000
|
|
)
|
|
self.assertEqual(self.func(add, Squares(10)), 285)
|
|
self.assertEqual(self.func(add, Squares(10), 0), 285)
|
|
self.assertEqual(self.func(add, Squares(0), 0), 0)
|
|
self.assertRaises(TypeError, self.func)
|
|
self.assertRaises(TypeError, self.func, 42, 42)
|
|
self.assertRaises(TypeError, self.func, 42, 42, 42)
|
|
self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
|
|
self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
|
|
self.assertRaises(TypeError, self.func, 42, (42, 42))
|
|
self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
|
|
self.assertRaises(TypeError, self.func, add, "")
|
|
self.assertRaises(TypeError, self.func, add, ())
|
|
self.assertRaises(TypeError, self.func, add, object())
|
|
|
|
class TestFailingIter:
|
|
def __iter__(self):
|
|
raise RuntimeError
|
|
self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
|
|
|
|
self.assertEqual(self.func(add, [], None), None)
|
|
self.assertEqual(self.func(add, [], 42), 42)
|
|
|
|
class BadSeq:
|
|
def __getitem__(self, index):
|
|
raise ValueError
|
|
self.assertRaises(ValueError, self.func, 42, BadSeq())
|
|
|
|
# Test reduce()'s use of iterators.
|
|
def test_iterator_usage(self):
|
|
class SequenceClass:
|
|
def __init__(self, n):
|
|
self.n = n
|
|
def __getitem__(self, i):
|
|
if 0 <= i < self.n:
|
|
return i
|
|
else:
|
|
raise IndexError
|
|
|
|
from operator import add
|
|
self.assertEqual(self.func(add, SequenceClass(5)), 10)
|
|
self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
|
|
self.assertRaises(TypeError, self.func, add, SequenceClass(0))
|
|
self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
|
|
self.assertEqual(self.func(add, SequenceClass(1)), 0)
|
|
self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
|
|
|
|
d = {"one": 1, "two": 2, "three": 3}
|
|
self.assertEqual(self.func(add, d), "".join(d.keys()))
|
|
|
|
|
|
class TestCmpToKey:
|
|
|
|
def test_cmp_to_key(self):
|
|
def cmp1(x, y):
|
|
return (x > y) - (x < y)
|
|
key = self.cmp_to_key(cmp1)
|
|
self.assertEqual(key(3), key(3))
|
|
self.assertGreater(key(3), key(1))
|
|
self.assertGreaterEqual(key(3), key(3))
|
|
|
|
def cmp2(x, y):
|
|
return int(x) - int(y)
|
|
key = self.cmp_to_key(cmp2)
|
|
self.assertEqual(key(4.0), key('4'))
|
|
self.assertLess(key(2), key('35'))
|
|
self.assertLessEqual(key(2), key('35'))
|
|
self.assertNotEqual(key(2), key('35'))
|
|
|
|
def test_cmp_to_key_arguments(self):
|
|
def cmp1(x, y):
|
|
return (x > y) - (x < y)
|
|
key = self.cmp_to_key(mycmp=cmp1)
|
|
self.assertEqual(key(obj=3), key(obj=3))
|
|
self.assertGreater(key(obj=3), key(obj=1))
|
|
with self.assertRaises((TypeError, AttributeError)):
|
|
key(3) > 1 # rhs is not a K object
|
|
with self.assertRaises((TypeError, AttributeError)):
|
|
1 < key(3) # lhs is not a K object
|
|
with self.assertRaises(TypeError):
|
|
key = self.cmp_to_key() # too few args
|
|
with self.assertRaises(TypeError):
|
|
key = self.cmp_to_key(cmp1, None) # too many args
|
|
key = self.cmp_to_key(cmp1)
|
|
with self.assertRaises(TypeError):
|
|
key() # too few args
|
|
with self.assertRaises(TypeError):
|
|
key(None, None) # too many args
|
|
|
|
def test_bad_cmp(self):
|
|
def cmp1(x, y):
|
|
raise ZeroDivisionError
|
|
key = self.cmp_to_key(cmp1)
|
|
with self.assertRaises(ZeroDivisionError):
|
|
key(3) > key(1)
|
|
|
|
class BadCmp:
|
|
def __lt__(self, other):
|
|
raise ZeroDivisionError
|
|
def cmp1(x, y):
|
|
return BadCmp()
|
|
with self.assertRaises(ZeroDivisionError):
|
|
key(3) > key(1)
|
|
|
|
def test_obj_field(self):
|
|
def cmp1(x, y):
|
|
return (x > y) - (x < y)
|
|
key = self.cmp_to_key(mycmp=cmp1)
|
|
self.assertEqual(key(50).obj, 50)
|
|
|
|
def test_sort_int(self):
|
|
def mycmp(x, y):
|
|
return y - x
|
|
self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
|
|
[4, 3, 2, 1, 0])
|
|
|
|
def test_sort_int_str(self):
|
|
def mycmp(x, y):
|
|
x, y = int(x), int(y)
|
|
return (x > y) - (x < y)
|
|
values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
|
|
values = sorted(values, key=self.cmp_to_key(mycmp))
|
|
self.assertEqual([int(value) for value in values],
|
|
[0, 1, 1, 2, 3, 4, 5, 7, 10])
|
|
|
|
def test_hash(self):
|
|
def mycmp(x, y):
|
|
return y - x
|
|
key = self.cmp_to_key(mycmp)
|
|
k = key(10)
|
|
self.assertRaises(TypeError, hash, k)
|
|
self.assertNotIsInstance(k, collections.Hashable)
|
|
|
|
|
|
@unittest.skipUnless(c_functools, 'requires the C _functools module')
|
|
class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
|
|
if c_functools:
|
|
cmp_to_key = c_functools.cmp_to_key
|
|
|
|
|
|
@unittest.skipIf(c_functools, "skip pure-python test if C impl is present")
|
|
class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
|
|
cmp_to_key = staticmethod(py_functools.cmp_to_key)
|
|
|
|
|
|
class TestTotalOrdering(unittest.TestCase):
|
|
|
|
def test_total_ordering_lt(self):
|
|
@functools.total_ordering
|
|
class A:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def __lt__(self, other):
|
|
return self.value < other.value
|
|
def __eq__(self, other):
|
|
return self.value == other.value
|
|
self.assertTrue(A(1) < A(2))
|
|
self.assertTrue(A(2) > A(1))
|
|
self.assertTrue(A(1) <= A(2))
|
|
self.assertTrue(A(2) >= A(1))
|
|
self.assertTrue(A(2) <= A(2))
|
|
self.assertTrue(A(2) >= A(2))
|
|
self.assertFalse(A(1) > A(2))
|
|
|
|
def test_total_ordering_le(self):
|
|
@functools.total_ordering
|
|
class A:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def __le__(self, other):
|
|
return self.value <= other.value
|
|
def __eq__(self, other):
|
|
return self.value == other.value
|
|
self.assertTrue(A(1) < A(2))
|
|
self.assertTrue(A(2) > A(1))
|
|
self.assertTrue(A(1) <= A(2))
|
|
self.assertTrue(A(2) >= A(1))
|
|
self.assertTrue(A(2) <= A(2))
|
|
self.assertTrue(A(2) >= A(2))
|
|
self.assertFalse(A(1) >= A(2))
|
|
|
|
def test_total_ordering_gt(self):
|
|
@functools.total_ordering
|
|
class A:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def __gt__(self, other):
|
|
return self.value > other.value
|
|
def __eq__(self, other):
|
|
return self.value == other.value
|
|
self.assertTrue(A(1) < A(2))
|
|
self.assertTrue(A(2) > A(1))
|
|
self.assertTrue(A(1) <= A(2))
|
|
self.assertTrue(A(2) >= A(1))
|
|
self.assertTrue(A(2) <= A(2))
|
|
self.assertTrue(A(2) >= A(2))
|
|
self.assertFalse(A(2) < A(1))
|
|
|
|
def test_total_ordering_ge(self):
|
|
@functools.total_ordering
|
|
class A:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def __ge__(self, other):
|
|
return self.value >= other.value
|
|
def __eq__(self, other):
|
|
return self.value == other.value
|
|
self.assertTrue(A(1) < A(2))
|
|
self.assertTrue(A(2) > A(1))
|
|
self.assertTrue(A(1) <= A(2))
|
|
self.assertTrue(A(2) >= A(1))
|
|
self.assertTrue(A(2) <= A(2))
|
|
self.assertTrue(A(2) >= A(2))
|
|
self.assertFalse(A(2) <= A(1))
|
|
|
|
def test_total_ordering_no_overwrite(self):
|
|
# new methods should not overwrite existing
|
|
@functools.total_ordering
|
|
class A(int):
|
|
pass
|
|
self.assertTrue(A(1) < A(2))
|
|
self.assertTrue(A(2) > A(1))
|
|
self.assertTrue(A(1) <= A(2))
|
|
self.assertTrue(A(2) >= A(1))
|
|
self.assertTrue(A(2) <= A(2))
|
|
self.assertTrue(A(2) >= A(2))
|
|
|
|
def test_no_operations_defined(self):
|
|
with self.assertRaises(ValueError):
|
|
@functools.total_ordering
|
|
class A:
|
|
pass
|
|
|
|
def test_type_error_when_not_implemented(self):
|
|
# bug 10042; ensure stack overflow does not occur
|
|
# when decorated types return NotImplemented
|
|
@functools.total_ordering
|
|
class ImplementsLessThan:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def __eq__(self, other):
|
|
if isinstance(other, ImplementsLessThan):
|
|
return self.value == other.value
|
|
return False
|
|
def __lt__(self, other):
|
|
if isinstance(other, ImplementsLessThan):
|
|
return self.value < other.value
|
|
return NotImplemented
|
|
|
|
@functools.total_ordering
|
|
class ImplementsGreaterThan:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def __eq__(self, other):
|
|
if isinstance(other, ImplementsGreaterThan):
|
|
return self.value == other.value
|
|
return False
|
|
def __gt__(self, other):
|
|
if isinstance(other, ImplementsGreaterThan):
|
|
return self.value > other.value
|
|
return NotImplemented
|
|
|
|
@functools.total_ordering
|
|
class ImplementsLessThanEqualTo:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def __eq__(self, other):
|
|
if isinstance(other, ImplementsLessThanEqualTo):
|
|
return self.value == other.value
|
|
return False
|
|
def __le__(self, other):
|
|
if isinstance(other, ImplementsLessThanEqualTo):
|
|
return self.value <= other.value
|
|
return NotImplemented
|
|
|
|
@functools.total_ordering
|
|
class ImplementsGreaterThanEqualTo:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def __eq__(self, other):
|
|
if isinstance(other, ImplementsGreaterThanEqualTo):
|
|
return self.value == other.value
|
|
return False
|
|
def __ge__(self, other):
|
|
if isinstance(other, ImplementsGreaterThanEqualTo):
|
|
return self.value >= other.value
|
|
return NotImplemented
|
|
|
|
@functools.total_ordering
|
|
class ComparatorNotImplemented:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def __eq__(self, other):
|
|
if isinstance(other, ComparatorNotImplemented):
|
|
return self.value == other.value
|
|
return False
|
|
def __lt__(self, other):
|
|
return NotImplemented
|
|
|
|
with self.subTest("LT < 1"), self.assertRaises(TypeError):
|
|
ImplementsLessThan(-1) < 1
|
|
|
|
with self.subTest("LT < LE"), self.assertRaises(TypeError):
|
|
ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
|
|
|
|
with self.subTest("LT < GT"), self.assertRaises(TypeError):
|
|
ImplementsLessThan(1) < ImplementsGreaterThan(1)
|
|
|
|
with self.subTest("LE <= LT"), self.assertRaises(TypeError):
|
|
ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
|
|
|
|
with self.subTest("LE <= GE"), self.assertRaises(TypeError):
|
|
ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
|
|
|
|
with self.subTest("GT > GE"), self.assertRaises(TypeError):
|
|
ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
|
|
|
|
with self.subTest("GT > LT"), self.assertRaises(TypeError):
|
|
ImplementsGreaterThan(5) > ImplementsLessThan(5)
|
|
|
|
with self.subTest("GE >= GT"), self.assertRaises(TypeError):
|
|
ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
|
|
|
|
with self.subTest("GE >= LE"), self.assertRaises(TypeError):
|
|
ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
|
|
|
|
with self.subTest("GE when equal"):
|
|
a = ComparatorNotImplemented(8)
|
|
b = ComparatorNotImplemented(8)
|
|
self.assertEqual(a, b)
|
|
with self.assertRaises(TypeError):
|
|
a >= b
|
|
|
|
with self.subTest("LE when equal"):
|
|
a = ComparatorNotImplemented(9)
|
|
b = ComparatorNotImplemented(9)
|
|
self.assertEqual(a, b)
|
|
with self.assertRaises(TypeError):
|
|
a <= b
|
|
|
|
def test_pickle(self):
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
for name in '__lt__', '__gt__', '__le__', '__ge__':
|
|
with self.subTest(method=name, proto=proto):
|
|
method = getattr(Orderable_LT, name)
|
|
method_copy = pickle.loads(pickle.dumps(method, proto))
|
|
self.assertIs(method_copy, method)
|
|
|
|
@functools.total_ordering
|
|
class Orderable_LT:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
def __lt__(self, other):
|
|
return self.value < other.value
|
|
def __eq__(self, other):
|
|
return self.value == other.value
|
|
|
|
|
|
class TestLRU:
|
|
|
|
def test_lru(self):
|
|
def orig(x, y):
|
|
return 3 * x + y
|
|
f = self.module.lru_cache(maxsize=20)(orig)
|
|
hits, misses, maxsize, currsize = f.cache_info()
|
|
self.assertEqual(maxsize, 20)
|
|
self.assertEqual(currsize, 0)
|
|
self.assertEqual(hits, 0)
|
|
self.assertEqual(misses, 0)
|
|
|
|
domain = range(5)
|
|
for i in range(1000):
|
|
x, y = choice(domain), choice(domain)
|
|
actual = f(x, y)
|
|
expected = orig(x, y)
|
|
self.assertEqual(actual, expected)
|
|
hits, misses, maxsize, currsize = f.cache_info()
|
|
self.assertTrue(hits > misses)
|
|
self.assertEqual(hits + misses, 1000)
|
|
self.assertEqual(currsize, 20)
|
|
|
|
f.cache_clear() # test clearing
|
|
hits, misses, maxsize, currsize = f.cache_info()
|
|
self.assertEqual(hits, 0)
|
|
self.assertEqual(misses, 0)
|
|
self.assertEqual(currsize, 0)
|
|
f(x, y)
|
|
hits, misses, maxsize, currsize = f.cache_info()
|
|
self.assertEqual(hits, 0)
|
|
self.assertEqual(misses, 1)
|
|
self.assertEqual(currsize, 1)
|
|
|
|
# Test bypassing the cache
|
|
self.assertIs(f.__wrapped__, orig)
|
|
f.__wrapped__(x, y)
|
|
hits, misses, maxsize, currsize = f.cache_info()
|
|
self.assertEqual(hits, 0)
|
|
self.assertEqual(misses, 1)
|
|
self.assertEqual(currsize, 1)
|
|
|
|
# test size zero (which means "never-cache")
|
|
@self.module.lru_cache(0)
|
|
def f():
|
|
nonlocal f_cnt
|
|
f_cnt += 1
|
|
return 20
|
|
self.assertEqual(f.cache_info().maxsize, 0)
|
|
f_cnt = 0
|
|
for i in range(5):
|
|
self.assertEqual(f(), 20)
|
|
self.assertEqual(f_cnt, 5)
|
|
hits, misses, maxsize, currsize = f.cache_info()
|
|
self.assertEqual(hits, 0)
|
|
self.assertEqual(misses, 5)
|
|
self.assertEqual(currsize, 0)
|
|
|
|
# test size one
|
|
@self.module.lru_cache(1)
|
|
def f():
|
|
nonlocal f_cnt
|
|
f_cnt += 1
|
|
return 20
|
|
self.assertEqual(f.cache_info().maxsize, 1)
|
|
f_cnt = 0
|
|
for i in range(5):
|
|
self.assertEqual(f(), 20)
|
|
self.assertEqual(f_cnt, 1)
|
|
hits, misses, maxsize, currsize = f.cache_info()
|
|
self.assertEqual(hits, 4)
|
|
self.assertEqual(misses, 1)
|
|
self.assertEqual(currsize, 1)
|
|
|
|
# test size two
|
|
@self.module.lru_cache(2)
|
|
def f(x):
|
|
nonlocal f_cnt
|
|
f_cnt += 1
|
|
return x*10
|
|
self.assertEqual(f.cache_info().maxsize, 2)
|
|
f_cnt = 0
|
|
for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
|
|
# * * * *
|
|
self.assertEqual(f(x), x*10)
|
|
self.assertEqual(f_cnt, 4)
|
|
hits, misses, maxsize, currsize = f.cache_info()
|
|
self.assertEqual(hits, 12)
|
|
self.assertEqual(misses, 4)
|
|
self.assertEqual(currsize, 2)
|
|
|
|
def test_lru_reentrancy_with_len(self):
|
|
# Test to make sure the LRU cache code isn't thrown-off by
|
|
# caching the built-in len() function. Since len() can be
|
|
# cached, we shouldn't use it inside the lru code itself.
|
|
old_len = builtins.len
|
|
try:
|
|
builtins.len = self.module.lru_cache(4)(len)
|
|
for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
|
|
self.assertEqual(len('abcdefghijklmn'[:i]), i)
|
|
finally:
|
|
builtins.len = old_len
|
|
|
|
def test_lru_type_error(self):
|
|
# Regression test for issue #28653.
|
|
# lru_cache was leaking when one of the arguments
|
|
# wasn't cacheable.
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def infinite_cache(o):
|
|
pass
|
|
|
|
@functools.lru_cache(maxsize=10)
|
|
def limited_cache(o):
|
|
pass
|
|
|
|
with self.assertRaises(TypeError):
|
|
infinite_cache([])
|
|
|
|
with self.assertRaises(TypeError):
|
|
limited_cache([])
|
|
|
|
def test_lru_with_maxsize_none(self):
|
|
@self.module.lru_cache(maxsize=None)
|
|
def fib(n):
|
|
if n < 2:
|
|
return n
|
|
return fib(n-1) + fib(n-2)
|
|
self.assertEqual([fib(n) for n in range(16)],
|
|
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
|
|
self.assertEqual(fib.cache_info(),
|
|
self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
|
|
fib.cache_clear()
|
|
self.assertEqual(fib.cache_info(),
|
|
self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
|
|
|
|
def test_lru_with_maxsize_negative(self):
|
|
@self.module.lru_cache(maxsize=-10)
|
|
def eq(n):
|
|
return n
|
|
for i in (0, 1):
|
|
self.assertEqual([eq(n) for n in range(150)], list(range(150)))
|
|
self.assertEqual(eq.cache_info(),
|
|
self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
|
|
|
|
def test_lru_with_exceptions(self):
|
|
# Verify that user_function exceptions get passed through without
|
|
# creating a hard-to-read chained exception.
|
|
# http://bugs.python.org/issue13177
|
|
for maxsize in (None, 128):
|
|
@self.module.lru_cache(maxsize)
|
|
def func(i):
|
|
return 'abc'[i]
|
|
self.assertEqual(func(0), 'a')
|
|
with self.assertRaises(IndexError) as cm:
|
|
func(15)
|
|
self.assertIsNone(cm.exception.__context__)
|
|
# Verify that the previous exception did not result in a cached entry
|
|
with self.assertRaises(IndexError):
|
|
func(15)
|
|
|
|
def test_lru_with_types(self):
|
|
for maxsize in (None, 128):
|
|
@self.module.lru_cache(maxsize=maxsize, typed=True)
|
|
def square(x):
|
|
return x * x
|
|
self.assertEqual(square(3), 9)
|
|
self.assertEqual(type(square(3)), type(9))
|
|
self.assertEqual(square(3.0), 9.0)
|
|
self.assertEqual(type(square(3.0)), type(9.0))
|
|
self.assertEqual(square(x=3), 9)
|
|
self.assertEqual(type(square(x=3)), type(9))
|
|
self.assertEqual(square(x=3.0), 9.0)
|
|
self.assertEqual(type(square(x=3.0)), type(9.0))
|
|
self.assertEqual(square.cache_info().hits, 4)
|
|
self.assertEqual(square.cache_info().misses, 4)
|
|
|
|
def test_lru_with_keyword_args(self):
|
|
@self.module.lru_cache()
|
|
def fib(n):
|
|
if n < 2:
|
|
return n
|
|
return fib(n=n-1) + fib(n=n-2)
|
|
self.assertEqual(
|
|
[fib(n=number) for number in range(16)],
|
|
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
|
|
)
|
|
self.assertEqual(fib.cache_info(),
|
|
self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
|
|
fib.cache_clear()
|
|
self.assertEqual(fib.cache_info(),
|
|
self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
|
|
|
|
def test_lru_with_keyword_args_maxsize_none(self):
|
|
@self.module.lru_cache(maxsize=None)
|
|
def fib(n):
|
|
if n < 2:
|
|
return n
|
|
return fib(n=n-1) + fib(n=n-2)
|
|
self.assertEqual([fib(n=number) for number in range(16)],
|
|
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
|
|
self.assertEqual(fib.cache_info(),
|
|
self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
|
|
fib.cache_clear()
|
|
self.assertEqual(fib.cache_info(),
|
|
self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
|
|
|
|
def test_kwargs_order(self):
|
|
# PEP 468: Preserving Keyword Argument Order
|
|
@self.module.lru_cache(maxsize=10)
|
|
def f(**kwargs):
|
|
return list(kwargs.items())
|
|
self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
|
|
self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
|
|
self.assertEqual(f.cache_info(),
|
|
self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
|
|
|
|
def test_lru_cache_decoration(self):
|
|
def f(zomg: 'zomg_annotation'):
|
|
"""f doc string"""
|
|
return 42
|
|
g = self.module.lru_cache()(f)
|
|
for attr in self.module.WRAPPER_ASSIGNMENTS:
|
|
self.assertEqual(getattr(g, attr), getattr(f, attr))
|
|
|
|
@unittest.skipUnless(threading, 'This test requires threading.')
|
|
def test_lru_cache_threaded(self):
|
|
n, m = 5, 11
|
|
def orig(x, y):
|
|
return 3 * x + y
|
|
f = self.module.lru_cache(maxsize=n*m)(orig)
|
|
hits, misses, maxsize, currsize = f.cache_info()
|
|
self.assertEqual(currsize, 0)
|
|
|
|
start = threading.Event()
|
|
def full(k):
|
|
start.wait(10)
|
|
for _ in range(m):
|
|
self.assertEqual(f(k, 0), orig(k, 0))
|
|
|
|
def clear():
|
|
start.wait(10)
|
|
for _ in range(2*m):
|
|
f.cache_clear()
|
|
|
|
orig_si = sys.getswitchinterval()
|
|
support.setswitchinterval(1e-6)
|
|
try:
|
|
# create n threads in order to fill cache
|
|
threads = [threading.Thread(target=full, args=[k])
|
|
for k in range(n)]
|
|
with support.start_threads(threads):
|
|
start.set()
|
|
|
|
hits, misses, maxsize, currsize = f.cache_info()
|
|
if self.module is py_functools:
|
|
# XXX: Why can be not equal?
|
|
self.assertLessEqual(misses, n)
|
|
self.assertLessEqual(hits, m*n - misses)
|
|
else:
|
|
self.assertEqual(misses, n)
|
|
self.assertEqual(hits, m*n - misses)
|
|
self.assertEqual(currsize, n)
|
|
|
|
# create n threads in order to fill cache and 1 to clear it
|
|
threads = [threading.Thread(target=clear)]
|
|
threads += [threading.Thread(target=full, args=[k])
|
|
for k in range(n)]
|
|
start.clear()
|
|
with support.start_threads(threads):
|
|
start.set()
|
|
finally:
|
|
sys.setswitchinterval(orig_si)
|
|
|
|
@unittest.skipUnless(threading, 'This test requires threading.')
|
|
def test_lru_cache_threaded2(self):
|
|
# Simultaneous call with the same arguments
|
|
n, m = 5, 7
|
|
start = threading.Barrier(n+1)
|
|
pause = threading.Barrier(n+1)
|
|
stop = threading.Barrier(n+1)
|
|
@self.module.lru_cache(maxsize=m*n)
|
|
def f(x):
|
|
pause.wait(10)
|
|
return 3 * x
|
|
self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
|
|
def test():
|
|
for i in range(m):
|
|
start.wait(10)
|
|
self.assertEqual(f(i), 3 * i)
|
|
stop.wait(10)
|
|
threads = [threading.Thread(target=test) for k in range(n)]
|
|
with support.start_threads(threads):
|
|
for i in range(m):
|
|
start.wait(10)
|
|
stop.reset()
|
|
pause.wait(10)
|
|
start.reset()
|
|
stop.wait(10)
|
|
pause.reset()
|
|
self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
|
|
|
|
@unittest.skipUnless(threading, 'This test requires threading.')
|
|
def test_lru_cache_threaded3(self):
|
|
@self.module.lru_cache(maxsize=2)
|
|
def f(x):
|
|
time.sleep(.01)
|
|
return 3 * x
|
|
def test(i, x):
|
|
with self.subTest(thread=i):
|
|
self.assertEqual(f(x), 3 * x, i)
|
|
threads = [threading.Thread(target=test, args=(i, v))
|
|
for i, v in enumerate([1, 2, 2, 3, 2])]
|
|
with support.start_threads(threads):
|
|
pass
|
|
|
|
def test_need_for_rlock(self):
|
|
# This will deadlock on an LRU cache that uses a regular lock
|
|
|
|
@self.module.lru_cache(maxsize=10)
|
|
def test_func(x):
|
|
'Used to demonstrate a reentrant lru_cache call within a single thread'
|
|
return x
|
|
|
|
class DoubleEq:
|
|
'Demonstrate a reentrant lru_cache call within a single thread'
|
|
def __init__(self, x):
|
|
self.x = x
|
|
def __hash__(self):
|
|
return self.x
|
|
def __eq__(self, other):
|
|
if self.x == 2:
|
|
test_func(DoubleEq(1))
|
|
return self.x == other.x
|
|
|
|
test_func(DoubleEq(1)) # Load the cache
|
|
test_func(DoubleEq(2)) # Load the cache
|
|
self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
|
|
DoubleEq(2)) # Verify the correct return value
|
|
|
|
def test_early_detection_of_bad_call(self):
|
|
# Issue #22184
|
|
with self.assertRaises(TypeError):
|
|
@functools.lru_cache
|
|
def f():
|
|
pass
|
|
|
|
def test_lru_method(self):
|
|
class X(int):
|
|
f_cnt = 0
|
|
@self.module.lru_cache(2)
|
|
def f(self, x):
|
|
self.f_cnt += 1
|
|
return x*10+self
|
|
a = X(5)
|
|
b = X(5)
|
|
c = X(7)
|
|
self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
|
|
|
|
for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
|
|
self.assertEqual(a.f(x), x*10 + 5)
|
|
self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
|
|
self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
|
|
|
|
for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
|
|
self.assertEqual(b.f(x), x*10 + 5)
|
|
self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
|
|
self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
|
|
|
|
for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
|
|
self.assertEqual(c.f(x), x*10 + 7)
|
|
self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
|
|
self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
|
|
|
|
self.assertEqual(a.f.cache_info(), X.f.cache_info())
|
|
self.assertEqual(b.f.cache_info(), X.f.cache_info())
|
|
self.assertEqual(c.f.cache_info(), X.f.cache_info())
|
|
|
|
def test_pickle(self):
|
|
cls = self.__class__
|
|
for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
|
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
with self.subTest(proto=proto, func=f):
|
|
f_copy = pickle.loads(pickle.dumps(f, proto))
|
|
self.assertIs(f_copy, f)
|
|
|
|
def test_copy(self):
|
|
cls = self.__class__
|
|
def orig(x, y):
|
|
return 3 * x + y
|
|
part = self.module.partial(orig, 2)
|
|
funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
|
|
self.module.lru_cache(2)(part))
|
|
for f in funcs:
|
|
with self.subTest(func=f):
|
|
f_copy = copy.copy(f)
|
|
self.assertIs(f_copy, f)
|
|
|
|
def test_deepcopy(self):
|
|
cls = self.__class__
|
|
def orig(x, y):
|
|
return 3 * x + y
|
|
part = self.module.partial(orig, 2)
|
|
funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
|
|
self.module.lru_cache(2)(part))
|
|
for f in funcs:
|
|
with self.subTest(func=f):
|
|
f_copy = copy.deepcopy(f)
|
|
self.assertIs(f_copy, f)
|
|
|
|
|
|
@py_functools.lru_cache()
|
|
def py_cached_func(x, y):
|
|
return 3 * x + y
|
|
|
|
@c_functools.lru_cache()
|
|
def c_cached_func(x, y):
|
|
return 3 * x + y
|
|
|
|
|
|
@unittest.skipIf(c_functools, "skip pure-python test if C impl is present")
|
|
class TestLRUPy(TestLRU, unittest.TestCase):
|
|
module = py_functools
|
|
cached_func = py_cached_func,
|
|
|
|
@module.lru_cache()
|
|
def cached_meth(self, x, y):
|
|
return 3 * x + y
|
|
|
|
@staticmethod
|
|
@module.lru_cache()
|
|
def cached_staticmeth(x, y):
|
|
return 3 * x + y
|
|
|
|
|
|
class TestLRUC(TestLRU, unittest.TestCase):
|
|
module = c_functools
|
|
cached_func = c_cached_func,
|
|
|
|
@module.lru_cache()
|
|
def cached_meth(self, x, y):
|
|
return 3 * x + y
|
|
|
|
@staticmethod
|
|
@module.lru_cache()
|
|
def cached_staticmeth(x, y):
|
|
return 3 * x + y
|
|
|
|
|
|
class TestSingleDispatch(unittest.TestCase):
|
|
def test_simple_overloads(self):
|
|
@functools.singledispatch
|
|
def g(obj):
|
|
return "base"
|
|
def g_int(i):
|
|
return "integer"
|
|
g.register(int, g_int)
|
|
self.assertEqual(g("str"), "base")
|
|
self.assertEqual(g(1), "integer")
|
|
self.assertEqual(g([1,2,3]), "base")
|
|
|
|
def test_mro(self):
|
|
@functools.singledispatch
|
|
def g(obj):
|
|
return "base"
|
|
class A:
|
|
pass
|
|
class C(A):
|
|
pass
|
|
class B(A):
|
|
pass
|
|
class D(C, B):
|
|
pass
|
|
def g_A(a):
|
|
return "A"
|
|
def g_B(b):
|
|
return "B"
|
|
g.register(A, g_A)
|
|
g.register(B, g_B)
|
|
self.assertEqual(g(A()), "A")
|
|
self.assertEqual(g(B()), "B")
|
|
self.assertEqual(g(C()), "A")
|
|
self.assertEqual(g(D()), "B")
|
|
|
|
def test_register_decorator(self):
|
|
@functools.singledispatch
|
|
def g(obj):
|
|
return "base"
|
|
@g.register(int)
|
|
def g_int(i):
|
|
return "int %s" % (i,)
|
|
self.assertEqual(g(""), "base")
|
|
self.assertEqual(g(12), "int 12")
|
|
self.assertIs(g.dispatch(int), g_int)
|
|
self.assertIs(g.dispatch(object), g.dispatch(str))
|
|
# Note: in the assert above this is not g.
|
|
# @singledispatch returns the wrapper.
|
|
|
|
@unittest.skipIf(cosmo.MODE in ('tiny', 'rel'),
|
|
"no pydocs in rel mode")
|
|
def test_wrapping_attributes(self):
|
|
@functools.singledispatch
|
|
def g(obj):
|
|
"Simple test"
|
|
return "Test"
|
|
self.assertEqual(g.__name__, "g")
|
|
if sys.flags.optimize < 2:
|
|
self.assertEqual(g.__doc__, "Simple test")
|
|
|
|
@unittest.skipUnless(decimal, 'requires _decimal')
|
|
@support.cpython_only
|
|
def test_c_classes(self):
|
|
@functools.singledispatch
|
|
def g(obj):
|
|
return "base"
|
|
@g.register(decimal.DecimalException)
|
|
def _(obj):
|
|
return obj.args
|
|
subn = decimal.Subnormal("Exponent < Emin")
|
|
rnd = decimal.Rounded("Number got rounded")
|
|
self.assertEqual(g(subn), ("Exponent < Emin",))
|
|
self.assertEqual(g(rnd), ("Number got rounded",))
|
|
@g.register(decimal.Subnormal)
|
|
def _(obj):
|
|
return "Too small to care."
|
|
self.assertEqual(g(subn), "Too small to care.")
|
|
self.assertEqual(g(rnd), ("Number got rounded",))
|
|
|
|
def test_compose_mro(self):
|
|
# None of the examples in this test depend on haystack ordering.
|
|
c = collections
|
|
mro = functools._compose_mro
|
|
bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
|
|
for haystack in permutations(bases):
|
|
m = mro(dict, haystack)
|
|
self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
|
|
c.Collection, c.Sized, c.Iterable,
|
|
c.Container, object])
|
|
bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
|
|
for haystack in permutations(bases):
|
|
m = mro(c.ChainMap, haystack)
|
|
self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
|
|
c.Collection, c.Sized, c.Iterable,
|
|
c.Container, object])
|
|
|
|
# If there's a generic function with implementations registered for
|
|
# both Sized and Container, passing a defaultdict to it results in an
|
|
# ambiguous dispatch which will cause a RuntimeError (see
|
|
# test_mro_conflicts).
|
|
bases = [c.Container, c.Sized, str]
|
|
for haystack in permutations(bases):
|
|
m = mro(c.defaultdict, [c.Sized, c.Container, str])
|
|
self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
|
|
object])
|
|
|
|
# MutableSequence below is registered directly on D. In other words, it
|
|
# precedes MutableMapping which means single dispatch will always
|
|
# choose MutableSequence here.
|
|
class D(c.defaultdict):
|
|
pass
|
|
c.MutableSequence.register(D)
|
|
bases = [c.MutableSequence, c.MutableMapping]
|
|
for haystack in permutations(bases):
|
|
m = mro(D, bases)
|
|
self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
|
|
c.defaultdict, dict, c.MutableMapping, c.Mapping,
|
|
c.Collection, c.Sized, c.Iterable, c.Container,
|
|
object])
|
|
|
|
# Container and Callable are registered on different base classes and
|
|
# a generic function supporting both should always pick the Callable
|
|
# implementation if a C instance is passed.
|
|
class C(c.defaultdict):
|
|
def __call__(self):
|
|
pass
|
|
bases = [c.Sized, c.Callable, c.Container, c.Mapping]
|
|
for haystack in permutations(bases):
|
|
m = mro(C, haystack)
|
|
self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
|
|
c.Collection, c.Sized, c.Iterable,
|
|
c.Container, object])
|
|
|
|
def test_register_abc(self):
|
|
c = collections
|
|
d = {"a": "b"}
|
|
l = [1, 2, 3]
|
|
s = {object(), None}
|
|
f = frozenset(s)
|
|
t = (1, 2, 3)
|
|
@functools.singledispatch
|
|
def g(obj):
|
|
return "base"
|
|
self.assertEqual(g(d), "base")
|
|
self.assertEqual(g(l), "base")
|
|
self.assertEqual(g(s), "base")
|
|
self.assertEqual(g(f), "base")
|
|
self.assertEqual(g(t), "base")
|
|
g.register(c.Sized, lambda obj: "sized")
|
|
self.assertEqual(g(d), "sized")
|
|
self.assertEqual(g(l), "sized")
|
|
self.assertEqual(g(s), "sized")
|
|
self.assertEqual(g(f), "sized")
|
|
self.assertEqual(g(t), "sized")
|
|
g.register(c.MutableMapping, lambda obj: "mutablemapping")
|
|
self.assertEqual(g(d), "mutablemapping")
|
|
self.assertEqual(g(l), "sized")
|
|
self.assertEqual(g(s), "sized")
|
|
self.assertEqual(g(f), "sized")
|
|
self.assertEqual(g(t), "sized")
|
|
g.register(c.ChainMap, lambda obj: "chainmap")
|
|
self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
|
|
self.assertEqual(g(l), "sized")
|
|
self.assertEqual(g(s), "sized")
|
|
self.assertEqual(g(f), "sized")
|
|
self.assertEqual(g(t), "sized")
|
|
g.register(c.MutableSequence, lambda obj: "mutablesequence")
|
|
self.assertEqual(g(d), "mutablemapping")
|
|
self.assertEqual(g(l), "mutablesequence")
|
|
self.assertEqual(g(s), "sized")
|
|
self.assertEqual(g(f), "sized")
|
|
self.assertEqual(g(t), "sized")
|
|
g.register(c.MutableSet, lambda obj: "mutableset")
|
|
self.assertEqual(g(d), "mutablemapping")
|
|
self.assertEqual(g(l), "mutablesequence")
|
|
self.assertEqual(g(s), "mutableset")
|
|
self.assertEqual(g(f), "sized")
|
|
self.assertEqual(g(t), "sized")
|
|
g.register(c.Mapping, lambda obj: "mapping")
|
|
self.assertEqual(g(d), "mutablemapping") # not specific enough
|
|
self.assertEqual(g(l), "mutablesequence")
|
|
self.assertEqual(g(s), "mutableset")
|
|
self.assertEqual(g(f), "sized")
|
|
self.assertEqual(g(t), "sized")
|
|
g.register(c.Sequence, lambda obj: "sequence")
|
|
self.assertEqual(g(d), "mutablemapping")
|
|
self.assertEqual(g(l), "mutablesequence")
|
|
self.assertEqual(g(s), "mutableset")
|
|
self.assertEqual(g(f), "sized")
|
|
self.assertEqual(g(t), "sequence")
|
|
g.register(c.Set, lambda obj: "set")
|
|
self.assertEqual(g(d), "mutablemapping")
|
|
self.assertEqual(g(l), "mutablesequence")
|
|
self.assertEqual(g(s), "mutableset")
|
|
self.assertEqual(g(f), "set")
|
|
self.assertEqual(g(t), "sequence")
|
|
g.register(dict, lambda obj: "dict")
|
|
self.assertEqual(g(d), "dict")
|
|
self.assertEqual(g(l), "mutablesequence")
|
|
self.assertEqual(g(s), "mutableset")
|
|
self.assertEqual(g(f), "set")
|
|
self.assertEqual(g(t), "sequence")
|
|
g.register(list, lambda obj: "list")
|
|
self.assertEqual(g(d), "dict")
|
|
self.assertEqual(g(l), "list")
|
|
self.assertEqual(g(s), "mutableset")
|
|
self.assertEqual(g(f), "set")
|
|
self.assertEqual(g(t), "sequence")
|
|
g.register(set, lambda obj: "concrete-set")
|
|
self.assertEqual(g(d), "dict")
|
|
self.assertEqual(g(l), "list")
|
|
self.assertEqual(g(s), "concrete-set")
|
|
self.assertEqual(g(f), "set")
|
|
self.assertEqual(g(t), "sequence")
|
|
g.register(frozenset, lambda obj: "frozen-set")
|
|
self.assertEqual(g(d), "dict")
|
|
self.assertEqual(g(l), "list")
|
|
self.assertEqual(g(s), "concrete-set")
|
|
self.assertEqual(g(f), "frozen-set")
|
|
self.assertEqual(g(t), "sequence")
|
|
g.register(tuple, lambda obj: "tuple")
|
|
self.assertEqual(g(d), "dict")
|
|
self.assertEqual(g(l), "list")
|
|
self.assertEqual(g(s), "concrete-set")
|
|
self.assertEqual(g(f), "frozen-set")
|
|
self.assertEqual(g(t), "tuple")
|
|
|
|
def test_c3_abc(self):
|
|
c = collections
|
|
mro = functools._c3_mro
|
|
class A(object):
|
|
pass
|
|
class B(A):
|
|
def __len__(self):
|
|
return 0 # implies Sized
|
|
@c.Container.register
|
|
class C(object):
|
|
pass
|
|
class D(object):
|
|
pass # unrelated
|
|
class X(D, C, B):
|
|
def __call__(self):
|
|
pass # implies Callable
|
|
expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
|
|
for abcs in permutations([c.Sized, c.Callable, c.Container]):
|
|
self.assertEqual(mro(X, abcs=abcs), expected)
|
|
# unrelated ABCs don't appear in the resulting MRO
|
|
many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
|
|
self.assertEqual(mro(X, abcs=many_abcs), expected)
|
|
|
|
def test_false_meta(self):
|
|
# see issue23572
|
|
class MetaA(type):
|
|
def __len__(self):
|
|
return 0
|
|
class A(metaclass=MetaA):
|
|
pass
|
|
class AA(A):
|
|
pass
|
|
@functools.singledispatch
|
|
def fun(a):
|
|
return 'base A'
|
|
@fun.register(A)
|
|
def _(a):
|
|
return 'fun A'
|
|
aa = AA()
|
|
self.assertEqual(fun(aa), 'fun A')
|
|
|
|
def test_mro_conflicts(self):
|
|
c = collections
|
|
@functools.singledispatch
|
|
def g(arg):
|
|
return "base"
|
|
class O(c.Sized):
|
|
def __len__(self):
|
|
return 0
|
|
o = O()
|
|
self.assertEqual(g(o), "base")
|
|
g.register(c.Iterable, lambda arg: "iterable")
|
|
g.register(c.Container, lambda arg: "container")
|
|
g.register(c.Sized, lambda arg: "sized")
|
|
g.register(c.Set, lambda arg: "set")
|
|
self.assertEqual(g(o), "sized")
|
|
c.Iterable.register(O)
|
|
self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
|
|
c.Container.register(O)
|
|
self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
|
|
c.Set.register(O)
|
|
self.assertEqual(g(o), "set") # because c.Set is a subclass of
|
|
# c.Sized and c.Container
|
|
class P:
|
|
pass
|
|
p = P()
|
|
self.assertEqual(g(p), "base")
|
|
c.Iterable.register(P)
|
|
self.assertEqual(g(p), "iterable")
|
|
c.Container.register(P)
|
|
with self.assertRaises(RuntimeError) as re_one:
|
|
g(p)
|
|
self.assertIn(
|
|
str(re_one.exception),
|
|
(("Ambiguous dispatch: <class 'collections.abc.Container'> "
|
|
"or <class 'collections.abc.Iterable'>"),
|
|
("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
|
|
"or <class 'collections.abc.Container'>")),
|
|
)
|
|
class Q(c.Sized):
|
|
def __len__(self):
|
|
return 0
|
|
q = Q()
|
|
self.assertEqual(g(q), "sized")
|
|
c.Iterable.register(Q)
|
|
self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
|
|
c.Set.register(Q)
|
|
self.assertEqual(g(q), "set") # because c.Set is a subclass of
|
|
# c.Sized and c.Iterable
|
|
@functools.singledispatch
|
|
def h(arg):
|
|
return "base"
|
|
@h.register(c.Sized)
|
|
def _(arg):
|
|
return "sized"
|
|
@h.register(c.Container)
|
|
def _(arg):
|
|
return "container"
|
|
# Even though Sized and Container are explicit bases of MutableMapping,
|
|
# this ABC is implicitly registered on defaultdict which makes all of
|
|
# MutableMapping's bases implicit as well from defaultdict's
|
|
# perspective.
|
|
with self.assertRaises(RuntimeError) as re_two:
|
|
h(c.defaultdict(lambda: 0))
|
|
self.assertIn(
|
|
str(re_two.exception),
|
|
(("Ambiguous dispatch: <class 'collections.abc.Container'> "
|
|
"or <class 'collections.abc.Sized'>"),
|
|
("Ambiguous dispatch: <class 'collections.abc.Sized'> "
|
|
"or <class 'collections.abc.Container'>")),
|
|
)
|
|
class R(c.defaultdict):
|
|
pass
|
|
c.MutableSequence.register(R)
|
|
@functools.singledispatch
|
|
def i(arg):
|
|
return "base"
|
|
@i.register(c.MutableMapping)
|
|
def _(arg):
|
|
return "mapping"
|
|
@i.register(c.MutableSequence)
|
|
def _(arg):
|
|
return "sequence"
|
|
r = R()
|
|
self.assertEqual(i(r), "sequence")
|
|
class S:
|
|
pass
|
|
class T(S, c.Sized):
|
|
def __len__(self):
|
|
return 0
|
|
t = T()
|
|
self.assertEqual(h(t), "sized")
|
|
c.Container.register(T)
|
|
self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
|
|
class U:
|
|
def __len__(self):
|
|
return 0
|
|
u = U()
|
|
self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
|
|
# from the existence of __len__()
|
|
c.Container.register(U)
|
|
# There is no preference for registered versus inferred ABCs.
|
|
with self.assertRaises(RuntimeError) as re_three:
|
|
h(u)
|
|
self.assertIn(
|
|
str(re_three.exception),
|
|
(("Ambiguous dispatch: <class 'collections.abc.Container'> "
|
|
"or <class 'collections.abc.Sized'>"),
|
|
("Ambiguous dispatch: <class 'collections.abc.Sized'> "
|
|
"or <class 'collections.abc.Container'>")),
|
|
)
|
|
class V(c.Sized, S):
|
|
def __len__(self):
|
|
return 0
|
|
@functools.singledispatch
|
|
def j(arg):
|
|
return "base"
|
|
@j.register(S)
|
|
def _(arg):
|
|
return "s"
|
|
@j.register(c.Container)
|
|
def _(arg):
|
|
return "container"
|
|
v = V()
|
|
self.assertEqual(j(v), "s")
|
|
c.Container.register(V)
|
|
self.assertEqual(j(v), "container") # because it ends up right after
|
|
# Sized in the MRO
|
|
|
|
def test_cache_invalidation(self):
|
|
from collections import UserDict
|
|
class TracingDict(UserDict):
|
|
def __init__(self, *args, **kwargs):
|
|
super(TracingDict, self).__init__(*args, **kwargs)
|
|
self.set_ops = []
|
|
self.get_ops = []
|
|
def __getitem__(self, key):
|
|
result = self.data[key]
|
|
self.get_ops.append(key)
|
|
return result
|
|
def __setitem__(self, key, value):
|
|
self.set_ops.append(key)
|
|
self.data[key] = value
|
|
def clear(self):
|
|
self.data.clear()
|
|
_orig_wkd = functools.WeakKeyDictionary
|
|
td = TracingDict()
|
|
functools.WeakKeyDictionary = lambda: td
|
|
c = collections
|
|
@functools.singledispatch
|
|
def g(arg):
|
|
return "base"
|
|
d = {}
|
|
l = []
|
|
self.assertEqual(len(td), 0)
|
|
self.assertEqual(g(d), "base")
|
|
self.assertEqual(len(td), 1)
|
|
self.assertEqual(td.get_ops, [])
|
|
self.assertEqual(td.set_ops, [dict])
|
|
self.assertEqual(td.data[dict], g.registry[object])
|
|
self.assertEqual(g(l), "base")
|
|
self.assertEqual(len(td), 2)
|
|
self.assertEqual(td.get_ops, [])
|
|
self.assertEqual(td.set_ops, [dict, list])
|
|
self.assertEqual(td.data[dict], g.registry[object])
|
|
self.assertEqual(td.data[list], g.registry[object])
|
|
self.assertEqual(td.data[dict], td.data[list])
|
|
self.assertEqual(g(l), "base")
|
|
self.assertEqual(g(d), "base")
|
|
self.assertEqual(td.get_ops, [list, dict])
|
|
self.assertEqual(td.set_ops, [dict, list])
|
|
g.register(list, lambda arg: "list")
|
|
self.assertEqual(td.get_ops, [list, dict])
|
|
self.assertEqual(len(td), 0)
|
|
self.assertEqual(g(d), "base")
|
|
self.assertEqual(len(td), 1)
|
|
self.assertEqual(td.get_ops, [list, dict])
|
|
self.assertEqual(td.set_ops, [dict, list, dict])
|
|
self.assertEqual(td.data[dict],
|
|
functools._find_impl(dict, g.registry))
|
|
self.assertEqual(g(l), "list")
|
|
self.assertEqual(len(td), 2)
|
|
self.assertEqual(td.get_ops, [list, dict])
|
|
self.assertEqual(td.set_ops, [dict, list, dict, list])
|
|
self.assertEqual(td.data[list],
|
|
functools._find_impl(list, g.registry))
|
|
class X:
|
|
pass
|
|
c.MutableMapping.register(X) # Will not invalidate the cache,
|
|
# not using ABCs yet.
|
|
self.assertEqual(g(d), "base")
|
|
self.assertEqual(g(l), "list")
|
|
self.assertEqual(td.get_ops, [list, dict, dict, list])
|
|
self.assertEqual(td.set_ops, [dict, list, dict, list])
|
|
g.register(c.Sized, lambda arg: "sized")
|
|
self.assertEqual(len(td), 0)
|
|
self.assertEqual(g(d), "sized")
|
|
self.assertEqual(len(td), 1)
|
|
self.assertEqual(td.get_ops, [list, dict, dict, list])
|
|
self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
|
|
self.assertEqual(g(l), "list")
|
|
self.assertEqual(len(td), 2)
|
|
self.assertEqual(td.get_ops, [list, dict, dict, list])
|
|
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
|
|
self.assertEqual(g(l), "list")
|
|
self.assertEqual(g(d), "sized")
|
|
self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
|
|
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
|
|
g.dispatch(list)
|
|
g.dispatch(dict)
|
|
self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
|
|
list, dict])
|
|
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
|
|
c.MutableSet.register(X) # Will invalidate the cache.
|
|
self.assertEqual(len(td), 2) # Stale cache.
|
|
self.assertEqual(g(l), "list")
|
|
self.assertEqual(len(td), 1)
|
|
g.register(c.MutableMapping, lambda arg: "mutablemapping")
|
|
self.assertEqual(len(td), 0)
|
|
self.assertEqual(g(d), "mutablemapping")
|
|
self.assertEqual(len(td), 1)
|
|
self.assertEqual(g(l), "list")
|
|
self.assertEqual(len(td), 2)
|
|
g.register(dict, lambda arg: "dict")
|
|
self.assertEqual(g(d), "dict")
|
|
self.assertEqual(g(l), "list")
|
|
g._clear_cache()
|
|
self.assertEqual(len(td), 0)
|
|
functools.WeakKeyDictionary = _orig_wkd
|
|
|
|
def test_invalid_positional_argument(self):
|
|
@functools.singledispatch
|
|
def f(*args):
|
|
pass
|
|
msg = 'f requires at least 1 positional argument'
|
|
with self.assertRaisesRegex(TypeError, msg):
|
|
f()
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|