From 20dd4dfe8eaf40d4b0159eb6a8e002e23660b7bc 2012-07-21 06:12:17 From: MinRK Date: 2012-07-21 06:12:17 Subject: [PATCH] add copy_threshold to limit use of zero-copy to sufficiently large messages threshold is triggered on the largest part of the message, and copy is applied to all parts, to avoid split-sends. --- diff --git a/IPython/zmq/session.py b/IPython/zmq/session.py index 26569fe..787fb79 100644 --- a/IPython/zmq/session.py +++ b/IPython/zmq/session.py @@ -49,7 +49,7 @@ from IPython.utils.importstring import import_item from IPython.utils.jsonutil import extract_dates, squash_dates, date_default from IPython.utils.py3compat import str_to_bytes from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set, - DottedObjectName, CUnicode, Dict) + DottedObjectName, CUnicode, Dict, Int) #----------------------------------------------------------------------------- # utility functions @@ -84,8 +84,9 @@ pickle_unpacker = pickle.loads default_packer = json_packer default_unpacker = json_unpacker -DELIM=b"" - +DELIM = b"" +# singleton dummy tracker, which will always report as done +DONE = zmq.MessageTracker() #----------------------------------------------------------------------------- # Mixin tools for apps that use Sessions @@ -329,6 +330,9 @@ class Session(Configurable): # unpacker is not checked - it is assumed to be if not callable(new): raise TypeError("unpacker must be callable, not %s"%type(new)) + + copy_threshold = Int(2**12, config=True, + help="Threshold (in bytes) beyond which a buffer should be sent without copying.") def __init__(self, **kwargs): """create a Session object @@ -544,11 +548,6 @@ class Session(Configurable): ------- msg : dict The constructed message. - (msg,tracker) : (dict, MessageTracker) - if track=True, then a 2-tuple will be returned, - the first element being the constructed - message, and the second being the MessageTracker - """ if not isinstance(stream, (zmq.Socket, ZMQStream)): @@ -566,25 +565,18 @@ class Session(Configurable): buffers = [] if buffers is None else buffers to_send = self.serialize(msg, ident) - flag = 0 - if buffers: - flag = zmq.SNDMORE - _track = False + to_send.extend(buffers) + longest = max([ len(s) for s in to_send ]) + copy = (longest > self.copy_threshold) + + if buffers and track and not copy: + # only really track when we are doing zero-copy buffers + tracker = stream.send_multipart(to_send, copy=False, track=True) else: - _track=track - if track: - tracker = stream.send_multipart(to_send, flag, copy=False, track=_track) - else: - tracker = stream.send_multipart(to_send, flag, copy=False) - for b in buffers[:-1]: - stream.send(b, flag, copy=False) - if buffers: - if track: - tracker = stream.send(buffers[-1], copy=False, track=track) - else: - tracker = stream.send(buffers[-1], copy=False) + # use dummy tracker, which will be done immediately + tracker = DONE + stream.send_multipart(to_send, copy=copy) - # omsg = Message(msg) if self.debug: pprint.pprint(msg) pprint.pprint(to_send)