diff --git a/contextlib2.py b/contextlib2.py index 632adf2..3aae8f4 100644 --- a/contextlib2.py +++ b/contextlib2.py @@ -1,17 +1,75 @@ """contextlib2 - backports and enhancements to the contextlib module""" +import abc import sys import warnings from collections import deque from functools import wraps -__all__ = ["contextmanager", "closing", "ContextDecorator", "ExitStack", +__all__ = ["contextmanager", "closing", "nullcontext", + "AbstractContextManager", + "ContextDecorator", "ExitStack", "redirect_stdout", "redirect_stderr", "suppress"] # Backwards compatibility __all__ += ["ContextStack"] +# Backport abc.ABC +if sys.version_info[:2] >= (3, 4): + _abc_ABC = abc.ABC +else: + _abc_ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) + + +# Backport classic class MRO +def _classic_mro(C, result): + if C in result: + return + result.append(C) + for B in C.__bases__: + _classic_mro(B, result) + return result + + +# Backport _collections_abc._check_methods +def _check_methods(C, *methods): + try: + mro = C.__mro__ + except AttributeError: + mro = tuple(_classic_mro(C, [])) + + for method in methods: + for B in mro: + if method in B.__dict__: + if B.__dict__[method] is None: + return NotImplemented + break + else: + return NotImplemented + return True + + +class AbstractContextManager(_abc_ABC): + """An abstract base class for context managers.""" + + def __enter__(self): + """Return `self` upon entering the runtime context.""" + return self + + @abc.abstractmethod + def __exit__(self, exc_type, exc_value, traceback): + """Raise any exception triggered within the runtime context.""" + return None + + @classmethod + def __subclasshook__(cls, C): + """Check whether subclass is considered a subclass of this ABC.""" + if cls is AbstractContextManager: + return _check_methods(C, "__enter__", "__exit__") + return NotImplemented + + class ContextDecorator(object): """A base class or mixin that enables context managers to work as decorators.""" @@ -439,3 +497,22 @@ class ContextStack(ExitStack): def preserve(self): return self.pop_all() + + +class nullcontext(AbstractContextManager): + """Context manager that does no additional processing. + Used as a stand-in for a normal context manager, when a particular + block of code is only sometimes used with a normal context manager: + cm = optional_cm if condition else nullcontext() + with cm: + # Perform operation, using optional_cm if condition is True + """ + + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, *excinfo): + pass diff --git a/test_contextlib2.py b/test_contextlib2.py index e9cd0c7..a9b15d3 100755 --- a/test_contextlib2.py +++ b/test_contextlib2.py @@ -15,6 +15,58 @@ if not hasattr(unittest.TestCase, "assertRaisesRegex"): requires_docstrings = unittest.skipIf(sys.flags.optimize >= 2, "Test requires docstrings") + +class TestAbstractContextManager(unittest.TestCase): + + def test_enter(self): + class DefaultEnter(AbstractContextManager): + def __exit__(self, *args): + super().__exit__(*args) + + manager = DefaultEnter() + self.assertIs(manager.__enter__(), manager) + + def test_exit_is_abstract(self): + class MissingExit(AbstractContextManager): + pass + + with self.assertRaises(TypeError): + MissingExit() + + def test_structural_subclassing(self): + # New style classes used here + class ManagerFromScratch(object): + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + return None + + self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager)) + + class DefaultEnter(AbstractContextManager): + def __exit__(self, *args): + super().__exit__(*args) + + self.assertTrue(issubclass(DefaultEnter, AbstractContextManager)) + + if sys.version_info[:2] <= (3, 0): + def test_structural_subclassing_classic(self): + # Old style classes used here + class ManagerFromScratch: + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + return None + + self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager)) + + class DefaultEnter(AbstractContextManager): + def __exit__(self, *args): + super().__exit__(*args) + + self.assertTrue(issubclass(DefaultEnter, AbstractContextManager)) + + class ContextManagerTestCase(unittest.TestCase): def test_contextmanager_plain(self):