##// END OF EJS Templates
commandserver: unblock SIGCHLD...
Jun Wu -
r35477:3a119a42 @9 default
parent child Browse files
Show More
@@ -1,549 +1,551 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import errno
11 11 import gc
12 12 import os
13 13 import random
14 14 import signal
15 15 import socket
16 16 import struct
17 17 import traceback
18 18
19 19 from .i18n import _
20 20 from .thirdparty import selectors2
21 21 from . import (
22 22 encoding,
23 23 error,
24 24 pycompat,
25 25 util,
26 26 )
27 27
28 28 logfile = None
29 29
30 30 def log(*args):
31 31 if not logfile:
32 32 return
33 33
34 34 for a in args:
35 35 logfile.write(str(a))
36 36
37 37 logfile.flush()
38 38
39 39 class channeledoutput(object):
40 40 """
41 41 Write data to out in the following format:
42 42
43 43 data length (unsigned int),
44 44 data
45 45 """
46 46 def __init__(self, out, channel):
47 47 self.out = out
48 48 self.channel = channel
49 49
50 50 @property
51 51 def name(self):
52 52 return '<%c-channel>' % self.channel
53 53
54 54 def write(self, data):
55 55 if not data:
56 56 return
57 57 # single write() to guarantee the same atomicity as the underlying file
58 58 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
59 59 self.out.flush()
60 60
61 61 def __getattr__(self, attr):
62 62 if attr in ('isatty', 'fileno', 'tell', 'seek'):
63 63 raise AttributeError(attr)
64 64 return getattr(self.out, attr)
65 65
66 66 class channeledinput(object):
67 67 """
68 68 Read data from in_.
69 69
70 70 Requests for input are written to out in the following format:
71 71 channel identifier - 'I' for plain input, 'L' line based (1 byte)
72 72 how many bytes to send at most (unsigned int),
73 73
74 74 The client replies with:
75 75 data length (unsigned int), 0 meaning EOF
76 76 data
77 77 """
78 78
79 79 maxchunksize = 4 * 1024
80 80
81 81 def __init__(self, in_, out, channel):
82 82 self.in_ = in_
83 83 self.out = out
84 84 self.channel = channel
85 85
86 86 @property
87 87 def name(self):
88 88 return '<%c-channel>' % self.channel
89 89
90 90 def read(self, size=-1):
91 91 if size < 0:
92 92 # if we need to consume all the clients input, ask for 4k chunks
93 93 # so the pipe doesn't fill up risking a deadlock
94 94 size = self.maxchunksize
95 95 s = self._read(size, self.channel)
96 96 buf = s
97 97 while s:
98 98 s = self._read(size, self.channel)
99 99 buf += s
100 100
101 101 return buf
102 102 else:
103 103 return self._read(size, self.channel)
104 104
105 105 def _read(self, size, channel):
106 106 if not size:
107 107 return ''
108 108 assert size > 0
109 109
110 110 # tell the client we need at most size bytes
111 111 self.out.write(struct.pack('>cI', channel, size))
112 112 self.out.flush()
113 113
114 114 length = self.in_.read(4)
115 115 length = struct.unpack('>I', length)[0]
116 116 if not length:
117 117 return ''
118 118 else:
119 119 return self.in_.read(length)
120 120
121 121 def readline(self, size=-1):
122 122 if size < 0:
123 123 size = self.maxchunksize
124 124 s = self._read(size, 'L')
125 125 buf = s
126 126 # keep asking for more until there's either no more or
127 127 # we got a full line
128 128 while s and s[-1] != '\n':
129 129 s = self._read(size, 'L')
130 130 buf += s
131 131
132 132 return buf
133 133 else:
134 134 return self._read(size, 'L')
135 135
136 136 def __iter__(self):
137 137 return self
138 138
139 139 def next(self):
140 140 l = self.readline()
141 141 if not l:
142 142 raise StopIteration
143 143 return l
144 144
145 145 def __getattr__(self, attr):
146 146 if attr in ('isatty', 'fileno', 'tell', 'seek'):
147 147 raise AttributeError(attr)
148 148 return getattr(self.in_, attr)
149 149
150 150 class server(object):
151 151 """
152 152 Listens for commands on fin, runs them and writes the output on a channel
153 153 based stream to fout.
154 154 """
155 155 def __init__(self, ui, repo, fin, fout):
156 156 self.cwd = pycompat.getcwd()
157 157
158 158 # developer config: cmdserver.log
159 159 logpath = ui.config("cmdserver", "log")
160 160 if logpath:
161 161 global logfile
162 162 if logpath == '-':
163 163 # write log on a special 'd' (debug) channel
164 164 logfile = channeledoutput(fout, 'd')
165 165 else:
166 166 logfile = open(logpath, 'a')
167 167
168 168 if repo:
169 169 # the ui here is really the repo ui so take its baseui so we don't
170 170 # end up with its local configuration
171 171 self.ui = repo.baseui
172 172 self.repo = repo
173 173 self.repoui = repo.ui
174 174 else:
175 175 self.ui = ui
176 176 self.repo = self.repoui = None
177 177
178 178 self.cerr = channeledoutput(fout, 'e')
179 179 self.cout = channeledoutput(fout, 'o')
180 180 self.cin = channeledinput(fin, fout, 'I')
181 181 self.cresult = channeledoutput(fout, 'r')
182 182
183 183 self.client = fin
184 184
185 185 def cleanup(self):
186 186 """release and restore resources taken during server session"""
187 187
188 188 def _read(self, size):
189 189 if not size:
190 190 return ''
191 191
192 192 data = self.client.read(size)
193 193
194 194 # is the other end closed?
195 195 if not data:
196 196 raise EOFError
197 197
198 198 return data
199 199
200 200 def _readstr(self):
201 201 """read a string from the channel
202 202
203 203 format:
204 204 data length (uint32), data
205 205 """
206 206 length = struct.unpack('>I', self._read(4))[0]
207 207 if not length:
208 208 return ''
209 209 return self._read(length)
210 210
211 211 def _readlist(self):
212 212 """read a list of NULL separated strings from the channel"""
213 213 s = self._readstr()
214 214 if s:
215 215 return s.split('\0')
216 216 else:
217 217 return []
218 218
219 219 def runcommand(self):
220 220 """ reads a list of \0 terminated arguments, executes
221 221 and writes the return code to the result channel """
222 222 from . import dispatch # avoid cycle
223 223
224 224 args = self._readlist()
225 225
226 226 # copy the uis so changes (e.g. --config or --verbose) don't
227 227 # persist between requests
228 228 copiedui = self.ui.copy()
229 229 uis = [copiedui]
230 230 if self.repo:
231 231 self.repo.baseui = copiedui
232 232 # clone ui without using ui.copy because this is protected
233 233 repoui = self.repoui.__class__(self.repoui)
234 234 repoui.copy = copiedui.copy # redo copy protection
235 235 uis.append(repoui)
236 236 self.repo.ui = self.repo.dirstate._ui = repoui
237 237 self.repo.invalidateall()
238 238
239 239 for ui in uis:
240 240 ui.resetstate()
241 241 # any kind of interaction must use server channels, but chg may
242 242 # replace channels by fully functional tty files. so nontty is
243 243 # enforced only if cin is a channel.
244 244 if not util.safehasattr(self.cin, 'fileno'):
245 245 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
246 246
247 247 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
248 248 self.cout, self.cerr)
249 249
250 250 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
251 251
252 252 # restore old cwd
253 253 if '--cwd' in args:
254 254 os.chdir(self.cwd)
255 255
256 256 self.cresult.write(struct.pack('>i', int(ret)))
257 257
258 258 def getencoding(self):
259 259 """ writes the current encoding to the result channel """
260 260 self.cresult.write(encoding.encoding)
261 261
262 262 def serveone(self):
263 263 cmd = self.client.readline()[:-1]
264 264 if cmd:
265 265 handler = self.capabilities.get(cmd)
266 266 if handler:
267 267 handler(self)
268 268 else:
269 269 # clients are expected to check what commands are supported by
270 270 # looking at the servers capabilities
271 271 raise error.Abort(_('unknown command %s') % cmd)
272 272
273 273 return cmd != ''
274 274
275 275 capabilities = {'runcommand': runcommand,
276 276 'getencoding': getencoding}
277 277
278 278 def serve(self):
279 279 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
280 280 hellomsg += '\n'
281 281 hellomsg += 'encoding: ' + encoding.encoding
282 282 hellomsg += '\n'
283 283 hellomsg += 'pid: %d' % util.getpid()
284 284 if util.safehasattr(os, 'getpgid'):
285 285 hellomsg += '\n'
286 286 hellomsg += 'pgid: %d' % os.getpgid(0)
287 287
288 288 # write the hello msg in -one- chunk
289 289 self.cout.write(hellomsg)
290 290
291 291 try:
292 292 while self.serveone():
293 293 pass
294 294 except EOFError:
295 295 # we'll get here if the client disconnected while we were reading
296 296 # its request
297 297 return 1
298 298
299 299 return 0
300 300
301 301 def _protectio(ui):
302 302 """ duplicates streams and redirect original to null if ui uses stdio """
303 303 ui.flush()
304 304 newfiles = []
305 305 nullfd = os.open(os.devnull, os.O_RDWR)
306 306 for f, sysf, mode in [(ui.fin, util.stdin, pycompat.sysstr('rb')),
307 307 (ui.fout, util.stdout, pycompat.sysstr('wb'))]:
308 308 if f is sysf:
309 309 newfd = os.dup(f.fileno())
310 310 os.dup2(nullfd, f.fileno())
311 311 f = os.fdopen(newfd, mode)
312 312 newfiles.append(f)
313 313 os.close(nullfd)
314 314 return tuple(newfiles)
315 315
316 316 def _restoreio(ui, fin, fout):
317 317 """ restores streams from duplicated ones """
318 318 ui.flush()
319 319 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
320 320 if f is not uif:
321 321 os.dup2(f.fileno(), uif.fileno())
322 322 f.close()
323 323
324 324 class pipeservice(object):
325 325 def __init__(self, ui, repo, opts):
326 326 self.ui = ui
327 327 self.repo = repo
328 328
329 329 def init(self):
330 330 pass
331 331
332 332 def run(self):
333 333 ui = self.ui
334 334 # redirect stdio to null device so that broken extensions or in-process
335 335 # hooks will never cause corruption of channel protocol.
336 336 fin, fout = _protectio(ui)
337 337 try:
338 338 sv = server(ui, self.repo, fin, fout)
339 339 return sv.serve()
340 340 finally:
341 341 sv.cleanup()
342 342 _restoreio(ui, fin, fout)
343 343
344 344 def _initworkerprocess():
345 345 # use a different process group from the master process, in order to:
346 346 # 1. make the current process group no longer "orphaned" (because the
347 347 # parent of this process is in a different process group while
348 348 # remains in a same session)
349 349 # according to POSIX 2.2.2.52, orphaned process group will ignore
350 350 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
351 351 # cause trouble for things like ncurses.
352 352 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
353 353 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
354 354 # processes like ssh will be killed properly, without affecting
355 355 # unrelated processes.
356 356 os.setpgid(0, 0)
357 357 # change random state otherwise forked request handlers would have a
358 358 # same state inherited from parent.
359 359 random.seed()
360 360
361 361 def _serverequest(ui, repo, conn, createcmdserver):
362 362 fin = conn.makefile('rb')
363 363 fout = conn.makefile('wb')
364 364 sv = None
365 365 try:
366 366 sv = createcmdserver(repo, conn, fin, fout)
367 367 try:
368 368 sv.serve()
369 369 # handle exceptions that may be raised by command server. most of
370 370 # known exceptions are caught by dispatch.
371 371 except error.Abort as inst:
372 372 ui.warn(_('abort: %s\n') % inst)
373 373 except IOError as inst:
374 374 if inst.errno != errno.EPIPE:
375 375 raise
376 376 except KeyboardInterrupt:
377 377 pass
378 378 finally:
379 379 sv.cleanup()
380 380 except: # re-raises
381 381 # also write traceback to error channel. otherwise client cannot
382 382 # see it because it is written to server's stderr by default.
383 383 if sv:
384 384 cerr = sv.cerr
385 385 else:
386 386 cerr = channeledoutput(fout, 'e')
387 387 traceback.print_exc(file=cerr)
388 388 raise
389 389 finally:
390 390 fin.close()
391 391 try:
392 392 fout.close() # implicit flush() may cause another EPIPE
393 393 except IOError as inst:
394 394 if inst.errno != errno.EPIPE:
395 395 raise
396 396
397 397 class unixservicehandler(object):
398 398 """Set of pluggable operations for unix-mode services
399 399
400 400 Almost all methods except for createcmdserver() are called in the main
401 401 process. You can't pass mutable resource back from createcmdserver().
402 402 """
403 403
404 404 pollinterval = None
405 405
406 406 def __init__(self, ui):
407 407 self.ui = ui
408 408
409 409 def bindsocket(self, sock, address):
410 410 util.bindunixsocket(sock, address)
411 411 sock.listen(socket.SOMAXCONN)
412 412 self.ui.status(_('listening at %s\n') % address)
413 413 self.ui.flush() # avoid buffering of status message
414 414
415 415 def unlinksocket(self, address):
416 416 os.unlink(address)
417 417
418 418 def shouldexit(self):
419 419 """True if server should shut down; checked per pollinterval"""
420 420 return False
421 421
422 422 def newconnection(self):
423 423 """Called when main process notices new connection"""
424 424
425 425 def createcmdserver(self, repo, conn, fin, fout):
426 426 """Create new command server instance; called in the process that
427 427 serves for the current connection"""
428 428 return server(self.ui, repo, fin, fout)
429 429
430 430 class unixforkingservice(object):
431 431 """
432 432 Listens on unix domain socket and forks server per connection
433 433 """
434 434
435 435 def __init__(self, ui, repo, opts, handler=None):
436 436 self.ui = ui
437 437 self.repo = repo
438 438 self.address = opts['address']
439 439 if not util.safehasattr(socket, 'AF_UNIX'):
440 440 raise error.Abort(_('unsupported platform'))
441 441 if not self.address:
442 442 raise error.Abort(_('no socket path specified with --address'))
443 443 self._servicehandler = handler or unixservicehandler(ui)
444 444 self._sock = None
445 445 self._oldsigchldhandler = None
446 446 self._workerpids = set() # updated by signal handler; do not iterate
447 447 self._socketunlinked = None
448 448
449 449 def init(self):
450 450 self._sock = socket.socket(socket.AF_UNIX)
451 451 self._servicehandler.bindsocket(self._sock, self.address)
452 if util.safehasattr(util, 'unblocksignal'):
453 util.unblocksignal(signal.SIGCHLD)
452 454 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
453 455 self._oldsigchldhandler = o
454 456 self._socketunlinked = False
455 457
456 458 def _unlinksocket(self):
457 459 if not self._socketunlinked:
458 460 self._servicehandler.unlinksocket(self.address)
459 461 self._socketunlinked = True
460 462
461 463 def _cleanup(self):
462 464 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
463 465 self._sock.close()
464 466 self._unlinksocket()
465 467 # don't kill child processes as they have active clients, just wait
466 468 self._reapworkers(0)
467 469
468 470 def run(self):
469 471 try:
470 472 self._mainloop()
471 473 finally:
472 474 self._cleanup()
473 475
474 476 def _mainloop(self):
475 477 exiting = False
476 478 h = self._servicehandler
477 479 selector = selectors2.DefaultSelector()
478 480 selector.register(self._sock, selectors2.EVENT_READ)
479 481 while True:
480 482 if not exiting and h.shouldexit():
481 483 # clients can no longer connect() to the domain socket, so
482 484 # we stop queuing new requests.
483 485 # for requests that are queued (connect()-ed, but haven't been
484 486 # accept()-ed), handle them before exit. otherwise, clients
485 487 # waiting for recv() will receive ECONNRESET.
486 488 self._unlinksocket()
487 489 exiting = True
488 490 ready = selector.select(timeout=h.pollinterval)
489 491 if not ready:
490 492 # only exit if we completed all queued requests
491 493 if exiting:
492 494 break
493 495 continue
494 496 try:
495 497 conn, _addr = self._sock.accept()
496 498 except socket.error as inst:
497 499 if inst.args[0] == errno.EINTR:
498 500 continue
499 501 raise
500 502
501 503 pid = os.fork()
502 504 if pid:
503 505 try:
504 506 self.ui.debug('forked worker process (pid=%d)\n' % pid)
505 507 self._workerpids.add(pid)
506 508 h.newconnection()
507 509 finally:
508 510 conn.close() # release handle in parent process
509 511 else:
510 512 try:
511 513 self._runworker(conn)
512 514 conn.close()
513 515 os._exit(0)
514 516 except: # never return, hence no re-raises
515 517 try:
516 518 self.ui.traceback(force=True)
517 519 finally:
518 520 os._exit(255)
519 521 selector.close()
520 522
521 523 def _sigchldhandler(self, signal, frame):
522 524 self._reapworkers(os.WNOHANG)
523 525
524 526 def _reapworkers(self, options):
525 527 while self._workerpids:
526 528 try:
527 529 pid, _status = os.waitpid(-1, options)
528 530 except OSError as inst:
529 531 if inst.errno == errno.EINTR:
530 532 continue
531 533 if inst.errno != errno.ECHILD:
532 534 raise
533 535 # no child processes at all (reaped by other waitpid()?)
534 536 self._workerpids.clear()
535 537 return
536 538 if pid == 0:
537 539 # no waitable child processes
538 540 return
539 541 self.ui.debug('worker process exited (pid=%d)\n' % pid)
540 542 self._workerpids.discard(pid)
541 543
542 544 def _runworker(self, conn):
543 545 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
544 546 _initworkerprocess()
545 547 h = self._servicehandler
546 548 try:
547 549 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
548 550 finally:
549 551 gc.collect() # trigger __del__ since worker process uses os._exit
General Comments 0
You need to be logged in to leave comments. Login now