Add AbstractContextManager

Closes https://github.com/jazzband/contextlib2/issues/16
This commit is contained in:
John Vandenberg 2019-09-09 22:31:58 +07:00
parent 656a4eba29
commit a8e6733e21
2 changed files with 109 additions and 0 deletions

View file

@ -1,11 +1,13 @@
"""contextlib2 - backports and enhancements to the contextlib module"""
import abc
import sys
import warnings
from collections import deque
from functools import wraps
__all__ = ["contextmanager", "closing", "nullcontext",
"AbstractContextManager",
"ContextDecorator", "ExitStack",
"redirect_stdout", "redirect_stderr", "suppress"]
@ -13,6 +15,61 @@ __all__ = ["contextmanager", "closing", "nullcontext",
__all__ += ["ContextStack"]
# Backport abc.ABC
if sys.version_info[:2] >= (3, 4):
_abc_ABC = abc.ABC
else:
_abc_ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()})
# Backport classic class MRO
def _classic_mro(C, result):
if C in result:
return
result.append(C)
for B in C.__bases__:
_classic_mro(B, result)
return result
# Backport _collections_abc._check_methods
def _check_methods(C, *methods):
try:
mro = C.__mro__
except AttributeError:
mro = tuple(_classic_mro(C, []))
for method in methods:
for B in mro:
if method in B.__dict__:
if B.__dict__[method] is None:
return NotImplemented
break
else:
return NotImplemented
return True
class AbstractContextManager(_abc_ABC):
"""An abstract base class for context managers."""
def __enter__(self):
"""Return `self` upon entering the runtime context."""
return self
@abc.abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None
@classmethod
def __subclasshook__(cls, C):
"""Check whether subclass is considered a subclass of this ABC."""
if cls is AbstractContextManager:
return _check_methods(C, "__enter__", "__exit__")
return NotImplemented
class ContextDecorator(object):
"""A base class or mixin that enables context managers to work as decorators."""

View file

@ -15,6 +15,58 @@ if not hasattr(unittest.TestCase, "assertRaisesRegex"):
requires_docstrings = unittest.skipIf(sys.flags.optimize >= 2,
"Test requires docstrings")
class TestAbstractContextManager(unittest.TestCase):
def test_enter(self):
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
manager = DefaultEnter()
self.assertIs(manager.__enter__(), manager)
def test_exit_is_abstract(self):
class MissingExit(AbstractContextManager):
pass
with self.assertRaises(TypeError):
MissingExit()
def test_structural_subclassing(self):
# New style classes used here
class ManagerFromScratch(object):
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return None
self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
if sys.version_info[:2] <= (3, 0):
def test_structural_subclassing_classic(self):
# Old style classes used here
class ManagerFromScratch:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return None
self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
class ContextManagerTestCase(unittest.TestCase):
def test_contextmanager_plain(self):