##// END OF EJS Templates
commandserver: get around ETIMEDOUT raised by selectors2...
Yuya Nishihara -
r40833:41f0529b stable
parent child Browse files
Show More
@@ -1,538 +1,546 b''
1 # commandserver.py - communicate with Mercurial's API over a pipe
1 # commandserver.py - communicate with Mercurial's API over a pipe
2 #
2 #
3 # Copyright Matt Mackall <mpm@selenic.com>
3 # Copyright Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import errno
10 import errno
11 import gc
11 import gc
12 import os
12 import os
13 import random
13 import random
14 import signal
14 import signal
15 import socket
15 import socket
16 import struct
16 import struct
17 import traceback
17 import traceback
18
18
19 try:
19 try:
20 import selectors
20 import selectors
21 selectors.BaseSelector
21 selectors.BaseSelector
22 except ImportError:
22 except ImportError:
23 from .thirdparty import selectors2 as selectors
23 from .thirdparty import selectors2 as selectors
24
24
25 from .i18n import _
25 from .i18n import _
26 from . import (
26 from . import (
27 encoding,
27 encoding,
28 error,
28 error,
29 util,
29 util,
30 )
30 )
31 from .utils import (
31 from .utils import (
32 procutil,
32 procutil,
33 )
33 )
34
34
35 logfile = None
35 logfile = None
36
36
37 def log(*args):
37 def log(*args):
38 if not logfile:
38 if not logfile:
39 return
39 return
40
40
41 for a in args:
41 for a in args:
42 logfile.write(str(a))
42 logfile.write(str(a))
43
43
44 logfile.flush()
44 logfile.flush()
45
45
46 class channeledoutput(object):
46 class channeledoutput(object):
47 """
47 """
48 Write data to out in the following format:
48 Write data to out in the following format:
49
49
50 data length (unsigned int),
50 data length (unsigned int),
51 data
51 data
52 """
52 """
53 def __init__(self, out, channel):
53 def __init__(self, out, channel):
54 self.out = out
54 self.out = out
55 self.channel = channel
55 self.channel = channel
56
56
57 @property
57 @property
58 def name(self):
58 def name(self):
59 return '<%c-channel>' % self.channel
59 return '<%c-channel>' % self.channel
60
60
61 def write(self, data):
61 def write(self, data):
62 if not data:
62 if not data:
63 return
63 return
64 # single write() to guarantee the same atomicity as the underlying file
64 # single write() to guarantee the same atomicity as the underlying file
65 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
65 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
66 self.out.flush()
66 self.out.flush()
67
67
68 def __getattr__(self, attr):
68 def __getattr__(self, attr):
69 if attr in (r'isatty', r'fileno', r'tell', r'seek'):
69 if attr in (r'isatty', r'fileno', r'tell', r'seek'):
70 raise AttributeError(attr)
70 raise AttributeError(attr)
71 return getattr(self.out, attr)
71 return getattr(self.out, attr)
72
72
73 class channeledinput(object):
73 class channeledinput(object):
74 """
74 """
75 Read data from in_.
75 Read data from in_.
76
76
77 Requests for input are written to out in the following format:
77 Requests for input are written to out in the following format:
78 channel identifier - 'I' for plain input, 'L' line based (1 byte)
78 channel identifier - 'I' for plain input, 'L' line based (1 byte)
79 how many bytes to send at most (unsigned int),
79 how many bytes to send at most (unsigned int),
80
80
81 The client replies with:
81 The client replies with:
82 data length (unsigned int), 0 meaning EOF
82 data length (unsigned int), 0 meaning EOF
83 data
83 data
84 """
84 """
85
85
86 maxchunksize = 4 * 1024
86 maxchunksize = 4 * 1024
87
87
88 def __init__(self, in_, out, channel):
88 def __init__(self, in_, out, channel):
89 self.in_ = in_
89 self.in_ = in_
90 self.out = out
90 self.out = out
91 self.channel = channel
91 self.channel = channel
92
92
93 @property
93 @property
94 def name(self):
94 def name(self):
95 return '<%c-channel>' % self.channel
95 return '<%c-channel>' % self.channel
96
96
97 def read(self, size=-1):
97 def read(self, size=-1):
98 if size < 0:
98 if size < 0:
99 # if we need to consume all the clients input, ask for 4k chunks
99 # if we need to consume all the clients input, ask for 4k chunks
100 # so the pipe doesn't fill up risking a deadlock
100 # so the pipe doesn't fill up risking a deadlock
101 size = self.maxchunksize
101 size = self.maxchunksize
102 s = self._read(size, self.channel)
102 s = self._read(size, self.channel)
103 buf = s
103 buf = s
104 while s:
104 while s:
105 s = self._read(size, self.channel)
105 s = self._read(size, self.channel)
106 buf += s
106 buf += s
107
107
108 return buf
108 return buf
109 else:
109 else:
110 return self._read(size, self.channel)
110 return self._read(size, self.channel)
111
111
112 def _read(self, size, channel):
112 def _read(self, size, channel):
113 if not size:
113 if not size:
114 return ''
114 return ''
115 assert size > 0
115 assert size > 0
116
116
117 # tell the client we need at most size bytes
117 # tell the client we need at most size bytes
118 self.out.write(struct.pack('>cI', channel, size))
118 self.out.write(struct.pack('>cI', channel, size))
119 self.out.flush()
119 self.out.flush()
120
120
121 length = self.in_.read(4)
121 length = self.in_.read(4)
122 length = struct.unpack('>I', length)[0]
122 length = struct.unpack('>I', length)[0]
123 if not length:
123 if not length:
124 return ''
124 return ''
125 else:
125 else:
126 return self.in_.read(length)
126 return self.in_.read(length)
127
127
128 def readline(self, size=-1):
128 def readline(self, size=-1):
129 if size < 0:
129 if size < 0:
130 size = self.maxchunksize
130 size = self.maxchunksize
131 s = self._read(size, 'L')
131 s = self._read(size, 'L')
132 buf = s
132 buf = s
133 # keep asking for more until there's either no more or
133 # keep asking for more until there's either no more or
134 # we got a full line
134 # we got a full line
135 while s and s[-1] != '\n':
135 while s and s[-1] != '\n':
136 s = self._read(size, 'L')
136 s = self._read(size, 'L')
137 buf += s
137 buf += s
138
138
139 return buf
139 return buf
140 else:
140 else:
141 return self._read(size, 'L')
141 return self._read(size, 'L')
142
142
143 def __iter__(self):
143 def __iter__(self):
144 return self
144 return self
145
145
146 def next(self):
146 def next(self):
147 l = self.readline()
147 l = self.readline()
148 if not l:
148 if not l:
149 raise StopIteration
149 raise StopIteration
150 return l
150 return l
151
151
152 __next__ = next
152 __next__ = next
153
153
154 def __getattr__(self, attr):
154 def __getattr__(self, attr):
155 if attr in (r'isatty', r'fileno', r'tell', r'seek'):
155 if attr in (r'isatty', r'fileno', r'tell', r'seek'):
156 raise AttributeError(attr)
156 raise AttributeError(attr)
157 return getattr(self.in_, attr)
157 return getattr(self.in_, attr)
158
158
159 class server(object):
159 class server(object):
160 """
160 """
161 Listens for commands on fin, runs them and writes the output on a channel
161 Listens for commands on fin, runs them and writes the output on a channel
162 based stream to fout.
162 based stream to fout.
163 """
163 """
164 def __init__(self, ui, repo, fin, fout):
164 def __init__(self, ui, repo, fin, fout):
165 self.cwd = encoding.getcwd()
165 self.cwd = encoding.getcwd()
166
166
167 # developer config: cmdserver.log
167 # developer config: cmdserver.log
168 logpath = ui.config("cmdserver", "log")
168 logpath = ui.config("cmdserver", "log")
169 if logpath:
169 if logpath:
170 global logfile
170 global logfile
171 if logpath == '-':
171 if logpath == '-':
172 # write log on a special 'd' (debug) channel
172 # write log on a special 'd' (debug) channel
173 logfile = channeledoutput(fout, 'd')
173 logfile = channeledoutput(fout, 'd')
174 else:
174 else:
175 logfile = open(logpath, 'a')
175 logfile = open(logpath, 'a')
176
176
177 if repo:
177 if repo:
178 # the ui here is really the repo ui so take its baseui so we don't
178 # the ui here is really the repo ui so take its baseui so we don't
179 # end up with its local configuration
179 # end up with its local configuration
180 self.ui = repo.baseui
180 self.ui = repo.baseui
181 self.repo = repo
181 self.repo = repo
182 self.repoui = repo.ui
182 self.repoui = repo.ui
183 else:
183 else:
184 self.ui = ui
184 self.ui = ui
185 self.repo = self.repoui = None
185 self.repo = self.repoui = None
186
186
187 self.cerr = channeledoutput(fout, 'e')
187 self.cerr = channeledoutput(fout, 'e')
188 self.cout = channeledoutput(fout, 'o')
188 self.cout = channeledoutput(fout, 'o')
189 self.cin = channeledinput(fin, fout, 'I')
189 self.cin = channeledinput(fin, fout, 'I')
190 self.cresult = channeledoutput(fout, 'r')
190 self.cresult = channeledoutput(fout, 'r')
191
191
192 self.client = fin
192 self.client = fin
193
193
194 def cleanup(self):
194 def cleanup(self):
195 """release and restore resources taken during server session"""
195 """release and restore resources taken during server session"""
196
196
197 def _read(self, size):
197 def _read(self, size):
198 if not size:
198 if not size:
199 return ''
199 return ''
200
200
201 data = self.client.read(size)
201 data = self.client.read(size)
202
202
203 # is the other end closed?
203 # is the other end closed?
204 if not data:
204 if not data:
205 raise EOFError
205 raise EOFError
206
206
207 return data
207 return data
208
208
209 def _readstr(self):
209 def _readstr(self):
210 """read a string from the channel
210 """read a string from the channel
211
211
212 format:
212 format:
213 data length (uint32), data
213 data length (uint32), data
214 """
214 """
215 length = struct.unpack('>I', self._read(4))[0]
215 length = struct.unpack('>I', self._read(4))[0]
216 if not length:
216 if not length:
217 return ''
217 return ''
218 return self._read(length)
218 return self._read(length)
219
219
220 def _readlist(self):
220 def _readlist(self):
221 """read a list of NULL separated strings from the channel"""
221 """read a list of NULL separated strings from the channel"""
222 s = self._readstr()
222 s = self._readstr()
223 if s:
223 if s:
224 return s.split('\0')
224 return s.split('\0')
225 else:
225 else:
226 return []
226 return []
227
227
228 def runcommand(self):
228 def runcommand(self):
229 """ reads a list of \0 terminated arguments, executes
229 """ reads a list of \0 terminated arguments, executes
230 and writes the return code to the result channel """
230 and writes the return code to the result channel """
231 from . import dispatch # avoid cycle
231 from . import dispatch # avoid cycle
232
232
233 args = self._readlist()
233 args = self._readlist()
234
234
235 # copy the uis so changes (e.g. --config or --verbose) don't
235 # copy the uis so changes (e.g. --config or --verbose) don't
236 # persist between requests
236 # persist between requests
237 copiedui = self.ui.copy()
237 copiedui = self.ui.copy()
238 uis = [copiedui]
238 uis = [copiedui]
239 if self.repo:
239 if self.repo:
240 self.repo.baseui = copiedui
240 self.repo.baseui = copiedui
241 # clone ui without using ui.copy because this is protected
241 # clone ui without using ui.copy because this is protected
242 repoui = self.repoui.__class__(self.repoui)
242 repoui = self.repoui.__class__(self.repoui)
243 repoui.copy = copiedui.copy # redo copy protection
243 repoui.copy = copiedui.copy # redo copy protection
244 uis.append(repoui)
244 uis.append(repoui)
245 self.repo.ui = self.repo.dirstate._ui = repoui
245 self.repo.ui = self.repo.dirstate._ui = repoui
246 self.repo.invalidateall()
246 self.repo.invalidateall()
247
247
248 for ui in uis:
248 for ui in uis:
249 ui.resetstate()
249 ui.resetstate()
250 # any kind of interaction must use server channels, but chg may
250 # any kind of interaction must use server channels, but chg may
251 # replace channels by fully functional tty files. so nontty is
251 # replace channels by fully functional tty files. so nontty is
252 # enforced only if cin is a channel.
252 # enforced only if cin is a channel.
253 if not util.safehasattr(self.cin, 'fileno'):
253 if not util.safehasattr(self.cin, 'fileno'):
254 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
254 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
255
255
256 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
256 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
257 self.cout, self.cerr)
257 self.cout, self.cerr)
258
258
259 try:
259 try:
260 ret = dispatch.dispatch(req) & 255
260 ret = dispatch.dispatch(req) & 255
261 self.cresult.write(struct.pack('>i', int(ret)))
261 self.cresult.write(struct.pack('>i', int(ret)))
262 finally:
262 finally:
263 # restore old cwd
263 # restore old cwd
264 if '--cwd' in args:
264 if '--cwd' in args:
265 os.chdir(self.cwd)
265 os.chdir(self.cwd)
266
266
267 def getencoding(self):
267 def getencoding(self):
268 """ writes the current encoding to the result channel """
268 """ writes the current encoding to the result channel """
269 self.cresult.write(encoding.encoding)
269 self.cresult.write(encoding.encoding)
270
270
271 def serveone(self):
271 def serveone(self):
272 cmd = self.client.readline()[:-1]
272 cmd = self.client.readline()[:-1]
273 if cmd:
273 if cmd:
274 handler = self.capabilities.get(cmd)
274 handler = self.capabilities.get(cmd)
275 if handler:
275 if handler:
276 handler(self)
276 handler(self)
277 else:
277 else:
278 # clients are expected to check what commands are supported by
278 # clients are expected to check what commands are supported by
279 # looking at the servers capabilities
279 # looking at the servers capabilities
280 raise error.Abort(_('unknown command %s') % cmd)
280 raise error.Abort(_('unknown command %s') % cmd)
281
281
282 return cmd != ''
282 return cmd != ''
283
283
284 capabilities = {'runcommand': runcommand,
284 capabilities = {'runcommand': runcommand,
285 'getencoding': getencoding}
285 'getencoding': getencoding}
286
286
287 def serve(self):
287 def serve(self):
288 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
288 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
289 hellomsg += '\n'
289 hellomsg += '\n'
290 hellomsg += 'encoding: ' + encoding.encoding
290 hellomsg += 'encoding: ' + encoding.encoding
291 hellomsg += '\n'
291 hellomsg += '\n'
292 hellomsg += 'pid: %d' % procutil.getpid()
292 hellomsg += 'pid: %d' % procutil.getpid()
293 if util.safehasattr(os, 'getpgid'):
293 if util.safehasattr(os, 'getpgid'):
294 hellomsg += '\n'
294 hellomsg += '\n'
295 hellomsg += 'pgid: %d' % os.getpgid(0)
295 hellomsg += 'pgid: %d' % os.getpgid(0)
296
296
297 # write the hello msg in -one- chunk
297 # write the hello msg in -one- chunk
298 self.cout.write(hellomsg)
298 self.cout.write(hellomsg)
299
299
300 try:
300 try:
301 while self.serveone():
301 while self.serveone():
302 pass
302 pass
303 except EOFError:
303 except EOFError:
304 # we'll get here if the client disconnected while we were reading
304 # we'll get here if the client disconnected while we were reading
305 # its request
305 # its request
306 return 1
306 return 1
307
307
308 return 0
308 return 0
309
309
310 class pipeservice(object):
310 class pipeservice(object):
311 def __init__(self, ui, repo, opts):
311 def __init__(self, ui, repo, opts):
312 self.ui = ui
312 self.ui = ui
313 self.repo = repo
313 self.repo = repo
314
314
315 def init(self):
315 def init(self):
316 pass
316 pass
317
317
318 def run(self):
318 def run(self):
319 ui = self.ui
319 ui = self.ui
320 # redirect stdio to null device so that broken extensions or in-process
320 # redirect stdio to null device so that broken extensions or in-process
321 # hooks will never cause corruption of channel protocol.
321 # hooks will never cause corruption of channel protocol.
322 with procutil.protectedstdio(ui.fin, ui.fout) as (fin, fout):
322 with procutil.protectedstdio(ui.fin, ui.fout) as (fin, fout):
323 try:
323 try:
324 sv = server(ui, self.repo, fin, fout)
324 sv = server(ui, self.repo, fin, fout)
325 return sv.serve()
325 return sv.serve()
326 finally:
326 finally:
327 sv.cleanup()
327 sv.cleanup()
328
328
329 def _initworkerprocess():
329 def _initworkerprocess():
330 # use a different process group from the master process, in order to:
330 # use a different process group from the master process, in order to:
331 # 1. make the current process group no longer "orphaned" (because the
331 # 1. make the current process group no longer "orphaned" (because the
332 # parent of this process is in a different process group while
332 # parent of this process is in a different process group while
333 # remains in a same session)
333 # remains in a same session)
334 # according to POSIX 2.2.2.52, orphaned process group will ignore
334 # according to POSIX 2.2.2.52, orphaned process group will ignore
335 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
335 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
336 # cause trouble for things like ncurses.
336 # cause trouble for things like ncurses.
337 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
337 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
338 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
338 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
339 # processes like ssh will be killed properly, without affecting
339 # processes like ssh will be killed properly, without affecting
340 # unrelated processes.
340 # unrelated processes.
341 os.setpgid(0, 0)
341 os.setpgid(0, 0)
342 # change random state otherwise forked request handlers would have a
342 # change random state otherwise forked request handlers would have a
343 # same state inherited from parent.
343 # same state inherited from parent.
344 random.seed()
344 random.seed()
345
345
346 def _serverequest(ui, repo, conn, createcmdserver):
346 def _serverequest(ui, repo, conn, createcmdserver):
347 fin = conn.makefile(r'rb')
347 fin = conn.makefile(r'rb')
348 fout = conn.makefile(r'wb')
348 fout = conn.makefile(r'wb')
349 sv = None
349 sv = None
350 try:
350 try:
351 sv = createcmdserver(repo, conn, fin, fout)
351 sv = createcmdserver(repo, conn, fin, fout)
352 try:
352 try:
353 sv.serve()
353 sv.serve()
354 # handle exceptions that may be raised by command server. most of
354 # handle exceptions that may be raised by command server. most of
355 # known exceptions are caught by dispatch.
355 # known exceptions are caught by dispatch.
356 except error.Abort as inst:
356 except error.Abort as inst:
357 ui.error(_('abort: %s\n') % inst)
357 ui.error(_('abort: %s\n') % inst)
358 except IOError as inst:
358 except IOError as inst:
359 if inst.errno != errno.EPIPE:
359 if inst.errno != errno.EPIPE:
360 raise
360 raise
361 except KeyboardInterrupt:
361 except KeyboardInterrupt:
362 pass
362 pass
363 finally:
363 finally:
364 sv.cleanup()
364 sv.cleanup()
365 except: # re-raises
365 except: # re-raises
366 # also write traceback to error channel. otherwise client cannot
366 # also write traceback to error channel. otherwise client cannot
367 # see it because it is written to server's stderr by default.
367 # see it because it is written to server's stderr by default.
368 if sv:
368 if sv:
369 cerr = sv.cerr
369 cerr = sv.cerr
370 else:
370 else:
371 cerr = channeledoutput(fout, 'e')
371 cerr = channeledoutput(fout, 'e')
372 cerr.write(encoding.strtolocal(traceback.format_exc()))
372 cerr.write(encoding.strtolocal(traceback.format_exc()))
373 raise
373 raise
374 finally:
374 finally:
375 fin.close()
375 fin.close()
376 try:
376 try:
377 fout.close() # implicit flush() may cause another EPIPE
377 fout.close() # implicit flush() may cause another EPIPE
378 except IOError as inst:
378 except IOError as inst:
379 if inst.errno != errno.EPIPE:
379 if inst.errno != errno.EPIPE:
380 raise
380 raise
381
381
382 class unixservicehandler(object):
382 class unixservicehandler(object):
383 """Set of pluggable operations for unix-mode services
383 """Set of pluggable operations for unix-mode services
384
384
385 Almost all methods except for createcmdserver() are called in the main
385 Almost all methods except for createcmdserver() are called in the main
386 process. You can't pass mutable resource back from createcmdserver().
386 process. You can't pass mutable resource back from createcmdserver().
387 """
387 """
388
388
389 pollinterval = None
389 pollinterval = None
390
390
391 def __init__(self, ui):
391 def __init__(self, ui):
392 self.ui = ui
392 self.ui = ui
393
393
394 def bindsocket(self, sock, address):
394 def bindsocket(self, sock, address):
395 util.bindunixsocket(sock, address)
395 util.bindunixsocket(sock, address)
396 sock.listen(socket.SOMAXCONN)
396 sock.listen(socket.SOMAXCONN)
397 self.ui.status(_('listening at %s\n') % address)
397 self.ui.status(_('listening at %s\n') % address)
398 self.ui.flush() # avoid buffering of status message
398 self.ui.flush() # avoid buffering of status message
399
399
400 def unlinksocket(self, address):
400 def unlinksocket(self, address):
401 os.unlink(address)
401 os.unlink(address)
402
402
403 def shouldexit(self):
403 def shouldexit(self):
404 """True if server should shut down; checked per pollinterval"""
404 """True if server should shut down; checked per pollinterval"""
405 return False
405 return False
406
406
407 def newconnection(self):
407 def newconnection(self):
408 """Called when main process notices new connection"""
408 """Called when main process notices new connection"""
409
409
410 def createcmdserver(self, repo, conn, fin, fout):
410 def createcmdserver(self, repo, conn, fin, fout):
411 """Create new command server instance; called in the process that
411 """Create new command server instance; called in the process that
412 serves for the current connection"""
412 serves for the current connection"""
413 return server(self.ui, repo, fin, fout)
413 return server(self.ui, repo, fin, fout)
414
414
415 class unixforkingservice(object):
415 class unixforkingservice(object):
416 """
416 """
417 Listens on unix domain socket and forks server per connection
417 Listens on unix domain socket and forks server per connection
418 """
418 """
419
419
420 def __init__(self, ui, repo, opts, handler=None):
420 def __init__(self, ui, repo, opts, handler=None):
421 self.ui = ui
421 self.ui = ui
422 self.repo = repo
422 self.repo = repo
423 self.address = opts['address']
423 self.address = opts['address']
424 if not util.safehasattr(socket, 'AF_UNIX'):
424 if not util.safehasattr(socket, 'AF_UNIX'):
425 raise error.Abort(_('unsupported platform'))
425 raise error.Abort(_('unsupported platform'))
426 if not self.address:
426 if not self.address:
427 raise error.Abort(_('no socket path specified with --address'))
427 raise error.Abort(_('no socket path specified with --address'))
428 self._servicehandler = handler or unixservicehandler(ui)
428 self._servicehandler = handler or unixservicehandler(ui)
429 self._sock = None
429 self._sock = None
430 self._oldsigchldhandler = None
430 self._oldsigchldhandler = None
431 self._workerpids = set() # updated by signal handler; do not iterate
431 self._workerpids = set() # updated by signal handler; do not iterate
432 self._socketunlinked = None
432 self._socketunlinked = None
433
433
434 def init(self):
434 def init(self):
435 self._sock = socket.socket(socket.AF_UNIX)
435 self._sock = socket.socket(socket.AF_UNIX)
436 self._servicehandler.bindsocket(self._sock, self.address)
436 self._servicehandler.bindsocket(self._sock, self.address)
437 if util.safehasattr(procutil, 'unblocksignal'):
437 if util.safehasattr(procutil, 'unblocksignal'):
438 procutil.unblocksignal(signal.SIGCHLD)
438 procutil.unblocksignal(signal.SIGCHLD)
439 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
439 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
440 self._oldsigchldhandler = o
440 self._oldsigchldhandler = o
441 self._socketunlinked = False
441 self._socketunlinked = False
442
442
443 def _unlinksocket(self):
443 def _unlinksocket(self):
444 if not self._socketunlinked:
444 if not self._socketunlinked:
445 self._servicehandler.unlinksocket(self.address)
445 self._servicehandler.unlinksocket(self.address)
446 self._socketunlinked = True
446 self._socketunlinked = True
447
447
448 def _cleanup(self):
448 def _cleanup(self):
449 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
449 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
450 self._sock.close()
450 self._sock.close()
451 self._unlinksocket()
451 self._unlinksocket()
452 # don't kill child processes as they have active clients, just wait
452 # don't kill child processes as they have active clients, just wait
453 self._reapworkers(0)
453 self._reapworkers(0)
454
454
455 def run(self):
455 def run(self):
456 try:
456 try:
457 self._mainloop()
457 self._mainloop()
458 finally:
458 finally:
459 self._cleanup()
459 self._cleanup()
460
460
461 def _mainloop(self):
461 def _mainloop(self):
462 exiting = False
462 exiting = False
463 h = self._servicehandler
463 h = self._servicehandler
464 selector = selectors.DefaultSelector()
464 selector = selectors.DefaultSelector()
465 selector.register(self._sock, selectors.EVENT_READ)
465 selector.register(self._sock, selectors.EVENT_READ)
466 while True:
466 while True:
467 if not exiting and h.shouldexit():
467 if not exiting and h.shouldexit():
468 # clients can no longer connect() to the domain socket, so
468 # clients can no longer connect() to the domain socket, so
469 # we stop queuing new requests.
469 # we stop queuing new requests.
470 # for requests that are queued (connect()-ed, but haven't been
470 # for requests that are queued (connect()-ed, but haven't been
471 # accept()-ed), handle them before exit. otherwise, clients
471 # accept()-ed), handle them before exit. otherwise, clients
472 # waiting for recv() will receive ECONNRESET.
472 # waiting for recv() will receive ECONNRESET.
473 self._unlinksocket()
473 self._unlinksocket()
474 exiting = True
474 exiting = True
475 ready = selector.select(timeout=h.pollinterval)
475 try:
476 ready = selector.select(timeout=h.pollinterval)
477 except OSError as inst:
478 # selectors2 raises ETIMEDOUT if timeout exceeded while
479 # handling signal interrupt. That's probably wrong, but
480 # we can easily get around it.
481 if inst.errno != errno.ETIMEDOUT:
482 raise
483 ready = []
476 if not ready:
484 if not ready:
477 # only exit if we completed all queued requests
485 # only exit if we completed all queued requests
478 if exiting:
486 if exiting:
479 break
487 break
480 continue
488 continue
481 try:
489 try:
482 conn, _addr = self._sock.accept()
490 conn, _addr = self._sock.accept()
483 except socket.error as inst:
491 except socket.error as inst:
484 if inst.args[0] == errno.EINTR:
492 if inst.args[0] == errno.EINTR:
485 continue
493 continue
486 raise
494 raise
487
495
488 pid = os.fork()
496 pid = os.fork()
489 if pid:
497 if pid:
490 try:
498 try:
491 self.ui.debug('forked worker process (pid=%d)\n' % pid)
499 self.ui.debug('forked worker process (pid=%d)\n' % pid)
492 self._workerpids.add(pid)
500 self._workerpids.add(pid)
493 h.newconnection()
501 h.newconnection()
494 finally:
502 finally:
495 conn.close() # release handle in parent process
503 conn.close() # release handle in parent process
496 else:
504 else:
497 try:
505 try:
498 selector.close()
506 selector.close()
499 self._sock.close()
507 self._sock.close()
500 self._runworker(conn)
508 self._runworker(conn)
501 conn.close()
509 conn.close()
502 os._exit(0)
510 os._exit(0)
503 except: # never return, hence no re-raises
511 except: # never return, hence no re-raises
504 try:
512 try:
505 self.ui.traceback(force=True)
513 self.ui.traceback(force=True)
506 finally:
514 finally:
507 os._exit(255)
515 os._exit(255)
508 selector.close()
516 selector.close()
509
517
510 def _sigchldhandler(self, signal, frame):
518 def _sigchldhandler(self, signal, frame):
511 self._reapworkers(os.WNOHANG)
519 self._reapworkers(os.WNOHANG)
512
520
513 def _reapworkers(self, options):
521 def _reapworkers(self, options):
514 while self._workerpids:
522 while self._workerpids:
515 try:
523 try:
516 pid, _status = os.waitpid(-1, options)
524 pid, _status = os.waitpid(-1, options)
517 except OSError as inst:
525 except OSError as inst:
518 if inst.errno == errno.EINTR:
526 if inst.errno == errno.EINTR:
519 continue
527 continue
520 if inst.errno != errno.ECHILD:
528 if inst.errno != errno.ECHILD:
521 raise
529 raise
522 # no child processes at all (reaped by other waitpid()?)
530 # no child processes at all (reaped by other waitpid()?)
523 self._workerpids.clear()
531 self._workerpids.clear()
524 return
532 return
525 if pid == 0:
533 if pid == 0:
526 # no waitable child processes
534 # no waitable child processes
527 return
535 return
528 self.ui.debug('worker process exited (pid=%d)\n' % pid)
536 self.ui.debug('worker process exited (pid=%d)\n' % pid)
529 self._workerpids.discard(pid)
537 self._workerpids.discard(pid)
530
538
531 def _runworker(self, conn):
539 def _runworker(self, conn):
532 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
540 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
533 _initworkerprocess()
541 _initworkerprocess()
534 h = self._servicehandler
542 h = self._servicehandler
535 try:
543 try:
536 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
544 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
537 finally:
545 finally:
538 gc.collect() # trigger __del__ since worker process uses os._exit
546 gc.collect() # trigger __del__ since worker process uses os._exit
General Comments 0
You need to be logged in to leave comments. Login now