From f02fba6d9fbad7595cfc5af7dbf1f1a4b6982a61 Mon Sep 17 00:00:00 2001 From: Tyson Clugg Date: Thu, 28 May 2015 10:53:29 +1000 Subject: [PATCH] Refactor serialization to imporove performance (less DB queries) and add support for auth updating subscriptions. --- dddp/__init__.py | 2 +- dddp/accounts/ddp.py | 61 ++++++++++++++++++++--- dddp/api.py | 95 +++++++++++++++++++++++------------- dddp/apps.py | 2 - dddp/models.py | 44 ++++++++++++++--- dddp/serializer.py | 114 ------------------------------------------- 6 files changed, 156 insertions(+), 162 deletions(-) delete mode 100644 dddp/serializer.py diff --git a/dddp/__init__.py b/dddp/__init__.py index d434d37..a2d4dcd 100644 --- a/dddp/__init__.py +++ b/dddp/__init__.py @@ -97,7 +97,7 @@ class RandomStreams(object): def serializer_factory(): """Make a new DDP serializer.""" from django.core.serializers import get_serializer - return get_serializer('ddp')() + return get_serializer('python')() THREAD_LOCAL = ThreadLocal( diff --git a/dddp/accounts/ddp.py b/dddp/accounts/ddp.py index 9599179..903416b 100644 --- a/dddp/accounts/ddp.py +++ b/dddp/accounts/ddp.py @@ -6,6 +6,7 @@ Matches Meteor 1.1 Accounts package: https://www.meteor.com/accounts See http://docs.meteor.com/#/full/accounts_api for details of each method. """ from binascii import Error +import collections from ejson import loads, dumps @@ -15,8 +16,8 @@ from django.contrib.auth.signals import user_login_failed from django.dispatch import Signal from django.utils import timezone -from dddp import THREAD_LOCAL as this -from dddp.models import get_meteor_id, get_object +from dddp import THREAD_LOCAL as this, ADDED, REMOVED +from dddp.models import get_meteor_id, get_object, Subscription from dddp.api import API, APIMixin, api_endpoint, Collection, Publication from dddp.websocket import MeteorError @@ -39,10 +40,10 @@ class Users(Collection): 'pk', ] - def serialize(self, obj): + def serialize(self, obj, *args, **kwargs): """Serialize user as per Meteor accounts serialization.""" # use default serialization, then modify to suit our needs. - data = super(Users, self).serialize(obj) + data = super(Users, self).serialize(obj, *args, **kwargs) # everything that isn't handled explicitly ends up in `profile` profile = data.pop('fields') @@ -145,6 +146,50 @@ class Auth(APIMixin): api_path_prefix = '' # auth endpoints don't have a common prefix user_model = auth.get_user_model() + def update_subs(self, new_user_id): + """Update subs to send added/removed for collections with user_rel.""" + for sub in Subscription.objects.filter(connection=this.ws.connection): + params = loads(sub.params_ejson) + pub = API.get_pub_by_name(sub.publication) + + # calculate the querysets prior to update + pre = collections.OrderedDict([ + (col, qs) for col, qs + in API.sub_unique_objects(sub, params, pub) + ]) + + # save the subscription with the updated user_id + sub.user_id = new_user_id + sub.save() + + # calculate the querysets after the update + post = collections.OrderedDict([ + (col, qs) for col, qs + in API.sub_unique_objects(sub, params, pub) + ]) + + # first pass, send `added` for objs unique to `post` + for col_post, qs in post.items(): + try: + qs_pre = pre[col_post] + qs = qs.exclude(pk__in=qs_pre.order_by().values('pk')) + except KeyError: + # collection not included pre-auth, everything is added. + pass + for obj in qs: + this.ws.send(col.obj_change_as_msg(obj, ADDED)) + + # second pass, send `removed` for objs unique to `pre` + for col_pre, qs in pre.items(): + try: + qs_post = post[col_pre] + qs = qs.exclude(pk__in=qs_post.order_by().values('pk')) + except KeyError: + # collection not included post-auth, everything is removed. + pass + for obj in qs: + this.ws.send(col.obj_change_as_msg(obj, REMOVED)) + @staticmethod def auth_failed(**credentials): """Consistent fail so we don't provide attackers with valuable info.""" @@ -279,17 +324,18 @@ class Auth(APIMixin): username=user.get_username(), password=params['password'], ) auth.login(this.request, user) + self.update_subs(user.pk) return self.get_user_token( user=user, session_key=this.request.session.session_key, expiry_date=this.request.session.get_expiry_date(), ) - @staticmethod @api_endpoint - def logout(): + def logout(self): """Logout current user.""" auth.logout(this.request) + self.update_subs(None) @api_endpoint def login(self, params): @@ -314,6 +360,7 @@ class Auth(APIMixin): # the password verified for the user if user.is_active: auth.login(this.request, user) + self.update_subs(user.pk) this.request.session.save() return self.get_user_token( user=user, @@ -342,6 +389,7 @@ class Auth(APIMixin): user, session = self.validated_user_and_session(params['resume']) auth.login(this.request, user) + self.update_subs(user.pk) this.request.session.save() return self.get_user_token( user=user, @@ -399,6 +447,7 @@ class Auth(APIMixin): user.set_password(params['newPassword']) user.save() auth.login(this.request, user) + self.update_subs(user.pk) API.register([Users, LoginPublication, Auth]) diff --git a/dddp/api.py b/dddp/api.py index a47060d..1601007 100644 --- a/dddp/api.py +++ b/dddp/api.py @@ -23,7 +23,7 @@ import ejson from dddp import ( AlreadyRegistered, THREAD_LOCAL as this, ADDED, CHANGED, REMOVED, ) -from dddp.models import Connection, Subscription, get_meteor_id +from dddp.models import Connection, Subscription, get_meteor_id, get_meteor_ids XMIN = {'select': {'xmin': "'xmin'"}} @@ -351,7 +351,11 @@ class Collection(APIMixin): in self.field_schema() } - def serialize(self, obj): + def serialize(self, obj, data): + """Default implementation for object serializer.""" + return data + + def serialize(self, obj, meteor_ids): """Generate a DDP msg for obj with specified msg type.""" # check for F expressions exps = [ @@ -367,20 +371,42 @@ class Collection(APIMixin): setattr(obj, name, val) # run serialization now all fields are "concrete" (not F expressions) - return this.serializer.serialize([obj])[0] + data = this.serializer.serialize([obj])[0] + fields = data['fields'] + del data['pk'], data['model'] + # Django supports model._meta -> pylint: disable=W0212 + meta = self.model._meta + for field in meta.local_fields: + rel = getattr(field, 'rel', None) + if rel: + fields[field.column] = get_meteor_id( + rel.to, fields.pop(field.name), + ) + for field in meta.local_many_to_many: + fields['%s_ids' % field.name] = get_meteor_ids( + field.rel.to, fields.pop(field.name), + ).values() + return data - def obj_change_as_msg(self, obj, msg): + def obj_change_as_msg(self, obj, msg, meteor_ids=None): """Return DDP change message of specified type (msg) for obj.""" + if meteor_ids is None: + meteor_ids = {} + try: + meteor_id = meteor_ids[str(obj.pk)] + except KeyError: + meteor_id = None + if meteor_id is None: + meteor_ids[str(obj.pk)] = meteor_id = get_meteor_id(obj) + assert meteor_id is not None if msg == REMOVED: - data = {'id': get_meteor_id(obj)} # `removed` only needs ID + data = {} # `removed` only needs ID (added below) elif msg in (ADDED, CHANGED): - data = self.serialize(obj) - data['id'] = str(data.pop('pk')) # force casting ID as string + data = self.serialize(obj, meteor_ids) else: raise ValueError('Invalid message type: %r' % msg) - data.pop('model', None) - data.update(msg=msg, collection=self.name) + data.update(msg=msg, collection=self.name, id=meteor_id) return data @@ -427,11 +453,6 @@ class Publication(APIMixin): ) -def pub_path(publication_name): - """Return api_path for a publication.""" - return Publication.api_path_prefix_format.format(name=publication_name) - - class DDP(APIMixin): """Django DDP API.""" @@ -455,6 +476,11 @@ class DDP(APIMixin): """Return collection instance for given name.""" return self._registry[COLLECTION_PATH_FORMAT.format(name=name)] + def get_pub_by_name(self, name): + """Return publication instance for given name.""" + path = Publication.api_path_prefix_format.format(name=name) + return self._registry[path] + @property def api_providers(self): """Return an iterable of API providers.""" @@ -474,11 +500,9 @@ class DDP(APIMixin): if params is None: params = ejson.loads(obj.params_ejson) if pub is None: - pub = self._registry[pub_path(obj.publication)] + pub = self.get_pub_by_name(obj.publication) queries = collections.OrderedDict( - (col.name, (col, qs)) - for (qs, col) - in ( + (col, qs) for (qs, col) in ( self.qs_and_collection(qs) for qs in pub.get_queries(*params) @@ -488,43 +512,42 @@ class DDP(APIMixin): # https://devcenter.heroku.com/articles/postgresql-concurrency to_send = collections.OrderedDict( ( - name, + col, col.objects_for_user( user=obj.user_id, qs=qs, *args, **kwargs ), ) - for name, (col, qs) + for col, qs in queries.items() ) for other in Subscription.objects.filter( connection=obj.connection_id, - collections__collection_name__in=queries.keys(), + collections__collection_name__in=[col.name for col in queries], ).exclude( pk=obj.pk, ).order_by('pk').distinct(): - other_pub = self._registry[pub_path(other.publication)] + other_pub = self.get_pub_by_name(other.publication) for qs in other_pub.get_queries(*other.params): qs, col = self.qs_and_collection(qs) - if col.name not in to_send: + if col not in to_send: continue - to_send[col.name] = to_send[col.name].exclude( + to_send[col] = to_send[col].exclude( pk__in=col.objects_for_user( user=other.user_id, qs=qs, *args, **kwargs ).values('pk'), ) - for collection_name, qs in to_send.items(): - col = self.get_col_by_name(collection_name) + for col, qs in to_send.items(): yield col, qs.distinct() @api_endpoint def sub(self, id_, name, *params): """Create subscription, send matched objects that haven't been sent.""" try: - pub = self._registry[pub_path(name)] + pub = self.get_pub_by_name(name) except KeyError: this.send({ 'msg': 'nosub', @@ -557,8 +580,11 @@ class DDP(APIMixin): model_name=model_name(qs.model), collection_name=col.name, ) + meteor_ids = get_meteor_ids( + qs.model, qs.values_list('pk', flat=True), + ) for obj in qs: - payload = col.obj_change_as_msg(obj, ADDED) + payload = col.obj_change_as_msg(obj, ADDED, meteor_ids) this.send(payload) this.send({'msg': 'ready', 'subs': [id_]}) @@ -569,8 +595,11 @@ class DDP(APIMixin): connection=this.ws.connection, sub_id=id_, ) for col, qs in self.sub_unique_objects(sub): + meteor_ids = get_meteor_ids( + qs.model, qs.values_list('pk', flat=True), + ) for obj in qs: - payload = col.obj_change_as_msg(obj, REMOVED) + payload = col.obj_change_as_msg(obj, REMOVED, meteor_ids) this.send(payload) sub.delete() this.send({'msg': 'nosub', 'id': id_}) @@ -739,12 +768,11 @@ class DDP(APIMixin): for sub in Subscription.objects.filter( collections__model_name=model_name(model), ).prefetch_related('collections'): + pub = self.get_pub_by_name(sub.publication) for qs, col in ( self.qs_and_collection(qs) for qs - in self._registry[ - 'publication/%s/' % sub.publication - ].get_queries(*sub.params) + in pub.get_queries(*sub.params) ): # check if obj is an instance of the model for the queryset if qs.model is not model: @@ -791,6 +819,7 @@ class DDP(APIMixin): my_connection_id = this.ws.connection.pk except AttributeError: ws = my_connection_id = None + meteor_ids = {} for col in set(old_col_connection_ids).union(new_col_connection_ids): old_connection_ids = old_col_connection_ids[col] new_connection_ids = new_col_connection_ids[col] @@ -801,7 +830,7 @@ class DDP(APIMixin): ): if not connection_ids: continue # nobody subscribed - payload = col.obj_change_as_msg(obj, msg) + payload = col.obj_change_as_msg(obj, msg, meteor_ids) payload['_connection_ids'] = sorted(connection_ids) if my_connection_id is not None: payload['_sender'] = my_connection_id diff --git a/dddp/apps.py b/dddp/apps.py index 6055dc6..9cbd8bc 100644 --- a/dddp/apps.py +++ b/dddp/apps.py @@ -3,7 +3,6 @@ from __future__ import print_function from django.apps import AppConfig -from django.core import serializers from django.conf import settings, ImproperlyConfigured from django.db import DatabaseError from django.db.models import signals @@ -23,7 +22,6 @@ class DjangoDDPConfig(AppConfig): def ready(self): """Initialisation for django-ddp (setup lookups and signal handlers).""" - serializers.register_serializer('ddp', 'dddp.serializer') if not settings.DATABASES: raise ImproperlyConfigured('No databases configured.') for (alias, conf) in settings.DATABASES.items(): diff --git a/dddp/models.py b/dddp/models.py index 707d6ff..8eea57d 100644 --- a/dddp/models.py +++ b/dddp/models.py @@ -1,4 +1,7 @@ """Django DDP models.""" +from __future__ import absolute_import + +import collections from django.db import models, transaction from django.conf import settings @@ -10,17 +13,21 @@ from dddp import meteor_random_id @transaction.atomic -def get_meteor_id(obj): +def get_meteor_id(obj_or_model, obj_pk=None): """Return an Alea ID for the given object.""" - if obj is None: + if obj_or_model is None: return None # Django model._meta is now public API -> pylint: disable=W0212 - meta = obj._meta - if meta.model is ObjectMapping: + meta = obj_or_model._meta + model = meta.model + if model is ObjectMapping: # this doesn't make sense - raise TypeError raise TypeError("Can't map ObjectMapping instances through self.") - obj_pk = str(obj.pk) - content_type = ContentType.objects.get_for_model(meta.model) + if obj_or_model is not model and obj_pk is None: + obj_pk = str(obj_or_model.pk) + if obj_pk is None: + return None + content_type = ContentType.objects.get_for_model(model) try: return ObjectMapping.objects.values_list( 'meteor_id', flat=True, @@ -36,6 +43,31 @@ def get_meteor_id(obj): ).meteor_id +@transaction.atomic +def get_meteor_ids(model, object_ids): + """Return Alea ID mapping for all given ids of specified model.""" + content_type = ContentType.objects.get_for_model(model) + result = collections.OrderedDict( + (str(obj_pk), None) + for obj_pk + in object_ids + ) + for obj_pk, meteor_id in ObjectMapping.objects.filter( + content_type=content_type, + object_id__in=list(result) + ).values_list('object_id', 'meteor_id'): + result[obj_pk] = meteor_id + for obj_pk, meteor_id in result.items(): + if meteor_id is None: + # Django model._meta is now public API -> pylint: disable=W0212 + result[obj_pk] = ObjectMapping.objects.create( + content_type=content_type, + object_id=obj_pk, + meteor_id=meteor_random_id('/collection/%s' % model._meta), + ).meteor_id + return result + + @transaction.atomic def get_object_id(model, meteor_id): """Return an object ID for the given meteor_id.""" diff --git a/dddp/serializer.py b/dddp/serializer.py deleted file mode 100644 index 5ea05d4..0000000 --- a/dddp/serializer.py +++ /dev/null @@ -1,114 +0,0 @@ -""" -A Python "serializer". Doesn't do much serializing per se -- just converts to -and from basic Python data types (lists, dicts, strings, etc.). Useful as a basis for -other serializers. -""" -from __future__ import unicode_literals - -from django.apps import apps -from django.conf import settings -from django.core.serializers import base -from django.core.serializers import python -from django.db import DEFAULT_DB_ALIAS, models -from django.utils import six -from django.utils.encoding import force_text, is_protected_type -from dddp.models import get_meteor_id, get_object_id - - -class Serializer(python.Serializer): - """ - Serializes a QuerySet to basic Python objects. - """ - - def get_dump_object(self, obj): - data = super(Serializer, self).get_dump_object(obj) - data["pk"] = get_meteor_id(obj) - return data - - def handle_fk_field(self, obj, field): - value = getattr(obj, field.name) - self._current[field.column] = get_meteor_id(value) - - def handle_m2m_field(self, obj, field): - if field.rel.through._meta.auto_created: - m2m_value = lambda value: get_meteor_id(value) - self._current['%s_ids' % field.name] = [m2m_value(related) - for related in getattr(obj, field.name).iterator()] - - -def Deserializer(object_list, **options): - """ - Deserialize simple Python objects back into Django ORM instances. - - It's expected that you pass the Python objects themselves (instead of a - stream or a string) to the constructor - """ - db = options.pop('using', DEFAULT_DB_ALIAS) - ignore = options.pop('ignorenonexistent', False) - - for d in object_list: - # Look up the model and starting build a dict of data for it. - try: - Model = _get_model(d["model"]) - except base.DeserializationError: - if ignore: - continue - else: - raise - data = {} - if 'pk' in d: - data[Model._meta.pk.attname] = Model._meta.pk.to_python( - get_object_id(Model, d.get("pk", None)), - ) - m2m_data = {} - field_names = {f.name for f in Model._meta.fields} - field_name_map = { - f.column: f.name - for f in Model._meta.fields - } - for field in Model._meta.many_to_many: - field_name_map.setdefault('%s_ids' % field.name, field.name) - - # Handle each field - for (field_column, field_value) in six.iteritems(d["fields"]): - field_name = field_name_map.get(field_column, None) - - if ignore and field_name not in field_names: - # skip fields no longer on model - continue - - if isinstance(field_value, str): - field_value = force_text( - field_value, options.get("encoding", settings.DEFAULT_CHARSET), strings_only=True - ) - - field = Model._meta.get_field(field_name) - - # Handle M2M relations - if field.rel and isinstance(field.rel, models.ManyToManyRel): - m2m_data[field.name] = [get_object_id(field.rel.to, pk) for pk in field_value] - - # Handle FK fields - elif field.rel and isinstance(field.rel, models.ManyToOneRel): - if field_value is not None: - field_value= get_object_id(field.rel.to, field_value) - data[field.attname] = field.rel.to._meta.get_field(field.rel.field_name).to_python(field_value) - else: - data[field.attname] = None - - # Handle all other fields - else: - data[field.name] = field.to_python(field_value) - - obj = base.build_instance(Model, data, db) - yield base.DeserializedObject(obj, m2m_data) - - -def _get_model(model_identifier): - """ - Helper to look up a model from an "app_label.model_name" string. - """ - try: - return apps.get_model(model_identifier) - except (LookupError, TypeError): - raise base.DeserializationError("Invalid model identifier: '%s'" % model_identifier)