"""Django DDP WebSocket service.""" from __future__ import absolute_import, print_function import atexit import collections import inspect import itertools import sys import traceback from six.moves import range as irange import ejson import geventwebsocket from django.core.handlers.base import BaseHandler from django.core.handlers.wsgi import WSGIRequest from django.db import connection, transaction from dddp import THREAD_LOCAL as this, alea, ADDED, CHANGED, REMOVED class MeteorError(Exception): """MeteorError.""" pass def validate_kwargs(func, kwargs): """Validate arguments to be supplied to func.""" func_name = func.__name__ argspec = inspect.getargspec(func) all_args = argspec.args[:] defaults = list(argspec.defaults or []) # ignore implicit 'self' argument if inspect.ismethod(func) and all_args[:1] == ['self']: all_args[:1] = [] # don't require arguments that have defaults if defaults: required = all_args[:-len(defaults)] else: required = all_args[:] # translate 'foo_' to avoid reserved names like 'id' trans = { arg: arg.endswith('_') and arg[:-1] or arg for arg in all_args } for key in list(kwargs): key_adj = '%s_' % key if key_adj in all_args: kwargs[key_adj] = kwargs.pop(key) # figure out what we're missing supplied = sorted(kwargs) missing = [ trans.get(arg, arg) for arg in required if arg not in supplied ] if missing: raise MeteorError( 'Missing required arguments to %s: %s' % ( func_name, ' '.join(missing), ), getattr(func, 'err', None), ) # figure out what is extra extra = [ arg for arg in supplied if arg not in all_args ] if extra: raise MeteorError( 'Unknown arguments to %s: %s' % ( func_name, ' '.join(extra), ), ) class DDPWebSocketApplication(geventwebsocket.WebSocketApplication): """Django DDP WebSocket application.""" _tx_buffer = None _tx_buffer_id_gen = None _tx_next_id_gen = None _tx_next_id = None methods = {} versions = [ # first item is preferred version '1', 'pre1', 'pre2', ] api = None logger = None pgworker = None remote_addr = None version = None support = None connection = None remote_ids = None base_handler = BaseHandler() def get_tx_id(self): """Get the next TX msg ID.""" return next(self._tx_buffer_id_gen) def on_open(self): """Handle new websocket connection.""" this.request = WSGIRequest(self.ws.environ) this.ws = self this.send = self.send this.reply = self.reply this.error = self.error self.logger = self.ws.logger self.remote_ids = collections.defaultdict(set) # self._tx_buffer collects outgoing messages which must be sent in order self._tx_buffer = {} # track the head of the queue (buffer) and the next msg to be sent self._tx_buffer_id_gen = itertools.cycle(irange(sys.maxint)) self._tx_next_id_gen = itertools.cycle(irange(sys.maxint)) # start by waiting for the very first message self._tx_next_id = next(self._tx_next_id_gen) this.remote_addr = self.remote_addr = \ '{0[REMOTE_ADDR]}:{0[REMOTE_PORT]}'.format( self.ws.environ, ) this.subs = {} self.logger.info('+ %s OPEN', self) self.send('o') self.send('a["{\\"server_id\\":\\"0\\"}"]') def __str__(self): """Show remote address that connected to us.""" return self.remote_addr def on_close(self, *args, **kwargs): """Handle closing of websocket connection.""" if self.connection is not None: del self.pgworker.connections[self.connection.pk] self.connection.delete() self.connection = None self.logger.info('- %s %s', self, args or 'CLOSE') def on_message(self, message): """Process a message received from remote.""" if self.ws.closed: return None try: self.logger.debug('< %s %r', self, message) # parse message set try: msgs = ejson.loads(message) except ValueError as err: self.error(400, 'Data is not valid EJSON') return if not isinstance(msgs, list): self.error(400, 'Invalid EJSON messages') return # process individual messages while msgs: # parse message payload raw = msgs.pop(0) try: data = ejson.loads(raw) except (TypeError, ValueError) as err: self.error(400, 'Data is not valid EJSON') continue if not isinstance(data, dict): self.error(400, 'Invalid EJSON message payload', raw) continue try: msg = data.pop('msg') except KeyError: self.error(400, 'Bad request', offendingMessage=data) continue # dispatch message try: self.dispatch(msg, data) except MeteorError as err: self.error(err) except Exception as err: traceback.print_exc() self.error(err) except geventwebsocket.WebSocketError as err: self.ws.close() @transaction.atomic def dispatch(self, msg, kwargs): """Dispatch msg to appropriate recv_foo handler.""" # enforce calling 'connect' first if self.connection is None and msg != 'connect': self.error(400, 'Must connect first') return # lookup method handler try: handler = getattr(self, 'recv_%s' % msg) except (AttributeError, UnicodeEncodeError): print('Method not found: %s %r' % (msg, kwargs)) self.error(404, 'Method not found', msg='result') return # validate handler arguments validate_kwargs(handler, kwargs) # dispatch to handler try: handler(**kwargs) except Exception as err: # print stack trace --> pylint: disable=W0703 traceback.print_exc() self.error(500, 'Internal server error', err) def send(self, data, tx_id=None): """Send `data` (raw string or EJSON payload) to WebSocket client.""" # buffer data until we get pre-requisite data if tx_id is None: tx_id = self.get_tx_id() self._tx_buffer[tx_id] = data # de-queue messages from buffer while self._tx_next_id in self._tx_buffer: # pull next message from buffer data = self._tx_buffer.pop(self._tx_next_id) if self._tx_buffer: self.logger.debug('TX found %d', self._tx_next_id) # advance next message ID self._tx_next_id = next(self._tx_next_id_gen) if not isinstance(data, basestring): # ejson payload msg = data.get('msg', None) if msg in (ADDED, CHANGED, REMOVED): ids = self.remote_ids[data['collection']] meteor_id = data['id'] if msg == ADDED: if meteor_id in ids: msg = data['msg'] = CHANGED else: ids.add(meteor_id) elif msg == CHANGED: if meteor_id not in ids: # object has become visible, treat as `added`. msg = data['msg'] = ADDED ids.add(meteor_id) elif msg == REMOVED: try: ids.remove(meteor_id) except KeyError: continue # client doesn't have this, don't send. data = 'a%s' % ejson.dumps([ejson.dumps(data)]) # send message self.logger.debug('> %s %r', self, data) try: self.ws.send(data) except geventwebsocket.WebSocketError: self.ws.close() break num_waiting = len(self._tx_buffer) if num_waiting > 10: self.logger.warn( 'TX received %d, waiting for %d, have %d waiting: %r.', tx_id, self._tx_next_id, num_waiting, self._tx_buffer, ) def reply(self, msg, **kwargs): """Send EJSON reply to remote.""" kwargs['msg'] = msg self.send(kwargs) def error( self, err, reason=None, detail=None, msg='error', exc_info=1, **kwargs ): """Send EJSON error to remote.""" if isinstance(err, MeteorError): ( err, reason, detail, kwargs, ) = ( err.args[:] + (None, None, None, None) )[:4] elif isinstance(err, Exception): reason = str(err) data = { 'error': '%s' % (err or ''), } if reason: if reason is Exception: reason = str(reason) data['reason'] = reason if detail: if isinstance(detail, Exception): detail = str(detail) data['detail'] = detail if kwargs: data.update(kwargs) record = { 'extra': { 'request': this.request, }, } self.logger.error('! %s %r', self, data, exc_info=exc_info, **record) self.reply(msg, **data) def recv_connect(self, version=None, support=None, session=None): """DDP connect handler.""" del session # Meteor doesn't even use this! if self.connection is not None: self.error( 400, 'Session already established.', detail=self.connection.connection_id, ) elif None in (version, support) or version not in self.versions: self.reply('failed', version=self.versions[0]) elif version not in support: self.error(400, 'Client version/support mismatch.') else: from dddp.models import Connection cur = connection.cursor() cur.execute('SELECT pg_backend_pid()') (backend_pid,) = cur.fetchone() this.version = version this.support = support self.connection = Connection.objects.create( server_addr='%d:%s' % ( backend_pid, self.ws.handler.socket.getsockname(), ), remote_addr=self.remote_addr, version=version, ) self.pgworker.connections[self.connection.pk] = self atexit.register(self.on_close, 'Shutting down.') self.reply('connected', session=self.connection.connection_id) def recv_ping(self, id_=None): """DDP ping handler.""" if id_ is None: self.reply('pong') else: self.reply('pong', id=id_) def recv_sub(self, id_, name, params): """DDP sub handler.""" self.api.sub(id_, name, *params) recv_sub.err = 'Malformed subscription' def recv_unsub(self, id_=None): """DDP unsub handler.""" if id_: self.api.unsub(id_) else: self.reply('nosub') def recv_method(self, method, params, id_, randomSeed=None): """DDP method handler.""" if randomSeed is not None: this.random_streams.random_seed = randomSeed this.alea_random = alea.Alea(randomSeed) self.api.method(method, params, id_) self.reply('updated', methods=[id_]) recv_method.err = 'Malformed method invocation'