##// END OF EJS Templates
commandserver: extract method to create commandserver instance per request...
Yuya Nishihara -
r29511:540c01a1 default
parent child Browse files
Show More
@@ -1,712 +1,716 b''
1 1 # chgserver.py - command server extension for cHg
2 2 #
3 3 # Copyright 2011 Yuya Nishihara <yuya@tcha.org>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 """command server extension for cHg (EXPERIMENTAL)
9 9
10 10 'S' channel (read/write)
11 11 propagate ui.system() request to client
12 12
13 13 'attachio' command
14 14 attach client's stdio passed by sendmsg()
15 15
16 16 'chdir' command
17 17 change current directory
18 18
19 19 'getpager' command
20 20 checks if pager is enabled and which pager should be executed
21 21
22 22 'setenv' command
23 23 replace os.environ completely
24 24
25 25 'setumask' command
26 26 set umask
27 27
28 28 'validate' command
29 29 reload the config and check if the server is up to date
30 30
31 31 Config
32 32 ------
33 33
34 34 ::
35 35
36 36 [chgserver]
37 37 idletimeout = 3600 # seconds, after which an idle server will exit
38 38 skiphash = False # whether to skip config or env change checks
39 39 """
40 40
41 41 from __future__ import absolute_import
42 42
43 43 import errno
44 44 import gc
45 45 import hashlib
46 46 import inspect
47 47 import os
48 48 import random
49 49 import re
50 50 import signal
51 51 import struct
52 52 import sys
53 53 import threading
54 54 import time
55 55 import traceback
56 56
57 57 from mercurial.i18n import _
58 58
59 59 from mercurial import (
60 60 cmdutil,
61 61 commands,
62 62 commandserver,
63 63 dispatch,
64 64 error,
65 65 extensions,
66 66 osutil,
67 67 util,
68 68 )
69 69
70 70 socketserver = util.socketserver
71 71
72 72 # Note for extension authors: ONLY specify testedwith = 'internal' for
73 73 # extensions which SHIP WITH MERCURIAL. Non-mainline extensions should
74 74 # be specifying the version(s) of Mercurial they are tested with, or
75 75 # leave the attribute unspecified.
76 76 testedwith = 'internal'
77 77
78 78 _log = commandserver.log
79 79
80 80 def _hashlist(items):
81 81 """return sha1 hexdigest for a list"""
82 82 return hashlib.sha1(str(items)).hexdigest()
83 83
84 84 # sensitive config sections affecting confighash
85 85 _configsections = [
86 86 'alias', # affects global state commands.table
87 87 'extdiff', # uisetup will register new commands
88 88 'extensions',
89 89 ]
90 90
91 91 # sensitive environment variables affecting confighash
92 92 _envre = re.compile(r'''\A(?:
93 93 CHGHG
94 94 |HG.*
95 95 |LANG(?:UAGE)?
96 96 |LC_.*
97 97 |LD_.*
98 98 |PATH
99 99 |PYTHON.*
100 100 |TERM(?:INFO)?
101 101 |TZ
102 102 )\Z''', re.X)
103 103
104 104 def _confighash(ui):
105 105 """return a quick hash for detecting config/env changes
106 106
107 107 confighash is the hash of sensitive config items and environment variables.
108 108
109 109 for chgserver, it is designed that once confighash changes, the server is
110 110 not qualified to serve its client and should redirect the client to a new
111 111 server. different from mtimehash, confighash change will not mark the
112 112 server outdated and exit since the user can have different configs at the
113 113 same time.
114 114 """
115 115 sectionitems = []
116 116 for section in _configsections:
117 117 sectionitems.append(ui.configitems(section))
118 118 sectionhash = _hashlist(sectionitems)
119 119 envitems = [(k, v) for k, v in os.environ.iteritems() if _envre.match(k)]
120 120 envhash = _hashlist(sorted(envitems))
121 121 return sectionhash[:6] + envhash[:6]
122 122
123 123 def _getmtimepaths(ui):
124 124 """get a list of paths that should be checked to detect change
125 125
126 126 The list will include:
127 127 - extensions (will not cover all files for complex extensions)
128 128 - mercurial/__version__.py
129 129 - python binary
130 130 """
131 131 modules = [m for n, m in extensions.extensions(ui)]
132 132 try:
133 133 from mercurial import __version__
134 134 modules.append(__version__)
135 135 except ImportError:
136 136 pass
137 137 files = [sys.executable]
138 138 for m in modules:
139 139 try:
140 140 files.append(inspect.getabsfile(m))
141 141 except TypeError:
142 142 pass
143 143 return sorted(set(files))
144 144
145 145 def _mtimehash(paths):
146 146 """return a quick hash for detecting file changes
147 147
148 148 mtimehash calls stat on given paths and calculate a hash based on size and
149 149 mtime of each file. mtimehash does not read file content because reading is
150 150 expensive. therefore it's not 100% reliable for detecting content changes.
151 151 it's possible to return different hashes for same file contents.
152 152 it's also possible to return a same hash for different file contents for
153 153 some carefully crafted situation.
154 154
155 155 for chgserver, it is designed that once mtimehash changes, the server is
156 156 considered outdated immediately and should no longer provide service.
157 157
158 158 mtimehash is not included in confighash because we only know the paths of
159 159 extensions after importing them (there is imp.find_module but that faces
160 160 race conditions). We need to calculate confighash without importing.
161 161 """
162 162 def trystat(path):
163 163 try:
164 164 st = os.stat(path)
165 165 return (st.st_mtime, st.st_size)
166 166 except OSError:
167 167 # could be ENOENT, EPERM etc. not fatal in any case
168 168 pass
169 169 return _hashlist(map(trystat, paths))[:12]
170 170
171 171 class hashstate(object):
172 172 """a structure storing confighash, mtimehash, paths used for mtimehash"""
173 173 def __init__(self, confighash, mtimehash, mtimepaths):
174 174 self.confighash = confighash
175 175 self.mtimehash = mtimehash
176 176 self.mtimepaths = mtimepaths
177 177
178 178 @staticmethod
179 179 def fromui(ui, mtimepaths=None):
180 180 if mtimepaths is None:
181 181 mtimepaths = _getmtimepaths(ui)
182 182 confighash = _confighash(ui)
183 183 mtimehash = _mtimehash(mtimepaths)
184 184 _log('confighash = %s mtimehash = %s\n' % (confighash, mtimehash))
185 185 return hashstate(confighash, mtimehash, mtimepaths)
186 186
187 187 # copied from hgext/pager.py:uisetup()
188 188 def _setuppagercmd(ui, options, cmd):
189 189 if not ui.formatted():
190 190 return
191 191
192 192 p = ui.config("pager", "pager", os.environ.get("PAGER"))
193 193 usepager = False
194 194 always = util.parsebool(options['pager'])
195 195 auto = options['pager'] == 'auto'
196 196
197 197 if not p:
198 198 pass
199 199 elif always:
200 200 usepager = True
201 201 elif not auto:
202 202 usepager = False
203 203 else:
204 204 attended = ['annotate', 'cat', 'diff', 'export', 'glog', 'log', 'qdiff']
205 205 attend = ui.configlist('pager', 'attend', attended)
206 206 ignore = ui.configlist('pager', 'ignore')
207 207 cmds, _ = cmdutil.findcmd(cmd, commands.table)
208 208
209 209 for cmd in cmds:
210 210 var = 'attend-%s' % cmd
211 211 if ui.config('pager', var):
212 212 usepager = ui.configbool('pager', var)
213 213 break
214 214 if (cmd in attend or
215 215 (cmd not in ignore and not attend)):
216 216 usepager = True
217 217 break
218 218
219 219 if usepager:
220 220 ui.setconfig('ui', 'formatted', ui.formatted(), 'pager')
221 221 ui.setconfig('ui', 'interactive', False, 'pager')
222 222 return p
223 223
224 224 def _newchgui(srcui, csystem):
225 225 class chgui(srcui.__class__):
226 226 def __init__(self, src=None):
227 227 super(chgui, self).__init__(src)
228 228 if src:
229 229 self._csystem = getattr(src, '_csystem', csystem)
230 230 else:
231 231 self._csystem = csystem
232 232
233 233 def system(self, cmd, environ=None, cwd=None, onerr=None,
234 234 errprefix=None):
235 235 # fallback to the original system method if the output needs to be
236 236 # captured (to self._buffers), or the output stream is not stdout
237 237 # (e.g. stderr, cStringIO), because the chg client is not aware of
238 238 # these situations and will behave differently (write to stdout).
239 239 if (any(s[1] for s in self._bufferstates)
240 240 or not util.safehasattr(self.fout, 'fileno')
241 241 or self.fout.fileno() != sys.stdout.fileno()):
242 242 return super(chgui, self).system(cmd, environ, cwd, onerr,
243 243 errprefix)
244 244 # copied from mercurial/util.py:system()
245 245 self.flush()
246 246 def py2shell(val):
247 247 if val is None or val is False:
248 248 return '0'
249 249 if val is True:
250 250 return '1'
251 251 return str(val)
252 252 env = os.environ.copy()
253 253 if environ:
254 254 env.update((k, py2shell(v)) for k, v in environ.iteritems())
255 255 env['HG'] = util.hgexecutable()
256 256 rc = self._csystem(cmd, env, cwd)
257 257 if rc and onerr:
258 258 errmsg = '%s %s' % (os.path.basename(cmd.split(None, 1)[0]),
259 259 util.explainexit(rc)[0])
260 260 if errprefix:
261 261 errmsg = '%s: %s' % (errprefix, errmsg)
262 262 raise onerr(errmsg)
263 263 return rc
264 264
265 265 return chgui(srcui)
266 266
267 267 def _loadnewui(srcui, args):
268 268 newui = srcui.__class__()
269 269 for a in ['fin', 'fout', 'ferr', 'environ']:
270 270 setattr(newui, a, getattr(srcui, a))
271 271 if util.safehasattr(srcui, '_csystem'):
272 272 newui._csystem = srcui._csystem
273 273
274 274 # internal config: extensions.chgserver
275 275 newui.setconfig('extensions', 'chgserver',
276 276 srcui.config('extensions', 'chgserver'), '--config')
277 277
278 278 # command line args
279 279 args = args[:]
280 280 dispatch._parseconfig(newui, dispatch._earlygetopt(['--config'], args))
281 281
282 282 # stolen from tortoisehg.util.copydynamicconfig()
283 283 for section, name, value in srcui.walkconfig():
284 284 source = srcui.configsource(section, name)
285 285 if ':' in source or source == '--config':
286 286 # path:line or command line
287 287 continue
288 288 if source == 'none':
289 289 # ui.configsource returns 'none' by default
290 290 source = ''
291 291 newui.setconfig(section, name, value, source)
292 292
293 293 # load wd and repo config, copied from dispatch.py
294 294 cwds = dispatch._earlygetopt(['--cwd'], args)
295 295 cwd = cwds and os.path.realpath(cwds[-1]) or None
296 296 rpath = dispatch._earlygetopt(["-R", "--repository", "--repo"], args)
297 297 path, newlui = dispatch._getlocal(newui, rpath, wd=cwd)
298 298
299 299 return (newui, newlui)
300 300
301 301 class channeledsystem(object):
302 302 """Propagate ui.system() request in the following format:
303 303
304 304 payload length (unsigned int),
305 305 cmd, '\0',
306 306 cwd, '\0',
307 307 envkey, '=', val, '\0',
308 308 ...
309 309 envkey, '=', val
310 310
311 311 and waits:
312 312
313 313 exitcode length (unsigned int),
314 314 exitcode (int)
315 315 """
316 316 def __init__(self, in_, out, channel):
317 317 self.in_ = in_
318 318 self.out = out
319 319 self.channel = channel
320 320
321 321 def __call__(self, cmd, environ, cwd):
322 322 args = [util.quotecommand(cmd), os.path.abspath(cwd or '.')]
323 323 args.extend('%s=%s' % (k, v) for k, v in environ.iteritems())
324 324 data = '\0'.join(args)
325 325 self.out.write(struct.pack('>cI', self.channel, len(data)))
326 326 self.out.write(data)
327 327 self.out.flush()
328 328
329 329 length = self.in_.read(4)
330 330 length, = struct.unpack('>I', length)
331 331 if length != 4:
332 332 raise error.Abort(_('invalid response'))
333 333 rc, = struct.unpack('>i', self.in_.read(4))
334 334 return rc
335 335
336 336 _iochannels = [
337 337 # server.ch, ui.fp, mode
338 338 ('cin', 'fin', 'rb'),
339 339 ('cout', 'fout', 'wb'),
340 340 ('cerr', 'ferr', 'wb'),
341 341 ]
342 342
343 343 class chgcmdserver(commandserver.server):
344 344 def __init__(self, ui, repo, fin, fout, sock, hashstate, baseaddress):
345 345 super(chgcmdserver, self).__init__(
346 346 _newchgui(ui, channeledsystem(fin, fout, 'S')), repo, fin, fout)
347 347 self.clientsock = sock
348 348 self._oldios = [] # original (self.ch, ui.fp, fd) before "attachio"
349 349 self.hashstate = hashstate
350 350 self.baseaddress = baseaddress
351 351 if hashstate is not None:
352 352 self.capabilities = self.capabilities.copy()
353 353 self.capabilities['validate'] = chgcmdserver.validate
354 354
355 355 def cleanup(self):
356 356 # dispatch._runcatch() does not flush outputs if exception is not
357 357 # handled by dispatch._dispatch()
358 358 self.ui.flush()
359 359 self._restoreio()
360 360
361 361 def attachio(self):
362 362 """Attach to client's stdio passed via unix domain socket; all
363 363 channels except cresult will no longer be used
364 364 """
365 365 # tell client to sendmsg() with 1-byte payload, which makes it
366 366 # distinctive from "attachio\n" command consumed by client.read()
367 367 self.clientsock.sendall(struct.pack('>cI', 'I', 1))
368 368 clientfds = osutil.recvfds(self.clientsock.fileno())
369 369 _log('received fds: %r\n' % clientfds)
370 370
371 371 ui = self.ui
372 372 ui.flush()
373 373 first = self._saveio()
374 374 for fd, (cn, fn, mode) in zip(clientfds, _iochannels):
375 375 assert fd > 0
376 376 fp = getattr(ui, fn)
377 377 os.dup2(fd, fp.fileno())
378 378 os.close(fd)
379 379 if not first:
380 380 continue
381 381 # reset buffering mode when client is first attached. as we want
382 382 # to see output immediately on pager, the mode stays unchanged
383 383 # when client re-attached. ferr is unchanged because it should
384 384 # be unbuffered no matter if it is a tty or not.
385 385 if fn == 'ferr':
386 386 newfp = fp
387 387 else:
388 388 # make it line buffered explicitly because the default is
389 389 # decided on first write(), where fout could be a pager.
390 390 if fp.isatty():
391 391 bufsize = 1 # line buffered
392 392 else:
393 393 bufsize = -1 # system default
394 394 newfp = os.fdopen(fp.fileno(), mode, bufsize)
395 395 setattr(ui, fn, newfp)
396 396 setattr(self, cn, newfp)
397 397
398 398 self.cresult.write(struct.pack('>i', len(clientfds)))
399 399
400 400 def _saveio(self):
401 401 if self._oldios:
402 402 return False
403 403 ui = self.ui
404 404 for cn, fn, _mode in _iochannels:
405 405 ch = getattr(self, cn)
406 406 fp = getattr(ui, fn)
407 407 fd = os.dup(fp.fileno())
408 408 self._oldios.append((ch, fp, fd))
409 409 return True
410 410
411 411 def _restoreio(self):
412 412 ui = self.ui
413 413 for (ch, fp, fd), (cn, fn, _mode) in zip(self._oldios, _iochannels):
414 414 newfp = getattr(ui, fn)
415 415 # close newfp while it's associated with client; otherwise it
416 416 # would be closed when newfp is deleted
417 417 if newfp is not fp:
418 418 newfp.close()
419 419 # restore original fd: fp is open again
420 420 os.dup2(fd, fp.fileno())
421 421 os.close(fd)
422 422 setattr(self, cn, ch)
423 423 setattr(ui, fn, fp)
424 424 del self._oldios[:]
425 425
426 426 def validate(self):
427 427 """Reload the config and check if the server is up to date
428 428
429 429 Read a list of '\0' separated arguments.
430 430 Write a non-empty list of '\0' separated instruction strings or '\0'
431 431 if the list is empty.
432 432 An instruction string could be either:
433 433 - "unlink $path", the client should unlink the path to stop the
434 434 outdated server.
435 435 - "redirect $path", the client should attempt to connect to $path
436 436 first. If it does not work, start a new server. It implies
437 437 "reconnect".
438 438 - "exit $n", the client should exit directly with code n.
439 439 This may happen if we cannot parse the config.
440 440 - "reconnect", the client should close the connection and
441 441 reconnect.
442 442 If neither "reconnect" nor "redirect" is included in the instruction
443 443 list, the client can continue with this server after completing all
444 444 the instructions.
445 445 """
446 446 args = self._readlist()
447 447 try:
448 448 self.ui, lui = _loadnewui(self.ui, args)
449 449 except error.ParseError as inst:
450 450 dispatch._formatparse(self.ui.warn, inst)
451 451 self.ui.flush()
452 452 self.cresult.write('exit 255')
453 453 return
454 454 newhash = hashstate.fromui(lui, self.hashstate.mtimepaths)
455 455 insts = []
456 456 if newhash.mtimehash != self.hashstate.mtimehash:
457 457 addr = _hashaddress(self.baseaddress, self.hashstate.confighash)
458 458 insts.append('unlink %s' % addr)
459 459 # mtimehash is empty if one or more extensions fail to load.
460 460 # to be compatible with hg, still serve the client this time.
461 461 if self.hashstate.mtimehash:
462 462 insts.append('reconnect')
463 463 if newhash.confighash != self.hashstate.confighash:
464 464 addr = _hashaddress(self.baseaddress, newhash.confighash)
465 465 insts.append('redirect %s' % addr)
466 466 _log('validate: %s\n' % insts)
467 467 self.cresult.write('\0'.join(insts) or '\0')
468 468
469 469 def chdir(self):
470 470 """Change current directory
471 471
472 472 Note that the behavior of --cwd option is bit different from this.
473 473 It does not affect --config parameter.
474 474 """
475 475 path = self._readstr()
476 476 if not path:
477 477 return
478 478 _log('chdir to %r\n' % path)
479 479 os.chdir(path)
480 480
481 481 def setumask(self):
482 482 """Change umask"""
483 483 mask = struct.unpack('>I', self._read(4))[0]
484 484 _log('setumask %r\n' % mask)
485 485 os.umask(mask)
486 486
487 487 def getpager(self):
488 488 """Read cmdargs and write pager command to r-channel if enabled
489 489
490 490 If pager isn't enabled, this writes '\0' because channeledoutput
491 491 does not allow to write empty data.
492 492 """
493 493 args = self._readlist()
494 494 try:
495 495 cmd, _func, args, options, _cmdoptions = dispatch._parse(self.ui,
496 496 args)
497 497 except (error.Abort, error.AmbiguousCommand, error.CommandError,
498 498 error.UnknownCommand):
499 499 cmd = None
500 500 options = {}
501 501 if not cmd or 'pager' not in options:
502 502 self.cresult.write('\0')
503 503 return
504 504
505 505 pagercmd = _setuppagercmd(self.ui, options, cmd)
506 506 if pagercmd:
507 507 # Python's SIGPIPE is SIG_IGN by default. change to SIG_DFL so
508 508 # we can exit if the pipe to the pager is closed
509 509 if util.safehasattr(signal, 'SIGPIPE') and \
510 510 signal.getsignal(signal.SIGPIPE) == signal.SIG_IGN:
511 511 signal.signal(signal.SIGPIPE, signal.SIG_DFL)
512 512 self.cresult.write(pagercmd)
513 513 else:
514 514 self.cresult.write('\0')
515 515
516 516 def setenv(self):
517 517 """Clear and update os.environ
518 518
519 519 Note that not all variables can make an effect on the running process.
520 520 """
521 521 l = self._readlist()
522 522 try:
523 523 newenv = dict(s.split('=', 1) for s in l)
524 524 except ValueError:
525 525 raise ValueError('unexpected value in setenv request')
526 526 _log('setenv: %r\n' % sorted(newenv.keys()))
527 527 os.environ.clear()
528 528 os.environ.update(newenv)
529 529
530 530 capabilities = commandserver.server.capabilities.copy()
531 531 capabilities.update({'attachio': attachio,
532 532 'chdir': chdir,
533 533 'getpager': getpager,
534 534 'setenv': setenv,
535 535 'setumask': setumask})
536 536
537 537 # copied from mercurial/commandserver.py
538 538 class _requesthandler(socketserver.StreamRequestHandler):
539 539 def handle(self):
540 540 # use a different process group from the master process, making this
541 541 # process pass kernel "is_current_pgrp_orphaned" check so signals like
542 542 # SIGTSTP, SIGTTIN, SIGTTOU are not ignored.
543 543 os.setpgid(0, 0)
544 544 # change random state otherwise forked request handlers would have a
545 545 # same state inherited from parent.
546 546 random.seed()
547 547 ui = self.server.ui
548 repo = self.server.repo
549 548 sv = None
550 549 try:
551 sv = chgcmdserver(ui, repo, self.rfile, self.wfile, self.connection,
552 self.server.hashstate, self.server.baseaddress)
550 sv = self._createcmdserver()
553 551 try:
554 552 sv.serve()
555 553 # handle exceptions that may be raised by command server. most of
556 554 # known exceptions are caught by dispatch.
557 555 except error.Abort as inst:
558 556 ui.warn(_('abort: %s\n') % inst)
559 557 except IOError as inst:
560 558 if inst.errno != errno.EPIPE:
561 559 raise
562 560 except KeyboardInterrupt:
563 561 pass
564 562 finally:
565 563 sv.cleanup()
566 564 except: # re-raises
567 565 # also write traceback to error channel. otherwise client cannot
568 566 # see it because it is written to server's stderr by default.
569 567 if sv:
570 568 cerr = sv.cerr
571 569 else:
572 570 cerr = commandserver.channeledoutput(self.wfile, 'e')
573 571 traceback.print_exc(file=cerr)
574 572 raise
575 573 finally:
576 574 # trigger __del__ since ForkingMixIn uses os._exit
577 575 gc.collect()
578 576
577 def _createcmdserver(self):
578 ui = self.server.ui
579 repo = self.server.repo
580 return chgcmdserver(ui, repo, self.rfile, self.wfile, self.connection,
581 self.server.hashstate, self.server.baseaddress)
582
579 583 def _tempaddress(address):
580 584 return '%s.%d.tmp' % (address, os.getpid())
581 585
582 586 def _hashaddress(address, hashstr):
583 587 return '%s-%s' % (address, hashstr)
584 588
585 589 class AutoExitMixIn: # use old-style to comply with SocketServer design
586 590 lastactive = time.time()
587 591 idletimeout = 3600 # default 1 hour
588 592
589 593 def startautoexitthread(self):
590 594 # note: the auto-exit check here is cheap enough to not use a thread,
591 595 # be done in serve_forever. however SocketServer is hook-unfriendly,
592 596 # you simply cannot hook serve_forever without copying a lot of code.
593 597 # besides, serve_forever's docstring suggests using thread.
594 598 thread = threading.Thread(target=self._autoexitloop)
595 599 thread.daemon = True
596 600 thread.start()
597 601
598 602 def _autoexitloop(self, interval=1):
599 603 while True:
600 604 time.sleep(interval)
601 605 if not self.issocketowner():
602 606 _log('%s is not owned, exiting.\n' % self.server_address)
603 607 break
604 608 if time.time() - self.lastactive > self.idletimeout:
605 609 _log('being idle too long. exiting.\n')
606 610 break
607 611 self.shutdown()
608 612
609 613 def process_request(self, request, address):
610 614 self.lastactive = time.time()
611 615 return socketserver.ForkingMixIn.process_request(
612 616 self, request, address)
613 617
614 618 def server_bind(self):
615 619 # use a unique temp address so we can stat the file and do ownership
616 620 # check later
617 621 tempaddress = _tempaddress(self.server_address)
618 622 # use relative path instead of full path at bind() if possible, since
619 623 # AF_UNIX path has very small length limit (107 chars) on common
620 624 # platforms (see sys/un.h)
621 625 dirname, basename = os.path.split(tempaddress)
622 626 bakwdfd = None
623 627 if dirname:
624 628 bakwdfd = os.open('.', os.O_DIRECTORY)
625 629 os.chdir(dirname)
626 630 self.socket.bind(basename)
627 631 self._socketstat = os.stat(basename)
628 632 # rename will replace the old socket file if exists atomically. the
629 633 # old server will detect ownership change and exit.
630 634 util.rename(basename, self.server_address)
631 635 if bakwdfd:
632 636 os.fchdir(bakwdfd)
633 637 os.close(bakwdfd)
634 638
635 639 def issocketowner(self):
636 640 try:
637 641 stat = os.stat(self.server_address)
638 642 return (stat.st_ino == self._socketstat.st_ino and
639 643 stat.st_mtime == self._socketstat.st_mtime)
640 644 except OSError:
641 645 return False
642 646
643 647 def unlinksocketfile(self):
644 648 if not self.issocketowner():
645 649 return
646 650 # it is possible to have a race condition here that we may
647 651 # remove another server's socket file. but that's okay
648 652 # since that server will detect and exit automatically and
649 653 # the client will start a new server on demand.
650 654 try:
651 655 os.unlink(self.server_address)
652 656 except OSError as exc:
653 657 if exc.errno != errno.ENOENT:
654 658 raise
655 659
656 660 class chgunixservice(commandserver.unixservice):
657 661 def init(self):
658 662 if self.repo:
659 663 # one chgserver can serve multiple repos. drop repo infomation
660 664 self.ui.setconfig('bundle', 'mainreporoot', '', 'repo')
661 665 self.repo = None
662 666 self._inithashstate()
663 667 self._checkextensions()
664 668 class cls(AutoExitMixIn, socketserver.ForkingMixIn,
665 669 socketserver.UnixStreamServer):
666 670 ui = self.ui
667 671 repo = self.repo
668 672 hashstate = self.hashstate
669 673 baseaddress = self.baseaddress
670 674 self.server = cls(self.address, _requesthandler)
671 675 self.server.idletimeout = self.ui.configint(
672 676 'chgserver', 'idletimeout', self.server.idletimeout)
673 677 self.server.startautoexitthread()
674 678 self._createsymlink()
675 679
676 680 def _inithashstate(self):
677 681 self.baseaddress = self.address
678 682 if self.ui.configbool('chgserver', 'skiphash', False):
679 683 self.hashstate = None
680 684 return
681 685 self.hashstate = hashstate.fromui(self.ui)
682 686 self.address = _hashaddress(self.address, self.hashstate.confighash)
683 687
684 688 def _checkextensions(self):
685 689 if not self.hashstate:
686 690 return
687 691 if extensions.notloaded():
688 692 # one or more extensions failed to load. mtimehash becomes
689 693 # meaningless because we do not know the paths of those extensions.
690 694 # set mtimehash to an illegal hash value to invalidate the server.
691 695 self.hashstate.mtimehash = ''
692 696
693 697 def _createsymlink(self):
694 698 if self.baseaddress == self.address:
695 699 return
696 700 tempaddress = _tempaddress(self.baseaddress)
697 701 os.symlink(os.path.basename(self.address), tempaddress)
698 702 util.rename(tempaddress, self.baseaddress)
699 703
700 704 def run(self):
701 705 try:
702 706 self.server.serve_forever()
703 707 finally:
704 708 self.server.unlinksocketfile()
705 709
706 710 def uisetup(ui):
707 711 commandserver._servicemap['chgunix'] = chgunixservice
708 712
709 713 # CHGINTERNALMARK is temporarily set by chg client to detect if chg will
710 714 # start another chg. drop it to avoid possible side effects.
711 715 if 'CHGINTERNALMARK' in os.environ:
712 716 del os.environ['CHGINTERNALMARK']
@@ -1,399 +1,403 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import errno
11 11 import os
12 12 import struct
13 13 import sys
14 14 import traceback
15 15
16 16 from .i18n import _
17 17 from . import (
18 18 encoding,
19 19 error,
20 20 util,
21 21 )
22 22
23 23 socketserver = util.socketserver
24 24
25 25 logfile = None
26 26
27 27 def log(*args):
28 28 if not logfile:
29 29 return
30 30
31 31 for a in args:
32 32 logfile.write(str(a))
33 33
34 34 logfile.flush()
35 35
36 36 class channeledoutput(object):
37 37 """
38 38 Write data to out in the following format:
39 39
40 40 data length (unsigned int),
41 41 data
42 42 """
43 43 def __init__(self, out, channel):
44 44 self.out = out
45 45 self.channel = channel
46 46
47 47 @property
48 48 def name(self):
49 49 return '<%c-channel>' % self.channel
50 50
51 51 def write(self, data):
52 52 if not data:
53 53 return
54 54 self.out.write(struct.pack('>cI', self.channel, len(data)))
55 55 self.out.write(data)
56 56 self.out.flush()
57 57
58 58 def __getattr__(self, attr):
59 59 if attr in ('isatty', 'fileno', 'tell', 'seek'):
60 60 raise AttributeError(attr)
61 61 return getattr(self.out, attr)
62 62
63 63 class channeledinput(object):
64 64 """
65 65 Read data from in_.
66 66
67 67 Requests for input are written to out in the following format:
68 68 channel identifier - 'I' for plain input, 'L' line based (1 byte)
69 69 how many bytes to send at most (unsigned int),
70 70
71 71 The client replies with:
72 72 data length (unsigned int), 0 meaning EOF
73 73 data
74 74 """
75 75
76 76 maxchunksize = 4 * 1024
77 77
78 78 def __init__(self, in_, out, channel):
79 79 self.in_ = in_
80 80 self.out = out
81 81 self.channel = channel
82 82
83 83 @property
84 84 def name(self):
85 85 return '<%c-channel>' % self.channel
86 86
87 87 def read(self, size=-1):
88 88 if size < 0:
89 89 # if we need to consume all the clients input, ask for 4k chunks
90 90 # so the pipe doesn't fill up risking a deadlock
91 91 size = self.maxchunksize
92 92 s = self._read(size, self.channel)
93 93 buf = s
94 94 while s:
95 95 s = self._read(size, self.channel)
96 96 buf += s
97 97
98 98 return buf
99 99 else:
100 100 return self._read(size, self.channel)
101 101
102 102 def _read(self, size, channel):
103 103 if not size:
104 104 return ''
105 105 assert size > 0
106 106
107 107 # tell the client we need at most size bytes
108 108 self.out.write(struct.pack('>cI', channel, size))
109 109 self.out.flush()
110 110
111 111 length = self.in_.read(4)
112 112 length = struct.unpack('>I', length)[0]
113 113 if not length:
114 114 return ''
115 115 else:
116 116 return self.in_.read(length)
117 117
118 118 def readline(self, size=-1):
119 119 if size < 0:
120 120 size = self.maxchunksize
121 121 s = self._read(size, 'L')
122 122 buf = s
123 123 # keep asking for more until there's either no more or
124 124 # we got a full line
125 125 while s and s[-1] != '\n':
126 126 s = self._read(size, 'L')
127 127 buf += s
128 128
129 129 return buf
130 130 else:
131 131 return self._read(size, 'L')
132 132
133 133 def __iter__(self):
134 134 return self
135 135
136 136 def next(self):
137 137 l = self.readline()
138 138 if not l:
139 139 raise StopIteration
140 140 return l
141 141
142 142 def __getattr__(self, attr):
143 143 if attr in ('isatty', 'fileno', 'tell', 'seek'):
144 144 raise AttributeError(attr)
145 145 return getattr(self.in_, attr)
146 146
147 147 class server(object):
148 148 """
149 149 Listens for commands on fin, runs them and writes the output on a channel
150 150 based stream to fout.
151 151 """
152 152 def __init__(self, ui, repo, fin, fout):
153 153 self.cwd = os.getcwd()
154 154
155 155 # developer config: cmdserver.log
156 156 logpath = ui.config("cmdserver", "log", None)
157 157 if logpath:
158 158 global logfile
159 159 if logpath == '-':
160 160 # write log on a special 'd' (debug) channel
161 161 logfile = channeledoutput(fout, 'd')
162 162 else:
163 163 logfile = open(logpath, 'a')
164 164
165 165 if repo:
166 166 # the ui here is really the repo ui so take its baseui so we don't
167 167 # end up with its local configuration
168 168 self.ui = repo.baseui
169 169 self.repo = repo
170 170 self.repoui = repo.ui
171 171 else:
172 172 self.ui = ui
173 173 self.repo = self.repoui = None
174 174
175 175 self.cerr = channeledoutput(fout, 'e')
176 176 self.cout = channeledoutput(fout, 'o')
177 177 self.cin = channeledinput(fin, fout, 'I')
178 178 self.cresult = channeledoutput(fout, 'r')
179 179
180 180 self.client = fin
181 181
182 182 def _read(self, size):
183 183 if not size:
184 184 return ''
185 185
186 186 data = self.client.read(size)
187 187
188 188 # is the other end closed?
189 189 if not data:
190 190 raise EOFError
191 191
192 192 return data
193 193
194 194 def _readstr(self):
195 195 """read a string from the channel
196 196
197 197 format:
198 198 data length (uint32), data
199 199 """
200 200 length = struct.unpack('>I', self._read(4))[0]
201 201 if not length:
202 202 return ''
203 203 return self._read(length)
204 204
205 205 def _readlist(self):
206 206 """read a list of NULL separated strings from the channel"""
207 207 s = self._readstr()
208 208 if s:
209 209 return s.split('\0')
210 210 else:
211 211 return []
212 212
213 213 def runcommand(self):
214 214 """ reads a list of \0 terminated arguments, executes
215 215 and writes the return code to the result channel """
216 216 from . import dispatch # avoid cycle
217 217
218 218 args = self._readlist()
219 219
220 220 # copy the uis so changes (e.g. --config or --verbose) don't
221 221 # persist between requests
222 222 copiedui = self.ui.copy()
223 223 uis = [copiedui]
224 224 if self.repo:
225 225 self.repo.baseui = copiedui
226 226 # clone ui without using ui.copy because this is protected
227 227 repoui = self.repoui.__class__(self.repoui)
228 228 repoui.copy = copiedui.copy # redo copy protection
229 229 uis.append(repoui)
230 230 self.repo.ui = self.repo.dirstate._ui = repoui
231 231 self.repo.invalidateall()
232 232
233 233 for ui in uis:
234 234 ui.resetstate()
235 235 # any kind of interaction must use server channels, but chg may
236 236 # replace channels by fully functional tty files. so nontty is
237 237 # enforced only if cin is a channel.
238 238 if not util.safehasattr(self.cin, 'fileno'):
239 239 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
240 240
241 241 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
242 242 self.cout, self.cerr)
243 243
244 244 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
245 245
246 246 # restore old cwd
247 247 if '--cwd' in args:
248 248 os.chdir(self.cwd)
249 249
250 250 self.cresult.write(struct.pack('>i', int(ret)))
251 251
252 252 def getencoding(self):
253 253 """ writes the current encoding to the result channel """
254 254 self.cresult.write(encoding.encoding)
255 255
256 256 def serveone(self):
257 257 cmd = self.client.readline()[:-1]
258 258 if cmd:
259 259 handler = self.capabilities.get(cmd)
260 260 if handler:
261 261 handler(self)
262 262 else:
263 263 # clients are expected to check what commands are supported by
264 264 # looking at the servers capabilities
265 265 raise error.Abort(_('unknown command %s') % cmd)
266 266
267 267 return cmd != ''
268 268
269 269 capabilities = {'runcommand' : runcommand,
270 270 'getencoding' : getencoding}
271 271
272 272 def serve(self):
273 273 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
274 274 hellomsg += '\n'
275 275 hellomsg += 'encoding: ' + encoding.encoding
276 276 hellomsg += '\n'
277 277 hellomsg += 'pid: %d' % util.getpid()
278 278
279 279 # write the hello msg in -one- chunk
280 280 self.cout.write(hellomsg)
281 281
282 282 try:
283 283 while self.serveone():
284 284 pass
285 285 except EOFError:
286 286 # we'll get here if the client disconnected while we were reading
287 287 # its request
288 288 return 1
289 289
290 290 return 0
291 291
292 292 def _protectio(ui):
293 293 """ duplicates streams and redirect original to null if ui uses stdio """
294 294 ui.flush()
295 295 newfiles = []
296 296 nullfd = os.open(os.devnull, os.O_RDWR)
297 297 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
298 298 (ui.fout, sys.stdout, 'wb')]:
299 299 if f is sysf:
300 300 newfd = os.dup(f.fileno())
301 301 os.dup2(nullfd, f.fileno())
302 302 f = os.fdopen(newfd, mode)
303 303 newfiles.append(f)
304 304 os.close(nullfd)
305 305 return tuple(newfiles)
306 306
307 307 def _restoreio(ui, fin, fout):
308 308 """ restores streams from duplicated ones """
309 309 ui.flush()
310 310 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
311 311 if f is not uif:
312 312 os.dup2(f.fileno(), uif.fileno())
313 313 f.close()
314 314
315 315 class pipeservice(object):
316 316 def __init__(self, ui, repo, opts):
317 317 self.ui = ui
318 318 self.repo = repo
319 319
320 320 def init(self):
321 321 pass
322 322
323 323 def run(self):
324 324 ui = self.ui
325 325 # redirect stdio to null device so that broken extensions or in-process
326 326 # hooks will never cause corruption of channel protocol.
327 327 fin, fout = _protectio(ui)
328 328 try:
329 329 sv = server(ui, self.repo, fin, fout)
330 330 return sv.serve()
331 331 finally:
332 332 _restoreio(ui, fin, fout)
333 333
334 334 class _requesthandler(socketserver.StreamRequestHandler):
335 335 def handle(self):
336 336 ui = self.server.ui
337 repo = self.server.repo
338 337 sv = None
339 338 try:
340 sv = server(ui, repo, self.rfile, self.wfile)
339 sv = self._createcmdserver()
341 340 try:
342 341 sv.serve()
343 342 # handle exceptions that may be raised by command server. most of
344 343 # known exceptions are caught by dispatch.
345 344 except error.Abort as inst:
346 345 ui.warn(_('abort: %s\n') % inst)
347 346 except IOError as inst:
348 347 if inst.errno != errno.EPIPE:
349 348 raise
350 349 except KeyboardInterrupt:
351 350 pass
352 351 except: # re-raises
353 352 # also write traceback to error channel. otherwise client cannot
354 353 # see it because it is written to server's stderr by default.
355 354 if sv:
356 355 cerr = sv.cerr
357 356 else:
358 357 cerr = channeledoutput(self.wfile, 'e')
359 358 traceback.print_exc(file=cerr)
360 359 raise
361 360
361 def _createcmdserver(self):
362 ui = self.server.ui
363 repo = self.server.repo
364 return server(ui, repo, self.rfile, self.wfile)
365
362 366 class unixservice(object):
363 367 """
364 368 Listens on unix domain socket and forks server per connection
365 369 """
366 370 def __init__(self, ui, repo, opts):
367 371 self.ui = ui
368 372 self.repo = repo
369 373 self.address = opts['address']
370 374 if not util.safehasattr(socketserver, 'UnixStreamServer'):
371 375 raise error.Abort(_('unsupported platform'))
372 376 if not self.address:
373 377 raise error.Abort(_('no socket path specified with --address'))
374 378
375 379 def init(self):
376 380 class cls(socketserver.ForkingMixIn, socketserver.UnixStreamServer):
377 381 ui = self.ui
378 382 repo = self.repo
379 383 self.server = cls(self.address, _requesthandler)
380 384 self.ui.status(_('listening at %s\n') % self.address)
381 385 self.ui.flush() # avoid buffering of status message
382 386
383 387 def run(self):
384 388 try:
385 389 self.server.serve_forever()
386 390 finally:
387 391 os.unlink(self.address)
388 392
389 393 _servicemap = {
390 394 'pipe': pipeservice,
391 395 'unix': unixservice,
392 396 }
393 397
394 398 def createservice(ui, repo, opts):
395 399 mode = opts['cmdserver']
396 400 try:
397 401 return _servicemap[mode](ui, repo, opts)
398 402 except KeyError:
399 403 raise error.Abort(_('unknown mode %s') % mode)
General Comments 0
You need to be logged in to leave comments. Login now