diff --git a/dddp/__init__.py b/dddp/__init__.py index 7c3c404..db0f1c4 100644 --- a/dddp/__init__.py +++ b/dddp/__init__.py @@ -48,28 +48,27 @@ class ThreadLocal(local): _init_done = False - def __init__(self, **default_factories): + def __init__(self): """Create new thread storage instance.""" if self._init_done: raise SystemError('__init__ called too many times') self._init_done = True - self._default_factories = default_factories def __getattr__(self, name): """Create missing attributes using default factories.""" try: - factory = self._default_factories[name] + factory = THREAD_LOCAL_FACTORIES[name] except KeyError: raise AttributeError(name) - obj = factory() - setattr(self, name, obj) - return obj + return self.get(name, factory) def get(self, name, factory, *factory_args, **factory_kwargs): """Get attribute, creating if required using specified factory.""" - if not hasattr(self, name): + update_thread_local = getattr(factory, 'update_thread_local', True) + if (not update_thread_local) or (not hasattr(self, name)): obj = factory(*factory_args, **factory_kwargs) - setattr(self, name, obj) + if update_thread_local: + setattr(self, name, obj) return obj return getattr(self, name) @@ -99,11 +98,12 @@ def serializer_factory(): return get_serializer('python')() -THREAD_LOCAL = ThreadLocal( - alea_random=alea.Alea, - random_streams=RandomStreams, - serializer=serializer_factory, -) +THREAD_LOCAL_FACTORIES = { + 'alea_random': alea.Alea, + 'random_streams': RandomStreams, + 'serializer': serializer_factory, +} +THREAD_LOCAL = ThreadLocal() METEOR_ID_CHARS = u'23456789ABCDEFGHJKLMNPQRSTWXYZabcdefghijkmnopqrstuvwxyz'