Refactor pub/sub functionality to fix support for removed.

This commit is contained in:
Tyson Clugg 2015-05-21 12:51:43 +10:00
parent 86be4fbd32
commit c7ecd60ea7
5 changed files with 182 additions and 203 deletions

View file

@ -19,6 +19,10 @@ else:
default_app_config = 'dddp.apps.DjangoDDPConfig'
ADDED = 'added'
CHANGED = 'changed'
REMOVED = 'removed'
def greenify():
"""Patch threading and psycopg2 modules for green threads."""

View file

@ -20,7 +20,9 @@ from django.db.models import signals
import ejson
# django-ddp
from dddp import AlreadyRegistered, THREAD_LOCAL as this
from dddp import (
AlreadyRegistered, THREAD_LOCAL as this, ADDED, CHANGED, REMOVED,
)
from dddp.models import Connection, Subscription, get_meteor_id
@ -226,6 +228,9 @@ class Collection(APIMixin):
"""Find user IDs related to object/pk in queryset."""
qs = base_qs or self.queryset
if self.user_rel:
user_ids = set()
if obj.pk is None:
return user_ids # nobody can see objects that don't exist
user_rels = self.user_rel
if isinstance(user_rels, basestring):
user_rels = [user_rels]
@ -235,7 +240,6 @@ class Collection(APIMixin):
in enumerate(user_rels)
}
user_ids = set()
if include_superusers:
user_ids.update(
get_user_model().objects.filter(
@ -356,15 +360,15 @@ class Collection(APIMixin):
def obj_change_as_msg(self, obj, msg):
"""Return DDP change message of specified type (msg) for obj."""
if msg == 'removed':
data = {'pk': get_meteor_id(obj)} # `removed` only needs ID
elif msg in ('added', 'changed'):
if msg == REMOVED:
data = {'id': get_meteor_id(obj)} # `removed` only needs ID
elif msg in (ADDED, CHANGED):
data = self.serialize(obj)
data['id'] = str(data.pop('pk')) # force casting ID as string
else:
raise ValueError('Invalid message type: %r' % msg)
del data['model']
data.pop('model', None)
data.update(msg=msg, collection=self.name)
return data
@ -426,18 +430,11 @@ class DDP(APIMixin):
pgworker = None
_in_migration = False
class Msg(object):
"""DDP message type enumeration."""
ADDED = 'added'
CHANGED = 'changed'
REMOVED = 'removed'
def __init__(self):
"""DDP API init."""
self._registry = {}
self._subs = {}
self._ddp_subscribers = {}
def get_collection(self, model):
"""Return collection instance for given model."""
@ -453,19 +450,6 @@ class DDP(APIMixin):
"""Return an iterable of API providers."""
return self._registry.values()
def sub_notify(self, id_, names, data):
"""Dispatch DDP updates to connections."""
ws, _ = self._subs[id_]
connection_pk = data.pop('_sender', None)
tx_id = data.pop('_tx_id', None)
connection_obj = getattr(ws, 'connection', None)
if connection_obj is None:
tx_id = None
else:
if connection_obj.pk != connection_pk:
tx_id = None
ws.send_msg(data, tx_id=tx_id)
def qs_and_collection(self, qs):
"""Return (qs, collection) from qs (which may be a tuple)."""
if hasattr(qs, 'model'):
@ -515,7 +499,6 @@ class DDP(APIMixin):
)
)
self._subs[id_] = (this.ws, sorted(queries))
self.pgworker.subscribe(self.sub_notify, id_, sorted(queries))
# mergebox via MVCC! For details on how this is possible, read this:
# https://devcenter.heroku.com/articles/postgresql-concurrency
to_send = collections.OrderedDict(
@ -556,24 +539,18 @@ class DDP(APIMixin):
for collection_name, qs in to_send.items():
col = self.get_col_by_name(collection_name)
for obj in qs:
payload = col.obj_change_as_msg(obj, self.Msg.ADDED)
payload = col.obj_change_as_msg(obj, ADDED)
this.send_msg(payload)
this.send_msg({'msg': 'ready', 'subs': [id_]})
def unsub_notify(self, id_):
"""Dispatch DDP updates to connections."""
(ws, _) = self._subs.pop(id_, (None, []))
if ws is not None:
Subscription.objects.filter(
connection=ws.connection,
sub_id=id_,
).delete()
ws.send_msg({'msg': 'nosub', 'id': id_})
@api_endpoint
def unsub(self, id_):
"""Remove a subscription."""
self.pgworker.unsubscribe(self.unsub_notify, id_)
Subscription.objects.filter(
connection=this.ws.connection,
sub_id=id_,
).delete()
this.ws.send_msg({'msg': 'nosub', 'id': id_})
@api_endpoint
def method(self, method, params, id_):
@ -633,57 +610,16 @@ class DDP(APIMixin):
def ready(self):
"""Initialisation for django-ddp (setup lookups and signal handlers)."""
signals.post_save.connect(self.on_save)
signals.post_delete.connect(self.on_delete)
signals.m2m_changed.connect(self.on_m2m_changed)
# set/unset self._in_migration
signals.pre_migrate.connect(self.on_pre_migrate)
signals.post_migrate.connect(self.on_post_migrate)
def on_save(self, sender, **kwargs):
"""Post-save signal handler."""
if self._in_migration:
return
self.send_notify(
model=sender,
obj=kwargs['instance'],
msg=kwargs['created'] and self.Msg.ADDED or self.Msg.CHANGED,
using=kwargs['using'],
)
def on_delete(self, sender, **kwargs):
"""Post-delete signal handler."""
if self._in_migration:
return
self.send_notify(
model=sender,
obj=kwargs['instance'],
msg='removed',
using=kwargs['using'],
)
def on_m2m_changed(self, sender, **kwargs):
"""M2M-changed signal handler."""
if self._in_migration:
return
# See https://docs.djangoproject.com/en/1.7/ref/signals/#m2m-changed
if kwargs['action'] in (
'post_add',
'post_remove',
'post_clear',
):
if kwargs['reverse'] is False:
objs = [kwargs['instance']]
model = objs[0].__class__
else:
model = kwargs['model']
objs = model.objects.filter(pk__in=kwargs['pk_set'])
for obj in objs:
self.send_notify(
model=model,
obj=obj,
msg='changed',
using=kwargs['using'],
)
# update self._ddp_subscribers before changes made
signals.pre_delete.connect(self.on_pre_change)
signals.pre_save.connect(self.on_pre_change)
# emit change message after changes made
signals.post_save.connect(self.on_post_save)
signals.post_delete.connect(self.on_post_delete)
signals.m2m_changed.connect(self.on_m2m_changed)
def on_pre_migrate(self, sender, **kwargs):
"""Pre-migrate signal handler."""
@ -697,15 +633,88 @@ class DDP(APIMixin):
except DatabaseError: # pylint: disable=E0712
pass
def send_notify(self, model, obj, msg, using):
"""Dispatch PostgreSQL async NOTIFY."""
col_user_ids = {}
def on_pre_change(self, sender, **kwargs):
"""Pre change (save/delete) signal handler."""
if self._in_migration:
return
# mod_name = model_name(sender)
# if mod_name.split('.', 1)[0] in ('migrations', 'dddp'):
# return # never send migration or DDP internal models
obj = kwargs['instance']
using = kwargs['using']
self._ddp_subscribers.setdefault(
using, {},
).setdefault(
sender, {},
)[obj.pk] = self.valid_subscribers(
model=sender, obj=obj, using=using,
)
def on_m2m_changed(self, sender, **kwargs):
"""M2M-changed signal handler."""
if self._in_migration:
return
if kwargs['reverse'] is False:
objs = [kwargs['instance']]
model = objs[0].__class__
else:
model = kwargs['model']
objs = model.objects.filter(pk__in=kwargs['pk_set'])
mod_name = model_name(model)
if mod_name.split('.', 1)[0] in ('migrations', 'dddp'):
return # never send migration or DDP internal models
col_sub_ids = collections.defaultdict(set)
# See https://docs.djangoproject.com/en/1.7/ref/signals/#m2m-changed
if kwargs['action'] in (
'pre_add',
'pre_remove',
'pre_clear',
):
for obj in objs:
self.on_pre_change(
sender=model, instance=obj, using=kwargs['using'],
)
elif kwargs['action'] in (
'post_add',
'post_remove',
'post_clear',
):
for obj in objs:
self.send_notify(
model=model,
obj=obj,
msg=CHANGED,
using=kwargs['using'],
)
def on_post_save(self, sender, **kwargs):
"""Post-save signal handler."""
if self._in_migration:
return
self.send_notify(
model=sender,
obj=kwargs['instance'],
msg=kwargs['created'] and ADDED or CHANGED,
using=kwargs['using'],
)
def on_post_delete(self, sender, **kwargs):
"""Post-delete signal handler."""
if self._in_migration:
return
self.send_notify(
model=sender,
obj=kwargs['instance'],
msg=REMOVED,
using=kwargs['using'],
)
def valid_subscribers(self, model, obj, using):
"""Calculate valid subscribers (connections) for obj."""
col_user_ids = {}
col_connection_ids = collections.defaultdict(set)
for sub in Subscription.objects.filter(
collections__model_name=mod_name,
collections__model_name=model_name(model),
).prefetch_related('collections'):
for qs, col in (
self.qs_and_collection(qs)
@ -729,39 +738,60 @@ class DDP(APIMixin):
except KeyError:
user_ids = col_user_ids[col.__class__] = \
col.user_ids_for_object(obj)
if user_ids is None:
pass # unrestricted collection, anyone permitted to see.
# check if user is in permitted list of users
if user_ids is not None:
if user_ids is None:
pass # unrestricted collection, anyone permitted to see.
elif sub.user_id in user_ids:
elif sub.user_id not in user_ids:
continue # not for this user
col_sub_ids[col].add(sub.sub_id)
col_connection_ids[col].add(sub.connection_id)
if not col_sub_ids:
get_meteor_id(obj) # force creation of meteor ID using randomSeed
return # no subscribers for this object, nothing more to do.
# result is {colleciton: set([connection_id])}
return col_connection_ids
for col, sub_ids in col_sub_ids.items():
payload = col.obj_change_as_msg(obj, msg)
payload['_sub_ids'] = sorted(sub_ids)
try:
ws = this.ws
payload['_sender'] = ws.connection.pk
if set(sub_ids).intersection(self._subs):
# message must go to connection that initiated the change
payload['_tx_id'] = ws.get_tx_id()
except AttributeError:
pass
cursor = connections[using].cursor()
cursor.execute(
'NOTIFY "%s", %%s' % col.name,
[
ejson.dumps(payload),
],
)
def send_notify(self, model, obj, msg, using):
"""Dispatch PostgreSQL async NOTIFY."""
if model_name(model).split('.', 1)[0] in ('migrations', 'dddp'):
return # never send migration or DDP internal models
new_col_connection_ids = self.valid_subscribers(model, obj, using)
old_col_connection_ids = self._ddp_subscribers.get(
using, {},
).get(
model, {},
).pop(
obj.pk, collections.defaultdict(set),
)
try:
ws = this.ws
my_connection_id = this.ws.connection.pk
except AttributeError:
ws = my_connection_id = None
for col in set(old_col_connection_ids).union(new_col_connection_ids):
old_connection_ids = old_col_connection_ids[col]
new_connection_ids = new_col_connection_ids[col]
for (msg, connection_ids) in (
(REMOVED, old_connection_ids - new_connection_ids),
(CHANGED, old_connection_ids & new_connection_ids),
(ADDED, new_connection_ids - old_connection_ids),
):
if not connection_ids:
continue # nobody subscribed
payload = col.obj_change_as_msg(obj, msg)
payload['_connection_ids'] = sorted(connection_ids)
if my_connection_id is not None:
payload['_sender'] = my_connection_id
if my_connection_id in connection_ids:
# msg must go to connection that initiated the change
payload['_tx_id'] = ws.get_tx_id()
cursor = connections[using].cursor()
cursor.execute(
'NOTIFY "ddp", %s',
[
ejson.dumps(payload),
],
)
API = DDP()

View file

@ -1,6 +1,6 @@
"""Django DDP utils for DDP messaging."""
from copy import deepcopy
from dddp import THREAD_LOCAL as this
from dddp import THREAD_LOCAL as this, REMOVED
from django.db.models.expressions import ExpressionNode
@ -34,7 +34,7 @@ def obj_change_as_msg(obj, msg):
'collection': name,
'id': data['pk'],
}
if msg != 'removed':
if msg != REMOVED:
payload['fields'] = data['fields']
return (name, payload)

View file

@ -2,8 +2,6 @@
from __future__ import absolute_import
import collections
import ejson
import gevent
import gevent.queue
@ -23,16 +21,9 @@ class PostgresGreenlet(gevent.Greenlet):
self.logger = create_logger(__name__, debug=debug)
# queues for processing incoming sub/unsub requests and processing
self.subs = gevent.queue.Queue()
self.unsubs = gevent.queue.Queue()
self.proc_queue = gevent.queue.Queue()
self.connections = {}
self._stop_event = gevent.event.Event()
# dict of name: subscribers
# eg: {'bookstore.book': {'tpozNWMPphaJ2n8bj': <function at ...>}}
self.all_subs = collections.defaultdict(dict)
self._sub_lock = gevent.lock.RLock()
# connect to DB in async mode
conn.allow_thread_sharing = True
self.connection = conn
@ -47,26 +38,16 @@ class PostgresGreenlet(gevent.Greenlet):
def _run(self): # pylint: disable=method-hidden
"""Spawn sub tasks, wait for stop signal."""
gevent.spawn(self.process_conn)
gevent.spawn(self.process_subs)
gevent.spawn(self.process_unsubs)
self._stop_event.wait()
def stop(self):
"""Stop subtasks and let run() finish."""
self._stop_event.set()
def subscribe(self, func, id_, names):
"""Register callback `func` to be called after NOTIFY for `names`."""
self.subs.put((func, id_, names))
def unsubscribe(self, func, id_):
"""Un-register callback `func` to be called after NOTIFY for `id`."""
self.unsubs.put((func, id_))
def process_conn(self):
"""Subtask to process NOTIFY async events from DB connection."""
self.cur.execute('LISTEN "ddp";')
while not self._stop_event.is_set():
# TODO: change timeout so self._stop_event is periodically checked?
gevent.select.select([self.conn], [], [], timeout=None)
self.poll()
@ -77,25 +58,22 @@ class PostgresGreenlet(gevent.Greenlet):
if state == psycopg2.extensions.POLL_OK:
while self.conn.notifies:
notify = self.conn.notifies.pop()
name = notify.channel
self.logger.info(
"Got NOTIFY (pid=%d, name=%r, payload=%r)",
notify.pid, name, notify.payload,
"Got NOTIFY (pid=%d, payload=%r)",
notify.pid, notify.payload,
)
try:
self._sub_lock.acquire()
self.logger.info(self.all_subs)
subs = self.all_subs[name]
data = ejson.loads(notify.payload)
sub_ids = data.pop('_sub_ids')
self.logger.info('Subscribers: %r', sub_ids)
self.logger.info(subs)
for id_, func in subs.items():
if id_ not in sub_ids:
continue # not for this subscription
gevent.spawn(func, id_, name, data)
finally:
self._sub_lock.release()
data = ejson.loads(notify.payload)
sender = data.pop('_sender', None)
tx_id = data.pop('_tx_id', None)
for connection_id in data.pop('_connection_ids'):
try:
ws = self.connections[connection_id]
except KeyError:
continue # connection not in this process
if connection_id == sender:
ws.send_msg(data, tx_id=tx_id)
else:
ws.send_msg(data)
break
elif state == psycopg2.extensions.POLL_WRITE:
gevent.select.select([], [self.conn.fileno()], [])
@ -103,38 +81,3 @@ class PostgresGreenlet(gevent.Greenlet):
gevent.select.select([self.conn.fileno()], [], [])
else:
self.logger.warn('POLL_ERR: %s', state)
def process_subs(self):
"""Subtask to process `sub` requests from `self.subs` queue."""
while not self._stop_event.is_set():
func, id_, names = self.subs.get()
try:
self._sub_lock.acquire()
for name in names:
subs = self.all_subs[name]
if len(subs) == 0:
self.logger.debug('LISTEN "%s";', name)
self.poll()
self.cur.execute('LISTEN "%s";' % name)
self.poll()
subs[id_] = func
finally:
self._sub_lock.release()
def process_unsubs(self):
"""Subtask to process `unsub` requests from `self.unsubs` queue."""
while not self._stop_event.is_set():
func, id_ = self.unsubs.get()
try:
self._sub_lock.acquire()
for name in list(self.all_subs):
subs = self.all_subs[name]
subs.pop(id_, None)
if len(subs) == 0:
self.logger.info('UNLISTEN "%s";', name)
self.cur.execute('UNLISTEN "%s";' % name)
self.poll()
del self.all_subs[name]
finally:
self._sub_lock.release()
gevent.spawn(func, id_)

View file

@ -17,7 +17,7 @@ from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import WSGIRequest
from django.db import connection, transaction
from dddp import THREAD_LOCAL as this, alea
from dddp import THREAD_LOCAL as this, alea, ADDED, CHANGED, REMOVED
class MeteorError(Exception):
@ -144,6 +144,7 @@ class DDPWebSocketApplication(geventwebsocket.WebSocketApplication):
def on_close(self, reason):
"""Handle closing of websocket connection."""
if self.connection is not None:
del self.pgworker.connections[self.connection.pk]
self.connection.delete()
self.connection = None
self.logger.info('- %s %s', self, reason or 'CLOSE')
@ -248,20 +249,20 @@ class DDPWebSocketApplication(geventwebsocket.WebSocketApplication):
def send_msg(self, payload, tx_id=None):
"""Send EJSON payload to remote."""
msg = payload.get('msg', None)
if msg in ('added', 'changed', 'removed'):
if msg in (ADDED, CHANGED, REMOVED):
ids = self.remote_ids[payload['collection']]
meteor_id = payload['id']
if msg == 'added':
if msg == ADDED:
if meteor_id in ids:
msg = payload['msg'] = 'changed'
msg = payload['msg'] = CHANGED
else:
ids.add(meteor_id)
elif msg == 'changed':
elif msg == CHANGED:
if meteor_id not in ids:
# object has become visible, treat as `added`.
msg = payload['msg'] = 'added'
msg = payload['msg'] = ADDED
ids.add(meteor_id)
elif msg == 'removed':
elif msg == REMOVED:
ids.remove(meteor_id)
data = ejson.dumps([ejson.dumps(payload)])
self.send('a%s' % data, tx_id=tx_id)
@ -341,6 +342,7 @@ class DDPWebSocketApplication(geventwebsocket.WebSocketApplication):
remote_addr=self.remote_addr,
version=version,
)
self.pgworker.connections[self.connection.pk] = self
atexit.register(self.on_close, 'Shutting down.')
self.reply('connected', session=self.connection.connection_id)