##// END OF EJS Templates
commandserver: loop over selector events...
Yuya Nishihara -
r40914:2525faf4 default
parent child Browse files
Show More
@@ -1,631 +1,633
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import errno
11 11 import gc
12 12 import os
13 13 import random
14 14 import signal
15 15 import socket
16 16 import struct
17 17 import traceback
18 18
19 19 try:
20 20 import selectors
21 21 selectors.BaseSelector
22 22 except ImportError:
23 23 from .thirdparty import selectors2 as selectors
24 24
25 25 from .i18n import _
26 26 from . import (
27 27 encoding,
28 28 error,
29 29 loggingutil,
30 30 pycompat,
31 31 util,
32 32 vfs as vfsmod,
33 33 )
34 34 from .utils import (
35 35 cborutil,
36 36 procutil,
37 37 )
38 38
39 39 class channeledoutput(object):
40 40 """
41 41 Write data to out in the following format:
42 42
43 43 data length (unsigned int),
44 44 data
45 45 """
46 46 def __init__(self, out, channel):
47 47 self.out = out
48 48 self.channel = channel
49 49
50 50 @property
51 51 def name(self):
52 52 return '<%c-channel>' % self.channel
53 53
54 54 def write(self, data):
55 55 if not data:
56 56 return
57 57 # single write() to guarantee the same atomicity as the underlying file
58 58 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
59 59 self.out.flush()
60 60
61 61 def __getattr__(self, attr):
62 62 if attr in (r'isatty', r'fileno', r'tell', r'seek'):
63 63 raise AttributeError(attr)
64 64 return getattr(self.out, attr)
65 65
66 66 class channeledmessage(object):
67 67 """
68 68 Write encoded message and metadata to out in the following format:
69 69
70 70 data length (unsigned int),
71 71 encoded message and metadata, as a flat key-value dict.
72 72
73 73 Each message should have 'type' attribute. Messages of unknown type
74 74 should be ignored.
75 75 """
76 76
77 77 # teach ui that write() can take **opts
78 78 structured = True
79 79
80 80 def __init__(self, out, channel, encodename, encodefn):
81 81 self._cout = channeledoutput(out, channel)
82 82 self.encoding = encodename
83 83 self._encodefn = encodefn
84 84
85 85 def write(self, data, **opts):
86 86 opts = pycompat.byteskwargs(opts)
87 87 if data is not None:
88 88 opts[b'data'] = data
89 89 self._cout.write(self._encodefn(opts))
90 90
91 91 def __getattr__(self, attr):
92 92 return getattr(self._cout, attr)
93 93
94 94 class channeledinput(object):
95 95 """
96 96 Read data from in_.
97 97
98 98 Requests for input are written to out in the following format:
99 99 channel identifier - 'I' for plain input, 'L' line based (1 byte)
100 100 how many bytes to send at most (unsigned int),
101 101
102 102 The client replies with:
103 103 data length (unsigned int), 0 meaning EOF
104 104 data
105 105 """
106 106
107 107 maxchunksize = 4 * 1024
108 108
109 109 def __init__(self, in_, out, channel):
110 110 self.in_ = in_
111 111 self.out = out
112 112 self.channel = channel
113 113
114 114 @property
115 115 def name(self):
116 116 return '<%c-channel>' % self.channel
117 117
118 118 def read(self, size=-1):
119 119 if size < 0:
120 120 # if we need to consume all the clients input, ask for 4k chunks
121 121 # so the pipe doesn't fill up risking a deadlock
122 122 size = self.maxchunksize
123 123 s = self._read(size, self.channel)
124 124 buf = s
125 125 while s:
126 126 s = self._read(size, self.channel)
127 127 buf += s
128 128
129 129 return buf
130 130 else:
131 131 return self._read(size, self.channel)
132 132
133 133 def _read(self, size, channel):
134 134 if not size:
135 135 return ''
136 136 assert size > 0
137 137
138 138 # tell the client we need at most size bytes
139 139 self.out.write(struct.pack('>cI', channel, size))
140 140 self.out.flush()
141 141
142 142 length = self.in_.read(4)
143 143 length = struct.unpack('>I', length)[0]
144 144 if not length:
145 145 return ''
146 146 else:
147 147 return self.in_.read(length)
148 148
149 149 def readline(self, size=-1):
150 150 if size < 0:
151 151 size = self.maxchunksize
152 152 s = self._read(size, 'L')
153 153 buf = s
154 154 # keep asking for more until there's either no more or
155 155 # we got a full line
156 156 while s and s[-1] != '\n':
157 157 s = self._read(size, 'L')
158 158 buf += s
159 159
160 160 return buf
161 161 else:
162 162 return self._read(size, 'L')
163 163
164 164 def __iter__(self):
165 165 return self
166 166
167 167 def next(self):
168 168 l = self.readline()
169 169 if not l:
170 170 raise StopIteration
171 171 return l
172 172
173 173 __next__ = next
174 174
175 175 def __getattr__(self, attr):
176 176 if attr in (r'isatty', r'fileno', r'tell', r'seek'):
177 177 raise AttributeError(attr)
178 178 return getattr(self.in_, attr)
179 179
180 180 _messageencoders = {
181 181 b'cbor': lambda v: b''.join(cborutil.streamencode(v)),
182 182 }
183 183
184 184 def _selectmessageencoder(ui):
185 185 # experimental config: cmdserver.message-encodings
186 186 encnames = ui.configlist(b'cmdserver', b'message-encodings')
187 187 for n in encnames:
188 188 f = _messageencoders.get(n)
189 189 if f:
190 190 return n, f
191 191 raise error.Abort(b'no supported message encodings: %s'
192 192 % b' '.join(encnames))
193 193
194 194 class server(object):
195 195 """
196 196 Listens for commands on fin, runs them and writes the output on a channel
197 197 based stream to fout.
198 198 """
199 199 def __init__(self, ui, repo, fin, fout, prereposetups=None):
200 200 self.cwd = encoding.getcwd()
201 201
202 202 if repo:
203 203 # the ui here is really the repo ui so take its baseui so we don't
204 204 # end up with its local configuration
205 205 self.ui = repo.baseui
206 206 self.repo = repo
207 207 self.repoui = repo.ui
208 208 else:
209 209 self.ui = ui
210 210 self.repo = self.repoui = None
211 211 self._prereposetups = prereposetups
212 212
213 213 self.cdebug = channeledoutput(fout, 'd')
214 214 self.cerr = channeledoutput(fout, 'e')
215 215 self.cout = channeledoutput(fout, 'o')
216 216 self.cin = channeledinput(fin, fout, 'I')
217 217 self.cresult = channeledoutput(fout, 'r')
218 218
219 219 if self.ui.config(b'cmdserver', b'log') == b'-':
220 220 # switch log stream of server's ui to the 'd' (debug) channel
221 221 # (don't touch repo.ui as its lifetime is longer than the server)
222 222 self.ui = self.ui.copy()
223 223 setuplogging(self.ui, repo=None, fp=self.cdebug)
224 224
225 225 # TODO: add this to help/config.txt when stabilized
226 226 # ``channel``
227 227 # Use separate channel for structured output. (Command-server only)
228 228 self.cmsg = None
229 229 if ui.config(b'ui', b'message-output') == b'channel':
230 230 encname, encfn = _selectmessageencoder(ui)
231 231 self.cmsg = channeledmessage(fout, b'm', encname, encfn)
232 232
233 233 self.client = fin
234 234
235 235 def cleanup(self):
236 236 """release and restore resources taken during server session"""
237 237
238 238 def _read(self, size):
239 239 if not size:
240 240 return ''
241 241
242 242 data = self.client.read(size)
243 243
244 244 # is the other end closed?
245 245 if not data:
246 246 raise EOFError
247 247
248 248 return data
249 249
250 250 def _readstr(self):
251 251 """read a string from the channel
252 252
253 253 format:
254 254 data length (uint32), data
255 255 """
256 256 length = struct.unpack('>I', self._read(4))[0]
257 257 if not length:
258 258 return ''
259 259 return self._read(length)
260 260
261 261 def _readlist(self):
262 262 """read a list of NULL separated strings from the channel"""
263 263 s = self._readstr()
264 264 if s:
265 265 return s.split('\0')
266 266 else:
267 267 return []
268 268
269 269 def runcommand(self):
270 270 """ reads a list of \0 terminated arguments, executes
271 271 and writes the return code to the result channel """
272 272 from . import dispatch # avoid cycle
273 273
274 274 args = self._readlist()
275 275
276 276 # copy the uis so changes (e.g. --config or --verbose) don't
277 277 # persist between requests
278 278 copiedui = self.ui.copy()
279 279 uis = [copiedui]
280 280 if self.repo:
281 281 self.repo.baseui = copiedui
282 282 # clone ui without using ui.copy because this is protected
283 283 repoui = self.repoui.__class__(self.repoui)
284 284 repoui.copy = copiedui.copy # redo copy protection
285 285 uis.append(repoui)
286 286 self.repo.ui = self.repo.dirstate._ui = repoui
287 287 self.repo.invalidateall()
288 288
289 289 for ui in uis:
290 290 ui.resetstate()
291 291 # any kind of interaction must use server channels, but chg may
292 292 # replace channels by fully functional tty files. so nontty is
293 293 # enforced only if cin is a channel.
294 294 if not util.safehasattr(self.cin, 'fileno'):
295 295 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
296 296
297 297 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
298 298 self.cout, self.cerr, self.cmsg,
299 299 prereposetups=self._prereposetups)
300 300
301 301 try:
302 302 ret = dispatch.dispatch(req) & 255
303 303 self.cresult.write(struct.pack('>i', int(ret)))
304 304 finally:
305 305 # restore old cwd
306 306 if '--cwd' in args:
307 307 os.chdir(self.cwd)
308 308
309 309 def getencoding(self):
310 310 """ writes the current encoding to the result channel """
311 311 self.cresult.write(encoding.encoding)
312 312
313 313 def serveone(self):
314 314 cmd = self.client.readline()[:-1]
315 315 if cmd:
316 316 handler = self.capabilities.get(cmd)
317 317 if handler:
318 318 handler(self)
319 319 else:
320 320 # clients are expected to check what commands are supported by
321 321 # looking at the servers capabilities
322 322 raise error.Abort(_('unknown command %s') % cmd)
323 323
324 324 return cmd != ''
325 325
326 326 capabilities = {'runcommand': runcommand,
327 327 'getencoding': getencoding}
328 328
329 329 def serve(self):
330 330 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
331 331 hellomsg += '\n'
332 332 hellomsg += 'encoding: ' + encoding.encoding
333 333 hellomsg += '\n'
334 334 if self.cmsg:
335 335 hellomsg += 'message-encoding: %s\n' % self.cmsg.encoding
336 336 hellomsg += 'pid: %d' % procutil.getpid()
337 337 if util.safehasattr(os, 'getpgid'):
338 338 hellomsg += '\n'
339 339 hellomsg += 'pgid: %d' % os.getpgid(0)
340 340
341 341 # write the hello msg in -one- chunk
342 342 self.cout.write(hellomsg)
343 343
344 344 try:
345 345 while self.serveone():
346 346 pass
347 347 except EOFError:
348 348 # we'll get here if the client disconnected while we were reading
349 349 # its request
350 350 return 1
351 351
352 352 return 0
353 353
354 354 def setuplogging(ui, repo=None, fp=None):
355 355 """Set up server logging facility
356 356
357 357 If cmdserver.log is '-', log messages will be sent to the given fp.
358 358 It should be the 'd' channel while a client is connected, and otherwise
359 359 is the stderr of the server process.
360 360 """
361 361 # developer config: cmdserver.log
362 362 logpath = ui.config(b'cmdserver', b'log')
363 363 if not logpath:
364 364 return
365 365 # developer config: cmdserver.track-log
366 366 tracked = set(ui.configlist(b'cmdserver', b'track-log'))
367 367
368 368 if logpath == b'-' and fp:
369 369 logger = loggingutil.fileobjectlogger(fp, tracked)
370 370 elif logpath == b'-':
371 371 logger = loggingutil.fileobjectlogger(ui.ferr, tracked)
372 372 else:
373 373 logpath = os.path.abspath(util.expandpath(logpath))
374 374 # developer config: cmdserver.max-log-files
375 375 maxfiles = ui.configint(b'cmdserver', b'max-log-files')
376 376 # developer config: cmdserver.max-log-size
377 377 maxsize = ui.configbytes(b'cmdserver', b'max-log-size')
378 378 vfs = vfsmod.vfs(os.path.dirname(logpath))
379 379 logger = loggingutil.filelogger(vfs, os.path.basename(logpath), tracked,
380 380 maxfiles=maxfiles, maxsize=maxsize)
381 381
382 382 targetuis = {ui}
383 383 if repo:
384 384 targetuis.add(repo.baseui)
385 385 targetuis.add(repo.ui)
386 386 for u in targetuis:
387 387 u.setlogger(b'cmdserver', logger)
388 388
389 389 class pipeservice(object):
390 390 def __init__(self, ui, repo, opts):
391 391 self.ui = ui
392 392 self.repo = repo
393 393
394 394 def init(self):
395 395 pass
396 396
397 397 def run(self):
398 398 ui = self.ui
399 399 # redirect stdio to null device so that broken extensions or in-process
400 400 # hooks will never cause corruption of channel protocol.
401 401 with procutil.protectedstdio(ui.fin, ui.fout) as (fin, fout):
402 402 sv = server(ui, self.repo, fin, fout)
403 403 try:
404 404 return sv.serve()
405 405 finally:
406 406 sv.cleanup()
407 407
408 408 def _initworkerprocess():
409 409 # use a different process group from the master process, in order to:
410 410 # 1. make the current process group no longer "orphaned" (because the
411 411 # parent of this process is in a different process group while
412 412 # remains in a same session)
413 413 # according to POSIX 2.2.2.52, orphaned process group will ignore
414 414 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
415 415 # cause trouble for things like ncurses.
416 416 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
417 417 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
418 418 # processes like ssh will be killed properly, without affecting
419 419 # unrelated processes.
420 420 os.setpgid(0, 0)
421 421 # change random state otherwise forked request handlers would have a
422 422 # same state inherited from parent.
423 423 random.seed()
424 424
425 425 def _serverequest(ui, repo, conn, createcmdserver, prereposetups):
426 426 fin = conn.makefile(r'rb')
427 427 fout = conn.makefile(r'wb')
428 428 sv = None
429 429 try:
430 430 sv = createcmdserver(repo, conn, fin, fout, prereposetups)
431 431 try:
432 432 sv.serve()
433 433 # handle exceptions that may be raised by command server. most of
434 434 # known exceptions are caught by dispatch.
435 435 except error.Abort as inst:
436 436 ui.error(_('abort: %s\n') % inst)
437 437 except IOError as inst:
438 438 if inst.errno != errno.EPIPE:
439 439 raise
440 440 except KeyboardInterrupt:
441 441 pass
442 442 finally:
443 443 sv.cleanup()
444 444 except: # re-raises
445 445 # also write traceback to error channel. otherwise client cannot
446 446 # see it because it is written to server's stderr by default.
447 447 if sv:
448 448 cerr = sv.cerr
449 449 else:
450 450 cerr = channeledoutput(fout, 'e')
451 451 cerr.write(encoding.strtolocal(traceback.format_exc()))
452 452 raise
453 453 finally:
454 454 fin.close()
455 455 try:
456 456 fout.close() # implicit flush() may cause another EPIPE
457 457 except IOError as inst:
458 458 if inst.errno != errno.EPIPE:
459 459 raise
460 460
461 461 class unixservicehandler(object):
462 462 """Set of pluggable operations for unix-mode services
463 463
464 464 Almost all methods except for createcmdserver() are called in the main
465 465 process. You can't pass mutable resource back from createcmdserver().
466 466 """
467 467
468 468 pollinterval = None
469 469
470 470 def __init__(self, ui):
471 471 self.ui = ui
472 472
473 473 def bindsocket(self, sock, address):
474 474 util.bindunixsocket(sock, address)
475 475 sock.listen(socket.SOMAXCONN)
476 476 self.ui.status(_('listening at %s\n') % address)
477 477 self.ui.flush() # avoid buffering of status message
478 478
479 479 def unlinksocket(self, address):
480 480 os.unlink(address)
481 481
482 482 def shouldexit(self):
483 483 """True if server should shut down; checked per pollinterval"""
484 484 return False
485 485
486 486 def newconnection(self):
487 487 """Called when main process notices new connection"""
488 488
489 489 def createcmdserver(self, repo, conn, fin, fout, prereposetups):
490 490 """Create new command server instance; called in the process that
491 491 serves for the current connection"""
492 492 return server(self.ui, repo, fin, fout, prereposetups)
493 493
494 494 class unixforkingservice(object):
495 495 """
496 496 Listens on unix domain socket and forks server per connection
497 497 """
498 498
499 499 def __init__(self, ui, repo, opts, handler=None):
500 500 self.ui = ui
501 501 self.repo = repo
502 502 self.address = opts['address']
503 503 if not util.safehasattr(socket, 'AF_UNIX'):
504 504 raise error.Abort(_('unsupported platform'))
505 505 if not self.address:
506 506 raise error.Abort(_('no socket path specified with --address'))
507 507 self._servicehandler = handler or unixservicehandler(ui)
508 508 self._sock = None
509 509 self._oldsigchldhandler = None
510 510 self._workerpids = set() # updated by signal handler; do not iterate
511 511 self._socketunlinked = None
512 512
513 513 def init(self):
514 514 self._sock = socket.socket(socket.AF_UNIX)
515 515 self._servicehandler.bindsocket(self._sock, self.address)
516 516 if util.safehasattr(procutil, 'unblocksignal'):
517 517 procutil.unblocksignal(signal.SIGCHLD)
518 518 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
519 519 self._oldsigchldhandler = o
520 520 self._socketunlinked = False
521 521
522 522 def _unlinksocket(self):
523 523 if not self._socketunlinked:
524 524 self._servicehandler.unlinksocket(self.address)
525 525 self._socketunlinked = True
526 526
527 527 def _cleanup(self):
528 528 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
529 529 self._sock.close()
530 530 self._unlinksocket()
531 531 # don't kill child processes as they have active clients, just wait
532 532 self._reapworkers(0)
533 533
534 534 def run(self):
535 535 try:
536 536 self._mainloop()
537 537 finally:
538 538 self._cleanup()
539 539
540 540 def _mainloop(self):
541 541 exiting = False
542 542 h = self._servicehandler
543 543 selector = selectors.DefaultSelector()
544 selector.register(self._sock, selectors.EVENT_READ)
544 selector.register(self._sock, selectors.EVENT_READ,
545 self._acceptnewconnection)
545 546 while True:
546 547 if not exiting and h.shouldexit():
547 548 # clients can no longer connect() to the domain socket, so
548 549 # we stop queuing new requests.
549 550 # for requests that are queued (connect()-ed, but haven't been
550 551 # accept()-ed), handle them before exit. otherwise, clients
551 552 # waiting for recv() will receive ECONNRESET.
552 553 self._unlinksocket()
553 554 exiting = True
554 555 try:
555 ready = selector.select(timeout=h.pollinterval)
556 events = selector.select(timeout=h.pollinterval)
556 557 except OSError as inst:
557 558 # selectors2 raises ETIMEDOUT if timeout exceeded while
558 559 # handling signal interrupt. That's probably wrong, but
559 560 # we can easily get around it.
560 561 if inst.errno != errno.ETIMEDOUT:
561 562 raise
562 ready = []
563 if not ready:
563 events = []
564 if not events:
564 565 # only exit if we completed all queued requests
565 566 if exiting:
566 567 break
567 568 continue
568 self._acceptnewconnection(self._sock, selector)
569 for key, _mask in events:
570 key.data(key.fileobj, selector)
569 571 selector.close()
570 572
571 573 def _acceptnewconnection(self, sock, selector):
572 574 h = self._servicehandler
573 575 try:
574 576 conn, _addr = sock.accept()
575 577 except socket.error as inst:
576 578 if inst.args[0] == errno.EINTR:
577 579 return
578 580 raise
579 581
580 582 pid = os.fork()
581 583 if pid:
582 584 try:
583 585 self.ui.log(b'cmdserver', b'forked worker process (pid=%d)\n',
584 586 pid)
585 587 self._workerpids.add(pid)
586 588 h.newconnection()
587 589 finally:
588 590 conn.close() # release handle in parent process
589 591 else:
590 592 try:
591 593 selector.close()
592 594 sock.close()
593 595 self._runworker(conn)
594 596 conn.close()
595 597 os._exit(0)
596 598 except: # never return, hence no re-raises
597 599 try:
598 600 self.ui.traceback(force=True)
599 601 finally:
600 602 os._exit(255)
601 603
602 604 def _sigchldhandler(self, signal, frame):
603 605 self._reapworkers(os.WNOHANG)
604 606
605 607 def _reapworkers(self, options):
606 608 while self._workerpids:
607 609 try:
608 610 pid, _status = os.waitpid(-1, options)
609 611 except OSError as inst:
610 612 if inst.errno == errno.EINTR:
611 613 continue
612 614 if inst.errno != errno.ECHILD:
613 615 raise
614 616 # no child processes at all (reaped by other waitpid()?)
615 617 self._workerpids.clear()
616 618 return
617 619 if pid == 0:
618 620 # no waitable child processes
619 621 return
620 622 self.ui.log(b'cmdserver', b'worker process exited (pid=%d)\n', pid)
621 623 self._workerpids.discard(pid)
622 624
623 625 def _runworker(self, conn):
624 626 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
625 627 _initworkerprocess()
626 628 h = self._servicehandler
627 629 try:
628 630 _serverequest(self.ui, self.repo, conn, h.createcmdserver,
629 631 prereposetups=None) # TODO: pass in hook functions
630 632 finally:
631 633 gc.collect() # trigger __del__ since worker process uses os._exit
General Comments 0
You need to be logged in to leave comments. Login now