##// END OF EJS Templates
commandserver: do not handle EINTR for selector.select...
Jun Wu -
r33543:3ef3bf70 default
parent child Browse files
Show More
@@ -1,552 +1,551 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import errno
11 11 import gc
12 12 import os
13 13 import random
14 import select
15 14 import signal
16 15 import socket
17 16 import struct
18 17 import traceback
19 18
20 19 from .i18n import _
21 20 from . import (
22 21 encoding,
23 22 error,
24 23 pycompat,
25 24 selectors2,
26 25 util,
27 26 )
28 27
29 28 logfile = None
30 29
31 30 def log(*args):
32 31 if not logfile:
33 32 return
34 33
35 34 for a in args:
36 35 logfile.write(str(a))
37 36
38 37 logfile.flush()
39 38
40 39 class channeledoutput(object):
41 40 """
42 41 Write data to out in the following format:
43 42
44 43 data length (unsigned int),
45 44 data
46 45 """
47 46 def __init__(self, out, channel):
48 47 self.out = out
49 48 self.channel = channel
50 49
51 50 @property
52 51 def name(self):
53 52 return '<%c-channel>' % self.channel
54 53
55 54 def write(self, data):
56 55 if not data:
57 56 return
58 57 # single write() to guarantee the same atomicity as the underlying file
59 58 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
60 59 self.out.flush()
61 60
62 61 def __getattr__(self, attr):
63 62 if attr in ('isatty', 'fileno', 'tell', 'seek'):
64 63 raise AttributeError(attr)
65 64 return getattr(self.out, attr)
66 65
67 66 class channeledinput(object):
68 67 """
69 68 Read data from in_.
70 69
71 70 Requests for input are written to out in the following format:
72 71 channel identifier - 'I' for plain input, 'L' line based (1 byte)
73 72 how many bytes to send at most (unsigned int),
74 73
75 74 The client replies with:
76 75 data length (unsigned int), 0 meaning EOF
77 76 data
78 77 """
79 78
80 79 maxchunksize = 4 * 1024
81 80
82 81 def __init__(self, in_, out, channel):
83 82 self.in_ = in_
84 83 self.out = out
85 84 self.channel = channel
86 85
87 86 @property
88 87 def name(self):
89 88 return '<%c-channel>' % self.channel
90 89
91 90 def read(self, size=-1):
92 91 if size < 0:
93 92 # if we need to consume all the clients input, ask for 4k chunks
94 93 # so the pipe doesn't fill up risking a deadlock
95 94 size = self.maxchunksize
96 95 s = self._read(size, self.channel)
97 96 buf = s
98 97 while s:
99 98 s = self._read(size, self.channel)
100 99 buf += s
101 100
102 101 return buf
103 102 else:
104 103 return self._read(size, self.channel)
105 104
106 105 def _read(self, size, channel):
107 106 if not size:
108 107 return ''
109 108 assert size > 0
110 109
111 110 # tell the client we need at most size bytes
112 111 self.out.write(struct.pack('>cI', channel, size))
113 112 self.out.flush()
114 113
115 114 length = self.in_.read(4)
116 115 length = struct.unpack('>I', length)[0]
117 116 if not length:
118 117 return ''
119 118 else:
120 119 return self.in_.read(length)
121 120
122 121 def readline(self, size=-1):
123 122 if size < 0:
124 123 size = self.maxchunksize
125 124 s = self._read(size, 'L')
126 125 buf = s
127 126 # keep asking for more until there's either no more or
128 127 # we got a full line
129 128 while s and s[-1] != '\n':
130 129 s = self._read(size, 'L')
131 130 buf += s
132 131
133 132 return buf
134 133 else:
135 134 return self._read(size, 'L')
136 135
137 136 def __iter__(self):
138 137 return self
139 138
140 139 def next(self):
141 140 l = self.readline()
142 141 if not l:
143 142 raise StopIteration
144 143 return l
145 144
146 145 def __getattr__(self, attr):
147 146 if attr in ('isatty', 'fileno', 'tell', 'seek'):
148 147 raise AttributeError(attr)
149 148 return getattr(self.in_, attr)
150 149
151 150 class server(object):
152 151 """
153 152 Listens for commands on fin, runs them and writes the output on a channel
154 153 based stream to fout.
155 154 """
156 155 def __init__(self, ui, repo, fin, fout):
157 156 self.cwd = pycompat.getcwd()
158 157
159 158 # developer config: cmdserver.log
160 159 logpath = ui.config("cmdserver", "log")
161 160 if logpath:
162 161 global logfile
163 162 if logpath == '-':
164 163 # write log on a special 'd' (debug) channel
165 164 logfile = channeledoutput(fout, 'd')
166 165 else:
167 166 logfile = open(logpath, 'a')
168 167
169 168 if repo:
170 169 # the ui here is really the repo ui so take its baseui so we don't
171 170 # end up with its local configuration
172 171 self.ui = repo.baseui
173 172 self.repo = repo
174 173 self.repoui = repo.ui
175 174 else:
176 175 self.ui = ui
177 176 self.repo = self.repoui = None
178 177
179 178 self.cerr = channeledoutput(fout, 'e')
180 179 self.cout = channeledoutput(fout, 'o')
181 180 self.cin = channeledinput(fin, fout, 'I')
182 181 self.cresult = channeledoutput(fout, 'r')
183 182
184 183 self.client = fin
185 184
186 185 def cleanup(self):
187 186 """release and restore resources taken during server session"""
188 187 pass
189 188
190 189 def _read(self, size):
191 190 if not size:
192 191 return ''
193 192
194 193 data = self.client.read(size)
195 194
196 195 # is the other end closed?
197 196 if not data:
198 197 raise EOFError
199 198
200 199 return data
201 200
202 201 def _readstr(self):
203 202 """read a string from the channel
204 203
205 204 format:
206 205 data length (uint32), data
207 206 """
208 207 length = struct.unpack('>I', self._read(4))[0]
209 208 if not length:
210 209 return ''
211 210 return self._read(length)
212 211
213 212 def _readlist(self):
214 213 """read a list of NULL separated strings from the channel"""
215 214 s = self._readstr()
216 215 if s:
217 216 return s.split('\0')
218 217 else:
219 218 return []
220 219
221 220 def runcommand(self):
222 221 """ reads a list of \0 terminated arguments, executes
223 222 and writes the return code to the result channel """
224 223 from . import dispatch # avoid cycle
225 224
226 225 args = self._readlist()
227 226
228 227 # copy the uis so changes (e.g. --config or --verbose) don't
229 228 # persist between requests
230 229 copiedui = self.ui.copy()
231 230 uis = [copiedui]
232 231 if self.repo:
233 232 self.repo.baseui = copiedui
234 233 # clone ui without using ui.copy because this is protected
235 234 repoui = self.repoui.__class__(self.repoui)
236 235 repoui.copy = copiedui.copy # redo copy protection
237 236 uis.append(repoui)
238 237 self.repo.ui = self.repo.dirstate._ui = repoui
239 238 self.repo.invalidateall()
240 239
241 240 for ui in uis:
242 241 ui.resetstate()
243 242 # any kind of interaction must use server channels, but chg may
244 243 # replace channels by fully functional tty files. so nontty is
245 244 # enforced only if cin is a channel.
246 245 if not util.safehasattr(self.cin, 'fileno'):
247 246 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
248 247
249 248 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
250 249 self.cout, self.cerr)
251 250
252 251 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
253 252
254 253 # restore old cwd
255 254 if '--cwd' in args:
256 255 os.chdir(self.cwd)
257 256
258 257 self.cresult.write(struct.pack('>i', int(ret)))
259 258
260 259 def getencoding(self):
261 260 """ writes the current encoding to the result channel """
262 261 self.cresult.write(encoding.encoding)
263 262
264 263 def serveone(self):
265 264 cmd = self.client.readline()[:-1]
266 265 if cmd:
267 266 handler = self.capabilities.get(cmd)
268 267 if handler:
269 268 handler(self)
270 269 else:
271 270 # clients are expected to check what commands are supported by
272 271 # looking at the servers capabilities
273 272 raise error.Abort(_('unknown command %s') % cmd)
274 273
275 274 return cmd != ''
276 275
277 276 capabilities = {'runcommand' : runcommand,
278 277 'getencoding' : getencoding}
279 278
280 279 def serve(self):
281 280 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
282 281 hellomsg += '\n'
283 282 hellomsg += 'encoding: ' + encoding.encoding
284 283 hellomsg += '\n'
285 284 hellomsg += 'pid: %d' % util.getpid()
286 285 if util.safehasattr(os, 'getpgid'):
287 286 hellomsg += '\n'
288 287 hellomsg += 'pgid: %d' % os.getpgid(0)
289 288
290 289 # write the hello msg in -one- chunk
291 290 self.cout.write(hellomsg)
292 291
293 292 try:
294 293 while self.serveone():
295 294 pass
296 295 except EOFError:
297 296 # we'll get here if the client disconnected while we were reading
298 297 # its request
299 298 return 1
300 299
301 300 return 0
302 301
303 302 def _protectio(ui):
304 303 """ duplicates streams and redirect original to null if ui uses stdio """
305 304 ui.flush()
306 305 newfiles = []
307 306 nullfd = os.open(os.devnull, os.O_RDWR)
308 307 for f, sysf, mode in [(ui.fin, util.stdin, pycompat.sysstr('rb')),
309 308 (ui.fout, util.stdout, pycompat.sysstr('wb'))]:
310 309 if f is sysf:
311 310 newfd = os.dup(f.fileno())
312 311 os.dup2(nullfd, f.fileno())
313 312 f = os.fdopen(newfd, mode)
314 313 newfiles.append(f)
315 314 os.close(nullfd)
316 315 return tuple(newfiles)
317 316
318 317 def _restoreio(ui, fin, fout):
319 318 """ restores streams from duplicated ones """
320 319 ui.flush()
321 320 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
322 321 if f is not uif:
323 322 os.dup2(f.fileno(), uif.fileno())
324 323 f.close()
325 324
326 325 class pipeservice(object):
327 326 def __init__(self, ui, repo, opts):
328 327 self.ui = ui
329 328 self.repo = repo
330 329
331 330 def init(self):
332 331 pass
333 332
334 333 def run(self):
335 334 ui = self.ui
336 335 # redirect stdio to null device so that broken extensions or in-process
337 336 # hooks will never cause corruption of channel protocol.
338 337 fin, fout = _protectio(ui)
339 338 try:
340 339 sv = server(ui, self.repo, fin, fout)
341 340 return sv.serve()
342 341 finally:
343 342 sv.cleanup()
344 343 _restoreio(ui, fin, fout)
345 344
346 345 def _initworkerprocess():
347 346 # use a different process group from the master process, in order to:
348 347 # 1. make the current process group no longer "orphaned" (because the
349 348 # parent of this process is in a different process group while
350 349 # remains in a same session)
351 350 # according to POSIX 2.2.2.52, orphaned process group will ignore
352 351 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
353 352 # cause trouble for things like ncurses.
354 353 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
355 354 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
356 355 # processes like ssh will be killed properly, without affecting
357 356 # unrelated processes.
358 357 os.setpgid(0, 0)
359 358 # change random state otherwise forked request handlers would have a
360 359 # same state inherited from parent.
361 360 random.seed()
362 361
363 362 def _serverequest(ui, repo, conn, createcmdserver):
364 363 fin = conn.makefile('rb')
365 364 fout = conn.makefile('wb')
366 365 sv = None
367 366 try:
368 367 sv = createcmdserver(repo, conn, fin, fout)
369 368 try:
370 369 sv.serve()
371 370 # handle exceptions that may be raised by command server. most of
372 371 # known exceptions are caught by dispatch.
373 372 except error.Abort as inst:
374 373 ui.warn(_('abort: %s\n') % inst)
375 374 except IOError as inst:
376 375 if inst.errno != errno.EPIPE:
377 376 raise
378 377 except KeyboardInterrupt:
379 378 pass
380 379 finally:
381 380 sv.cleanup()
382 381 except: # re-raises
383 382 # also write traceback to error channel. otherwise client cannot
384 383 # see it because it is written to server's stderr by default.
385 384 if sv:
386 385 cerr = sv.cerr
387 386 else:
388 387 cerr = channeledoutput(fout, 'e')
389 388 traceback.print_exc(file=cerr)
390 389 raise
391 390 finally:
392 391 fin.close()
393 392 try:
394 393 fout.close() # implicit flush() may cause another EPIPE
395 394 except IOError as inst:
396 395 if inst.errno != errno.EPIPE:
397 396 raise
398 397
399 398 class unixservicehandler(object):
400 399 """Set of pluggable operations for unix-mode services
401 400
402 401 Almost all methods except for createcmdserver() are called in the main
403 402 process. You can't pass mutable resource back from createcmdserver().
404 403 """
405 404
406 405 pollinterval = None
407 406
408 407 def __init__(self, ui):
409 408 self.ui = ui
410 409
411 410 def bindsocket(self, sock, address):
412 411 util.bindunixsocket(sock, address)
413 412 sock.listen(socket.SOMAXCONN)
414 413 self.ui.status(_('listening at %s\n') % address)
415 414 self.ui.flush() # avoid buffering of status message
416 415
417 416 def unlinksocket(self, address):
418 417 os.unlink(address)
419 418
420 419 def shouldexit(self):
421 420 """True if server should shut down; checked per pollinterval"""
422 421 return False
423 422
424 423 def newconnection(self):
425 424 """Called when main process notices new connection"""
426 425 pass
427 426
428 427 def createcmdserver(self, repo, conn, fin, fout):
429 428 """Create new command server instance; called in the process that
430 429 serves for the current connection"""
431 430 return server(self.ui, repo, fin, fout)
432 431
433 432 class unixforkingservice(object):
434 433 """
435 434 Listens on unix domain socket and forks server per connection
436 435 """
437 436
438 437 def __init__(self, ui, repo, opts, handler=None):
439 438 self.ui = ui
440 439 self.repo = repo
441 440 self.address = opts['address']
442 441 if not util.safehasattr(socket, 'AF_UNIX'):
443 442 raise error.Abort(_('unsupported platform'))
444 443 if not self.address:
445 444 raise error.Abort(_('no socket path specified with --address'))
446 445 self._servicehandler = handler or unixservicehandler(ui)
447 446 self._sock = None
448 447 self._oldsigchldhandler = None
449 448 self._workerpids = set() # updated by signal handler; do not iterate
450 449 self._socketunlinked = None
451 450
452 451 def init(self):
453 452 self._sock = socket.socket(socket.AF_UNIX)
454 453 self._servicehandler.bindsocket(self._sock, self.address)
455 454 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
456 455 self._oldsigchldhandler = o
457 456 self._socketunlinked = False
458 457
459 458 def _unlinksocket(self):
460 459 if not self._socketunlinked:
461 460 self._servicehandler.unlinksocket(self.address)
462 461 self._socketunlinked = True
463 462
464 463 def _cleanup(self):
465 464 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
466 465 self._sock.close()
467 466 self._unlinksocket()
468 467 # don't kill child processes as they have active clients, just wait
469 468 self._reapworkers(0)
470 469
471 470 def run(self):
472 471 try:
473 472 self._mainloop()
474 473 finally:
475 474 self._cleanup()
476 475
477 476 def _mainloop(self):
478 477 exiting = False
479 478 h = self._servicehandler
480 479 selector = selectors2.DefaultSelector()
481 480 selector.register(self._sock, selectors2.EVENT_READ)
482 481 while True:
483 482 if not exiting and h.shouldexit():
484 483 # clients can no longer connect() to the domain socket, so
485 484 # we stop queuing new requests.
486 485 # for requests that are queued (connect()-ed, but haven't been
487 486 # accept()-ed), handle them before exit. otherwise, clients
488 487 # waiting for recv() will receive ECONNRESET.
489 488 self._unlinksocket()
490 489 exiting = True
490 ready = selector.select(timeout=h.pollinterval)
491 if not ready:
492 # only exit if we completed all queued requests
493 if exiting:
494 break
495 continue
491 496 try:
492 ready = selector.select(timeout=h.pollinterval)
493 if not ready:
494 # only exit if we completed all queued requests
495 if exiting:
496 break
497 continue
498 497 conn, _addr = self._sock.accept()
499 except (select.error, socket.error) as inst:
498 except socket.error as inst:
500 499 if inst.args[0] == errno.EINTR:
501 500 continue
502 501 raise
503 502
504 503 pid = os.fork()
505 504 if pid:
506 505 try:
507 506 self.ui.debug('forked worker process (pid=%d)\n' % pid)
508 507 self._workerpids.add(pid)
509 508 h.newconnection()
510 509 finally:
511 510 conn.close() # release handle in parent process
512 511 else:
513 512 try:
514 513 self._runworker(conn)
515 514 conn.close()
516 515 os._exit(0)
517 516 except: # never return, hence no re-raises
518 517 try:
519 518 self.ui.traceback(force=True)
520 519 finally:
521 520 os._exit(255)
522 521 selector.close()
523 522
524 523 def _sigchldhandler(self, signal, frame):
525 524 self._reapworkers(os.WNOHANG)
526 525
527 526 def _reapworkers(self, options):
528 527 while self._workerpids:
529 528 try:
530 529 pid, _status = os.waitpid(-1, options)
531 530 except OSError as inst:
532 531 if inst.errno == errno.EINTR:
533 532 continue
534 533 if inst.errno != errno.ECHILD:
535 534 raise
536 535 # no child processes at all (reaped by other waitpid()?)
537 536 self._workerpids.clear()
538 537 return
539 538 if pid == 0:
540 539 # no waitable child processes
541 540 return
542 541 self.ui.debug('worker process exited (pid=%d)\n' % pid)
543 542 self._workerpids.discard(pid)
544 543
545 544 def _runworker(self, conn):
546 545 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
547 546 _initworkerprocess()
548 547 h = self._servicehandler
549 548 try:
550 549 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
551 550 finally:
552 551 gc.collect() # trigger __del__ since worker process uses os._exit
General Comments 0
You need to be logged in to leave comments. Login now