##// END OF EJS Templates
use zmq to protect subprocess stdout after fork
MinRK -
Show More
@@ -1,20 +1,28
1 1 import sys
2 2 import time
3 3 import os
4 import threading
5 import uuid
4 6 from io import StringIO
5 7
8 import zmq
9
6 10 from session import extract_header, Message
7 11
8 12 from IPython.utils import io, text
9 13 from IPython.utils import py3compat
10 14
11 15 import multiprocessing as mp
12 import multiprocessing.sharedctypes as mpshc
16 # import multiprocessing.sharedctypes as mpshc
13 17 from ctypes import c_bool
14 18 #-----------------------------------------------------------------------------
15 19 # Globals
16 20 #-----------------------------------------------------------------------------
17 21
22 MASTER_NO_CHILDREN = 0
23 MASTER_WITH_CHILDREN = 1
24 CHILD = 2
25
18 26 #-----------------------------------------------------------------------------
19 27 # Stream classes
20 28 #-----------------------------------------------------------------------------
@@ -33,46 +41,105 class OutStream(object):
33 41 self.name = name
34 42 self.parent_header = {}
35 43 self._new_buffer()
36 self._manager = mp.Manager()
37 #use sharectype here so it don't have to hit the manager
38 #no synchronize needed either(right?). Just a flag telling the master
39 #to switch the buffer to que
40 self._found_newprocess = mpshc.RawValue(c_bool, False)
41 self._que_buffer = self._manager.Queue()
42 self._que_lock = self._manager.Lock()
43 self._masterpid = os.getpid()
44 self._master_has_switched = False
45
46 def _switch_to_que(self):
47 #should only be called on master process
48 #don't clear the que before putting data in since
49 #child process might have put something in the que before the
50 #master know it.
51 self._que_buffer.put(self._buffer.getvalue())
52 self._new_buffer()
53 self._start = -1
44 self._found_newprocess = threading.Event()
45 self._buffer_lock = threading.Lock()
46 self._master_pid = os.getpid()
47 self._master_thread = threading.current_thread().ident
48 self._pipe_pid = os.getpid()
49 self._setup_pipe_in()
50
51 def _setup_pipe_in(self):
52 """setup listening pipe for subprocesses"""
53 ctx = self._pipe_ctx = zmq.Context()
54
55 # signal pair for terminating background thread
56 self._pipe_signaler = ctx.socket(zmq.PAIR)
57 self._pipe_signalee = ctx.socket(zmq.PAIR)
58 self._pipe_signaler.bind("inproc://ostream_pipe")
59 self._pipe_signalee.connect("inproc://ostream_pipe")
60 # thread event to signal cleanup is done
61 self._pipe_done = threading.Event()
62
63 # use UUID to authenticate pipe messages
64 self._pipe_uuid = uuid.uuid4().bytes
65
66 self._pipe_thread = threading.Thread(target=self._pipe_main)
67 self._pipe_thread.start()
68
69 def _setup_pipe_out(self):
70 # must be new context after fork
71 ctx = zmq.Context()
72 self._pipe_pid = os.getpid()
73 self._pipe_out = ctx.socket(zmq.PUSH)
74 self._pipe_out_lock = threading.Lock()
75 self._pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port)
76
77 def _pipe_main(self):
78 """eventloop for receiving"""
79 ctx = self._pipe_ctx
80 self._pipe_in = ctx.socket(zmq.PULL)
81 self._pipe_port = self._pipe_in.bind_to_random_port("tcp://127.0.0.1")
82 poller = zmq.Poller()
83 poller.register(self._pipe_signalee, zmq.POLLIN)
84 poller.register(self._pipe_in, zmq.POLLIN)
85 while True:
86 if not self._is_master_process():
87 return
88 try:
89 events = dict(poller.poll(1000))
90 except zmq.ZMQError:
91 # should only be triggered by process-ending cleanup
92 return
93
94 if self._pipe_signalee in events:
95 break
96 if self._pipe_in in events:
97 msg = self._pipe_in.recv_multipart()
98 if msg[0] != self._pipe_uuid:
99 # message not authenticated
100 continue
101 self._found_newprocess.set()
102 text = msg[1].decode(self.encoding, 'replace')
103 with self._buffer_lock:
104 self._buffer.write(text)
105 if self._start < 0:
106 self._start = time.time()
107
108 # wrap it up
109 self._pipe_signaler.close()
110 self._pipe_signalee.close()
111 self._pipe_in.close()
112 self._pipe_ctx.term()
113 self._pipe_done.set()
114
115
116 def __del__(self):
117 if not self._is_master_process():
118 return
119 self._pipe_signaler.send(b'die')
120 self._pipe_done.wait(10)
54 121
55 122 def _is_master_process(self):
56 return os.getpid()==self._masterpid
123 return os.getpid() == self._master_pid
57 124
58 def _debug_print(self,s):
59 sys.__stdout__.write(s+'\n')
60 sys.__stdout__.flush()
125 def _is_master_thread(self):
126 return threading.current_thread().ident == self._master_thread
61 127
62 def _check_mp_mode(self):
63 """check multiprocess and switch to que if necessary"""
64 if not self._found_newprocess.value:
65 if not self._is_master_process():
66 self._found_newprocess.value = True
67 elif self._found_newprocess.value and not self._master_has_switched:
128 def _have_pipe_out(self):
129 return os.getpid() == self._pipe_pid
68 130
69 #switch to que if it has not been switch
131 def _check_mp_mode(self):
132 """check for forks, and switch to zmq pipeline if necessary"""
70 133 if self._is_master_process():
71 self._switch_to_que()
72 self._master_has_switched = True
73
74 return self._found_newprocess.value
75
134 if self._found_newprocess.is_set():
135 return MASTER_WITH_CHILDREN
136 else:
137 return MASTER_NO_CHILDREN
138 else:
139 if not self._have_pipe_out():
140 # setup a new out pipe
141 self._setup_pipe_out()
142 return CHILD
76 143
77 144 def set_parent(self, parent):
78 145 self.parent_header = extract_header(parent)
@@ -81,20 +148,26 class OutStream(object):
81 148 self.pub_socket = None
82 149
83 150 def flush(self):
84 #io.rprint('>>>flushing output buffer: %s<<<' % self.name) # dbg
85
151 """trigger actual zmq send"""
86 152 if self.pub_socket is None:
87 153 raise ValueError(u'I/O operation on closed file')
88 154 else:
89 155 if self._is_master_process():
156 if not self._is_master_thread():
157 # sub-threads mustn't trigger flush,
158 # but at least they can force the timer.
159 self._start = 0
90 160 data = u''
91 161 #obtain data
92 if self._check_mp_mode():#multiprocess
93 with self._que_lock:
94 while not self._que_buffer.empty():
95 data += self._que_buffer.get()
162 if self._check_mp_mode(): # multiprocess, needs a lock
163 with self._buffer_lock:
164 data = self._buffer.getvalue()
165 self._buffer.close()
166 self._new_buffer()
96 167 else:#single process mode
97 168 data = self._buffer.getvalue()
169 self._buffer.close()
170 self._new_buffer()
98 171
99 172 if data:
100 173 content = {u'name':self.name, u'data':data}
@@ -104,10 +177,11 class OutStream(object):
104 177 if hasattr(self.pub_socket, 'flush'):
105 178 # socket itself has flush (presumably ZMQStream)
106 179 self.pub_socket.flush()
107 self._buffer.close()
108 self._new_buffer()
109 180 else:
110 pass
181 self._check_mp_mode()
182 with self._pipe_out_lock:
183 tracker = self._pipe_out.send(b'', copy=False, track=True)
184 tracker.wait(1)
111 185
112 186
113 187 def isatty(self):
@@ -133,14 +207,22 class OutStream(object):
133 207 if not isinstance(string, unicode):
134 208 string = string.decode(self.encoding, 'replace')
135 209
136 if self._check_mp_mode(): #multi process mode
137 with self._que_lock:
138 self._que_buffer.put(string)
139 else: #sigle process mode
210 mp_mode = self._check_mp_mode()
211 if mp_mode == CHILD:
212 with self._pipe_out_lock:
213 self._pipe_out.send_multipart([
214 self._pipe_uuid,
215 string.encode(self.encoding, 'replace'),
216 ])
217 return
218 elif mp_mode == MASTER_NO_CHILDREN:
219 self._buffer.write(string)
220 elif mp_mode == MASTER_WITH_CHILDREN:
221 with self._buffer_lock:
140 222 self._buffer.write(string)
141 223
142 224 current_time = time.time()
143 if self._start <= 0:
225 if self._start < 0:
144 226 self._start = current_time
145 227 elif current_time - self._start > self.flush_interval:
146 228 self.flush()
General Comments 0
You need to be logged in to leave comments. Login now