##// END OF EJS Templates
use zmq to protect subprocess stdout after fork
MinRK -
Show More
@@ -1,20 +1,28 b''
1 import sys
1 import sys
2 import time
2 import time
3 import os
3 import os
4 import threading
5 import uuid
4 from io import StringIO
6 from io import StringIO
5
7
8 import zmq
9
6 from session import extract_header, Message
10 from session import extract_header, Message
7
11
8 from IPython.utils import io, text
12 from IPython.utils import io, text
9 from IPython.utils import py3compat
13 from IPython.utils import py3compat
10
14
11 import multiprocessing as mp
15 import multiprocessing as mp
12 import multiprocessing.sharedctypes as mpshc
16 # import multiprocessing.sharedctypes as mpshc
13 from ctypes import c_bool
17 from ctypes import c_bool
14 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
15 # Globals
19 # Globals
16 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
17
21
22 MASTER_NO_CHILDREN = 0
23 MASTER_WITH_CHILDREN = 1
24 CHILD = 2
25
18 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
19 # Stream classes
27 # Stream classes
20 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
@@ -33,46 +41,105 b' class OutStream(object):'
33 self.name = name
41 self.name = name
34 self.parent_header = {}
42 self.parent_header = {}
35 self._new_buffer()
43 self._new_buffer()
36 self._manager = mp.Manager()
44 self._found_newprocess = threading.Event()
37 #use sharectype here so it don't have to hit the manager
45 self._buffer_lock = threading.Lock()
38 #no synchronize needed either(right?). Just a flag telling the master
46 self._master_pid = os.getpid()
39 #to switch the buffer to que
47 self._master_thread = threading.current_thread().ident
40 self._found_newprocess = mpshc.RawValue(c_bool, False)
48 self._pipe_pid = os.getpid()
41 self._que_buffer = self._manager.Queue()
49 self._setup_pipe_in()
42 self._que_lock = self._manager.Lock()
50
43 self._masterpid = os.getpid()
51 def _setup_pipe_in(self):
44 self._master_has_switched = False
52 """setup listening pipe for subprocesses"""
45
53 ctx = self._pipe_ctx = zmq.Context()
46 def _switch_to_que(self):
54
47 #should only be called on master process
55 # signal pair for terminating background thread
48 #don't clear the que before putting data in since
56 self._pipe_signaler = ctx.socket(zmq.PAIR)
49 #child process might have put something in the que before the
57 self._pipe_signalee = ctx.socket(zmq.PAIR)
50 #master know it.
58 self._pipe_signaler.bind("inproc://ostream_pipe")
51 self._que_buffer.put(self._buffer.getvalue())
59 self._pipe_signalee.connect("inproc://ostream_pipe")
52 self._new_buffer()
60 # thread event to signal cleanup is done
53 self._start = -1
61 self._pipe_done = threading.Event()
54
62
63 # use UUID to authenticate pipe messages
64 self._pipe_uuid = uuid.uuid4().bytes
65
66 self._pipe_thread = threading.Thread(target=self._pipe_main)
67 self._pipe_thread.start()
68
69 def _setup_pipe_out(self):
70 # must be new context after fork
71 ctx = zmq.Context()
72 self._pipe_pid = os.getpid()
73 self._pipe_out = ctx.socket(zmq.PUSH)
74 self._pipe_out_lock = threading.Lock()
75 self._pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port)
76
77 def _pipe_main(self):
78 """eventloop for receiving"""
79 ctx = self._pipe_ctx
80 self._pipe_in = ctx.socket(zmq.PULL)
81 self._pipe_port = self._pipe_in.bind_to_random_port("tcp://127.0.0.1")
82 poller = zmq.Poller()
83 poller.register(self._pipe_signalee, zmq.POLLIN)
84 poller.register(self._pipe_in, zmq.POLLIN)
85 while True:
86 if not self._is_master_process():
87 return
88 try:
89 events = dict(poller.poll(1000))
90 except zmq.ZMQError:
91 # should only be triggered by process-ending cleanup
92 return
93
94 if self._pipe_signalee in events:
95 break
96 if self._pipe_in in events:
97 msg = self._pipe_in.recv_multipart()
98 if msg[0] != self._pipe_uuid:
99 # message not authenticated
100 continue
101 self._found_newprocess.set()
102 text = msg[1].decode(self.encoding, 'replace')
103 with self._buffer_lock:
104 self._buffer.write(text)
105 if self._start < 0:
106 self._start = time.time()
107
108 # wrap it up
109 self._pipe_signaler.close()
110 self._pipe_signalee.close()
111 self._pipe_in.close()
112 self._pipe_ctx.term()
113 self._pipe_done.set()
114
115
116 def __del__(self):
117 if not self._is_master_process():
118 return
119 self._pipe_signaler.send(b'die')
120 self._pipe_done.wait(10)
121
55 def _is_master_process(self):
122 def _is_master_process(self):
56 return os.getpid()==self._masterpid
123 return os.getpid() == self._master_pid
57
124
58 def _debug_print(self,s):
125 def _is_master_thread(self):
59 sys.__stdout__.write(s+'\n')
126 return threading.current_thread().ident == self._master_thread
60 sys.__stdout__.flush()
127
128 def _have_pipe_out(self):
129 return os.getpid() == self._pipe_pid
61
130
62 def _check_mp_mode(self):
131 def _check_mp_mode(self):
63 """check multiprocess and switch to que if necessary"""
132 """check for forks, and switch to zmq pipeline if necessary"""
64 if not self._found_newprocess.value:
133 if self._is_master_process():
65 if not self._is_master_process():
134 if self._found_newprocess.is_set():
66 self._found_newprocess.value = True
135 return MASTER_WITH_CHILDREN
67 elif self._found_newprocess.value and not self._master_has_switched:
136 else:
68
137 return MASTER_NO_CHILDREN
69 #switch to que if it has not been switch
138 else:
70 if self._is_master_process():
139 if not self._have_pipe_out():
71 self._switch_to_que()
140 # setup a new out pipe
72 self._master_has_switched = True
141 self._setup_pipe_out()
73
142 return CHILD
74 return self._found_newprocess.value
75
76
143
77 def set_parent(self, parent):
144 def set_parent(self, parent):
78 self.parent_header = extract_header(parent)
145 self.parent_header = extract_header(parent)
@@ -81,20 +148,26 b' class OutStream(object):'
81 self.pub_socket = None
148 self.pub_socket = None
82
149
83 def flush(self):
150 def flush(self):
84 #io.rprint('>>>flushing output buffer: %s<<<' % self.name) # dbg
151 """trigger actual zmq send"""
85
86 if self.pub_socket is None:
152 if self.pub_socket is None:
87 raise ValueError(u'I/O operation on closed file')
153 raise ValueError(u'I/O operation on closed file')
88 else:
154 else:
89 if self._is_master_process():
155 if self._is_master_process():
156 if not self._is_master_thread():
157 # sub-threads mustn't trigger flush,
158 # but at least they can force the timer.
159 self._start = 0
90 data = u''
160 data = u''
91 #obtain data
161 # obtain data
92 if self._check_mp_mode():#multiprocess
162 if self._check_mp_mode(): # multiprocess, needs a lock
93 with self._que_lock:
163 with self._buffer_lock:
94 while not self._que_buffer.empty():
164 data = self._buffer.getvalue()
95 data += self._que_buffer.get()
165 self._buffer.close()
96 else:#single process mode
166 self._new_buffer()
167 else: # single process mode
97 data = self._buffer.getvalue()
168 data = self._buffer.getvalue()
169 self._buffer.close()
170 self._new_buffer()
98
171
99 if data:
172 if data:
100 content = {u'name':self.name, u'data':data}
173 content = {u'name':self.name, u'data':data}
@@ -104,10 +177,11 b' class OutStream(object):'
104 if hasattr(self.pub_socket, 'flush'):
177 if hasattr(self.pub_socket, 'flush'):
105 # socket itself has flush (presumably ZMQStream)
178 # socket itself has flush (presumably ZMQStream)
106 self.pub_socket.flush()
179 self.pub_socket.flush()
107 self._buffer.close()
108 self._new_buffer()
109 else:
180 else:
110 pass
181 self._check_mp_mode()
182 with self._pipe_out_lock:
183 tracker = self._pipe_out.send(b'', copy=False, track=True)
184 tracker.wait(1)
111
185
112
186
113 def isatty(self):
187 def isatty(self):
@@ -132,15 +206,23 b' class OutStream(object):'
132 # Make sure that we're handling unicode
206 # Make sure that we're handling unicode
133 if not isinstance(string, unicode):
207 if not isinstance(string, unicode):
134 string = string.decode(self.encoding, 'replace')
208 string = string.decode(self.encoding, 'replace')
135
209
136 if self._check_mp_mode(): #multi process mode
210 mp_mode = self._check_mp_mode()
137 with self._que_lock:
211 if mp_mode == CHILD:
138 self._que_buffer.put(string)
212 with self._pipe_out_lock:
139 else: #sigle process mode
213 self._pipe_out.send_multipart([
214 self._pipe_uuid,
215 string.encode(self.encoding, 'replace'),
216 ])
217 return
218 elif mp_mode == MASTER_NO_CHILDREN:
140 self._buffer.write(string)
219 self._buffer.write(string)
220 elif mp_mode == MASTER_WITH_CHILDREN:
221 with self._buffer_lock:
222 self._buffer.write(string)
141
223
142 current_time = time.time()
224 current_time = time.time()
143 if self._start <= 0:
225 if self._start < 0:
144 self._start = current_time
226 self._start = current_time
145 elif current_time - self._start > self.flush_interval:
227 elif current_time - self._start > self.flush_interval:
146 self.flush()
228 self.flush()
General Comments 0
You need to be logged in to leave comments. Login now