##// END OF EJS Templates
thirdparty: remove Python 2-specific selectors2 copy...
Manuel Jacob -
r50175:311fcc5a default
parent child Browse files
Show More
@@ -1,770 +1,756 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Olivia Mackall <olivia@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8
9 9 import errno
10 10 import gc
11 11 import os
12 12 import random
13 import selectors
13 14 import signal
14 15 import socket
15 16 import struct
16 17 import traceback
17 18
18 try:
19 import selectors
20
21 selectors.BaseSelector
22 except ImportError:
23 from .thirdparty import selectors2 as selectors
24
25 19 from .i18n import _
26 20 from .pycompat import getattr
27 21 from . import (
28 22 encoding,
29 23 error,
30 24 loggingutil,
31 25 pycompat,
32 26 repocache,
33 27 util,
34 28 vfs as vfsmod,
35 29 )
36 30 from .utils import (
37 31 cborutil,
38 32 procutil,
39 33 )
40 34
41 35
42 36 class channeledoutput:
43 37 """
44 38 Write data to out in the following format:
45 39
46 40 data length (unsigned int),
47 41 data
48 42 """
49 43
50 44 def __init__(self, out, channel):
51 45 self.out = out
52 46 self.channel = channel
53 47
54 48 @property
55 49 def name(self):
56 50 return b'<%c-channel>' % self.channel
57 51
58 52 def write(self, data):
59 53 if not data:
60 54 return
61 55 # single write() to guarantee the same atomicity as the underlying file
62 56 self.out.write(struct.pack(b'>cI', self.channel, len(data)) + data)
63 57 self.out.flush()
64 58
65 59 def __getattr__(self, attr):
66 60 if attr in ('isatty', 'fileno', 'tell', 'seek'):
67 61 raise AttributeError(attr)
68 62 return getattr(self.out, attr)
69 63
70 64
71 65 class channeledmessage:
72 66 """
73 67 Write encoded message and metadata to out in the following format:
74 68
75 69 data length (unsigned int),
76 70 encoded message and metadata, as a flat key-value dict.
77 71
78 72 Each message should have 'type' attribute. Messages of unknown type
79 73 should be ignored.
80 74 """
81 75
82 76 # teach ui that write() can take **opts
83 77 structured = True
84 78
85 79 def __init__(self, out, channel, encodename, encodefn):
86 80 self._cout = channeledoutput(out, channel)
87 81 self.encoding = encodename
88 82 self._encodefn = encodefn
89 83
90 84 def write(self, data, **opts):
91 85 opts = pycompat.byteskwargs(opts)
92 86 if data is not None:
93 87 opts[b'data'] = data
94 88 self._cout.write(self._encodefn(opts))
95 89
96 90 def __getattr__(self, attr):
97 91 return getattr(self._cout, attr)
98 92
99 93
100 94 class channeledinput:
101 95 """
102 96 Read data from in_.
103 97
104 98 Requests for input are written to out in the following format:
105 99 channel identifier - 'I' for plain input, 'L' line based (1 byte)
106 100 how many bytes to send at most (unsigned int),
107 101
108 102 The client replies with:
109 103 data length (unsigned int), 0 meaning EOF
110 104 data
111 105 """
112 106
113 107 maxchunksize = 4 * 1024
114 108
115 109 def __init__(self, in_, out, channel):
116 110 self.in_ = in_
117 111 self.out = out
118 112 self.channel = channel
119 113
120 114 @property
121 115 def name(self):
122 116 return b'<%c-channel>' % self.channel
123 117
124 118 def read(self, size=-1):
125 119 if size < 0:
126 120 # if we need to consume all the clients input, ask for 4k chunks
127 121 # so the pipe doesn't fill up risking a deadlock
128 122 size = self.maxchunksize
129 123 s = self._read(size, self.channel)
130 124 buf = s
131 125 while s:
132 126 s = self._read(size, self.channel)
133 127 buf += s
134 128
135 129 return buf
136 130 else:
137 131 return self._read(size, self.channel)
138 132
139 133 def _read(self, size, channel):
140 134 if not size:
141 135 return b''
142 136 assert size > 0
143 137
144 138 # tell the client we need at most size bytes
145 139 self.out.write(struct.pack(b'>cI', channel, size))
146 140 self.out.flush()
147 141
148 142 length = self.in_.read(4)
149 143 length = struct.unpack(b'>I', length)[0]
150 144 if not length:
151 145 return b''
152 146 else:
153 147 return self.in_.read(length)
154 148
155 149 def readline(self, size=-1):
156 150 if size < 0:
157 151 size = self.maxchunksize
158 152 s = self._read(size, b'L')
159 153 buf = s
160 154 # keep asking for more until there's either no more or
161 155 # we got a full line
162 156 while s and not s.endswith(b'\n'):
163 157 s = self._read(size, b'L')
164 158 buf += s
165 159
166 160 return buf
167 161 else:
168 162 return self._read(size, b'L')
169 163
170 164 def __iter__(self):
171 165 return self
172 166
173 167 def next(self):
174 168 l = self.readline()
175 169 if not l:
176 170 raise StopIteration
177 171 return l
178 172
179 173 __next__ = next
180 174
181 175 def __getattr__(self, attr):
182 176 if attr in ('isatty', 'fileno', 'tell', 'seek'):
183 177 raise AttributeError(attr)
184 178 return getattr(self.in_, attr)
185 179
186 180
187 181 _messageencoders = {
188 182 b'cbor': lambda v: b''.join(cborutil.streamencode(v)),
189 183 }
190 184
191 185
192 186 def _selectmessageencoder(ui):
193 187 encnames = ui.configlist(b'cmdserver', b'message-encodings')
194 188 for n in encnames:
195 189 f = _messageencoders.get(n)
196 190 if f:
197 191 return n, f
198 192 raise error.Abort(
199 193 b'no supported message encodings: %s' % b' '.join(encnames)
200 194 )
201 195
202 196
203 197 class server:
204 198 """
205 199 Listens for commands on fin, runs them and writes the output on a channel
206 200 based stream to fout.
207 201 """
208 202
209 203 def __init__(self, ui, repo, fin, fout, prereposetups=None):
210 204 self.cwd = encoding.getcwd()
211 205
212 206 if repo:
213 207 # the ui here is really the repo ui so take its baseui so we don't
214 208 # end up with its local configuration
215 209 self.ui = repo.baseui
216 210 self.repo = repo
217 211 self.repoui = repo.ui
218 212 else:
219 213 self.ui = ui
220 214 self.repo = self.repoui = None
221 215 self._prereposetups = prereposetups
222 216
223 217 self.cdebug = channeledoutput(fout, b'd')
224 218 self.cerr = channeledoutput(fout, b'e')
225 219 self.cout = channeledoutput(fout, b'o')
226 220 self.cin = channeledinput(fin, fout, b'I')
227 221 self.cresult = channeledoutput(fout, b'r')
228 222
229 223 if self.ui.config(b'cmdserver', b'log') == b'-':
230 224 # switch log stream of server's ui to the 'd' (debug) channel
231 225 # (don't touch repo.ui as its lifetime is longer than the server)
232 226 self.ui = self.ui.copy()
233 227 setuplogging(self.ui, repo=None, fp=self.cdebug)
234 228
235 229 self.cmsg = None
236 230 if ui.config(b'ui', b'message-output') == b'channel':
237 231 encname, encfn = _selectmessageencoder(ui)
238 232 self.cmsg = channeledmessage(fout, b'm', encname, encfn)
239 233
240 234 self.client = fin
241 235
242 236 # If shutdown-on-interrupt is off, the default SIGINT handler is
243 237 # removed so that client-server communication wouldn't be interrupted.
244 238 # For example, 'runcommand' handler will issue three short read()s.
245 239 # If one of the first two read()s were interrupted, the communication
246 240 # channel would be left at dirty state and the subsequent request
247 241 # wouldn't be parsed. So catching KeyboardInterrupt isn't enough.
248 242 self._shutdown_on_interrupt = ui.configbool(
249 243 b'cmdserver', b'shutdown-on-interrupt'
250 244 )
251 245 self._old_inthandler = None
252 246 if not self._shutdown_on_interrupt:
253 247 self._old_inthandler = signal.signal(signal.SIGINT, signal.SIG_IGN)
254 248
255 249 def cleanup(self):
256 250 """release and restore resources taken during server session"""
257 251 if not self._shutdown_on_interrupt:
258 252 signal.signal(signal.SIGINT, self._old_inthandler)
259 253
260 254 def _read(self, size):
261 255 if not size:
262 256 return b''
263 257
264 258 data = self.client.read(size)
265 259
266 260 # is the other end closed?
267 261 if not data:
268 262 raise EOFError
269 263
270 264 return data
271 265
272 266 def _readstr(self):
273 267 """read a string from the channel
274 268
275 269 format:
276 270 data length (uint32), data
277 271 """
278 272 length = struct.unpack(b'>I', self._read(4))[0]
279 273 if not length:
280 274 return b''
281 275 return self._read(length)
282 276
283 277 def _readlist(self):
284 278 """read a list of NULL separated strings from the channel"""
285 279 s = self._readstr()
286 280 if s:
287 281 return s.split(b'\0')
288 282 else:
289 283 return []
290 284
291 285 def _dispatchcommand(self, req):
292 286 from . import dispatch # avoid cycle
293 287
294 288 if self._shutdown_on_interrupt:
295 289 # no need to restore SIGINT handler as it is unmodified.
296 290 return dispatch.dispatch(req)
297 291
298 292 try:
299 293 signal.signal(signal.SIGINT, self._old_inthandler)
300 294 return dispatch.dispatch(req)
301 295 except error.SignalInterrupt:
302 296 # propagate SIGBREAK, SIGHUP, or SIGTERM.
303 297 raise
304 298 except KeyboardInterrupt:
305 299 # SIGINT may be received out of the try-except block of dispatch(),
306 300 # so catch it as last ditch. Another KeyboardInterrupt may be
307 301 # raised while handling exceptions here, but there's no way to
308 302 # avoid that except for doing everything in C.
309 303 pass
310 304 finally:
311 305 signal.signal(signal.SIGINT, signal.SIG_IGN)
312 306 # On KeyboardInterrupt, print error message and exit *after* SIGINT
313 307 # handler removed.
314 308 req.ui.error(_(b'interrupted!\n'))
315 309 return -1
316 310
317 311 def runcommand(self):
318 312 """reads a list of \0 terminated arguments, executes
319 313 and writes the return code to the result channel"""
320 314 from . import dispatch # avoid cycle
321 315
322 316 args = self._readlist()
323 317
324 318 # copy the uis so changes (e.g. --config or --verbose) don't
325 319 # persist between requests
326 320 copiedui = self.ui.copy()
327 321 uis = [copiedui]
328 322 if self.repo:
329 323 self.repo.baseui = copiedui
330 324 # clone ui without using ui.copy because this is protected
331 325 repoui = self.repoui.__class__(self.repoui)
332 326 repoui.copy = copiedui.copy # redo copy protection
333 327 uis.append(repoui)
334 328 self.repo.ui = self.repo.dirstate._ui = repoui
335 329 self.repo.invalidateall()
336 330
337 331 for ui in uis:
338 332 ui.resetstate()
339 333 # any kind of interaction must use server channels, but chg may
340 334 # replace channels by fully functional tty files. so nontty is
341 335 # enforced only if cin is a channel.
342 336 if not util.safehasattr(self.cin, b'fileno'):
343 337 ui.setconfig(b'ui', b'nontty', b'true', b'commandserver')
344 338
345 339 req = dispatch.request(
346 340 args[:],
347 341 copiedui,
348 342 self.repo,
349 343 self.cin,
350 344 self.cout,
351 345 self.cerr,
352 346 self.cmsg,
353 347 prereposetups=self._prereposetups,
354 348 )
355 349
356 350 try:
357 351 ret = self._dispatchcommand(req) & 255
358 352 # If shutdown-on-interrupt is off, it's important to write the
359 353 # result code *after* SIGINT handler removed. If the result code
360 354 # were lost, the client wouldn't be able to continue processing.
361 355 self.cresult.write(struct.pack(b'>i', int(ret)))
362 356 finally:
363 357 # restore old cwd
364 358 if b'--cwd' in args:
365 359 os.chdir(self.cwd)
366 360
367 361 def getencoding(self):
368 362 """writes the current encoding to the result channel"""
369 363 self.cresult.write(encoding.encoding)
370 364
371 365 def serveone(self):
372 366 cmd = self.client.readline()[:-1]
373 367 if cmd:
374 368 handler = self.capabilities.get(cmd)
375 369 if handler:
376 370 handler(self)
377 371 else:
378 372 # clients are expected to check what commands are supported by
379 373 # looking at the servers capabilities
380 374 raise error.Abort(_(b'unknown command %s') % cmd)
381 375
382 376 return cmd != b''
383 377
384 378 capabilities = {b'runcommand': runcommand, b'getencoding': getencoding}
385 379
386 380 def serve(self):
387 381 hellomsg = b'capabilities: ' + b' '.join(sorted(self.capabilities))
388 382 hellomsg += b'\n'
389 383 hellomsg += b'encoding: ' + encoding.encoding
390 384 hellomsg += b'\n'
391 385 if self.cmsg:
392 386 hellomsg += b'message-encoding: %s\n' % self.cmsg.encoding
393 387 hellomsg += b'pid: %d' % procutil.getpid()
394 388 if util.safehasattr(os, b'getpgid'):
395 389 hellomsg += b'\n'
396 390 hellomsg += b'pgid: %d' % os.getpgid(0)
397 391
398 392 # write the hello msg in -one- chunk
399 393 self.cout.write(hellomsg)
400 394
401 395 try:
402 396 while self.serveone():
403 397 pass
404 398 except EOFError:
405 399 # we'll get here if the client disconnected while we were reading
406 400 # its request
407 401 return 1
408 402
409 403 return 0
410 404
411 405
412 406 def setuplogging(ui, repo=None, fp=None):
413 407 """Set up server logging facility
414 408
415 409 If cmdserver.log is '-', log messages will be sent to the given fp.
416 410 It should be the 'd' channel while a client is connected, and otherwise
417 411 is the stderr of the server process.
418 412 """
419 413 # developer config: cmdserver.log
420 414 logpath = ui.config(b'cmdserver', b'log')
421 415 if not logpath:
422 416 return
423 417 # developer config: cmdserver.track-log
424 418 tracked = set(ui.configlist(b'cmdserver', b'track-log'))
425 419
426 420 if logpath == b'-' and fp:
427 421 logger = loggingutil.fileobjectlogger(fp, tracked)
428 422 elif logpath == b'-':
429 423 logger = loggingutil.fileobjectlogger(ui.ferr, tracked)
430 424 else:
431 425 logpath = util.abspath(util.expandpath(logpath))
432 426 # developer config: cmdserver.max-log-files
433 427 maxfiles = ui.configint(b'cmdserver', b'max-log-files')
434 428 # developer config: cmdserver.max-log-size
435 429 maxsize = ui.configbytes(b'cmdserver', b'max-log-size')
436 430 vfs = vfsmod.vfs(os.path.dirname(logpath))
437 431 logger = loggingutil.filelogger(
438 432 vfs,
439 433 os.path.basename(logpath),
440 434 tracked,
441 435 maxfiles=maxfiles,
442 436 maxsize=maxsize,
443 437 )
444 438
445 439 targetuis = {ui}
446 440 if repo:
447 441 targetuis.add(repo.baseui)
448 442 targetuis.add(repo.ui)
449 443 for u in targetuis:
450 444 u.setlogger(b'cmdserver', logger)
451 445
452 446
453 447 class pipeservice:
454 448 def __init__(self, ui, repo, opts):
455 449 self.ui = ui
456 450 self.repo = repo
457 451
458 452 def init(self):
459 453 pass
460 454
461 455 def run(self):
462 456 ui = self.ui
463 457 # redirect stdio to null device so that broken extensions or in-process
464 458 # hooks will never cause corruption of channel protocol.
465 459 with ui.protectedfinout() as (fin, fout):
466 460 sv = server(ui, self.repo, fin, fout)
467 461 try:
468 462 return sv.serve()
469 463 finally:
470 464 sv.cleanup()
471 465
472 466
473 467 def _initworkerprocess():
474 468 # use a different process group from the master process, in order to:
475 469 # 1. make the current process group no longer "orphaned" (because the
476 470 # parent of this process is in a different process group while
477 471 # remains in a same session)
478 472 # according to POSIX 2.2.2.52, orphaned process group will ignore
479 473 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
480 474 # cause trouble for things like ncurses.
481 475 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
482 476 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
483 477 # processes like ssh will be killed properly, without affecting
484 478 # unrelated processes.
485 479 os.setpgid(0, 0)
486 480 # change random state otherwise forked request handlers would have a
487 481 # same state inherited from parent.
488 482 random.seed()
489 483
490 484
491 485 def _serverequest(ui, repo, conn, createcmdserver, prereposetups):
492 486 fin = conn.makefile('rb')
493 487 fout = conn.makefile('wb')
494 488 sv = None
495 489 try:
496 490 sv = createcmdserver(repo, conn, fin, fout, prereposetups)
497 491 try:
498 492 sv.serve()
499 493 # handle exceptions that may be raised by command server. most of
500 494 # known exceptions are caught by dispatch.
501 495 except error.Abort as inst:
502 496 ui.error(_(b'abort: %s\n') % inst.message)
503 497 except IOError as inst:
504 498 if inst.errno != errno.EPIPE:
505 499 raise
506 500 except KeyboardInterrupt:
507 501 pass
508 502 finally:
509 503 sv.cleanup()
510 504 except: # re-raises
511 505 # also write traceback to error channel. otherwise client cannot
512 506 # see it because it is written to server's stderr by default.
513 507 if sv:
514 508 cerr = sv.cerr
515 509 else:
516 510 cerr = channeledoutput(fout, b'e')
517 511 cerr.write(encoding.strtolocal(traceback.format_exc()))
518 512 raise
519 513 finally:
520 514 fin.close()
521 515 try:
522 516 fout.close() # implicit flush() may cause another EPIPE
523 517 except IOError as inst:
524 518 if inst.errno != errno.EPIPE:
525 519 raise
526 520
527 521
528 522 class unixservicehandler:
529 523 """Set of pluggable operations for unix-mode services
530 524
531 525 Almost all methods except for createcmdserver() are called in the main
532 526 process. You can't pass mutable resource back from createcmdserver().
533 527 """
534 528
535 529 pollinterval = None
536 530
537 531 def __init__(self, ui):
538 532 self.ui = ui
539 533
540 534 def bindsocket(self, sock, address):
541 535 util.bindunixsocket(sock, address)
542 536 sock.listen(socket.SOMAXCONN)
543 537 self.ui.status(_(b'listening at %s\n') % address)
544 538 self.ui.flush() # avoid buffering of status message
545 539
546 540 def unlinksocket(self, address):
547 541 os.unlink(address)
548 542
549 543 def shouldexit(self):
550 544 """True if server should shut down; checked per pollinterval"""
551 545 return False
552 546
553 547 def newconnection(self):
554 548 """Called when main process notices new connection"""
555 549
556 550 def createcmdserver(self, repo, conn, fin, fout, prereposetups):
557 551 """Create new command server instance; called in the process that
558 552 serves for the current connection"""
559 553 return server(self.ui, repo, fin, fout, prereposetups)
560 554
561 555
562 556 class unixforkingservice:
563 557 """
564 558 Listens on unix domain socket and forks server per connection
565 559 """
566 560
567 561 def __init__(self, ui, repo, opts, handler=None):
568 562 self.ui = ui
569 563 self.repo = repo
570 564 self.address = opts[b'address']
571 565 if not util.safehasattr(socket, b'AF_UNIX'):
572 566 raise error.Abort(_(b'unsupported platform'))
573 567 if not self.address:
574 568 raise error.Abort(_(b'no socket path specified with --address'))
575 569 self._servicehandler = handler or unixservicehandler(ui)
576 570 self._sock = None
577 571 self._mainipc = None
578 572 self._workeripc = None
579 573 self._oldsigchldhandler = None
580 574 self._workerpids = set() # updated by signal handler; do not iterate
581 575 self._socketunlinked = None
582 576 # experimental config: cmdserver.max-repo-cache
583 577 maxlen = ui.configint(b'cmdserver', b'max-repo-cache')
584 578 if maxlen < 0:
585 579 raise error.Abort(_(b'negative max-repo-cache size not allowed'))
586 580 self._repoloader = repocache.repoloader(ui, maxlen)
587 581 # attempt to avoid crash in CoreFoundation when using chg after fix in
588 582 # a89381e04c58
589 583 if pycompat.isdarwin:
590 584 procutil.gui()
591 585
592 586 def init(self):
593 587 self._sock = socket.socket(socket.AF_UNIX)
594 588 # IPC channel from many workers to one main process; this is actually
595 589 # a uni-directional pipe, but is backed by a DGRAM socket so each
596 590 # message can be easily separated.
597 591 o = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
598 592 self._mainipc, self._workeripc = o
599 593 self._servicehandler.bindsocket(self._sock, self.address)
600 594 if util.safehasattr(procutil, b'unblocksignal'):
601 595 procutil.unblocksignal(signal.SIGCHLD)
602 596 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
603 597 self._oldsigchldhandler = o
604 598 self._socketunlinked = False
605 599 self._repoloader.start()
606 600
607 601 def _unlinksocket(self):
608 602 if not self._socketunlinked:
609 603 self._servicehandler.unlinksocket(self.address)
610 604 self._socketunlinked = True
611 605
612 606 def _cleanup(self):
613 607 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
614 608 self._sock.close()
615 609 self._mainipc.close()
616 610 self._workeripc.close()
617 611 self._unlinksocket()
618 612 self._repoloader.stop()
619 613 # don't kill child processes as they have active clients, just wait
620 614 self._reapworkers(0)
621 615
622 616 def run(self):
623 617 try:
624 618 self._mainloop()
625 619 finally:
626 620 self._cleanup()
627 621
628 622 def _mainloop(self):
629 623 exiting = False
630 624 h = self._servicehandler
631 625 selector = selectors.DefaultSelector()
632 626 selector.register(
633 627 self._sock, selectors.EVENT_READ, self._acceptnewconnection
634 628 )
635 629 selector.register(
636 630 self._mainipc, selectors.EVENT_READ, self._handlemainipc
637 631 )
638 632 while True:
639 633 if not exiting and h.shouldexit():
640 634 # clients can no longer connect() to the domain socket, so
641 635 # we stop queuing new requests.
642 636 # for requests that are queued (connect()-ed, but haven't been
643 637 # accept()-ed), handle them before exit. otherwise, clients
644 638 # waiting for recv() will receive ECONNRESET.
645 639 self._unlinksocket()
646 640 exiting = True
647 try:
648 641 events = selector.select(timeout=h.pollinterval)
649 except OSError as inst:
650 # selectors2 raises ETIMEDOUT if timeout exceeded while
651 # handling signal interrupt. That's probably wrong, but
652 # we can easily get around it.
653 if inst.errno != errno.ETIMEDOUT:
654 raise
655 events = []
656 642 if not events:
657 643 # only exit if we completed all queued requests
658 644 if exiting:
659 645 break
660 646 continue
661 647 for key, _mask in events:
662 648 key.data(key.fileobj, selector)
663 649 selector.close()
664 650
665 651 def _acceptnewconnection(self, sock, selector):
666 652 h = self._servicehandler
667 653 try:
668 654 conn, _addr = sock.accept()
669 655 except socket.error as inst:
670 656 if inst.args[0] == errno.EINTR:
671 657 return
672 658 raise
673 659
674 660 # Future improvement: On Python 3.7, maybe gc.freeze() can be used
675 661 # to prevent COW memory from being touched by GC.
676 662 # https://instagram-engineering.com/
677 663 # copy-on-write-friendly-python-garbage-collection-ad6ed5233ddf
678 664 pid = os.fork()
679 665 if pid:
680 666 try:
681 667 self.ui.log(
682 668 b'cmdserver', b'forked worker process (pid=%d)\n', pid
683 669 )
684 670 self._workerpids.add(pid)
685 671 h.newconnection()
686 672 finally:
687 673 conn.close() # release handle in parent process
688 674 else:
689 675 try:
690 676 selector.close()
691 677 sock.close()
692 678 self._mainipc.close()
693 679 self._runworker(conn)
694 680 conn.close()
695 681 self._workeripc.close()
696 682 os._exit(0)
697 683 except: # never return, hence no re-raises
698 684 try:
699 685 self.ui.traceback(force=True)
700 686 finally:
701 687 os._exit(255)
702 688
703 689 def _handlemainipc(self, sock, selector):
704 690 """Process messages sent from a worker"""
705 691 try:
706 692 path = sock.recv(32768) # large enough to receive path
707 693 except socket.error as inst:
708 694 if inst.args[0] == errno.EINTR:
709 695 return
710 696 raise
711 697 self._repoloader.load(path)
712 698
713 699 def _sigchldhandler(self, signal, frame):
714 700 self._reapworkers(os.WNOHANG)
715 701
716 702 def _reapworkers(self, options):
717 703 while self._workerpids:
718 704 try:
719 705 pid, _status = os.waitpid(-1, options)
720 706 except OSError as inst:
721 707 if inst.errno == errno.EINTR:
722 708 continue
723 709 if inst.errno != errno.ECHILD:
724 710 raise
725 711 # no child processes at all (reaped by other waitpid()?)
726 712 self._workerpids.clear()
727 713 return
728 714 if pid == 0:
729 715 # no waitable child processes
730 716 return
731 717 self.ui.log(b'cmdserver', b'worker process exited (pid=%d)\n', pid)
732 718 self._workerpids.discard(pid)
733 719
734 720 def _runworker(self, conn):
735 721 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
736 722 _initworkerprocess()
737 723 h = self._servicehandler
738 724 try:
739 725 _serverequest(
740 726 self.ui,
741 727 self.repo,
742 728 conn,
743 729 h.createcmdserver,
744 730 prereposetups=[self._reposetup],
745 731 )
746 732 finally:
747 733 gc.collect() # trigger __del__ since worker process uses os._exit
748 734
749 735 def _reposetup(self, ui, repo):
750 736 if not repo.local():
751 737 return
752 738
753 739 class unixcmdserverrepo(repo.__class__):
754 740 def close(self):
755 741 super(unixcmdserverrepo, self).close()
756 742 try:
757 743 self._cmdserveripc.send(self.root)
758 744 except socket.error:
759 745 self.ui.log(
760 746 b'cmdserver', b'failed to send repo root to master\n'
761 747 )
762 748
763 749 repo.__class__ = unixcmdserverrepo
764 750 repo._cmdserveripc = self._workeripc
765 751
766 752 cachedrepo = self._repoloader.get(repo.root)
767 753 if cachedrepo is None:
768 754 return
769 755 repo.ui.log(b'repocache', b'repo from cache: %s\n', repo.root)
770 756 repocache.copycache(cachedrepo, repo)
@@ -1,473 +1,469 b''
1 1 # worker.py - master-slave parallelism support
2 2 #
3 3 # Copyright 2013 Facebook, Inc.
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8
9 9 import errno
10 10 import os
11 11 import pickle
12 import selectors
12 13 import signal
13 14 import sys
14 15 import threading
15 16 import time
16 17
17 try:
18 import selectors
19
20 selectors.BaseSelector
21 except ImportError:
22 from .thirdparty import selectors2 as selectors
23
24 18 from .i18n import _
25 19 from . import (
26 20 encoding,
27 21 error,
28 22 pycompat,
29 23 scmutil,
30 24 )
31 25
32 26
33 27 def countcpus():
34 28 '''try to count the number of CPUs on the system'''
35 29
36 30 # posix
37 31 try:
38 32 n = int(os.sysconf('SC_NPROCESSORS_ONLN'))
39 33 if n > 0:
40 34 return n
41 35 except (AttributeError, ValueError):
42 36 pass
43 37
44 38 # windows
45 39 try:
46 40 n = int(encoding.environ[b'NUMBER_OF_PROCESSORS'])
47 41 if n > 0:
48 42 return n
49 43 except (KeyError, ValueError):
50 44 pass
51 45
52 46 return 1
53 47
54 48
55 49 def _numworkers(ui):
56 50 s = ui.config(b'worker', b'numcpus')
57 51 if s:
58 52 try:
59 53 n = int(s)
60 54 if n >= 1:
61 55 return n
62 56 except ValueError:
63 57 raise error.Abort(_(b'number of cpus must be an integer'))
64 58 return min(max(countcpus(), 4), 32)
65 59
66 60
67 61 def ismainthread():
68 62 return threading.current_thread() == threading.main_thread()
69 63
70 64
71 65 class _blockingreader:
72 66 """Wrap unbuffered stream such that pickle.load() works with it.
73 67
74 68 pickle.load() expects that calls to read() and readinto() read as many
75 69 bytes as requested. On EOF, it is fine to read fewer bytes. In this case,
76 70 pickle.load() raises an EOFError.
77 71 """
78 72
79 73 def __init__(self, wrapped):
80 74 self._wrapped = wrapped
81 75
82 76 def readline(self):
83 77 return self._wrapped.readline()
84 78
85 79 def readinto(self, buf):
86 80 pos = 0
87 81 size = len(buf)
88 82
89 83 with memoryview(buf) as view:
90 84 while pos < size:
91 85 with view[pos:] as subview:
92 86 ret = self._wrapped.readinto(subview)
93 87 if not ret:
94 88 break
95 89 pos += ret
96 90
97 91 return pos
98 92
99 93 # issue multiple reads until size is fulfilled (or EOF is encountered)
100 94 def read(self, size=-1):
101 95 if size < 0:
102 96 return self._wrapped.readall()
103 97
104 98 buf = bytearray(size)
105 99 n_read = self.readinto(buf)
106 100 del buf[n_read:]
107 101 return bytes(buf)
108 102
109 103
110 104 if pycompat.isposix or pycompat.iswindows:
111 105 _STARTUP_COST = 0.01
112 106 # The Windows worker is thread based. If tasks are CPU bound, threads
113 107 # in the presence of the GIL result in excessive context switching and
114 108 # this overhead can slow down execution.
115 109 _DISALLOW_THREAD_UNSAFE = pycompat.iswindows
116 110 else:
117 111 _STARTUP_COST = 1e30
118 112 _DISALLOW_THREAD_UNSAFE = False
119 113
120 114
121 115 def worthwhile(ui, costperop, nops, threadsafe=True):
122 116 """try to determine whether the benefit of multiple processes can
123 117 outweigh the cost of starting them"""
124 118
125 119 if not threadsafe and _DISALLOW_THREAD_UNSAFE:
126 120 return False
127 121
128 122 linear = costperop * nops
129 123 workers = _numworkers(ui)
130 124 benefit = linear - (_STARTUP_COST * workers + linear / workers)
131 125 return benefit >= 0.15
132 126
133 127
134 128 def worker(
135 129 ui, costperarg, func, staticargs, args, hasretval=False, threadsafe=True
136 130 ):
137 131 """run a function, possibly in parallel in multiple worker
138 132 processes.
139 133
140 134 returns a progress iterator
141 135
142 136 costperarg - cost of a single task
143 137
144 138 func - function to run. It is expected to return a progress iterator.
145 139
146 140 staticargs - arguments to pass to every invocation of the function
147 141
148 142 args - arguments to split into chunks, to pass to individual
149 143 workers
150 144
151 145 hasretval - when True, func and the current function return an progress
152 146 iterator then a dict (encoded as an iterator that yield many (False, ..)
153 147 then a (True, dict)). The dicts are joined in some arbitrary order, so
154 148 overlapping keys are a bad idea.
155 149
156 150 threadsafe - whether work items are thread safe and can be executed using
157 151 a thread-based worker. Should be disabled for CPU heavy tasks that don't
158 152 release the GIL.
159 153 """
160 154 enabled = ui.configbool(b'worker', b'enabled')
161 155 if enabled and _platformworker is _posixworker and not ismainthread():
162 156 # The POSIX worker has to install a handler for SIGCHLD.
163 157 # Python up to 3.9 only allows this in the main thread.
164 158 enabled = False
165 159
166 160 if enabled and worthwhile(ui, costperarg, len(args), threadsafe=threadsafe):
167 161 return _platformworker(ui, func, staticargs, args, hasretval)
168 162 return func(*staticargs + (args,))
169 163
170 164
171 165 def _posixworker(ui, func, staticargs, args, hasretval):
172 166 workers = _numworkers(ui)
173 167 oldhandler = signal.getsignal(signal.SIGINT)
174 168 signal.signal(signal.SIGINT, signal.SIG_IGN)
175 169 pids, problem = set(), [0]
176 170
177 171 def killworkers():
178 172 # unregister SIGCHLD handler as all children will be killed. This
179 173 # function shouldn't be interrupted by another SIGCHLD; otherwise pids
180 174 # could be updated while iterating, which would cause inconsistency.
181 175 signal.signal(signal.SIGCHLD, oldchldhandler)
182 176 # if one worker bails, there's no good reason to wait for the rest
183 177 for p in pids:
184 178 try:
185 179 os.kill(p, signal.SIGTERM)
186 180 except OSError as err:
187 181 if err.errno != errno.ESRCH:
188 182 raise
189 183
190 184 def waitforworkers(blocking=True):
191 185 for pid in pids.copy():
192 186 p = st = 0
193 187 while True:
194 188 try:
195 189 p, st = os.waitpid(pid, (0 if blocking else os.WNOHANG))
196 190 break
197 191 except OSError as e:
198 192 if e.errno == errno.EINTR:
199 193 continue
200 194 elif e.errno == errno.ECHILD:
201 195 # child would already be reaped, but pids yet been
202 196 # updated (maybe interrupted just after waitpid)
203 197 pids.discard(pid)
204 198 break
205 199 else:
206 200 raise
207 201 if not p:
208 202 # skip subsequent steps, because child process should
209 203 # be still running in this case
210 204 continue
211 205 pids.discard(p)
212 206 st = _exitstatus(st)
213 207 if st and not problem[0]:
214 208 problem[0] = st
215 209
216 210 def sigchldhandler(signum, frame):
217 211 waitforworkers(blocking=False)
218 212 if problem[0]:
219 213 killworkers()
220 214
221 215 oldchldhandler = signal.signal(signal.SIGCHLD, sigchldhandler)
222 216 ui.flush()
223 217 parentpid = os.getpid()
224 218 pipes = []
225 219 retval = {}
226 220 for pargs in partition(args, min(workers, len(args))):
227 221 # Every worker gets its own pipe to send results on, so we don't have to
228 222 # implement atomic writes larger than PIPE_BUF. Each forked process has
229 223 # its own pipe's descriptors in the local variables, and the parent
230 224 # process has the full list of pipe descriptors (and it doesn't really
231 225 # care what order they're in).
232 226 rfd, wfd = os.pipe()
233 227 pipes.append((rfd, wfd))
234 228 # make sure we use os._exit in all worker code paths. otherwise the
235 229 # worker may do some clean-ups which could cause surprises like
236 230 # deadlock. see sshpeer.cleanup for example.
237 231 # override error handling *before* fork. this is necessary because
238 232 # exception (signal) may arrive after fork, before "pid =" assignment
239 233 # completes, and other exception handler (dispatch.py) can lead to
240 234 # unexpected code path without os._exit.
241 235 ret = -1
242 236 try:
243 237 pid = os.fork()
244 238 if pid == 0:
245 239 signal.signal(signal.SIGINT, oldhandler)
246 240 signal.signal(signal.SIGCHLD, oldchldhandler)
247 241
248 242 def workerfunc():
249 243 for r, w in pipes[:-1]:
250 244 os.close(r)
251 245 os.close(w)
252 246 os.close(rfd)
253 247 with os.fdopen(wfd, 'wb') as wf:
254 248 for result in func(*(staticargs + (pargs,))):
255 249 pickle.dump(result, wf)
256 250 wf.flush()
257 251 return 0
258 252
259 253 ret = scmutil.callcatch(ui, workerfunc)
260 254 except: # parent re-raises, child never returns
261 255 if os.getpid() == parentpid:
262 256 raise
263 257 exctype = sys.exc_info()[0]
264 258 force = not issubclass(exctype, KeyboardInterrupt)
265 259 ui.traceback(force=force)
266 260 finally:
267 261 if os.getpid() != parentpid:
268 262 try:
269 263 ui.flush()
270 264 except: # never returns, no re-raises
271 265 pass
272 266 finally:
273 267 os._exit(ret & 255)
274 268 pids.add(pid)
275 269 selector = selectors.DefaultSelector()
276 270 for rfd, wfd in pipes:
277 271 os.close(wfd)
278 272 # The stream has to be unbuffered. Otherwise, if all data is read from
279 273 # the raw file into the buffer, the selector thinks that the FD is not
280 274 # ready to read while pickle.load() could read from the buffer. This
281 275 # would delay the processing of readable items.
282 276 selector.register(os.fdopen(rfd, 'rb', 0), selectors.EVENT_READ)
283 277
284 278 def cleanup():
285 279 signal.signal(signal.SIGINT, oldhandler)
286 280 waitforworkers()
287 281 signal.signal(signal.SIGCHLD, oldchldhandler)
288 282 selector.close()
289 283 return problem[0]
290 284
291 285 try:
292 286 openpipes = len(pipes)
293 287 while openpipes > 0:
294 288 for key, events in selector.select():
295 289 try:
296 290 # The pytype error likely goes away on a modern version of
297 291 # pytype having a modern typeshed snapshot.
298 292 # pytype: disable=wrong-arg-types
299 293 res = pickle.load(_blockingreader(key.fileobj))
300 294 # pytype: enable=wrong-arg-types
301 295 if hasretval and res[0]:
302 296 retval.update(res[1])
303 297 else:
304 298 yield res
305 299 except EOFError:
306 300 selector.unregister(key.fileobj)
301 # pytype: disable=attribute-error
307 302 key.fileobj.close()
303 # pytype: enable=attribute-error
308 304 openpipes -= 1
309 305 except IOError as e:
310 306 if e.errno == errno.EINTR:
311 307 continue
312 308 raise
313 309 except: # re-raises
314 310 killworkers()
315 311 cleanup()
316 312 raise
317 313 status = cleanup()
318 314 if status:
319 315 if status < 0:
320 316 os.kill(os.getpid(), -status)
321 317 raise error.WorkerError(status)
322 318 if hasretval:
323 319 yield True, retval
324 320
325 321
326 322 def _posixexitstatus(code):
327 323 """convert a posix exit status into the same form returned by
328 324 os.spawnv
329 325
330 326 returns None if the process was stopped instead of exiting"""
331 327 if os.WIFEXITED(code):
332 328 return os.WEXITSTATUS(code)
333 329 elif os.WIFSIGNALED(code):
334 330 return -(os.WTERMSIG(code))
335 331
336 332
337 333 def _windowsworker(ui, func, staticargs, args, hasretval):
338 334 class Worker(threading.Thread):
339 335 def __init__(
340 336 self, taskqueue, resultqueue, func, staticargs, *args, **kwargs
341 337 ):
342 338 threading.Thread.__init__(self, *args, **kwargs)
343 339 self._taskqueue = taskqueue
344 340 self._resultqueue = resultqueue
345 341 self._func = func
346 342 self._staticargs = staticargs
347 343 self._interrupted = False
348 344 self.daemon = True
349 345 self.exception = None
350 346
351 347 def interrupt(self):
352 348 self._interrupted = True
353 349
354 350 def run(self):
355 351 try:
356 352 while not self._taskqueue.empty():
357 353 try:
358 354 args = self._taskqueue.get_nowait()
359 355 for res in self._func(*self._staticargs + (args,)):
360 356 self._resultqueue.put(res)
361 357 # threading doesn't provide a native way to
362 358 # interrupt execution. handle it manually at every
363 359 # iteration.
364 360 if self._interrupted:
365 361 return
366 362 except pycompat.queue.Empty:
367 363 break
368 364 except Exception as e:
369 365 # store the exception such that the main thread can resurface
370 366 # it as if the func was running without workers.
371 367 self.exception = e
372 368 raise
373 369
374 370 threads = []
375 371
376 372 def trykillworkers():
377 373 # Allow up to 1 second to clean worker threads nicely
378 374 cleanupend = time.time() + 1
379 375 for t in threads:
380 376 t.interrupt()
381 377 for t in threads:
382 378 remainingtime = cleanupend - time.time()
383 379 t.join(remainingtime)
384 380 if t.is_alive():
385 381 # pass over the workers joining failure. it is more
386 382 # important to surface the inital exception than the
387 383 # fact that one of workers may be processing a large
388 384 # task and does not get to handle the interruption.
389 385 ui.warn(
390 386 _(
391 387 b"failed to kill worker threads while "
392 388 b"handling an exception\n"
393 389 )
394 390 )
395 391 return
396 392
397 393 workers = _numworkers(ui)
398 394 resultqueue = pycompat.queue.Queue()
399 395 taskqueue = pycompat.queue.Queue()
400 396 retval = {}
401 397 # partition work to more pieces than workers to minimize the chance
402 398 # of uneven distribution of large tasks between the workers
403 399 for pargs in partition(args, workers * 20):
404 400 taskqueue.put(pargs)
405 401 for _i in range(workers):
406 402 t = Worker(taskqueue, resultqueue, func, staticargs)
407 403 threads.append(t)
408 404 t.start()
409 405 try:
410 406 while len(threads) > 0:
411 407 while not resultqueue.empty():
412 408 res = resultqueue.get()
413 409 if hasretval and res[0]:
414 410 retval.update(res[1])
415 411 else:
416 412 yield res
417 413 threads[0].join(0.05)
418 414 finishedthreads = [_t for _t in threads if not _t.is_alive()]
419 415 for t in finishedthreads:
420 416 if t.exception is not None:
421 417 raise t.exception
422 418 threads.remove(t)
423 419 except (Exception, KeyboardInterrupt): # re-raises
424 420 trykillworkers()
425 421 raise
426 422 while not resultqueue.empty():
427 423 res = resultqueue.get()
428 424 if hasretval and res[0]:
429 425 retval.update(res[1])
430 426 else:
431 427 yield res
432 428 if hasretval:
433 429 yield True, retval
434 430
435 431
436 432 if pycompat.iswindows:
437 433 _platformworker = _windowsworker
438 434 else:
439 435 _platformworker = _posixworker
440 436 _exitstatus = _posixexitstatus
441 437
442 438
443 439 def partition(lst, nslices):
444 440 """partition a list into N slices of roughly equal size
445 441
446 442 The current strategy takes every Nth element from the input. If
447 443 we ever write workers that need to preserve grouping in input
448 444 we should consider allowing callers to specify a partition strategy.
449 445
450 446 olivia is not a fan of this partitioning strategy when files are involved.
451 447 In his words:
452 448
453 449 Single-threaded Mercurial makes a point of creating and visiting
454 450 files in a fixed order (alphabetical). When creating files in order,
455 451 a typical filesystem is likely to allocate them on nearby regions on
456 452 disk. Thus, when revisiting in the same order, locality is maximized
457 453 and various forms of OS and disk-level caching and read-ahead get a
458 454 chance to work.
459 455
460 456 This effect can be quite significant on spinning disks. I discovered it
461 457 circa Mercurial v0.4 when revlogs were named by hashes of filenames.
462 458 Tarring a repo and copying it to another disk effectively randomized
463 459 the revlog ordering on disk by sorting the revlogs by hash and suddenly
464 460 performance of my kernel checkout benchmark dropped by ~10x because the
465 461 "working set" of sectors visited no longer fit in the drive's cache and
466 462 the workload switched from streaming to random I/O.
467 463
468 464 What we should really be doing is have workers read filenames from a
469 465 ordered queue. This preserves locality and also keeps any worker from
470 466 getting more than one file out of balance.
471 467 """
472 468 for i in range(nslices):
473 469 yield lst[i::nslices]
1 NO CONTENT: file was removed
This diff has been collapsed as it changes many lines, (743 lines changed) Show them Hide them
General Comments 0
You need to be logged in to leave comments. Login now