##// END OF EJS Templates
commandserver: prefer first-party selectors module from Python 3 to backport...
Augie Fackler -
r36958:b0ffcb54 default
parent child Browse files
Show More
@@ -1,551 +1,556 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import errno
11 11 import gc
12 12 import os
13 13 import random
14 14 import signal
15 15 import socket
16 16 import struct
17 17 import traceback
18 18
19 try:
20 import selectors
21 selectors.BaseSelector
22 except ImportError:
23 from .thirdparty import selectors2 as selectors
24
19 25 from .i18n import _
20 from .thirdparty import selectors2
21 26 from . import (
22 27 encoding,
23 28 error,
24 29 pycompat,
25 30 util,
26 31 )
27 32
28 33 logfile = None
29 34
30 35 def log(*args):
31 36 if not logfile:
32 37 return
33 38
34 39 for a in args:
35 40 logfile.write(str(a))
36 41
37 42 logfile.flush()
38 43
39 44 class channeledoutput(object):
40 45 """
41 46 Write data to out in the following format:
42 47
43 48 data length (unsigned int),
44 49 data
45 50 """
46 51 def __init__(self, out, channel):
47 52 self.out = out
48 53 self.channel = channel
49 54
50 55 @property
51 56 def name(self):
52 57 return '<%c-channel>' % self.channel
53 58
54 59 def write(self, data):
55 60 if not data:
56 61 return
57 62 # single write() to guarantee the same atomicity as the underlying file
58 63 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
59 64 self.out.flush()
60 65
61 66 def __getattr__(self, attr):
62 67 if attr in ('isatty', 'fileno', 'tell', 'seek'):
63 68 raise AttributeError(attr)
64 69 return getattr(self.out, attr)
65 70
66 71 class channeledinput(object):
67 72 """
68 73 Read data from in_.
69 74
70 75 Requests for input are written to out in the following format:
71 76 channel identifier - 'I' for plain input, 'L' line based (1 byte)
72 77 how many bytes to send at most (unsigned int),
73 78
74 79 The client replies with:
75 80 data length (unsigned int), 0 meaning EOF
76 81 data
77 82 """
78 83
79 84 maxchunksize = 4 * 1024
80 85
81 86 def __init__(self, in_, out, channel):
82 87 self.in_ = in_
83 88 self.out = out
84 89 self.channel = channel
85 90
86 91 @property
87 92 def name(self):
88 93 return '<%c-channel>' % self.channel
89 94
90 95 def read(self, size=-1):
91 96 if size < 0:
92 97 # if we need to consume all the clients input, ask for 4k chunks
93 98 # so the pipe doesn't fill up risking a deadlock
94 99 size = self.maxchunksize
95 100 s = self._read(size, self.channel)
96 101 buf = s
97 102 while s:
98 103 s = self._read(size, self.channel)
99 104 buf += s
100 105
101 106 return buf
102 107 else:
103 108 return self._read(size, self.channel)
104 109
105 110 def _read(self, size, channel):
106 111 if not size:
107 112 return ''
108 113 assert size > 0
109 114
110 115 # tell the client we need at most size bytes
111 116 self.out.write(struct.pack('>cI', channel, size))
112 117 self.out.flush()
113 118
114 119 length = self.in_.read(4)
115 120 length = struct.unpack('>I', length)[0]
116 121 if not length:
117 122 return ''
118 123 else:
119 124 return self.in_.read(length)
120 125
121 126 def readline(self, size=-1):
122 127 if size < 0:
123 128 size = self.maxchunksize
124 129 s = self._read(size, 'L')
125 130 buf = s
126 131 # keep asking for more until there's either no more or
127 132 # we got a full line
128 133 while s and s[-1] != '\n':
129 134 s = self._read(size, 'L')
130 135 buf += s
131 136
132 137 return buf
133 138 else:
134 139 return self._read(size, 'L')
135 140
136 141 def __iter__(self):
137 142 return self
138 143
139 144 def next(self):
140 145 l = self.readline()
141 146 if not l:
142 147 raise StopIteration
143 148 return l
144 149
145 150 def __getattr__(self, attr):
146 151 if attr in ('isatty', 'fileno', 'tell', 'seek'):
147 152 raise AttributeError(attr)
148 153 return getattr(self.in_, attr)
149 154
150 155 class server(object):
151 156 """
152 157 Listens for commands on fin, runs them and writes the output on a channel
153 158 based stream to fout.
154 159 """
155 160 def __init__(self, ui, repo, fin, fout):
156 161 self.cwd = pycompat.getcwd()
157 162
158 163 # developer config: cmdserver.log
159 164 logpath = ui.config("cmdserver", "log")
160 165 if logpath:
161 166 global logfile
162 167 if logpath == '-':
163 168 # write log on a special 'd' (debug) channel
164 169 logfile = channeledoutput(fout, 'd')
165 170 else:
166 171 logfile = open(logpath, 'a')
167 172
168 173 if repo:
169 174 # the ui here is really the repo ui so take its baseui so we don't
170 175 # end up with its local configuration
171 176 self.ui = repo.baseui
172 177 self.repo = repo
173 178 self.repoui = repo.ui
174 179 else:
175 180 self.ui = ui
176 181 self.repo = self.repoui = None
177 182
178 183 self.cerr = channeledoutput(fout, 'e')
179 184 self.cout = channeledoutput(fout, 'o')
180 185 self.cin = channeledinput(fin, fout, 'I')
181 186 self.cresult = channeledoutput(fout, 'r')
182 187
183 188 self.client = fin
184 189
185 190 def cleanup(self):
186 191 """release and restore resources taken during server session"""
187 192
188 193 def _read(self, size):
189 194 if not size:
190 195 return ''
191 196
192 197 data = self.client.read(size)
193 198
194 199 # is the other end closed?
195 200 if not data:
196 201 raise EOFError
197 202
198 203 return data
199 204
200 205 def _readstr(self):
201 206 """read a string from the channel
202 207
203 208 format:
204 209 data length (uint32), data
205 210 """
206 211 length = struct.unpack('>I', self._read(4))[0]
207 212 if not length:
208 213 return ''
209 214 return self._read(length)
210 215
211 216 def _readlist(self):
212 217 """read a list of NULL separated strings from the channel"""
213 218 s = self._readstr()
214 219 if s:
215 220 return s.split('\0')
216 221 else:
217 222 return []
218 223
219 224 def runcommand(self):
220 225 """ reads a list of \0 terminated arguments, executes
221 226 and writes the return code to the result channel """
222 227 from . import dispatch # avoid cycle
223 228
224 229 args = self._readlist()
225 230
226 231 # copy the uis so changes (e.g. --config or --verbose) don't
227 232 # persist between requests
228 233 copiedui = self.ui.copy()
229 234 uis = [copiedui]
230 235 if self.repo:
231 236 self.repo.baseui = copiedui
232 237 # clone ui without using ui.copy because this is protected
233 238 repoui = self.repoui.__class__(self.repoui)
234 239 repoui.copy = copiedui.copy # redo copy protection
235 240 uis.append(repoui)
236 241 self.repo.ui = self.repo.dirstate._ui = repoui
237 242 self.repo.invalidateall()
238 243
239 244 for ui in uis:
240 245 ui.resetstate()
241 246 # any kind of interaction must use server channels, but chg may
242 247 # replace channels by fully functional tty files. so nontty is
243 248 # enforced only if cin is a channel.
244 249 if not util.safehasattr(self.cin, 'fileno'):
245 250 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
246 251
247 252 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
248 253 self.cout, self.cerr)
249 254
250 255 try:
251 256 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
252 257 self.cresult.write(struct.pack('>i', int(ret)))
253 258 finally:
254 259 # restore old cwd
255 260 if '--cwd' in args:
256 261 os.chdir(self.cwd)
257 262
258 263 def getencoding(self):
259 264 """ writes the current encoding to the result channel """
260 265 self.cresult.write(encoding.encoding)
261 266
262 267 def serveone(self):
263 268 cmd = self.client.readline()[:-1]
264 269 if cmd:
265 270 handler = self.capabilities.get(cmd)
266 271 if handler:
267 272 handler(self)
268 273 else:
269 274 # clients are expected to check what commands are supported by
270 275 # looking at the servers capabilities
271 276 raise error.Abort(_('unknown command %s') % cmd)
272 277
273 278 return cmd != ''
274 279
275 280 capabilities = {'runcommand': runcommand,
276 281 'getencoding': getencoding}
277 282
278 283 def serve(self):
279 284 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
280 285 hellomsg += '\n'
281 286 hellomsg += 'encoding: ' + encoding.encoding
282 287 hellomsg += '\n'
283 288 hellomsg += 'pid: %d' % util.getpid()
284 289 if util.safehasattr(os, 'getpgid'):
285 290 hellomsg += '\n'
286 291 hellomsg += 'pgid: %d' % os.getpgid(0)
287 292
288 293 # write the hello msg in -one- chunk
289 294 self.cout.write(hellomsg)
290 295
291 296 try:
292 297 while self.serveone():
293 298 pass
294 299 except EOFError:
295 300 # we'll get here if the client disconnected while we were reading
296 301 # its request
297 302 return 1
298 303
299 304 return 0
300 305
301 306 def _protectio(ui):
302 307 """ duplicates streams and redirect original to null if ui uses stdio """
303 308 ui.flush()
304 309 newfiles = []
305 310 nullfd = os.open(os.devnull, os.O_RDWR)
306 311 for f, sysf, mode in [(ui.fin, util.stdin, r'rb'),
307 312 (ui.fout, util.stdout, r'wb')]:
308 313 if f is sysf:
309 314 newfd = os.dup(f.fileno())
310 315 os.dup2(nullfd, f.fileno())
311 316 f = os.fdopen(newfd, mode)
312 317 newfiles.append(f)
313 318 os.close(nullfd)
314 319 return tuple(newfiles)
315 320
316 321 def _restoreio(ui, fin, fout):
317 322 """ restores streams from duplicated ones """
318 323 ui.flush()
319 324 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
320 325 if f is not uif:
321 326 os.dup2(f.fileno(), uif.fileno())
322 327 f.close()
323 328
324 329 class pipeservice(object):
325 330 def __init__(self, ui, repo, opts):
326 331 self.ui = ui
327 332 self.repo = repo
328 333
329 334 def init(self):
330 335 pass
331 336
332 337 def run(self):
333 338 ui = self.ui
334 339 # redirect stdio to null device so that broken extensions or in-process
335 340 # hooks will never cause corruption of channel protocol.
336 341 fin, fout = _protectio(ui)
337 342 try:
338 343 sv = server(ui, self.repo, fin, fout)
339 344 return sv.serve()
340 345 finally:
341 346 sv.cleanup()
342 347 _restoreio(ui, fin, fout)
343 348
344 349 def _initworkerprocess():
345 350 # use a different process group from the master process, in order to:
346 351 # 1. make the current process group no longer "orphaned" (because the
347 352 # parent of this process is in a different process group while
348 353 # remains in a same session)
349 354 # according to POSIX 2.2.2.52, orphaned process group will ignore
350 355 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
351 356 # cause trouble for things like ncurses.
352 357 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
353 358 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
354 359 # processes like ssh will be killed properly, without affecting
355 360 # unrelated processes.
356 361 os.setpgid(0, 0)
357 362 # change random state otherwise forked request handlers would have a
358 363 # same state inherited from parent.
359 364 random.seed()
360 365
361 366 def _serverequest(ui, repo, conn, createcmdserver):
362 367 fin = conn.makefile('rb')
363 368 fout = conn.makefile('wb')
364 369 sv = None
365 370 try:
366 371 sv = createcmdserver(repo, conn, fin, fout)
367 372 try:
368 373 sv.serve()
369 374 # handle exceptions that may be raised by command server. most of
370 375 # known exceptions are caught by dispatch.
371 376 except error.Abort as inst:
372 377 ui.warn(_('abort: %s\n') % inst)
373 378 except IOError as inst:
374 379 if inst.errno != errno.EPIPE:
375 380 raise
376 381 except KeyboardInterrupt:
377 382 pass
378 383 finally:
379 384 sv.cleanup()
380 385 except: # re-raises
381 386 # also write traceback to error channel. otherwise client cannot
382 387 # see it because it is written to server's stderr by default.
383 388 if sv:
384 389 cerr = sv.cerr
385 390 else:
386 391 cerr = channeledoutput(fout, 'e')
387 392 traceback.print_exc(file=cerr)
388 393 raise
389 394 finally:
390 395 fin.close()
391 396 try:
392 397 fout.close() # implicit flush() may cause another EPIPE
393 398 except IOError as inst:
394 399 if inst.errno != errno.EPIPE:
395 400 raise
396 401
397 402 class unixservicehandler(object):
398 403 """Set of pluggable operations for unix-mode services
399 404
400 405 Almost all methods except for createcmdserver() are called in the main
401 406 process. You can't pass mutable resource back from createcmdserver().
402 407 """
403 408
404 409 pollinterval = None
405 410
406 411 def __init__(self, ui):
407 412 self.ui = ui
408 413
409 414 def bindsocket(self, sock, address):
410 415 util.bindunixsocket(sock, address)
411 416 sock.listen(socket.SOMAXCONN)
412 417 self.ui.status(_('listening at %s\n') % address)
413 418 self.ui.flush() # avoid buffering of status message
414 419
415 420 def unlinksocket(self, address):
416 421 os.unlink(address)
417 422
418 423 def shouldexit(self):
419 424 """True if server should shut down; checked per pollinterval"""
420 425 return False
421 426
422 427 def newconnection(self):
423 428 """Called when main process notices new connection"""
424 429
425 430 def createcmdserver(self, repo, conn, fin, fout):
426 431 """Create new command server instance; called in the process that
427 432 serves for the current connection"""
428 433 return server(self.ui, repo, fin, fout)
429 434
430 435 class unixforkingservice(object):
431 436 """
432 437 Listens on unix domain socket and forks server per connection
433 438 """
434 439
435 440 def __init__(self, ui, repo, opts, handler=None):
436 441 self.ui = ui
437 442 self.repo = repo
438 443 self.address = opts['address']
439 444 if not util.safehasattr(socket, 'AF_UNIX'):
440 445 raise error.Abort(_('unsupported platform'))
441 446 if not self.address:
442 447 raise error.Abort(_('no socket path specified with --address'))
443 448 self._servicehandler = handler or unixservicehandler(ui)
444 449 self._sock = None
445 450 self._oldsigchldhandler = None
446 451 self._workerpids = set() # updated by signal handler; do not iterate
447 452 self._socketunlinked = None
448 453
449 454 def init(self):
450 455 self._sock = socket.socket(socket.AF_UNIX)
451 456 self._servicehandler.bindsocket(self._sock, self.address)
452 457 if util.safehasattr(util, 'unblocksignal'):
453 458 util.unblocksignal(signal.SIGCHLD)
454 459 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
455 460 self._oldsigchldhandler = o
456 461 self._socketunlinked = False
457 462
458 463 def _unlinksocket(self):
459 464 if not self._socketunlinked:
460 465 self._servicehandler.unlinksocket(self.address)
461 466 self._socketunlinked = True
462 467
463 468 def _cleanup(self):
464 469 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
465 470 self._sock.close()
466 471 self._unlinksocket()
467 472 # don't kill child processes as they have active clients, just wait
468 473 self._reapworkers(0)
469 474
470 475 def run(self):
471 476 try:
472 477 self._mainloop()
473 478 finally:
474 479 self._cleanup()
475 480
476 481 def _mainloop(self):
477 482 exiting = False
478 483 h = self._servicehandler
479 selector = selectors2.DefaultSelector()
480 selector.register(self._sock, selectors2.EVENT_READ)
484 selector = selectors.DefaultSelector()
485 selector.register(self._sock, selectors.EVENT_READ)
481 486 while True:
482 487 if not exiting and h.shouldexit():
483 488 # clients can no longer connect() to the domain socket, so
484 489 # we stop queuing new requests.
485 490 # for requests that are queued (connect()-ed, but haven't been
486 491 # accept()-ed), handle them before exit. otherwise, clients
487 492 # waiting for recv() will receive ECONNRESET.
488 493 self._unlinksocket()
489 494 exiting = True
490 495 ready = selector.select(timeout=h.pollinterval)
491 496 if not ready:
492 497 # only exit if we completed all queued requests
493 498 if exiting:
494 499 break
495 500 continue
496 501 try:
497 502 conn, _addr = self._sock.accept()
498 503 except socket.error as inst:
499 504 if inst.args[0] == errno.EINTR:
500 505 continue
501 506 raise
502 507
503 508 pid = os.fork()
504 509 if pid:
505 510 try:
506 511 self.ui.debug('forked worker process (pid=%d)\n' % pid)
507 512 self._workerpids.add(pid)
508 513 h.newconnection()
509 514 finally:
510 515 conn.close() # release handle in parent process
511 516 else:
512 517 try:
513 518 self._runworker(conn)
514 519 conn.close()
515 520 os._exit(0)
516 521 except: # never return, hence no re-raises
517 522 try:
518 523 self.ui.traceback(force=True)
519 524 finally:
520 525 os._exit(255)
521 526 selector.close()
522 527
523 528 def _sigchldhandler(self, signal, frame):
524 529 self._reapworkers(os.WNOHANG)
525 530
526 531 def _reapworkers(self, options):
527 532 while self._workerpids:
528 533 try:
529 534 pid, _status = os.waitpid(-1, options)
530 535 except OSError as inst:
531 536 if inst.errno == errno.EINTR:
532 537 continue
533 538 if inst.errno != errno.ECHILD:
534 539 raise
535 540 # no child processes at all (reaped by other waitpid()?)
536 541 self._workerpids.clear()
537 542 return
538 543 if pid == 0:
539 544 # no waitable child processes
540 545 return
541 546 self.ui.debug('worker process exited (pid=%d)\n' % pid)
542 547 self._workerpids.discard(pid)
543 548
544 549 def _runworker(self, conn):
545 550 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
546 551 _initworkerprocess()
547 552 h = self._servicehandler
548 553 try:
549 554 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
550 555 finally:
551 556 gc.collect() # trigger __del__ since worker process uses os._exit
General Comments 0
You need to be logged in to leave comments. Login now