##// END OF EJS Templates
commandserver: use selectors2...
Jun Wu -
r33503:27d23fe3 default
parent child Browse files
Show More
@@ -1,548 +1,551 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import errno
11 11 import gc
12 12 import os
13 13 import random
14 14 import select
15 15 import signal
16 16 import socket
17 17 import struct
18 18 import traceback
19 19
20 20 from .i18n import _
21 21 from . import (
22 22 encoding,
23 23 error,
24 24 pycompat,
25 selectors2,
25 26 util,
26 27 )
27 28
28 29 logfile = None
29 30
30 31 def log(*args):
31 32 if not logfile:
32 33 return
33 34
34 35 for a in args:
35 36 logfile.write(str(a))
36 37
37 38 logfile.flush()
38 39
39 40 class channeledoutput(object):
40 41 """
41 42 Write data to out in the following format:
42 43
43 44 data length (unsigned int),
44 45 data
45 46 """
46 47 def __init__(self, out, channel):
47 48 self.out = out
48 49 self.channel = channel
49 50
50 51 @property
51 52 def name(self):
52 53 return '<%c-channel>' % self.channel
53 54
54 55 def write(self, data):
55 56 if not data:
56 57 return
57 58 # single write() to guarantee the same atomicity as the underlying file
58 59 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
59 60 self.out.flush()
60 61
61 62 def __getattr__(self, attr):
62 63 if attr in ('isatty', 'fileno', 'tell', 'seek'):
63 64 raise AttributeError(attr)
64 65 return getattr(self.out, attr)
65 66
66 67 class channeledinput(object):
67 68 """
68 69 Read data from in_.
69 70
70 71 Requests for input are written to out in the following format:
71 72 channel identifier - 'I' for plain input, 'L' line based (1 byte)
72 73 how many bytes to send at most (unsigned int),
73 74
74 75 The client replies with:
75 76 data length (unsigned int), 0 meaning EOF
76 77 data
77 78 """
78 79
79 80 maxchunksize = 4 * 1024
80 81
81 82 def __init__(self, in_, out, channel):
82 83 self.in_ = in_
83 84 self.out = out
84 85 self.channel = channel
85 86
86 87 @property
87 88 def name(self):
88 89 return '<%c-channel>' % self.channel
89 90
90 91 def read(self, size=-1):
91 92 if size < 0:
92 93 # if we need to consume all the clients input, ask for 4k chunks
93 94 # so the pipe doesn't fill up risking a deadlock
94 95 size = self.maxchunksize
95 96 s = self._read(size, self.channel)
96 97 buf = s
97 98 while s:
98 99 s = self._read(size, self.channel)
99 100 buf += s
100 101
101 102 return buf
102 103 else:
103 104 return self._read(size, self.channel)
104 105
105 106 def _read(self, size, channel):
106 107 if not size:
107 108 return ''
108 109 assert size > 0
109 110
110 111 # tell the client we need at most size bytes
111 112 self.out.write(struct.pack('>cI', channel, size))
112 113 self.out.flush()
113 114
114 115 length = self.in_.read(4)
115 116 length = struct.unpack('>I', length)[0]
116 117 if not length:
117 118 return ''
118 119 else:
119 120 return self.in_.read(length)
120 121
121 122 def readline(self, size=-1):
122 123 if size < 0:
123 124 size = self.maxchunksize
124 125 s = self._read(size, 'L')
125 126 buf = s
126 127 # keep asking for more until there's either no more or
127 128 # we got a full line
128 129 while s and s[-1] != '\n':
129 130 s = self._read(size, 'L')
130 131 buf += s
131 132
132 133 return buf
133 134 else:
134 135 return self._read(size, 'L')
135 136
136 137 def __iter__(self):
137 138 return self
138 139
139 140 def next(self):
140 141 l = self.readline()
141 142 if not l:
142 143 raise StopIteration
143 144 return l
144 145
145 146 def __getattr__(self, attr):
146 147 if attr in ('isatty', 'fileno', 'tell', 'seek'):
147 148 raise AttributeError(attr)
148 149 return getattr(self.in_, attr)
149 150
150 151 class server(object):
151 152 """
152 153 Listens for commands on fin, runs them and writes the output on a channel
153 154 based stream to fout.
154 155 """
155 156 def __init__(self, ui, repo, fin, fout):
156 157 self.cwd = pycompat.getcwd()
157 158
158 159 # developer config: cmdserver.log
159 160 logpath = ui.config("cmdserver", "log")
160 161 if logpath:
161 162 global logfile
162 163 if logpath == '-':
163 164 # write log on a special 'd' (debug) channel
164 165 logfile = channeledoutput(fout, 'd')
165 166 else:
166 167 logfile = open(logpath, 'a')
167 168
168 169 if repo:
169 170 # the ui here is really the repo ui so take its baseui so we don't
170 171 # end up with its local configuration
171 172 self.ui = repo.baseui
172 173 self.repo = repo
173 174 self.repoui = repo.ui
174 175 else:
175 176 self.ui = ui
176 177 self.repo = self.repoui = None
177 178
178 179 self.cerr = channeledoutput(fout, 'e')
179 180 self.cout = channeledoutput(fout, 'o')
180 181 self.cin = channeledinput(fin, fout, 'I')
181 182 self.cresult = channeledoutput(fout, 'r')
182 183
183 184 self.client = fin
184 185
185 186 def cleanup(self):
186 187 """release and restore resources taken during server session"""
187 188 pass
188 189
189 190 def _read(self, size):
190 191 if not size:
191 192 return ''
192 193
193 194 data = self.client.read(size)
194 195
195 196 # is the other end closed?
196 197 if not data:
197 198 raise EOFError
198 199
199 200 return data
200 201
201 202 def _readstr(self):
202 203 """read a string from the channel
203 204
204 205 format:
205 206 data length (uint32), data
206 207 """
207 208 length = struct.unpack('>I', self._read(4))[0]
208 209 if not length:
209 210 return ''
210 211 return self._read(length)
211 212
212 213 def _readlist(self):
213 214 """read a list of NULL separated strings from the channel"""
214 215 s = self._readstr()
215 216 if s:
216 217 return s.split('\0')
217 218 else:
218 219 return []
219 220
220 221 def runcommand(self):
221 222 """ reads a list of \0 terminated arguments, executes
222 223 and writes the return code to the result channel """
223 224 from . import dispatch # avoid cycle
224 225
225 226 args = self._readlist()
226 227
227 228 # copy the uis so changes (e.g. --config or --verbose) don't
228 229 # persist between requests
229 230 copiedui = self.ui.copy()
230 231 uis = [copiedui]
231 232 if self.repo:
232 233 self.repo.baseui = copiedui
233 234 # clone ui without using ui.copy because this is protected
234 235 repoui = self.repoui.__class__(self.repoui)
235 236 repoui.copy = copiedui.copy # redo copy protection
236 237 uis.append(repoui)
237 238 self.repo.ui = self.repo.dirstate._ui = repoui
238 239 self.repo.invalidateall()
239 240
240 241 for ui in uis:
241 242 ui.resetstate()
242 243 # any kind of interaction must use server channels, but chg may
243 244 # replace channels by fully functional tty files. so nontty is
244 245 # enforced only if cin is a channel.
245 246 if not util.safehasattr(self.cin, 'fileno'):
246 247 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
247 248
248 249 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
249 250 self.cout, self.cerr)
250 251
251 252 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
252 253
253 254 # restore old cwd
254 255 if '--cwd' in args:
255 256 os.chdir(self.cwd)
256 257
257 258 self.cresult.write(struct.pack('>i', int(ret)))
258 259
259 260 def getencoding(self):
260 261 """ writes the current encoding to the result channel """
261 262 self.cresult.write(encoding.encoding)
262 263
263 264 def serveone(self):
264 265 cmd = self.client.readline()[:-1]
265 266 if cmd:
266 267 handler = self.capabilities.get(cmd)
267 268 if handler:
268 269 handler(self)
269 270 else:
270 271 # clients are expected to check what commands are supported by
271 272 # looking at the servers capabilities
272 273 raise error.Abort(_('unknown command %s') % cmd)
273 274
274 275 return cmd != ''
275 276
276 277 capabilities = {'runcommand' : runcommand,
277 278 'getencoding' : getencoding}
278 279
279 280 def serve(self):
280 281 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
281 282 hellomsg += '\n'
282 283 hellomsg += 'encoding: ' + encoding.encoding
283 284 hellomsg += '\n'
284 285 hellomsg += 'pid: %d' % util.getpid()
285 286 if util.safehasattr(os, 'getpgid'):
286 287 hellomsg += '\n'
287 288 hellomsg += 'pgid: %d' % os.getpgid(0)
288 289
289 290 # write the hello msg in -one- chunk
290 291 self.cout.write(hellomsg)
291 292
292 293 try:
293 294 while self.serveone():
294 295 pass
295 296 except EOFError:
296 297 # we'll get here if the client disconnected while we were reading
297 298 # its request
298 299 return 1
299 300
300 301 return 0
301 302
302 303 def _protectio(ui):
303 304 """ duplicates streams and redirect original to null if ui uses stdio """
304 305 ui.flush()
305 306 newfiles = []
306 307 nullfd = os.open(os.devnull, os.O_RDWR)
307 308 for f, sysf, mode in [(ui.fin, util.stdin, pycompat.sysstr('rb')),
308 309 (ui.fout, util.stdout, pycompat.sysstr('wb'))]:
309 310 if f is sysf:
310 311 newfd = os.dup(f.fileno())
311 312 os.dup2(nullfd, f.fileno())
312 313 f = os.fdopen(newfd, mode)
313 314 newfiles.append(f)
314 315 os.close(nullfd)
315 316 return tuple(newfiles)
316 317
317 318 def _restoreio(ui, fin, fout):
318 319 """ restores streams from duplicated ones """
319 320 ui.flush()
320 321 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
321 322 if f is not uif:
322 323 os.dup2(f.fileno(), uif.fileno())
323 324 f.close()
324 325
325 326 class pipeservice(object):
326 327 def __init__(self, ui, repo, opts):
327 328 self.ui = ui
328 329 self.repo = repo
329 330
330 331 def init(self):
331 332 pass
332 333
333 334 def run(self):
334 335 ui = self.ui
335 336 # redirect stdio to null device so that broken extensions or in-process
336 337 # hooks will never cause corruption of channel protocol.
337 338 fin, fout = _protectio(ui)
338 339 try:
339 340 sv = server(ui, self.repo, fin, fout)
340 341 return sv.serve()
341 342 finally:
342 343 sv.cleanup()
343 344 _restoreio(ui, fin, fout)
344 345
345 346 def _initworkerprocess():
346 347 # use a different process group from the master process, in order to:
347 348 # 1. make the current process group no longer "orphaned" (because the
348 349 # parent of this process is in a different process group while
349 350 # remains in a same session)
350 351 # according to POSIX 2.2.2.52, orphaned process group will ignore
351 352 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
352 353 # cause trouble for things like ncurses.
353 354 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
354 355 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
355 356 # processes like ssh will be killed properly, without affecting
356 357 # unrelated processes.
357 358 os.setpgid(0, 0)
358 359 # change random state otherwise forked request handlers would have a
359 360 # same state inherited from parent.
360 361 random.seed()
361 362
362 363 def _serverequest(ui, repo, conn, createcmdserver):
363 364 fin = conn.makefile('rb')
364 365 fout = conn.makefile('wb')
365 366 sv = None
366 367 try:
367 368 sv = createcmdserver(repo, conn, fin, fout)
368 369 try:
369 370 sv.serve()
370 371 # handle exceptions that may be raised by command server. most of
371 372 # known exceptions are caught by dispatch.
372 373 except error.Abort as inst:
373 374 ui.warn(_('abort: %s\n') % inst)
374 375 except IOError as inst:
375 376 if inst.errno != errno.EPIPE:
376 377 raise
377 378 except KeyboardInterrupt:
378 379 pass
379 380 finally:
380 381 sv.cleanup()
381 382 except: # re-raises
382 383 # also write traceback to error channel. otherwise client cannot
383 384 # see it because it is written to server's stderr by default.
384 385 if sv:
385 386 cerr = sv.cerr
386 387 else:
387 388 cerr = channeledoutput(fout, 'e')
388 389 traceback.print_exc(file=cerr)
389 390 raise
390 391 finally:
391 392 fin.close()
392 393 try:
393 394 fout.close() # implicit flush() may cause another EPIPE
394 395 except IOError as inst:
395 396 if inst.errno != errno.EPIPE:
396 397 raise
397 398
398 399 class unixservicehandler(object):
399 400 """Set of pluggable operations for unix-mode services
400 401
401 402 Almost all methods except for createcmdserver() are called in the main
402 403 process. You can't pass mutable resource back from createcmdserver().
403 404 """
404 405
405 406 pollinterval = None
406 407
407 408 def __init__(self, ui):
408 409 self.ui = ui
409 410
410 411 def bindsocket(self, sock, address):
411 412 util.bindunixsocket(sock, address)
412 413 sock.listen(socket.SOMAXCONN)
413 414 self.ui.status(_('listening at %s\n') % address)
414 415 self.ui.flush() # avoid buffering of status message
415 416
416 417 def unlinksocket(self, address):
417 418 os.unlink(address)
418 419
419 420 def shouldexit(self):
420 421 """True if server should shut down; checked per pollinterval"""
421 422 return False
422 423
423 424 def newconnection(self):
424 425 """Called when main process notices new connection"""
425 426 pass
426 427
427 428 def createcmdserver(self, repo, conn, fin, fout):
428 429 """Create new command server instance; called in the process that
429 430 serves for the current connection"""
430 431 return server(self.ui, repo, fin, fout)
431 432
432 433 class unixforkingservice(object):
433 434 """
434 435 Listens on unix domain socket and forks server per connection
435 436 """
436 437
437 438 def __init__(self, ui, repo, opts, handler=None):
438 439 self.ui = ui
439 440 self.repo = repo
440 441 self.address = opts['address']
441 442 if not util.safehasattr(socket, 'AF_UNIX'):
442 443 raise error.Abort(_('unsupported platform'))
443 444 if not self.address:
444 445 raise error.Abort(_('no socket path specified with --address'))
445 446 self._servicehandler = handler or unixservicehandler(ui)
446 447 self._sock = None
447 448 self._oldsigchldhandler = None
448 449 self._workerpids = set() # updated by signal handler; do not iterate
449 450 self._socketunlinked = None
450 451
451 452 def init(self):
452 453 self._sock = socket.socket(socket.AF_UNIX)
453 454 self._servicehandler.bindsocket(self._sock, self.address)
454 455 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
455 456 self._oldsigchldhandler = o
456 457 self._socketunlinked = False
457 458
458 459 def _unlinksocket(self):
459 460 if not self._socketunlinked:
460 461 self._servicehandler.unlinksocket(self.address)
461 462 self._socketunlinked = True
462 463
463 464 def _cleanup(self):
464 465 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
465 466 self._sock.close()
466 467 self._unlinksocket()
467 468 # don't kill child processes as they have active clients, just wait
468 469 self._reapworkers(0)
469 470
470 471 def run(self):
471 472 try:
472 473 self._mainloop()
473 474 finally:
474 475 self._cleanup()
475 476
476 477 def _mainloop(self):
477 478 exiting = False
478 479 h = self._servicehandler
480 selector = selectors2.DefaultSelector()
481 selector.register(self._sock, selectors2.EVENT_READ)
479 482 while True:
480 483 if not exiting and h.shouldexit():
481 484 # clients can no longer connect() to the domain socket, so
482 485 # we stop queuing new requests.
483 486 # for requests that are queued (connect()-ed, but haven't been
484 487 # accept()-ed), handle them before exit. otherwise, clients
485 488 # waiting for recv() will receive ECONNRESET.
486 489 self._unlinksocket()
487 490 exiting = True
488 491 try:
489 ready = select.select([self._sock], [], [], h.pollinterval)[0]
492 ready = selector.select(timeout=h.pollinterval)
490 493 if not ready:
491 494 # only exit if we completed all queued requests
492 495 if exiting:
493 496 break
494 497 continue
495 498 conn, _addr = self._sock.accept()
496 499 except (select.error, socket.error) as inst:
497 500 if inst.args[0] == errno.EINTR:
498 501 continue
499 502 raise
500 503
501 504 pid = os.fork()
502 505 if pid:
503 506 try:
504 507 self.ui.debug('forked worker process (pid=%d)\n' % pid)
505 508 self._workerpids.add(pid)
506 509 h.newconnection()
507 510 finally:
508 511 conn.close() # release handle in parent process
509 512 else:
510 513 try:
511 514 self._runworker(conn)
512 515 conn.close()
513 516 os._exit(0)
514 517 except: # never return, hence no re-raises
515 518 try:
516 519 self.ui.traceback(force=True)
517 520 finally:
518 521 os._exit(255)
519 522
520 523 def _sigchldhandler(self, signal, frame):
521 524 self._reapworkers(os.WNOHANG)
522 525
523 526 def _reapworkers(self, options):
524 527 while self._workerpids:
525 528 try:
526 529 pid, _status = os.waitpid(-1, options)
527 530 except OSError as inst:
528 531 if inst.errno == errno.EINTR:
529 532 continue
530 533 if inst.errno != errno.ECHILD:
531 534 raise
532 535 # no child processes at all (reaped by other waitpid()?)
533 536 self._workerpids.clear()
534 537 return
535 538 if pid == 0:
536 539 # no waitable child processes
537 540 return
538 541 self.ui.debug('worker process exited (pid=%d)\n' % pid)
539 542 self._workerpids.discard(pid)
540 543
541 544 def _runworker(self, conn):
542 545 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
543 546 _initworkerprocess()
544 547 h = self._servicehandler
545 548 try:
546 549 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
547 550 finally:
548 551 gc.collect() # trigger __del__ since worker process uses os._exit
General Comments 0
You need to be logged in to leave comments. Login now