##// END OF EJS Templates
commandserver: close selector explicitly...
Jun Wu -
r33506:8a1a7935 default
parent child Browse files
Show More
@@ -1,551 +1,552 b''
1 # commandserver.py - communicate with Mercurial's API over a pipe
1 # commandserver.py - communicate with Mercurial's API over a pipe
2 #
2 #
3 # Copyright Matt Mackall <mpm@selenic.com>
3 # Copyright Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import errno
10 import errno
11 import gc
11 import gc
12 import os
12 import os
13 import random
13 import random
14 import select
14 import select
15 import signal
15 import signal
16 import socket
16 import socket
17 import struct
17 import struct
18 import traceback
18 import traceback
19
19
20 from .i18n import _
20 from .i18n import _
21 from . import (
21 from . import (
22 encoding,
22 encoding,
23 error,
23 error,
24 pycompat,
24 pycompat,
25 selectors2,
25 selectors2,
26 util,
26 util,
27 )
27 )
28
28
29 logfile = None
29 logfile = None
30
30
31 def log(*args):
31 def log(*args):
32 if not logfile:
32 if not logfile:
33 return
33 return
34
34
35 for a in args:
35 for a in args:
36 logfile.write(str(a))
36 logfile.write(str(a))
37
37
38 logfile.flush()
38 logfile.flush()
39
39
40 class channeledoutput(object):
40 class channeledoutput(object):
41 """
41 """
42 Write data to out in the following format:
42 Write data to out in the following format:
43
43
44 data length (unsigned int),
44 data length (unsigned int),
45 data
45 data
46 """
46 """
47 def __init__(self, out, channel):
47 def __init__(self, out, channel):
48 self.out = out
48 self.out = out
49 self.channel = channel
49 self.channel = channel
50
50
51 @property
51 @property
52 def name(self):
52 def name(self):
53 return '<%c-channel>' % self.channel
53 return '<%c-channel>' % self.channel
54
54
55 def write(self, data):
55 def write(self, data):
56 if not data:
56 if not data:
57 return
57 return
58 # single write() to guarantee the same atomicity as the underlying file
58 # single write() to guarantee the same atomicity as the underlying file
59 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
59 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
60 self.out.flush()
60 self.out.flush()
61
61
62 def __getattr__(self, attr):
62 def __getattr__(self, attr):
63 if attr in ('isatty', 'fileno', 'tell', 'seek'):
63 if attr in ('isatty', 'fileno', 'tell', 'seek'):
64 raise AttributeError(attr)
64 raise AttributeError(attr)
65 return getattr(self.out, attr)
65 return getattr(self.out, attr)
66
66
67 class channeledinput(object):
67 class channeledinput(object):
68 """
68 """
69 Read data from in_.
69 Read data from in_.
70
70
71 Requests for input are written to out in the following format:
71 Requests for input are written to out in the following format:
72 channel identifier - 'I' for plain input, 'L' line based (1 byte)
72 channel identifier - 'I' for plain input, 'L' line based (1 byte)
73 how many bytes to send at most (unsigned int),
73 how many bytes to send at most (unsigned int),
74
74
75 The client replies with:
75 The client replies with:
76 data length (unsigned int), 0 meaning EOF
76 data length (unsigned int), 0 meaning EOF
77 data
77 data
78 """
78 """
79
79
80 maxchunksize = 4 * 1024
80 maxchunksize = 4 * 1024
81
81
82 def __init__(self, in_, out, channel):
82 def __init__(self, in_, out, channel):
83 self.in_ = in_
83 self.in_ = in_
84 self.out = out
84 self.out = out
85 self.channel = channel
85 self.channel = channel
86
86
87 @property
87 @property
88 def name(self):
88 def name(self):
89 return '<%c-channel>' % self.channel
89 return '<%c-channel>' % self.channel
90
90
91 def read(self, size=-1):
91 def read(self, size=-1):
92 if size < 0:
92 if size < 0:
93 # if we need to consume all the clients input, ask for 4k chunks
93 # if we need to consume all the clients input, ask for 4k chunks
94 # so the pipe doesn't fill up risking a deadlock
94 # so the pipe doesn't fill up risking a deadlock
95 size = self.maxchunksize
95 size = self.maxchunksize
96 s = self._read(size, self.channel)
96 s = self._read(size, self.channel)
97 buf = s
97 buf = s
98 while s:
98 while s:
99 s = self._read(size, self.channel)
99 s = self._read(size, self.channel)
100 buf += s
100 buf += s
101
101
102 return buf
102 return buf
103 else:
103 else:
104 return self._read(size, self.channel)
104 return self._read(size, self.channel)
105
105
106 def _read(self, size, channel):
106 def _read(self, size, channel):
107 if not size:
107 if not size:
108 return ''
108 return ''
109 assert size > 0
109 assert size > 0
110
110
111 # tell the client we need at most size bytes
111 # tell the client we need at most size bytes
112 self.out.write(struct.pack('>cI', channel, size))
112 self.out.write(struct.pack('>cI', channel, size))
113 self.out.flush()
113 self.out.flush()
114
114
115 length = self.in_.read(4)
115 length = self.in_.read(4)
116 length = struct.unpack('>I', length)[0]
116 length = struct.unpack('>I', length)[0]
117 if not length:
117 if not length:
118 return ''
118 return ''
119 else:
119 else:
120 return self.in_.read(length)
120 return self.in_.read(length)
121
121
122 def readline(self, size=-1):
122 def readline(self, size=-1):
123 if size < 0:
123 if size < 0:
124 size = self.maxchunksize
124 size = self.maxchunksize
125 s = self._read(size, 'L')
125 s = self._read(size, 'L')
126 buf = s
126 buf = s
127 # keep asking for more until there's either no more or
127 # keep asking for more until there's either no more or
128 # we got a full line
128 # we got a full line
129 while s and s[-1] != '\n':
129 while s and s[-1] != '\n':
130 s = self._read(size, 'L')
130 s = self._read(size, 'L')
131 buf += s
131 buf += s
132
132
133 return buf
133 return buf
134 else:
134 else:
135 return self._read(size, 'L')
135 return self._read(size, 'L')
136
136
137 def __iter__(self):
137 def __iter__(self):
138 return self
138 return self
139
139
140 def next(self):
140 def next(self):
141 l = self.readline()
141 l = self.readline()
142 if not l:
142 if not l:
143 raise StopIteration
143 raise StopIteration
144 return l
144 return l
145
145
146 def __getattr__(self, attr):
146 def __getattr__(self, attr):
147 if attr in ('isatty', 'fileno', 'tell', 'seek'):
147 if attr in ('isatty', 'fileno', 'tell', 'seek'):
148 raise AttributeError(attr)
148 raise AttributeError(attr)
149 return getattr(self.in_, attr)
149 return getattr(self.in_, attr)
150
150
151 class server(object):
151 class server(object):
152 """
152 """
153 Listens for commands on fin, runs them and writes the output on a channel
153 Listens for commands on fin, runs them and writes the output on a channel
154 based stream to fout.
154 based stream to fout.
155 """
155 """
156 def __init__(self, ui, repo, fin, fout):
156 def __init__(self, ui, repo, fin, fout):
157 self.cwd = pycompat.getcwd()
157 self.cwd = pycompat.getcwd()
158
158
159 # developer config: cmdserver.log
159 # developer config: cmdserver.log
160 logpath = ui.config("cmdserver", "log")
160 logpath = ui.config("cmdserver", "log")
161 if logpath:
161 if logpath:
162 global logfile
162 global logfile
163 if logpath == '-':
163 if logpath == '-':
164 # write log on a special 'd' (debug) channel
164 # write log on a special 'd' (debug) channel
165 logfile = channeledoutput(fout, 'd')
165 logfile = channeledoutput(fout, 'd')
166 else:
166 else:
167 logfile = open(logpath, 'a')
167 logfile = open(logpath, 'a')
168
168
169 if repo:
169 if repo:
170 # the ui here is really the repo ui so take its baseui so we don't
170 # the ui here is really the repo ui so take its baseui so we don't
171 # end up with its local configuration
171 # end up with its local configuration
172 self.ui = repo.baseui
172 self.ui = repo.baseui
173 self.repo = repo
173 self.repo = repo
174 self.repoui = repo.ui
174 self.repoui = repo.ui
175 else:
175 else:
176 self.ui = ui
176 self.ui = ui
177 self.repo = self.repoui = None
177 self.repo = self.repoui = None
178
178
179 self.cerr = channeledoutput(fout, 'e')
179 self.cerr = channeledoutput(fout, 'e')
180 self.cout = channeledoutput(fout, 'o')
180 self.cout = channeledoutput(fout, 'o')
181 self.cin = channeledinput(fin, fout, 'I')
181 self.cin = channeledinput(fin, fout, 'I')
182 self.cresult = channeledoutput(fout, 'r')
182 self.cresult = channeledoutput(fout, 'r')
183
183
184 self.client = fin
184 self.client = fin
185
185
186 def cleanup(self):
186 def cleanup(self):
187 """release and restore resources taken during server session"""
187 """release and restore resources taken during server session"""
188 pass
188 pass
189
189
190 def _read(self, size):
190 def _read(self, size):
191 if not size:
191 if not size:
192 return ''
192 return ''
193
193
194 data = self.client.read(size)
194 data = self.client.read(size)
195
195
196 # is the other end closed?
196 # is the other end closed?
197 if not data:
197 if not data:
198 raise EOFError
198 raise EOFError
199
199
200 return data
200 return data
201
201
202 def _readstr(self):
202 def _readstr(self):
203 """read a string from the channel
203 """read a string from the channel
204
204
205 format:
205 format:
206 data length (uint32), data
206 data length (uint32), data
207 """
207 """
208 length = struct.unpack('>I', self._read(4))[0]
208 length = struct.unpack('>I', self._read(4))[0]
209 if not length:
209 if not length:
210 return ''
210 return ''
211 return self._read(length)
211 return self._read(length)
212
212
213 def _readlist(self):
213 def _readlist(self):
214 """read a list of NULL separated strings from the channel"""
214 """read a list of NULL separated strings from the channel"""
215 s = self._readstr()
215 s = self._readstr()
216 if s:
216 if s:
217 return s.split('\0')
217 return s.split('\0')
218 else:
218 else:
219 return []
219 return []
220
220
221 def runcommand(self):
221 def runcommand(self):
222 """ reads a list of \0 terminated arguments, executes
222 """ reads a list of \0 terminated arguments, executes
223 and writes the return code to the result channel """
223 and writes the return code to the result channel """
224 from . import dispatch # avoid cycle
224 from . import dispatch # avoid cycle
225
225
226 args = self._readlist()
226 args = self._readlist()
227
227
228 # copy the uis so changes (e.g. --config or --verbose) don't
228 # copy the uis so changes (e.g. --config or --verbose) don't
229 # persist between requests
229 # persist between requests
230 copiedui = self.ui.copy()
230 copiedui = self.ui.copy()
231 uis = [copiedui]
231 uis = [copiedui]
232 if self.repo:
232 if self.repo:
233 self.repo.baseui = copiedui
233 self.repo.baseui = copiedui
234 # clone ui without using ui.copy because this is protected
234 # clone ui without using ui.copy because this is protected
235 repoui = self.repoui.__class__(self.repoui)
235 repoui = self.repoui.__class__(self.repoui)
236 repoui.copy = copiedui.copy # redo copy protection
236 repoui.copy = copiedui.copy # redo copy protection
237 uis.append(repoui)
237 uis.append(repoui)
238 self.repo.ui = self.repo.dirstate._ui = repoui
238 self.repo.ui = self.repo.dirstate._ui = repoui
239 self.repo.invalidateall()
239 self.repo.invalidateall()
240
240
241 for ui in uis:
241 for ui in uis:
242 ui.resetstate()
242 ui.resetstate()
243 # any kind of interaction must use server channels, but chg may
243 # any kind of interaction must use server channels, but chg may
244 # replace channels by fully functional tty files. so nontty is
244 # replace channels by fully functional tty files. so nontty is
245 # enforced only if cin is a channel.
245 # enforced only if cin is a channel.
246 if not util.safehasattr(self.cin, 'fileno'):
246 if not util.safehasattr(self.cin, 'fileno'):
247 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
247 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
248
248
249 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
249 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
250 self.cout, self.cerr)
250 self.cout, self.cerr)
251
251
252 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
252 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
253
253
254 # restore old cwd
254 # restore old cwd
255 if '--cwd' in args:
255 if '--cwd' in args:
256 os.chdir(self.cwd)
256 os.chdir(self.cwd)
257
257
258 self.cresult.write(struct.pack('>i', int(ret)))
258 self.cresult.write(struct.pack('>i', int(ret)))
259
259
260 def getencoding(self):
260 def getencoding(self):
261 """ writes the current encoding to the result channel """
261 """ writes the current encoding to the result channel """
262 self.cresult.write(encoding.encoding)
262 self.cresult.write(encoding.encoding)
263
263
264 def serveone(self):
264 def serveone(self):
265 cmd = self.client.readline()[:-1]
265 cmd = self.client.readline()[:-1]
266 if cmd:
266 if cmd:
267 handler = self.capabilities.get(cmd)
267 handler = self.capabilities.get(cmd)
268 if handler:
268 if handler:
269 handler(self)
269 handler(self)
270 else:
270 else:
271 # clients are expected to check what commands are supported by
271 # clients are expected to check what commands are supported by
272 # looking at the servers capabilities
272 # looking at the servers capabilities
273 raise error.Abort(_('unknown command %s') % cmd)
273 raise error.Abort(_('unknown command %s') % cmd)
274
274
275 return cmd != ''
275 return cmd != ''
276
276
277 capabilities = {'runcommand' : runcommand,
277 capabilities = {'runcommand' : runcommand,
278 'getencoding' : getencoding}
278 'getencoding' : getencoding}
279
279
280 def serve(self):
280 def serve(self):
281 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
281 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
282 hellomsg += '\n'
282 hellomsg += '\n'
283 hellomsg += 'encoding: ' + encoding.encoding
283 hellomsg += 'encoding: ' + encoding.encoding
284 hellomsg += '\n'
284 hellomsg += '\n'
285 hellomsg += 'pid: %d' % util.getpid()
285 hellomsg += 'pid: %d' % util.getpid()
286 if util.safehasattr(os, 'getpgid'):
286 if util.safehasattr(os, 'getpgid'):
287 hellomsg += '\n'
287 hellomsg += '\n'
288 hellomsg += 'pgid: %d' % os.getpgid(0)
288 hellomsg += 'pgid: %d' % os.getpgid(0)
289
289
290 # write the hello msg in -one- chunk
290 # write the hello msg in -one- chunk
291 self.cout.write(hellomsg)
291 self.cout.write(hellomsg)
292
292
293 try:
293 try:
294 while self.serveone():
294 while self.serveone():
295 pass
295 pass
296 except EOFError:
296 except EOFError:
297 # we'll get here if the client disconnected while we were reading
297 # we'll get here if the client disconnected while we were reading
298 # its request
298 # its request
299 return 1
299 return 1
300
300
301 return 0
301 return 0
302
302
303 def _protectio(ui):
303 def _protectio(ui):
304 """ duplicates streams and redirect original to null if ui uses stdio """
304 """ duplicates streams and redirect original to null if ui uses stdio """
305 ui.flush()
305 ui.flush()
306 newfiles = []
306 newfiles = []
307 nullfd = os.open(os.devnull, os.O_RDWR)
307 nullfd = os.open(os.devnull, os.O_RDWR)
308 for f, sysf, mode in [(ui.fin, util.stdin, pycompat.sysstr('rb')),
308 for f, sysf, mode in [(ui.fin, util.stdin, pycompat.sysstr('rb')),
309 (ui.fout, util.stdout, pycompat.sysstr('wb'))]:
309 (ui.fout, util.stdout, pycompat.sysstr('wb'))]:
310 if f is sysf:
310 if f is sysf:
311 newfd = os.dup(f.fileno())
311 newfd = os.dup(f.fileno())
312 os.dup2(nullfd, f.fileno())
312 os.dup2(nullfd, f.fileno())
313 f = os.fdopen(newfd, mode)
313 f = os.fdopen(newfd, mode)
314 newfiles.append(f)
314 newfiles.append(f)
315 os.close(nullfd)
315 os.close(nullfd)
316 return tuple(newfiles)
316 return tuple(newfiles)
317
317
318 def _restoreio(ui, fin, fout):
318 def _restoreio(ui, fin, fout):
319 """ restores streams from duplicated ones """
319 """ restores streams from duplicated ones """
320 ui.flush()
320 ui.flush()
321 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
321 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
322 if f is not uif:
322 if f is not uif:
323 os.dup2(f.fileno(), uif.fileno())
323 os.dup2(f.fileno(), uif.fileno())
324 f.close()
324 f.close()
325
325
326 class pipeservice(object):
326 class pipeservice(object):
327 def __init__(self, ui, repo, opts):
327 def __init__(self, ui, repo, opts):
328 self.ui = ui
328 self.ui = ui
329 self.repo = repo
329 self.repo = repo
330
330
331 def init(self):
331 def init(self):
332 pass
332 pass
333
333
334 def run(self):
334 def run(self):
335 ui = self.ui
335 ui = self.ui
336 # redirect stdio to null device so that broken extensions or in-process
336 # redirect stdio to null device so that broken extensions or in-process
337 # hooks will never cause corruption of channel protocol.
337 # hooks will never cause corruption of channel protocol.
338 fin, fout = _protectio(ui)
338 fin, fout = _protectio(ui)
339 try:
339 try:
340 sv = server(ui, self.repo, fin, fout)
340 sv = server(ui, self.repo, fin, fout)
341 return sv.serve()
341 return sv.serve()
342 finally:
342 finally:
343 sv.cleanup()
343 sv.cleanup()
344 _restoreio(ui, fin, fout)
344 _restoreio(ui, fin, fout)
345
345
346 def _initworkerprocess():
346 def _initworkerprocess():
347 # use a different process group from the master process, in order to:
347 # use a different process group from the master process, in order to:
348 # 1. make the current process group no longer "orphaned" (because the
348 # 1. make the current process group no longer "orphaned" (because the
349 # parent of this process is in a different process group while
349 # parent of this process is in a different process group while
350 # remains in a same session)
350 # remains in a same session)
351 # according to POSIX 2.2.2.52, orphaned process group will ignore
351 # according to POSIX 2.2.2.52, orphaned process group will ignore
352 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
352 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
353 # cause trouble for things like ncurses.
353 # cause trouble for things like ncurses.
354 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
354 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
355 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
355 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
356 # processes like ssh will be killed properly, without affecting
356 # processes like ssh will be killed properly, without affecting
357 # unrelated processes.
357 # unrelated processes.
358 os.setpgid(0, 0)
358 os.setpgid(0, 0)
359 # change random state otherwise forked request handlers would have a
359 # change random state otherwise forked request handlers would have a
360 # same state inherited from parent.
360 # same state inherited from parent.
361 random.seed()
361 random.seed()
362
362
363 def _serverequest(ui, repo, conn, createcmdserver):
363 def _serverequest(ui, repo, conn, createcmdserver):
364 fin = conn.makefile('rb')
364 fin = conn.makefile('rb')
365 fout = conn.makefile('wb')
365 fout = conn.makefile('wb')
366 sv = None
366 sv = None
367 try:
367 try:
368 sv = createcmdserver(repo, conn, fin, fout)
368 sv = createcmdserver(repo, conn, fin, fout)
369 try:
369 try:
370 sv.serve()
370 sv.serve()
371 # handle exceptions that may be raised by command server. most of
371 # handle exceptions that may be raised by command server. most of
372 # known exceptions are caught by dispatch.
372 # known exceptions are caught by dispatch.
373 except error.Abort as inst:
373 except error.Abort as inst:
374 ui.warn(_('abort: %s\n') % inst)
374 ui.warn(_('abort: %s\n') % inst)
375 except IOError as inst:
375 except IOError as inst:
376 if inst.errno != errno.EPIPE:
376 if inst.errno != errno.EPIPE:
377 raise
377 raise
378 except KeyboardInterrupt:
378 except KeyboardInterrupt:
379 pass
379 pass
380 finally:
380 finally:
381 sv.cleanup()
381 sv.cleanup()
382 except: # re-raises
382 except: # re-raises
383 # also write traceback to error channel. otherwise client cannot
383 # also write traceback to error channel. otherwise client cannot
384 # see it because it is written to server's stderr by default.
384 # see it because it is written to server's stderr by default.
385 if sv:
385 if sv:
386 cerr = sv.cerr
386 cerr = sv.cerr
387 else:
387 else:
388 cerr = channeledoutput(fout, 'e')
388 cerr = channeledoutput(fout, 'e')
389 traceback.print_exc(file=cerr)
389 traceback.print_exc(file=cerr)
390 raise
390 raise
391 finally:
391 finally:
392 fin.close()
392 fin.close()
393 try:
393 try:
394 fout.close() # implicit flush() may cause another EPIPE
394 fout.close() # implicit flush() may cause another EPIPE
395 except IOError as inst:
395 except IOError as inst:
396 if inst.errno != errno.EPIPE:
396 if inst.errno != errno.EPIPE:
397 raise
397 raise
398
398
399 class unixservicehandler(object):
399 class unixservicehandler(object):
400 """Set of pluggable operations for unix-mode services
400 """Set of pluggable operations for unix-mode services
401
401
402 Almost all methods except for createcmdserver() are called in the main
402 Almost all methods except for createcmdserver() are called in the main
403 process. You can't pass mutable resource back from createcmdserver().
403 process. You can't pass mutable resource back from createcmdserver().
404 """
404 """
405
405
406 pollinterval = None
406 pollinterval = None
407
407
408 def __init__(self, ui):
408 def __init__(self, ui):
409 self.ui = ui
409 self.ui = ui
410
410
411 def bindsocket(self, sock, address):
411 def bindsocket(self, sock, address):
412 util.bindunixsocket(sock, address)
412 util.bindunixsocket(sock, address)
413 sock.listen(socket.SOMAXCONN)
413 sock.listen(socket.SOMAXCONN)
414 self.ui.status(_('listening at %s\n') % address)
414 self.ui.status(_('listening at %s\n') % address)
415 self.ui.flush() # avoid buffering of status message
415 self.ui.flush() # avoid buffering of status message
416
416
417 def unlinksocket(self, address):
417 def unlinksocket(self, address):
418 os.unlink(address)
418 os.unlink(address)
419
419
420 def shouldexit(self):
420 def shouldexit(self):
421 """True if server should shut down; checked per pollinterval"""
421 """True if server should shut down; checked per pollinterval"""
422 return False
422 return False
423
423
424 def newconnection(self):
424 def newconnection(self):
425 """Called when main process notices new connection"""
425 """Called when main process notices new connection"""
426 pass
426 pass
427
427
428 def createcmdserver(self, repo, conn, fin, fout):
428 def createcmdserver(self, repo, conn, fin, fout):
429 """Create new command server instance; called in the process that
429 """Create new command server instance; called in the process that
430 serves for the current connection"""
430 serves for the current connection"""
431 return server(self.ui, repo, fin, fout)
431 return server(self.ui, repo, fin, fout)
432
432
433 class unixforkingservice(object):
433 class unixforkingservice(object):
434 """
434 """
435 Listens on unix domain socket and forks server per connection
435 Listens on unix domain socket and forks server per connection
436 """
436 """
437
437
438 def __init__(self, ui, repo, opts, handler=None):
438 def __init__(self, ui, repo, opts, handler=None):
439 self.ui = ui
439 self.ui = ui
440 self.repo = repo
440 self.repo = repo
441 self.address = opts['address']
441 self.address = opts['address']
442 if not util.safehasattr(socket, 'AF_UNIX'):
442 if not util.safehasattr(socket, 'AF_UNIX'):
443 raise error.Abort(_('unsupported platform'))
443 raise error.Abort(_('unsupported platform'))
444 if not self.address:
444 if not self.address:
445 raise error.Abort(_('no socket path specified with --address'))
445 raise error.Abort(_('no socket path specified with --address'))
446 self._servicehandler = handler or unixservicehandler(ui)
446 self._servicehandler = handler or unixservicehandler(ui)
447 self._sock = None
447 self._sock = None
448 self._oldsigchldhandler = None
448 self._oldsigchldhandler = None
449 self._workerpids = set() # updated by signal handler; do not iterate
449 self._workerpids = set() # updated by signal handler; do not iterate
450 self._socketunlinked = None
450 self._socketunlinked = None
451
451
452 def init(self):
452 def init(self):
453 self._sock = socket.socket(socket.AF_UNIX)
453 self._sock = socket.socket(socket.AF_UNIX)
454 self._servicehandler.bindsocket(self._sock, self.address)
454 self._servicehandler.bindsocket(self._sock, self.address)
455 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
455 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
456 self._oldsigchldhandler = o
456 self._oldsigchldhandler = o
457 self._socketunlinked = False
457 self._socketunlinked = False
458
458
459 def _unlinksocket(self):
459 def _unlinksocket(self):
460 if not self._socketunlinked:
460 if not self._socketunlinked:
461 self._servicehandler.unlinksocket(self.address)
461 self._servicehandler.unlinksocket(self.address)
462 self._socketunlinked = True
462 self._socketunlinked = True
463
463
464 def _cleanup(self):
464 def _cleanup(self):
465 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
465 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
466 self._sock.close()
466 self._sock.close()
467 self._unlinksocket()
467 self._unlinksocket()
468 # don't kill child processes as they have active clients, just wait
468 # don't kill child processes as they have active clients, just wait
469 self._reapworkers(0)
469 self._reapworkers(0)
470
470
471 def run(self):
471 def run(self):
472 try:
472 try:
473 self._mainloop()
473 self._mainloop()
474 finally:
474 finally:
475 self._cleanup()
475 self._cleanup()
476
476
477 def _mainloop(self):
477 def _mainloop(self):
478 exiting = False
478 exiting = False
479 h = self._servicehandler
479 h = self._servicehandler
480 selector = selectors2.DefaultSelector()
480 selector = selectors2.DefaultSelector()
481 selector.register(self._sock, selectors2.EVENT_READ)
481 selector.register(self._sock, selectors2.EVENT_READ)
482 while True:
482 while True:
483 if not exiting and h.shouldexit():
483 if not exiting and h.shouldexit():
484 # clients can no longer connect() to the domain socket, so
484 # clients can no longer connect() to the domain socket, so
485 # we stop queuing new requests.
485 # we stop queuing new requests.
486 # for requests that are queued (connect()-ed, but haven't been
486 # for requests that are queued (connect()-ed, but haven't been
487 # accept()-ed), handle them before exit. otherwise, clients
487 # accept()-ed), handle them before exit. otherwise, clients
488 # waiting for recv() will receive ECONNRESET.
488 # waiting for recv() will receive ECONNRESET.
489 self._unlinksocket()
489 self._unlinksocket()
490 exiting = True
490 exiting = True
491 try:
491 try:
492 ready = selector.select(timeout=h.pollinterval)
492 ready = selector.select(timeout=h.pollinterval)
493 if not ready:
493 if not ready:
494 # only exit if we completed all queued requests
494 # only exit if we completed all queued requests
495 if exiting:
495 if exiting:
496 break
496 break
497 continue
497 continue
498 conn, _addr = self._sock.accept()
498 conn, _addr = self._sock.accept()
499 except (select.error, socket.error) as inst:
499 except (select.error, socket.error) as inst:
500 if inst.args[0] == errno.EINTR:
500 if inst.args[0] == errno.EINTR:
501 continue
501 continue
502 raise
502 raise
503
503
504 pid = os.fork()
504 pid = os.fork()
505 if pid:
505 if pid:
506 try:
506 try:
507 self.ui.debug('forked worker process (pid=%d)\n' % pid)
507 self.ui.debug('forked worker process (pid=%d)\n' % pid)
508 self._workerpids.add(pid)
508 self._workerpids.add(pid)
509 h.newconnection()
509 h.newconnection()
510 finally:
510 finally:
511 conn.close() # release handle in parent process
511 conn.close() # release handle in parent process
512 else:
512 else:
513 try:
513 try:
514 self._runworker(conn)
514 self._runworker(conn)
515 conn.close()
515 conn.close()
516 os._exit(0)
516 os._exit(0)
517 except: # never return, hence no re-raises
517 except: # never return, hence no re-raises
518 try:
518 try:
519 self.ui.traceback(force=True)
519 self.ui.traceback(force=True)
520 finally:
520 finally:
521 os._exit(255)
521 os._exit(255)
522 selector.close()
522
523
523 def _sigchldhandler(self, signal, frame):
524 def _sigchldhandler(self, signal, frame):
524 self._reapworkers(os.WNOHANG)
525 self._reapworkers(os.WNOHANG)
525
526
526 def _reapworkers(self, options):
527 def _reapworkers(self, options):
527 while self._workerpids:
528 while self._workerpids:
528 try:
529 try:
529 pid, _status = os.waitpid(-1, options)
530 pid, _status = os.waitpid(-1, options)
530 except OSError as inst:
531 except OSError as inst:
531 if inst.errno == errno.EINTR:
532 if inst.errno == errno.EINTR:
532 continue
533 continue
533 if inst.errno != errno.ECHILD:
534 if inst.errno != errno.ECHILD:
534 raise
535 raise
535 # no child processes at all (reaped by other waitpid()?)
536 # no child processes at all (reaped by other waitpid()?)
536 self._workerpids.clear()
537 self._workerpids.clear()
537 return
538 return
538 if pid == 0:
539 if pid == 0:
539 # no waitable child processes
540 # no waitable child processes
540 return
541 return
541 self.ui.debug('worker process exited (pid=%d)\n' % pid)
542 self.ui.debug('worker process exited (pid=%d)\n' % pid)
542 self._workerpids.discard(pid)
543 self._workerpids.discard(pid)
543
544
544 def _runworker(self, conn):
545 def _runworker(self, conn):
545 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
546 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
546 _initworkerprocess()
547 _initworkerprocess()
547 h = self._servicehandler
548 h = self._servicehandler
548 try:
549 try:
549 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
550 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
550 finally:
551 finally:
551 gc.collect() # trigger __del__ since worker process uses os._exit
552 gc.collect() # trigger __del__ since worker process uses os._exit
General Comments 0
You need to be logged in to leave comments. Login now