##// END OF EJS Templates
commandserver: do not handle EINTR for selector.select...
Jun Wu -
r33543:3ef3bf70 default
parent child Browse files
Show More
@@ -1,552 +1,551 b''
1 # commandserver.py - communicate with Mercurial's API over a pipe
1 # commandserver.py - communicate with Mercurial's API over a pipe
2 #
2 #
3 # Copyright Matt Mackall <mpm@selenic.com>
3 # Copyright Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import errno
10 import errno
11 import gc
11 import gc
12 import os
12 import os
13 import random
13 import random
14 import select
15 import signal
14 import signal
16 import socket
15 import socket
17 import struct
16 import struct
18 import traceback
17 import traceback
19
18
20 from .i18n import _
19 from .i18n import _
21 from . import (
20 from . import (
22 encoding,
21 encoding,
23 error,
22 error,
24 pycompat,
23 pycompat,
25 selectors2,
24 selectors2,
26 util,
25 util,
27 )
26 )
28
27
29 logfile = None
28 logfile = None
30
29
31 def log(*args):
30 def log(*args):
32 if not logfile:
31 if not logfile:
33 return
32 return
34
33
35 for a in args:
34 for a in args:
36 logfile.write(str(a))
35 logfile.write(str(a))
37
36
38 logfile.flush()
37 logfile.flush()
39
38
40 class channeledoutput(object):
39 class channeledoutput(object):
41 """
40 """
42 Write data to out in the following format:
41 Write data to out in the following format:
43
42
44 data length (unsigned int),
43 data length (unsigned int),
45 data
44 data
46 """
45 """
47 def __init__(self, out, channel):
46 def __init__(self, out, channel):
48 self.out = out
47 self.out = out
49 self.channel = channel
48 self.channel = channel
50
49
51 @property
50 @property
52 def name(self):
51 def name(self):
53 return '<%c-channel>' % self.channel
52 return '<%c-channel>' % self.channel
54
53
55 def write(self, data):
54 def write(self, data):
56 if not data:
55 if not data:
57 return
56 return
58 # single write() to guarantee the same atomicity as the underlying file
57 # single write() to guarantee the same atomicity as the underlying file
59 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
58 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
60 self.out.flush()
59 self.out.flush()
61
60
62 def __getattr__(self, attr):
61 def __getattr__(self, attr):
63 if attr in ('isatty', 'fileno', 'tell', 'seek'):
62 if attr in ('isatty', 'fileno', 'tell', 'seek'):
64 raise AttributeError(attr)
63 raise AttributeError(attr)
65 return getattr(self.out, attr)
64 return getattr(self.out, attr)
66
65
67 class channeledinput(object):
66 class channeledinput(object):
68 """
67 """
69 Read data from in_.
68 Read data from in_.
70
69
71 Requests for input are written to out in the following format:
70 Requests for input are written to out in the following format:
72 channel identifier - 'I' for plain input, 'L' line based (1 byte)
71 channel identifier - 'I' for plain input, 'L' line based (1 byte)
73 how many bytes to send at most (unsigned int),
72 how many bytes to send at most (unsigned int),
74
73
75 The client replies with:
74 The client replies with:
76 data length (unsigned int), 0 meaning EOF
75 data length (unsigned int), 0 meaning EOF
77 data
76 data
78 """
77 """
79
78
80 maxchunksize = 4 * 1024
79 maxchunksize = 4 * 1024
81
80
82 def __init__(self, in_, out, channel):
81 def __init__(self, in_, out, channel):
83 self.in_ = in_
82 self.in_ = in_
84 self.out = out
83 self.out = out
85 self.channel = channel
84 self.channel = channel
86
85
87 @property
86 @property
88 def name(self):
87 def name(self):
89 return '<%c-channel>' % self.channel
88 return '<%c-channel>' % self.channel
90
89
91 def read(self, size=-1):
90 def read(self, size=-1):
92 if size < 0:
91 if size < 0:
93 # if we need to consume all the clients input, ask for 4k chunks
92 # if we need to consume all the clients input, ask for 4k chunks
94 # so the pipe doesn't fill up risking a deadlock
93 # so the pipe doesn't fill up risking a deadlock
95 size = self.maxchunksize
94 size = self.maxchunksize
96 s = self._read(size, self.channel)
95 s = self._read(size, self.channel)
97 buf = s
96 buf = s
98 while s:
97 while s:
99 s = self._read(size, self.channel)
98 s = self._read(size, self.channel)
100 buf += s
99 buf += s
101
100
102 return buf
101 return buf
103 else:
102 else:
104 return self._read(size, self.channel)
103 return self._read(size, self.channel)
105
104
106 def _read(self, size, channel):
105 def _read(self, size, channel):
107 if not size:
106 if not size:
108 return ''
107 return ''
109 assert size > 0
108 assert size > 0
110
109
111 # tell the client we need at most size bytes
110 # tell the client we need at most size bytes
112 self.out.write(struct.pack('>cI', channel, size))
111 self.out.write(struct.pack('>cI', channel, size))
113 self.out.flush()
112 self.out.flush()
114
113
115 length = self.in_.read(4)
114 length = self.in_.read(4)
116 length = struct.unpack('>I', length)[0]
115 length = struct.unpack('>I', length)[0]
117 if not length:
116 if not length:
118 return ''
117 return ''
119 else:
118 else:
120 return self.in_.read(length)
119 return self.in_.read(length)
121
120
122 def readline(self, size=-1):
121 def readline(self, size=-1):
123 if size < 0:
122 if size < 0:
124 size = self.maxchunksize
123 size = self.maxchunksize
125 s = self._read(size, 'L')
124 s = self._read(size, 'L')
126 buf = s
125 buf = s
127 # keep asking for more until there's either no more or
126 # keep asking for more until there's either no more or
128 # we got a full line
127 # we got a full line
129 while s and s[-1] != '\n':
128 while s and s[-1] != '\n':
130 s = self._read(size, 'L')
129 s = self._read(size, 'L')
131 buf += s
130 buf += s
132
131
133 return buf
132 return buf
134 else:
133 else:
135 return self._read(size, 'L')
134 return self._read(size, 'L')
136
135
137 def __iter__(self):
136 def __iter__(self):
138 return self
137 return self
139
138
140 def next(self):
139 def next(self):
141 l = self.readline()
140 l = self.readline()
142 if not l:
141 if not l:
143 raise StopIteration
142 raise StopIteration
144 return l
143 return l
145
144
146 def __getattr__(self, attr):
145 def __getattr__(self, attr):
147 if attr in ('isatty', 'fileno', 'tell', 'seek'):
146 if attr in ('isatty', 'fileno', 'tell', 'seek'):
148 raise AttributeError(attr)
147 raise AttributeError(attr)
149 return getattr(self.in_, attr)
148 return getattr(self.in_, attr)
150
149
151 class server(object):
150 class server(object):
152 """
151 """
153 Listens for commands on fin, runs them and writes the output on a channel
152 Listens for commands on fin, runs them and writes the output on a channel
154 based stream to fout.
153 based stream to fout.
155 """
154 """
156 def __init__(self, ui, repo, fin, fout):
155 def __init__(self, ui, repo, fin, fout):
157 self.cwd = pycompat.getcwd()
156 self.cwd = pycompat.getcwd()
158
157
159 # developer config: cmdserver.log
158 # developer config: cmdserver.log
160 logpath = ui.config("cmdserver", "log")
159 logpath = ui.config("cmdserver", "log")
161 if logpath:
160 if logpath:
162 global logfile
161 global logfile
163 if logpath == '-':
162 if logpath == '-':
164 # write log on a special 'd' (debug) channel
163 # write log on a special 'd' (debug) channel
165 logfile = channeledoutput(fout, 'd')
164 logfile = channeledoutput(fout, 'd')
166 else:
165 else:
167 logfile = open(logpath, 'a')
166 logfile = open(logpath, 'a')
168
167
169 if repo:
168 if repo:
170 # the ui here is really the repo ui so take its baseui so we don't
169 # the ui here is really the repo ui so take its baseui so we don't
171 # end up with its local configuration
170 # end up with its local configuration
172 self.ui = repo.baseui
171 self.ui = repo.baseui
173 self.repo = repo
172 self.repo = repo
174 self.repoui = repo.ui
173 self.repoui = repo.ui
175 else:
174 else:
176 self.ui = ui
175 self.ui = ui
177 self.repo = self.repoui = None
176 self.repo = self.repoui = None
178
177
179 self.cerr = channeledoutput(fout, 'e')
178 self.cerr = channeledoutput(fout, 'e')
180 self.cout = channeledoutput(fout, 'o')
179 self.cout = channeledoutput(fout, 'o')
181 self.cin = channeledinput(fin, fout, 'I')
180 self.cin = channeledinput(fin, fout, 'I')
182 self.cresult = channeledoutput(fout, 'r')
181 self.cresult = channeledoutput(fout, 'r')
183
182
184 self.client = fin
183 self.client = fin
185
184
186 def cleanup(self):
185 def cleanup(self):
187 """release and restore resources taken during server session"""
186 """release and restore resources taken during server session"""
188 pass
187 pass
189
188
190 def _read(self, size):
189 def _read(self, size):
191 if not size:
190 if not size:
192 return ''
191 return ''
193
192
194 data = self.client.read(size)
193 data = self.client.read(size)
195
194
196 # is the other end closed?
195 # is the other end closed?
197 if not data:
196 if not data:
198 raise EOFError
197 raise EOFError
199
198
200 return data
199 return data
201
200
202 def _readstr(self):
201 def _readstr(self):
203 """read a string from the channel
202 """read a string from the channel
204
203
205 format:
204 format:
206 data length (uint32), data
205 data length (uint32), data
207 """
206 """
208 length = struct.unpack('>I', self._read(4))[0]
207 length = struct.unpack('>I', self._read(4))[0]
209 if not length:
208 if not length:
210 return ''
209 return ''
211 return self._read(length)
210 return self._read(length)
212
211
213 def _readlist(self):
212 def _readlist(self):
214 """read a list of NULL separated strings from the channel"""
213 """read a list of NULL separated strings from the channel"""
215 s = self._readstr()
214 s = self._readstr()
216 if s:
215 if s:
217 return s.split('\0')
216 return s.split('\0')
218 else:
217 else:
219 return []
218 return []
220
219
221 def runcommand(self):
220 def runcommand(self):
222 """ reads a list of \0 terminated arguments, executes
221 """ reads a list of \0 terminated arguments, executes
223 and writes the return code to the result channel """
222 and writes the return code to the result channel """
224 from . import dispatch # avoid cycle
223 from . import dispatch # avoid cycle
225
224
226 args = self._readlist()
225 args = self._readlist()
227
226
228 # copy the uis so changes (e.g. --config or --verbose) don't
227 # copy the uis so changes (e.g. --config or --verbose) don't
229 # persist between requests
228 # persist between requests
230 copiedui = self.ui.copy()
229 copiedui = self.ui.copy()
231 uis = [copiedui]
230 uis = [copiedui]
232 if self.repo:
231 if self.repo:
233 self.repo.baseui = copiedui
232 self.repo.baseui = copiedui
234 # clone ui without using ui.copy because this is protected
233 # clone ui without using ui.copy because this is protected
235 repoui = self.repoui.__class__(self.repoui)
234 repoui = self.repoui.__class__(self.repoui)
236 repoui.copy = copiedui.copy # redo copy protection
235 repoui.copy = copiedui.copy # redo copy protection
237 uis.append(repoui)
236 uis.append(repoui)
238 self.repo.ui = self.repo.dirstate._ui = repoui
237 self.repo.ui = self.repo.dirstate._ui = repoui
239 self.repo.invalidateall()
238 self.repo.invalidateall()
240
239
241 for ui in uis:
240 for ui in uis:
242 ui.resetstate()
241 ui.resetstate()
243 # any kind of interaction must use server channels, but chg may
242 # any kind of interaction must use server channels, but chg may
244 # replace channels by fully functional tty files. so nontty is
243 # replace channels by fully functional tty files. so nontty is
245 # enforced only if cin is a channel.
244 # enforced only if cin is a channel.
246 if not util.safehasattr(self.cin, 'fileno'):
245 if not util.safehasattr(self.cin, 'fileno'):
247 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
246 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
248
247
249 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
248 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
250 self.cout, self.cerr)
249 self.cout, self.cerr)
251
250
252 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
251 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
253
252
254 # restore old cwd
253 # restore old cwd
255 if '--cwd' in args:
254 if '--cwd' in args:
256 os.chdir(self.cwd)
255 os.chdir(self.cwd)
257
256
258 self.cresult.write(struct.pack('>i', int(ret)))
257 self.cresult.write(struct.pack('>i', int(ret)))
259
258
260 def getencoding(self):
259 def getencoding(self):
261 """ writes the current encoding to the result channel """
260 """ writes the current encoding to the result channel """
262 self.cresult.write(encoding.encoding)
261 self.cresult.write(encoding.encoding)
263
262
264 def serveone(self):
263 def serveone(self):
265 cmd = self.client.readline()[:-1]
264 cmd = self.client.readline()[:-1]
266 if cmd:
265 if cmd:
267 handler = self.capabilities.get(cmd)
266 handler = self.capabilities.get(cmd)
268 if handler:
267 if handler:
269 handler(self)
268 handler(self)
270 else:
269 else:
271 # clients are expected to check what commands are supported by
270 # clients are expected to check what commands are supported by
272 # looking at the servers capabilities
271 # looking at the servers capabilities
273 raise error.Abort(_('unknown command %s') % cmd)
272 raise error.Abort(_('unknown command %s') % cmd)
274
273
275 return cmd != ''
274 return cmd != ''
276
275
277 capabilities = {'runcommand' : runcommand,
276 capabilities = {'runcommand' : runcommand,
278 'getencoding' : getencoding}
277 'getencoding' : getencoding}
279
278
280 def serve(self):
279 def serve(self):
281 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
280 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
282 hellomsg += '\n'
281 hellomsg += '\n'
283 hellomsg += 'encoding: ' + encoding.encoding
282 hellomsg += 'encoding: ' + encoding.encoding
284 hellomsg += '\n'
283 hellomsg += '\n'
285 hellomsg += 'pid: %d' % util.getpid()
284 hellomsg += 'pid: %d' % util.getpid()
286 if util.safehasattr(os, 'getpgid'):
285 if util.safehasattr(os, 'getpgid'):
287 hellomsg += '\n'
286 hellomsg += '\n'
288 hellomsg += 'pgid: %d' % os.getpgid(0)
287 hellomsg += 'pgid: %d' % os.getpgid(0)
289
288
290 # write the hello msg in -one- chunk
289 # write the hello msg in -one- chunk
291 self.cout.write(hellomsg)
290 self.cout.write(hellomsg)
292
291
293 try:
292 try:
294 while self.serveone():
293 while self.serveone():
295 pass
294 pass
296 except EOFError:
295 except EOFError:
297 # we'll get here if the client disconnected while we were reading
296 # we'll get here if the client disconnected while we were reading
298 # its request
297 # its request
299 return 1
298 return 1
300
299
301 return 0
300 return 0
302
301
303 def _protectio(ui):
302 def _protectio(ui):
304 """ duplicates streams and redirect original to null if ui uses stdio """
303 """ duplicates streams and redirect original to null if ui uses stdio """
305 ui.flush()
304 ui.flush()
306 newfiles = []
305 newfiles = []
307 nullfd = os.open(os.devnull, os.O_RDWR)
306 nullfd = os.open(os.devnull, os.O_RDWR)
308 for f, sysf, mode in [(ui.fin, util.stdin, pycompat.sysstr('rb')),
307 for f, sysf, mode in [(ui.fin, util.stdin, pycompat.sysstr('rb')),
309 (ui.fout, util.stdout, pycompat.sysstr('wb'))]:
308 (ui.fout, util.stdout, pycompat.sysstr('wb'))]:
310 if f is sysf:
309 if f is sysf:
311 newfd = os.dup(f.fileno())
310 newfd = os.dup(f.fileno())
312 os.dup2(nullfd, f.fileno())
311 os.dup2(nullfd, f.fileno())
313 f = os.fdopen(newfd, mode)
312 f = os.fdopen(newfd, mode)
314 newfiles.append(f)
313 newfiles.append(f)
315 os.close(nullfd)
314 os.close(nullfd)
316 return tuple(newfiles)
315 return tuple(newfiles)
317
316
318 def _restoreio(ui, fin, fout):
317 def _restoreio(ui, fin, fout):
319 """ restores streams from duplicated ones """
318 """ restores streams from duplicated ones """
320 ui.flush()
319 ui.flush()
321 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
320 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
322 if f is not uif:
321 if f is not uif:
323 os.dup2(f.fileno(), uif.fileno())
322 os.dup2(f.fileno(), uif.fileno())
324 f.close()
323 f.close()
325
324
326 class pipeservice(object):
325 class pipeservice(object):
327 def __init__(self, ui, repo, opts):
326 def __init__(self, ui, repo, opts):
328 self.ui = ui
327 self.ui = ui
329 self.repo = repo
328 self.repo = repo
330
329
331 def init(self):
330 def init(self):
332 pass
331 pass
333
332
334 def run(self):
333 def run(self):
335 ui = self.ui
334 ui = self.ui
336 # redirect stdio to null device so that broken extensions or in-process
335 # redirect stdio to null device so that broken extensions or in-process
337 # hooks will never cause corruption of channel protocol.
336 # hooks will never cause corruption of channel protocol.
338 fin, fout = _protectio(ui)
337 fin, fout = _protectio(ui)
339 try:
338 try:
340 sv = server(ui, self.repo, fin, fout)
339 sv = server(ui, self.repo, fin, fout)
341 return sv.serve()
340 return sv.serve()
342 finally:
341 finally:
343 sv.cleanup()
342 sv.cleanup()
344 _restoreio(ui, fin, fout)
343 _restoreio(ui, fin, fout)
345
344
346 def _initworkerprocess():
345 def _initworkerprocess():
347 # use a different process group from the master process, in order to:
346 # use a different process group from the master process, in order to:
348 # 1. make the current process group no longer "orphaned" (because the
347 # 1. make the current process group no longer "orphaned" (because the
349 # parent of this process is in a different process group while
348 # parent of this process is in a different process group while
350 # remains in a same session)
349 # remains in a same session)
351 # according to POSIX 2.2.2.52, orphaned process group will ignore
350 # according to POSIX 2.2.2.52, orphaned process group will ignore
352 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
351 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
353 # cause trouble for things like ncurses.
352 # cause trouble for things like ncurses.
354 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
353 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
355 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
354 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
356 # processes like ssh will be killed properly, without affecting
355 # processes like ssh will be killed properly, without affecting
357 # unrelated processes.
356 # unrelated processes.
358 os.setpgid(0, 0)
357 os.setpgid(0, 0)
359 # change random state otherwise forked request handlers would have a
358 # change random state otherwise forked request handlers would have a
360 # same state inherited from parent.
359 # same state inherited from parent.
361 random.seed()
360 random.seed()
362
361
363 def _serverequest(ui, repo, conn, createcmdserver):
362 def _serverequest(ui, repo, conn, createcmdserver):
364 fin = conn.makefile('rb')
363 fin = conn.makefile('rb')
365 fout = conn.makefile('wb')
364 fout = conn.makefile('wb')
366 sv = None
365 sv = None
367 try:
366 try:
368 sv = createcmdserver(repo, conn, fin, fout)
367 sv = createcmdserver(repo, conn, fin, fout)
369 try:
368 try:
370 sv.serve()
369 sv.serve()
371 # handle exceptions that may be raised by command server. most of
370 # handle exceptions that may be raised by command server. most of
372 # known exceptions are caught by dispatch.
371 # known exceptions are caught by dispatch.
373 except error.Abort as inst:
372 except error.Abort as inst:
374 ui.warn(_('abort: %s\n') % inst)
373 ui.warn(_('abort: %s\n') % inst)
375 except IOError as inst:
374 except IOError as inst:
376 if inst.errno != errno.EPIPE:
375 if inst.errno != errno.EPIPE:
377 raise
376 raise
378 except KeyboardInterrupt:
377 except KeyboardInterrupt:
379 pass
378 pass
380 finally:
379 finally:
381 sv.cleanup()
380 sv.cleanup()
382 except: # re-raises
381 except: # re-raises
383 # also write traceback to error channel. otherwise client cannot
382 # also write traceback to error channel. otherwise client cannot
384 # see it because it is written to server's stderr by default.
383 # see it because it is written to server's stderr by default.
385 if sv:
384 if sv:
386 cerr = sv.cerr
385 cerr = sv.cerr
387 else:
386 else:
388 cerr = channeledoutput(fout, 'e')
387 cerr = channeledoutput(fout, 'e')
389 traceback.print_exc(file=cerr)
388 traceback.print_exc(file=cerr)
390 raise
389 raise
391 finally:
390 finally:
392 fin.close()
391 fin.close()
393 try:
392 try:
394 fout.close() # implicit flush() may cause another EPIPE
393 fout.close() # implicit flush() may cause another EPIPE
395 except IOError as inst:
394 except IOError as inst:
396 if inst.errno != errno.EPIPE:
395 if inst.errno != errno.EPIPE:
397 raise
396 raise
398
397
399 class unixservicehandler(object):
398 class unixservicehandler(object):
400 """Set of pluggable operations for unix-mode services
399 """Set of pluggable operations for unix-mode services
401
400
402 Almost all methods except for createcmdserver() are called in the main
401 Almost all methods except for createcmdserver() are called in the main
403 process. You can't pass mutable resource back from createcmdserver().
402 process. You can't pass mutable resource back from createcmdserver().
404 """
403 """
405
404
406 pollinterval = None
405 pollinterval = None
407
406
408 def __init__(self, ui):
407 def __init__(self, ui):
409 self.ui = ui
408 self.ui = ui
410
409
411 def bindsocket(self, sock, address):
410 def bindsocket(self, sock, address):
412 util.bindunixsocket(sock, address)
411 util.bindunixsocket(sock, address)
413 sock.listen(socket.SOMAXCONN)
412 sock.listen(socket.SOMAXCONN)
414 self.ui.status(_('listening at %s\n') % address)
413 self.ui.status(_('listening at %s\n') % address)
415 self.ui.flush() # avoid buffering of status message
414 self.ui.flush() # avoid buffering of status message
416
415
417 def unlinksocket(self, address):
416 def unlinksocket(self, address):
418 os.unlink(address)
417 os.unlink(address)
419
418
420 def shouldexit(self):
419 def shouldexit(self):
421 """True if server should shut down; checked per pollinterval"""
420 """True if server should shut down; checked per pollinterval"""
422 return False
421 return False
423
422
424 def newconnection(self):
423 def newconnection(self):
425 """Called when main process notices new connection"""
424 """Called when main process notices new connection"""
426 pass
425 pass
427
426
428 def createcmdserver(self, repo, conn, fin, fout):
427 def createcmdserver(self, repo, conn, fin, fout):
429 """Create new command server instance; called in the process that
428 """Create new command server instance; called in the process that
430 serves for the current connection"""
429 serves for the current connection"""
431 return server(self.ui, repo, fin, fout)
430 return server(self.ui, repo, fin, fout)
432
431
433 class unixforkingservice(object):
432 class unixforkingservice(object):
434 """
433 """
435 Listens on unix domain socket and forks server per connection
434 Listens on unix domain socket and forks server per connection
436 """
435 """
437
436
438 def __init__(self, ui, repo, opts, handler=None):
437 def __init__(self, ui, repo, opts, handler=None):
439 self.ui = ui
438 self.ui = ui
440 self.repo = repo
439 self.repo = repo
441 self.address = opts['address']
440 self.address = opts['address']
442 if not util.safehasattr(socket, 'AF_UNIX'):
441 if not util.safehasattr(socket, 'AF_UNIX'):
443 raise error.Abort(_('unsupported platform'))
442 raise error.Abort(_('unsupported platform'))
444 if not self.address:
443 if not self.address:
445 raise error.Abort(_('no socket path specified with --address'))
444 raise error.Abort(_('no socket path specified with --address'))
446 self._servicehandler = handler or unixservicehandler(ui)
445 self._servicehandler = handler or unixservicehandler(ui)
447 self._sock = None
446 self._sock = None
448 self._oldsigchldhandler = None
447 self._oldsigchldhandler = None
449 self._workerpids = set() # updated by signal handler; do not iterate
448 self._workerpids = set() # updated by signal handler; do not iterate
450 self._socketunlinked = None
449 self._socketunlinked = None
451
450
452 def init(self):
451 def init(self):
453 self._sock = socket.socket(socket.AF_UNIX)
452 self._sock = socket.socket(socket.AF_UNIX)
454 self._servicehandler.bindsocket(self._sock, self.address)
453 self._servicehandler.bindsocket(self._sock, self.address)
455 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
454 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
456 self._oldsigchldhandler = o
455 self._oldsigchldhandler = o
457 self._socketunlinked = False
456 self._socketunlinked = False
458
457
459 def _unlinksocket(self):
458 def _unlinksocket(self):
460 if not self._socketunlinked:
459 if not self._socketunlinked:
461 self._servicehandler.unlinksocket(self.address)
460 self._servicehandler.unlinksocket(self.address)
462 self._socketunlinked = True
461 self._socketunlinked = True
463
462
464 def _cleanup(self):
463 def _cleanup(self):
465 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
464 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
466 self._sock.close()
465 self._sock.close()
467 self._unlinksocket()
466 self._unlinksocket()
468 # don't kill child processes as they have active clients, just wait
467 # don't kill child processes as they have active clients, just wait
469 self._reapworkers(0)
468 self._reapworkers(0)
470
469
471 def run(self):
470 def run(self):
472 try:
471 try:
473 self._mainloop()
472 self._mainloop()
474 finally:
473 finally:
475 self._cleanup()
474 self._cleanup()
476
475
477 def _mainloop(self):
476 def _mainloop(self):
478 exiting = False
477 exiting = False
479 h = self._servicehandler
478 h = self._servicehandler
480 selector = selectors2.DefaultSelector()
479 selector = selectors2.DefaultSelector()
481 selector.register(self._sock, selectors2.EVENT_READ)
480 selector.register(self._sock, selectors2.EVENT_READ)
482 while True:
481 while True:
483 if not exiting and h.shouldexit():
482 if not exiting and h.shouldexit():
484 # clients can no longer connect() to the domain socket, so
483 # clients can no longer connect() to the domain socket, so
485 # we stop queuing new requests.
484 # we stop queuing new requests.
486 # for requests that are queued (connect()-ed, but haven't been
485 # for requests that are queued (connect()-ed, but haven't been
487 # accept()-ed), handle them before exit. otherwise, clients
486 # accept()-ed), handle them before exit. otherwise, clients
488 # waiting for recv() will receive ECONNRESET.
487 # waiting for recv() will receive ECONNRESET.
489 self._unlinksocket()
488 self._unlinksocket()
490 exiting = True
489 exiting = True
490 ready = selector.select(timeout=h.pollinterval)
491 if not ready:
492 # only exit if we completed all queued requests
493 if exiting:
494 break
495 continue
491 try:
496 try:
492 ready = selector.select(timeout=h.pollinterval)
493 if not ready:
494 # only exit if we completed all queued requests
495 if exiting:
496 break
497 continue
498 conn, _addr = self._sock.accept()
497 conn, _addr = self._sock.accept()
499 except (select.error, socket.error) as inst:
498 except socket.error as inst:
500 if inst.args[0] == errno.EINTR:
499 if inst.args[0] == errno.EINTR:
501 continue
500 continue
502 raise
501 raise
503
502
504 pid = os.fork()
503 pid = os.fork()
505 if pid:
504 if pid:
506 try:
505 try:
507 self.ui.debug('forked worker process (pid=%d)\n' % pid)
506 self.ui.debug('forked worker process (pid=%d)\n' % pid)
508 self._workerpids.add(pid)
507 self._workerpids.add(pid)
509 h.newconnection()
508 h.newconnection()
510 finally:
509 finally:
511 conn.close() # release handle in parent process
510 conn.close() # release handle in parent process
512 else:
511 else:
513 try:
512 try:
514 self._runworker(conn)
513 self._runworker(conn)
515 conn.close()
514 conn.close()
516 os._exit(0)
515 os._exit(0)
517 except: # never return, hence no re-raises
516 except: # never return, hence no re-raises
518 try:
517 try:
519 self.ui.traceback(force=True)
518 self.ui.traceback(force=True)
520 finally:
519 finally:
521 os._exit(255)
520 os._exit(255)
522 selector.close()
521 selector.close()
523
522
524 def _sigchldhandler(self, signal, frame):
523 def _sigchldhandler(self, signal, frame):
525 self._reapworkers(os.WNOHANG)
524 self._reapworkers(os.WNOHANG)
526
525
527 def _reapworkers(self, options):
526 def _reapworkers(self, options):
528 while self._workerpids:
527 while self._workerpids:
529 try:
528 try:
530 pid, _status = os.waitpid(-1, options)
529 pid, _status = os.waitpid(-1, options)
531 except OSError as inst:
530 except OSError as inst:
532 if inst.errno == errno.EINTR:
531 if inst.errno == errno.EINTR:
533 continue
532 continue
534 if inst.errno != errno.ECHILD:
533 if inst.errno != errno.ECHILD:
535 raise
534 raise
536 # no child processes at all (reaped by other waitpid()?)
535 # no child processes at all (reaped by other waitpid()?)
537 self._workerpids.clear()
536 self._workerpids.clear()
538 return
537 return
539 if pid == 0:
538 if pid == 0:
540 # no waitable child processes
539 # no waitable child processes
541 return
540 return
542 self.ui.debug('worker process exited (pid=%d)\n' % pid)
541 self.ui.debug('worker process exited (pid=%d)\n' % pid)
543 self._workerpids.discard(pid)
542 self._workerpids.discard(pid)
544
543
545 def _runworker(self, conn):
544 def _runworker(self, conn):
546 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
545 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
547 _initworkerprocess()
546 _initworkerprocess()
548 h = self._servicehandler
547 h = self._servicehandler
549 try:
548 try:
550 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
549 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
551 finally:
550 finally:
552 gc.collect() # trigger __del__ since worker process uses os._exit
551 gc.collect() # trigger __del__ since worker process uses os._exit
General Comments 0
You need to be logged in to leave comments. Login now