diff --git a/IPython/kernel/zmq/iostream.py b/IPython/kernel/zmq/iostream.py index 9523e5b..93f6f6b 100644 --- a/IPython/kernel/zmq/iostream.py +++ b/IPython/kernel/zmq/iostream.py @@ -41,7 +41,7 @@ class OutStream(object): self.name = name self.parent_header = {} self._new_buffer() - self._found_newprocess = threading.Event() + self._found_newprocess = 0 self._buffer_lock = threading.Lock() self._master_pid = os.getpid() self._master_thread = threading.current_thread().ident @@ -50,21 +50,16 @@ class OutStream(object): def _setup_pipe_in(self): """setup listening pipe for subprocesses""" - ctx = self._pipe_ctx = zmq.Context() - - # signal pair for terminating background thread - self._pipe_signaler = ctx.socket(zmq.PAIR) - self._pipe_signalee = ctx.socket(zmq.PAIR) - self._pipe_signaler.bind("inproc://ostream_pipe") - self._pipe_signalee.connect("inproc://ostream_pipe") - # thread event to signal cleanup is done - self._pipe_done = threading.Event() + ctx = self.pub_socket.context # use UUID to authenticate pipe messages self._pipe_uuid = uuid.uuid4().bytes - self._pipe_thread = threading.Thread(target=self._pipe_main) - self._pipe_thread.start() + self._pipe_in = ctx.socket(zmq.PULL) + self._pipe_in.linger = 0 + self._pipe_port = self._pipe_in.bind_to_random_port("tcp://127.0.0.1") + self._pipe_poller = zmq.Poller() + self._pipe_poller.register(self._pipe_in, zmq.POLLIN) def _setup_pipe_out(self): # must be new context after fork @@ -74,51 +69,6 @@ class OutStream(object): self._pipe_out_lock = threading.Lock() self._pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port) - def _pipe_main(self): - """eventloop for receiving""" - ctx = self._pipe_ctx - self._pipe_in = ctx.socket(zmq.PULL) - self._pipe_port = self._pipe_in.bind_to_random_port("tcp://127.0.0.1") - poller = zmq.Poller() - poller.register(self._pipe_signalee, zmq.POLLIN) - poller.register(self._pipe_in, zmq.POLLIN) - while True: - if not self._is_master_process(): - return - try: - events = dict(poller.poll(1000)) - except zmq.ZMQError: - # should only be triggered by process-ending cleanup - return - - if self._pipe_signalee in events: - break - if self._pipe_in in events: - msg = self._pipe_in.recv_multipart() - if msg[0] != self._pipe_uuid: - # message not authenticated - continue - self._found_newprocess.set() - text = msg[1].decode(self.encoding, 'replace') - with self._buffer_lock: - self._buffer.write(text) - if self._start < 0: - self._start = time.time() - - # wrap it up - self._pipe_signaler.close() - self._pipe_signalee.close() - self._pipe_in.close() - self._pipe_ctx.term() - self._pipe_done.set() - - - def __del__(self): - if not self._is_master_process(): - return - self._pipe_signaler.send(b'die') - self._pipe_done.wait(10) - def _is_master_process(self): return os.getpid() == self._master_pid @@ -131,7 +81,7 @@ class OutStream(object): def _check_mp_mode(self): """check for forks, and switch to zmq pipeline if necessary""" if self._is_master_process(): - if self._found_newprocess.is_set(): + if self._found_newprocess: return MASTER_WITH_CHILDREN else: return MASTER_NO_CHILDREN @@ -147,42 +97,57 @@ class OutStream(object): def close(self): self.pub_socket = None + def _flush_from_subprocesses(self): + """flush possible pub data from subprocesses into my buffer""" + if not self._is_master_process(): + return + for i in range(100): + if self._pipe_poller.poll(0): + msg = self._pipe_in.recv_multipart() + if msg[0] != self._pipe_uuid: + continue + else: + self._buffer.write(msg[1].decode(self.encoding, 'replace')) + # this always means a flush, + # so reset our timer + self._start = 0 + else: + break + def flush(self): """trigger actual zmq send""" if self.pub_socket is None: raise ValueError(u'I/O operation on closed file') + + mp_mode = self._check_mp_mode() + + if mp_mode != CHILD: + # we are master + if not self._is_master_thread(): + # sub-threads must not trigger flush, + # but at least they can force the timer. + self._start = 0 + return + + self._flush_from_subprocesses() + data = self._flush_buffer() + + if data: + content = {u'name':self.name, u'data':data} + msg = self.session.send(self.pub_socket, u'stream', content=content, + parent=self.parent_header, ident=self.topic) + + if hasattr(self.pub_socket, 'flush'): + # socket itself has flush (presumably ZMQStream) + self.pub_socket.flush() else: - if self._is_master_process(): - if not self._is_master_thread(): - # sub-threads mustn't trigger flush, - # but at least they can force the timer. - self._start = 0 - data = u'' - # obtain data - if self._check_mp_mode(): # multiprocess, needs a lock - with self._buffer_lock: - data = self._buffer.getvalue() - self._buffer.close() - self._new_buffer() - else: # single process mode - data = self._buffer.getvalue() - self._buffer.close() - self._new_buffer() - - if data: - content = {u'name':self.name, u'data':data} - msg = self.session.send(self.pub_socket, u'stream', content=content, - parent=self.parent_header, ident=self.topic) - - if hasattr(self.pub_socket, 'flush'): - # socket itself has flush (presumably ZMQStream) - self.pub_socket.flush() - else: - self._check_mp_mode() - with self._pipe_out_lock: - tracker = self._pipe_out.send(b'', copy=False, track=True) - tracker.wait(1) - + with self._pipe_out_lock: + string = self._flush_buffer() + tracker = self._pipe_out.send_multipart([ + self._pipe_uuid, + string.encode(self.encoding, 'replace'), + ], copy=False, track=True) + tracker.wait(1) def isatty(self): return False @@ -206,21 +171,10 @@ class OutStream(object): # Make sure that we're handling unicode if not isinstance(string, unicode): string = string.decode(self.encoding, 'replace') - - mp_mode = self._check_mp_mode() - if mp_mode == CHILD: - with self._pipe_out_lock: - self._pipe_out.send_multipart([ - self._pipe_uuid, - string.encode(self.encoding, 'replace'), - ]) - return - elif mp_mode == MASTER_NO_CHILDREN: - self._buffer.write(string) - elif mp_mode == MASTER_WITH_CHILDREN: - with self._buffer_lock: - self._buffer.write(string) - + self._buffer.write(string) + self._check_mp_mode() + # do we want to check subprocess flushes on write? + # self._flush_from_subprocesses() current_time = time.time() if self._start < 0: self._start = current_time @@ -234,6 +188,16 @@ class OutStream(object): for string in sequence: self.write(string) + def _flush_buffer(self): + """clear the current buffer and return the current buffer data""" + data = u'' + if self._buffer is not None: + data = self._buffer.getvalue() + self._buffer.close() + self._new_buffer() + return data + def _new_buffer(self): self._buffer = StringIO() self._start = -1 + diff --git a/IPython/zmq/tests/test_kernel.py b/IPython/zmq/tests/test_kernel.py index f684145..b6875f9 100644 --- a/IPython/zmq/tests/test_kernel.py +++ b/IPython/zmq/tests/test_kernel.py @@ -151,7 +151,7 @@ def test_subprocess_print(): for n in range(np): nt.assert_equal(stdout.count(str(n)), 1, stdout) nt.assert_equal(stderr, '') - _check_mp_mode(km, expected=True) + _check_mp_mode(km, expected=False) _check_mp_mode(km, expected=False, stream="stderr") @@ -199,5 +199,5 @@ def test_subprocess_error(): nt.assert_true("ZeroDivisionError" in stderr, stderr) _check_mp_mode(km, expected=False) - _check_mp_mode(km, expected=True, stream="stderr") + _check_mp_mode(km, expected=False, stream="stderr")