zmqhandlers.py
269 lines
| 9.0 KiB
| text/x-python
|
PythonLexer
MinRK
|
r18498 | # coding: utf-8 | ||
MinRK
|
r16697 | """Tornado handlers for WebSocket <-> ZMQ sockets.""" | ||
Brian E. Granger
|
r10653 | |||
MinRK
|
r16697 | # Copyright (c) IPython Development Team. | ||
# Distributed under the terms of the Modified BSD License. | ||||
Brian E. Granger
|
r10653 | |||
MinRK
|
r17021 | import json | ||
MinRK
|
r18329 | import struct | ||
MinRK
|
r17021 | |||
Thomas Kluyver
|
r13354 | try: | ||
Kyle Kelley
|
r14652 | from urllib.parse import urlparse # Py 3 | ||
Kyle Kelley
|
r14646 | except ImportError: | ||
Kyle Kelley
|
r14652 | from urlparse import urlparse # Py 2 | ||
Kyle Kelley
|
r14646 | |||
MinRK
|
r17106 | import tornado | ||
Min RK
|
r18522 | from tornado import gen, ioloop, web, websocket | ||
Brian E. Granger
|
r10653 | |||
from IPython.kernel.zmq.session import Session | ||||
MinRK
|
r18335 | from IPython.utils.jsonutil import date_default, extract_dates | ||
MinRK
|
r18498 | from IPython.utils.py3compat import cast_unicode | ||
Brian E. Granger
|
r10653 | |||
Brian E. Granger
|
r10667 | from .handlers import IPythonHandler | ||
Brian E. Granger
|
r10653 | |||
MinRK
|
r18329 | def serialize_binary_message(msg): | ||
"""serialize a message as a binary blob | ||||
Header: | ||||
4 bytes: number of msg parts (nbufs) as 32b int | ||||
4 * nbufs bytes: offset for each buffer as integer as 32b int | ||||
Offsets are from the start of the buffer, including the header. | ||||
Returns | ||||
------- | ||||
The message serialized to bytes. | ||||
""" | ||||
MinRK
|
r18335 | # don't modify msg or buffer list in-place | ||
msg = msg.copy() | ||||
buffers = list(msg.pop('buffers')) | ||||
MinRK
|
r18329 | bmsg = json.dumps(msg, default=date_default).encode('utf8') | ||
buffers.insert(0, bmsg) | ||||
nbufs = len(buffers) | ||||
offsets = [4 * (nbufs + 1)] | ||||
for buf in buffers[:-1]: | ||||
offsets.append(offsets[-1] + len(buf)) | ||||
MinRK
|
r18338 | offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets) | ||
MinRK
|
r18329 | buffers.insert(0, offsets_buf) | ||
return b''.join(buffers) | ||||
MinRK
|
r18330 | def deserialize_binary_message(bmsg): | ||
"""deserialize a message from a binary blog | ||||
MinRK
|
r18329 | |||
Header: | ||||
4 bytes: number of msg parts (nbufs) as 32b int | ||||
4 * nbufs bytes: offset for each buffer as integer as 32b int | ||||
Offsets are from the start of the buffer, including the header. | ||||
Returns | ||||
------- | ||||
message dictionary | ||||
""" | ||||
MinRK
|
r18335 | nbufs = struct.unpack('!i', bmsg[:4])[0] | ||
MinRK
|
r18338 | offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)])) | ||
MinRK
|
r18329 | offsets.append(None) | ||
bufs = [] | ||||
for start, stop in zip(offsets[:-1], offsets[1:]): | ||||
bufs.append(bmsg[start:stop]) | ||||
MinRK
|
r18335 | msg = json.loads(bufs[0].decode('utf8')) | ||
msg['header'] = extract_dates(msg['header']) | ||||
msg['parent_header'] = extract_dates(msg['parent_header']) | ||||
MinRK
|
r18329 | msg['buffers'] = bufs[1:] | ||
return msg | ||||
Brian E. Granger
|
r10653 | class ZMQStreamHandler(websocket.WebSocketHandler): | ||
MinRK
|
r17106 | |||
def check_origin(self, origin): | ||||
MinRK
|
r17116 | """Check Origin == Host or Access-Control-Allow-Origin. | ||
Tornado >= 4 calls this method automatically, raising 403 if it returns False. | ||||
We call it explicitly in `open` on Tornado < 4. | ||||
""" | ||||
if self.allow_origin == '*': | ||||
MinRK
|
r17106 | return True | ||
Kyle Kelley
|
r14732 | |||
Kyle Kelley
|
r14700 | host = self.request.headers.get("Host") | ||
Kyle Kelley
|
r14646 | |||
Kyle Kelley
|
r14703 | # If no header is provided, assume we can't verify origin | ||
MinRK
|
r17881 | if origin is None: | ||
self.log.warn("Missing Origin header, rejecting WebSocket connection.") | ||||
return False | ||||
if host is None: | ||||
self.log.warn("Missing Host header, rejecting WebSocket connection.") | ||||
MinRK
|
r17106 | return False | ||
MinRK
|
r17851 | origin = origin.lower() | ||
origin_host = urlparse(origin).netloc | ||||
MinRK
|
r17106 | |||
# OK if origin matches host | ||||
MinRK
|
r17851 | if origin_host == host: | ||
MinRK
|
r17106 | return True | ||
# Check CORS headers | ||||
MinRK
|
r17116 | if self.allow_origin: | ||
MinRK
|
r17881 | allow = self.allow_origin == origin | ||
MinRK
|
r17116 | elif self.allow_origin_pat: | ||
MinRK
|
r17881 | allow = bool(self.allow_origin_pat.match(origin)) | ||
MinRK
|
r17106 | else: | ||
MinRK
|
r17116 | # No CORS headers deny the request | ||
MinRK
|
r17881 | allow = False | ||
if not allow: | ||||
self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s", | ||||
origin, host, | ||||
) | ||||
return allow | ||||
Kyle Kelley
|
r14646 | |||
Brian E. Granger
|
r10653 | def clear_cookie(self, *args, **kwargs): | ||
"""meaningless for websockets""" | ||||
pass | ||||
def _reserialize_reply(self, msg_list): | ||||
"""Reserialize a reply message using JSON. | ||||
MinRK
|
r18330 | This takes the msg list from the ZMQ socket, deserializes it using | ||
Brian E. Granger
|
r10653 | self.session and then serializes the result using JSON. This method | ||
should be used by self._on_zmq_reply to build messages that can | ||||
be sent back to the browser. | ||||
""" | ||||
idents, msg_list = self.session.feed_identities(msg_list) | ||||
MinRK
|
r18330 | msg = self.session.deserialize(msg_list) | ||
MinRK
|
r18329 | if msg['buffers']: | ||
buf = serialize_binary_message(msg) | ||||
return buf | ||||
else: | ||||
smsg = json.dumps(msg, default=date_default) | ||||
return cast_unicode(smsg) | ||||
Brian E. Granger
|
r10653 | |||
def _on_zmq_reply(self, msg_list): | ||||
# Sometimes this gets triggered when the on_close method is scheduled in the | ||||
# eventloop but hasn't been called. | ||||
if self.stream.closed(): return | ||||
try: | ||||
msg = self._reserialize_reply(msg_list) | ||||
except Exception: | ||||
self.log.critical("Malformed message: %r" % msg_list, exc_info=True) | ||||
else: | ||||
MinRK
|
r18329 | self.write_message(msg, binary=isinstance(msg, bytes)) | ||
Brian E. Granger
|
r10653 | |||
def allow_draft76(self): | ||||
"""Allow draft 76, until browsers such as Safari update to RFC 6455. | ||||
This has been disabled by default in tornado in release 2.2.0, and | ||||
support will be removed in later versions. | ||||
""" | ||||
return True | ||||
MinRK
|
r17341 | # ping interval for keeping websockets alive (30 seconds) | ||
WS_PING_INTERVAL = 30000 | ||||
Brian E. Granger
|
r10653 | |||
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler): | ||||
MinRK
|
r17341 | ping_callback = None | ||
Richard Everson
|
r17841 | last_ping = 0 | ||
MinRK
|
r17635 | last_pong = 0 | ||
@property | ||||
def ping_interval(self): | ||||
"""The interval for websocket keep-alive pings. | ||||
Set ws_ping_interval = 0 to disable pings. | ||||
""" | ||||
return self.settings.get('ws_ping_interval', WS_PING_INTERVAL) | ||||
@property | ||||
def ping_timeout(self): | ||||
"""If no ping is received in this many milliseconds, | ||||
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). | ||||
Default is max of 3 pings or 30 seconds. | ||||
""" | ||||
return self.settings.get('ws_ping_timeout', | ||||
max(3 * self.ping_interval, WS_PING_INTERVAL) | ||||
) | ||||
MinRK
|
r17341 | |||
MinRK
|
r17106 | def set_default_headers(self): | ||
"""Undo the set_default_headers in IPythonHandler | ||||
which doesn't make sense for websockets | ||||
""" | ||||
pass | ||||
MinRK
|
r18277 | |||
Min RK
|
r18522 | def pre_get(self): | ||
"""Run before finishing the GET request | ||||
Extend this method to add logic that should fire before | ||||
the websocket finishes completing. | ||||
""" | ||||
Kyle Kelley
|
r14700 | # Check to see that origin matches host directly, including ports | ||
MinRK
|
r17106 | # Tornado 4 already does CORS checking | ||
if tornado.version_info[0] < 4: | ||||
if not self.check_origin(self.get_origin()): | ||||
MinRK
|
r17116 | raise web.HTTPError(403) | ||
MinRK
|
r18277 | |||
# authenticate the request before opening the websocket | ||||
if self.get_current_user() is None: | ||||
self.log.warn("Couldn't authenticate WebSocket connection") | ||||
raise web.HTTPError(403) | ||||
MinRK
|
r18307 | if self.get_argument('session_id', False): | ||
MinRK
|
r18277 | self.session.session = cast_unicode(self.get_argument('session_id')) | ||
else: | ||||
self.log.warn("No session ID specified") | ||||
Min RK
|
r18522 | |||
@gen.coroutine | ||||
def get(self, *args, **kwargs): | ||||
# pre_get can be a coroutine in subclasses | ||||
yield gen.maybe_future(self.pre_get()) | ||||
MinRK
|
r18498 | # FIXME: only do super get on tornado ≥ 4 | ||
# tornado 3 has no get, will raise 405 | ||||
if tornado.version_info >= (4,): | ||||
Min RK
|
r18522 | super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs) | ||
MinRK
|
r18277 | |||
def initialize(self): | ||||
Min RK
|
r18522 | self.log.debug("Initializing websocket connection %s", self.request.path) | ||
MinRK
|
r11105 | self.session = Session(config=self.config) | ||
MinRK
|
r18277 | |||
MinRK
|
r18497 | def open(self, *args, **kwargs): | ||
Min RK
|
r18522 | self.log.debug("Opening websocket %s", self.request.path) | ||
MinRK
|
r18498 | if tornado.version_info < (4,): | ||
try: | ||||
self.get(*self.open_args, **self.open_kwargs) | ||||
except web.HTTPError: | ||||
self.close() | ||||
raise | ||||
MinRK
|
r17635 | |||
# start the pinging | ||||
if self.ping_interval > 0: | ||||
Richard Everson
|
r17841 | self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping | ||
self.last_pong = self.last_ping | ||||
MinRK
|
r17635 | self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval) | ||
self.ping_callback.start() | ||||
MinRK
|
r17341 | |||
def send_ping(self): | ||||
"""send a ping to keep the websocket alive""" | ||||
if self.stream.closed() and self.ping_callback is not None: | ||||
self.ping_callback.stop() | ||||
return | ||||
MinRK
|
r17635 | |||
Richard Everson
|
r17841 | # check for timeout on pong. Make sure that we really have sent a recent ping in | ||
# case the machine with both server and client has been suspended since the last ping. | ||||
now = ioloop.IOLoop.instance().time() | ||||
since_last_pong = 1e3 * (now - self.last_pong) | ||||
since_last_ping = 1e3 * (now - self.last_ping) | ||||
if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout: | ||||
MinRK
|
r17635 | self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong) | ||
self.close() | ||||
return | ||||
MinRK
|
r17341 | |||
self.ping(b'') | ||||
Richard Everson
|
r17841 | self.last_ping = now | ||
MinRK
|
r17635 | def on_pong(self, data): | ||
self.last_pong = ioloop.IOLoop.instance().time() | ||||