Merge pull request #21 from jayvdb/abc

Add nullcontext and AbstractContextManager
This commit is contained in:
Nick Coghlan 2019-09-20 21:15:35 +10:00 committed by GitHub
commit 739bd253bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 130 additions and 1 deletions

View file

@ -1,17 +1,75 @@
"""contextlib2 - backports and enhancements to the contextlib module"""
import abc
import sys
import warnings
from collections import deque
from functools import wraps
__all__ = ["contextmanager", "closing", "ContextDecorator", "ExitStack",
__all__ = ["contextmanager", "closing", "nullcontext",
"AbstractContextManager",
"ContextDecorator", "ExitStack",
"redirect_stdout", "redirect_stderr", "suppress"]
# Backwards compatibility
__all__ += ["ContextStack"]
# Backport abc.ABC
if sys.version_info[:2] >= (3, 4):
_abc_ABC = abc.ABC
else:
_abc_ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()})
# Backport classic class MRO
def _classic_mro(C, result):
if C in result:
return
result.append(C)
for B in C.__bases__:
_classic_mro(B, result)
return result
# Backport _collections_abc._check_methods
def _check_methods(C, *methods):
try:
mro = C.__mro__
except AttributeError:
mro = tuple(_classic_mro(C, []))
for method in methods:
for B in mro:
if method in B.__dict__:
if B.__dict__[method] is None:
return NotImplemented
break
else:
return NotImplemented
return True
class AbstractContextManager(_abc_ABC):
"""An abstract base class for context managers."""
def __enter__(self):
"""Return `self` upon entering the runtime context."""
return self
@abc.abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None
@classmethod
def __subclasshook__(cls, C):
"""Check whether subclass is considered a subclass of this ABC."""
if cls is AbstractContextManager:
return _check_methods(C, "__enter__", "__exit__")
return NotImplemented
class ContextDecorator(object):
"""A base class or mixin that enables context managers to work as decorators."""
@ -439,3 +497,22 @@ class ContextStack(ExitStack):
def preserve(self):
return self.pop_all()
class nullcontext(AbstractContextManager):
"""Context manager that does no additional processing.
Used as a stand-in for a normal context manager, when a particular
block of code is only sometimes used with a normal context manager:
cm = optional_cm if condition else nullcontext()
with cm:
# Perform operation, using optional_cm if condition is True
"""
def __init__(self, enter_result=None):
self.enter_result = enter_result
def __enter__(self):
return self.enter_result
def __exit__(self, *excinfo):
pass

View file

@ -15,6 +15,58 @@ if not hasattr(unittest.TestCase, "assertRaisesRegex"):
requires_docstrings = unittest.skipIf(sys.flags.optimize >= 2,
"Test requires docstrings")
class TestAbstractContextManager(unittest.TestCase):
def test_enter(self):
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
manager = DefaultEnter()
self.assertIs(manager.__enter__(), manager)
def test_exit_is_abstract(self):
class MissingExit(AbstractContextManager):
pass
with self.assertRaises(TypeError):
MissingExit()
def test_structural_subclassing(self):
# New style classes used here
class ManagerFromScratch(object):
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return None
self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
if sys.version_info[:2] <= (3, 0):
def test_structural_subclassing_classic(self):
# Old style classes used here
class ManagerFromScratch:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return None
self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
class DefaultEnter(AbstractContextManager):
def __exit__(self, *args):
super().__exit__(*args)
self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
class ContextManagerTestCase(unittest.TestCase):
def test_contextmanager_plain(self):