##// END OF EJS Templates
commandserver: prevent unlink socket twice...
Jun Wu -
r30887:a95fc01a default
parent child Browse files
Show More
@@ -1,532 +1,539 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import errno
11 11 import gc
12 12 import os
13 13 import random
14 14 import select
15 15 import signal
16 16 import socket
17 17 import struct
18 18 import traceback
19 19
20 20 from .i18n import _
21 21 from . import (
22 22 encoding,
23 23 error,
24 24 pycompat,
25 25 util,
26 26 )
27 27
28 28 logfile = None
29 29
30 30 def log(*args):
31 31 if not logfile:
32 32 return
33 33
34 34 for a in args:
35 35 logfile.write(str(a))
36 36
37 37 logfile.flush()
38 38
39 39 class channeledoutput(object):
40 40 """
41 41 Write data to out in the following format:
42 42
43 43 data length (unsigned int),
44 44 data
45 45 """
46 46 def __init__(self, out, channel):
47 47 self.out = out
48 48 self.channel = channel
49 49
50 50 @property
51 51 def name(self):
52 52 return '<%c-channel>' % self.channel
53 53
54 54 def write(self, data):
55 55 if not data:
56 56 return
57 57 # single write() to guarantee the same atomicity as the underlying file
58 58 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
59 59 self.out.flush()
60 60
61 61 def __getattr__(self, attr):
62 62 if attr in ('isatty', 'fileno', 'tell', 'seek'):
63 63 raise AttributeError(attr)
64 64 return getattr(self.out, attr)
65 65
66 66 class channeledinput(object):
67 67 """
68 68 Read data from in_.
69 69
70 70 Requests for input are written to out in the following format:
71 71 channel identifier - 'I' for plain input, 'L' line based (1 byte)
72 72 how many bytes to send at most (unsigned int),
73 73
74 74 The client replies with:
75 75 data length (unsigned int), 0 meaning EOF
76 76 data
77 77 """
78 78
79 79 maxchunksize = 4 * 1024
80 80
81 81 def __init__(self, in_, out, channel):
82 82 self.in_ = in_
83 83 self.out = out
84 84 self.channel = channel
85 85
86 86 @property
87 87 def name(self):
88 88 return '<%c-channel>' % self.channel
89 89
90 90 def read(self, size=-1):
91 91 if size < 0:
92 92 # if we need to consume all the clients input, ask for 4k chunks
93 93 # so the pipe doesn't fill up risking a deadlock
94 94 size = self.maxchunksize
95 95 s = self._read(size, self.channel)
96 96 buf = s
97 97 while s:
98 98 s = self._read(size, self.channel)
99 99 buf += s
100 100
101 101 return buf
102 102 else:
103 103 return self._read(size, self.channel)
104 104
105 105 def _read(self, size, channel):
106 106 if not size:
107 107 return ''
108 108 assert size > 0
109 109
110 110 # tell the client we need at most size bytes
111 111 self.out.write(struct.pack('>cI', channel, size))
112 112 self.out.flush()
113 113
114 114 length = self.in_.read(4)
115 115 length = struct.unpack('>I', length)[0]
116 116 if not length:
117 117 return ''
118 118 else:
119 119 return self.in_.read(length)
120 120
121 121 def readline(self, size=-1):
122 122 if size < 0:
123 123 size = self.maxchunksize
124 124 s = self._read(size, 'L')
125 125 buf = s
126 126 # keep asking for more until there's either no more or
127 127 # we got a full line
128 128 while s and s[-1] != '\n':
129 129 s = self._read(size, 'L')
130 130 buf += s
131 131
132 132 return buf
133 133 else:
134 134 return self._read(size, 'L')
135 135
136 136 def __iter__(self):
137 137 return self
138 138
139 139 def next(self):
140 140 l = self.readline()
141 141 if not l:
142 142 raise StopIteration
143 143 return l
144 144
145 145 def __getattr__(self, attr):
146 146 if attr in ('isatty', 'fileno', 'tell', 'seek'):
147 147 raise AttributeError(attr)
148 148 return getattr(self.in_, attr)
149 149
150 150 class server(object):
151 151 """
152 152 Listens for commands on fin, runs them and writes the output on a channel
153 153 based stream to fout.
154 154 """
155 155 def __init__(self, ui, repo, fin, fout):
156 156 self.cwd = pycompat.getcwd()
157 157
158 158 # developer config: cmdserver.log
159 159 logpath = ui.config("cmdserver", "log", None)
160 160 if logpath:
161 161 global logfile
162 162 if logpath == '-':
163 163 # write log on a special 'd' (debug) channel
164 164 logfile = channeledoutput(fout, 'd')
165 165 else:
166 166 logfile = open(logpath, 'a')
167 167
168 168 if repo:
169 169 # the ui here is really the repo ui so take its baseui so we don't
170 170 # end up with its local configuration
171 171 self.ui = repo.baseui
172 172 self.repo = repo
173 173 self.repoui = repo.ui
174 174 else:
175 175 self.ui = ui
176 176 self.repo = self.repoui = None
177 177
178 178 self.cerr = channeledoutput(fout, 'e')
179 179 self.cout = channeledoutput(fout, 'o')
180 180 self.cin = channeledinput(fin, fout, 'I')
181 181 self.cresult = channeledoutput(fout, 'r')
182 182
183 183 self.client = fin
184 184
185 185 def cleanup(self):
186 186 """release and restore resources taken during server session"""
187 187 pass
188 188
189 189 def _read(self, size):
190 190 if not size:
191 191 return ''
192 192
193 193 data = self.client.read(size)
194 194
195 195 # is the other end closed?
196 196 if not data:
197 197 raise EOFError
198 198
199 199 return data
200 200
201 201 def _readstr(self):
202 202 """read a string from the channel
203 203
204 204 format:
205 205 data length (uint32), data
206 206 """
207 207 length = struct.unpack('>I', self._read(4))[0]
208 208 if not length:
209 209 return ''
210 210 return self._read(length)
211 211
212 212 def _readlist(self):
213 213 """read a list of NULL separated strings from the channel"""
214 214 s = self._readstr()
215 215 if s:
216 216 return s.split('\0')
217 217 else:
218 218 return []
219 219
220 220 def runcommand(self):
221 221 """ reads a list of \0 terminated arguments, executes
222 222 and writes the return code to the result channel """
223 223 from . import dispatch # avoid cycle
224 224
225 225 args = self._readlist()
226 226
227 227 # copy the uis so changes (e.g. --config or --verbose) don't
228 228 # persist between requests
229 229 copiedui = self.ui.copy()
230 230 uis = [copiedui]
231 231 if self.repo:
232 232 self.repo.baseui = copiedui
233 233 # clone ui without using ui.copy because this is protected
234 234 repoui = self.repoui.__class__(self.repoui)
235 235 repoui.copy = copiedui.copy # redo copy protection
236 236 uis.append(repoui)
237 237 self.repo.ui = self.repo.dirstate._ui = repoui
238 238 self.repo.invalidateall()
239 239
240 240 for ui in uis:
241 241 ui.resetstate()
242 242 # any kind of interaction must use server channels, but chg may
243 243 # replace channels by fully functional tty files. so nontty is
244 244 # enforced only if cin is a channel.
245 245 if not util.safehasattr(self.cin, 'fileno'):
246 246 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
247 247
248 248 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
249 249 self.cout, self.cerr)
250 250
251 251 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
252 252
253 253 # restore old cwd
254 254 if '--cwd' in args:
255 255 os.chdir(self.cwd)
256 256
257 257 self.cresult.write(struct.pack('>i', int(ret)))
258 258
259 259 def getencoding(self):
260 260 """ writes the current encoding to the result channel """
261 261 self.cresult.write(encoding.encoding)
262 262
263 263 def serveone(self):
264 264 cmd = self.client.readline()[:-1]
265 265 if cmd:
266 266 handler = self.capabilities.get(cmd)
267 267 if handler:
268 268 handler(self)
269 269 else:
270 270 # clients are expected to check what commands are supported by
271 271 # looking at the servers capabilities
272 272 raise error.Abort(_('unknown command %s') % cmd)
273 273
274 274 return cmd != ''
275 275
276 276 capabilities = {'runcommand' : runcommand,
277 277 'getencoding' : getencoding}
278 278
279 279 def serve(self):
280 280 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
281 281 hellomsg += '\n'
282 282 hellomsg += 'encoding: ' + encoding.encoding
283 283 hellomsg += '\n'
284 284 hellomsg += 'pid: %d' % util.getpid()
285 285 if util.safehasattr(os, 'getpgid'):
286 286 hellomsg += '\n'
287 287 hellomsg += 'pgid: %d' % os.getpgid(0)
288 288
289 289 # write the hello msg in -one- chunk
290 290 self.cout.write(hellomsg)
291 291
292 292 try:
293 293 while self.serveone():
294 294 pass
295 295 except EOFError:
296 296 # we'll get here if the client disconnected while we were reading
297 297 # its request
298 298 return 1
299 299
300 300 return 0
301 301
302 302 def _protectio(ui):
303 303 """ duplicates streams and redirect original to null if ui uses stdio """
304 304 ui.flush()
305 305 newfiles = []
306 306 nullfd = os.open(os.devnull, os.O_RDWR)
307 307 for f, sysf, mode in [(ui.fin, util.stdin, 'rb'),
308 308 (ui.fout, util.stdout, 'wb')]:
309 309 if f is sysf:
310 310 newfd = os.dup(f.fileno())
311 311 os.dup2(nullfd, f.fileno())
312 312 f = os.fdopen(newfd, mode)
313 313 newfiles.append(f)
314 314 os.close(nullfd)
315 315 return tuple(newfiles)
316 316
317 317 def _restoreio(ui, fin, fout):
318 318 """ restores streams from duplicated ones """
319 319 ui.flush()
320 320 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
321 321 if f is not uif:
322 322 os.dup2(f.fileno(), uif.fileno())
323 323 f.close()
324 324
325 325 class pipeservice(object):
326 326 def __init__(self, ui, repo, opts):
327 327 self.ui = ui
328 328 self.repo = repo
329 329
330 330 def init(self):
331 331 pass
332 332
333 333 def run(self):
334 334 ui = self.ui
335 335 # redirect stdio to null device so that broken extensions or in-process
336 336 # hooks will never cause corruption of channel protocol.
337 337 fin, fout = _protectio(ui)
338 338 try:
339 339 sv = server(ui, self.repo, fin, fout)
340 340 return sv.serve()
341 341 finally:
342 342 sv.cleanup()
343 343 _restoreio(ui, fin, fout)
344 344
345 345 def _initworkerprocess():
346 346 # use a different process group from the master process, in order to:
347 347 # 1. make the current process group no longer "orphaned" (because the
348 348 # parent of this process is in a different process group while
349 349 # remains in a same session)
350 350 # according to POSIX 2.2.2.52, orphaned process group will ignore
351 351 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
352 352 # cause trouble for things like ncurses.
353 353 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
354 354 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
355 355 # processes like ssh will be killed properly, without affecting
356 356 # unrelated processes.
357 357 os.setpgid(0, 0)
358 358 # change random state otherwise forked request handlers would have a
359 359 # same state inherited from parent.
360 360 random.seed()
361 361
362 362 def _serverequest(ui, repo, conn, createcmdserver):
363 363 fin = conn.makefile('rb')
364 364 fout = conn.makefile('wb')
365 365 sv = None
366 366 try:
367 367 sv = createcmdserver(repo, conn, fin, fout)
368 368 try:
369 369 sv.serve()
370 370 # handle exceptions that may be raised by command server. most of
371 371 # known exceptions are caught by dispatch.
372 372 except error.Abort as inst:
373 373 ui.warn(_('abort: %s\n') % inst)
374 374 except IOError as inst:
375 375 if inst.errno != errno.EPIPE:
376 376 raise
377 377 except KeyboardInterrupt:
378 378 pass
379 379 finally:
380 380 sv.cleanup()
381 381 except: # re-raises
382 382 # also write traceback to error channel. otherwise client cannot
383 383 # see it because it is written to server's stderr by default.
384 384 if sv:
385 385 cerr = sv.cerr
386 386 else:
387 387 cerr = channeledoutput(fout, 'e')
388 388 traceback.print_exc(file=cerr)
389 389 raise
390 390 finally:
391 391 fin.close()
392 392 try:
393 393 fout.close() # implicit flush() may cause another EPIPE
394 394 except IOError as inst:
395 395 if inst.errno != errno.EPIPE:
396 396 raise
397 397
398 398 class unixservicehandler(object):
399 399 """Set of pluggable operations for unix-mode services
400 400
401 401 Almost all methods except for createcmdserver() are called in the main
402 402 process. You can't pass mutable resource back from createcmdserver().
403 403 """
404 404
405 405 pollinterval = None
406 406
407 407 def __init__(self, ui):
408 408 self.ui = ui
409 409
410 410 def bindsocket(self, sock, address):
411 411 util.bindunixsocket(sock, address)
412 412
413 413 def unlinksocket(self, address):
414 414 os.unlink(address)
415 415
416 416 def printbanner(self, address):
417 417 self.ui.status(_('listening at %s\n') % address)
418 418 self.ui.flush() # avoid buffering of status message
419 419
420 420 def shouldexit(self):
421 421 """True if server should shut down; checked per pollinterval"""
422 422 return False
423 423
424 424 def newconnection(self):
425 425 """Called when main process notices new connection"""
426 426 pass
427 427
428 428 def createcmdserver(self, repo, conn, fin, fout):
429 429 """Create new command server instance; called in the process that
430 430 serves for the current connection"""
431 431 return server(self.ui, repo, fin, fout)
432 432
433 433 class unixforkingservice(object):
434 434 """
435 435 Listens on unix domain socket and forks server per connection
436 436 """
437 437
438 438 def __init__(self, ui, repo, opts, handler=None):
439 439 self.ui = ui
440 440 self.repo = repo
441 441 self.address = opts['address']
442 442 if not util.safehasattr(socket, 'AF_UNIX'):
443 443 raise error.Abort(_('unsupported platform'))
444 444 if not self.address:
445 445 raise error.Abort(_('no socket path specified with --address'))
446 446 self._servicehandler = handler or unixservicehandler(ui)
447 447 self._sock = None
448 448 self._oldsigchldhandler = None
449 449 self._workerpids = set() # updated by signal handler; do not iterate
450 self._socketunlinked = None
450 451
451 452 def init(self):
452 453 self._sock = socket.socket(socket.AF_UNIX)
453 454 self._servicehandler.bindsocket(self._sock, self.address)
454 455 self._sock.listen(socket.SOMAXCONN)
455 456 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
456 457 self._oldsigchldhandler = o
457 458 self._servicehandler.printbanner(self.address)
459 self._socketunlinked = False
460
461 def _unlinksocket(self):
462 if not self._socketunlinked:
463 self._servicehandler.unlinksocket(self.address)
464 self._socketunlinked = True
458 465
459 466 def _cleanup(self):
460 467 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
461 468 self._sock.close()
462 self._servicehandler.unlinksocket(self.address)
469 self._unlinksocket()
463 470 # don't kill child processes as they have active clients, just wait
464 471 self._reapworkers(0)
465 472
466 473 def run(self):
467 474 try:
468 475 self._mainloop()
469 476 finally:
470 477 self._cleanup()
471 478
472 479 def _mainloop(self):
473 480 h = self._servicehandler
474 481 while not h.shouldexit():
475 482 try:
476 483 ready = select.select([self._sock], [], [], h.pollinterval)[0]
477 484 if not ready:
478 485 continue
479 486 conn, _addr = self._sock.accept()
480 487 except (select.error, socket.error) as inst:
481 488 if inst.args[0] == errno.EINTR:
482 489 continue
483 490 raise
484 491
485 492 pid = os.fork()
486 493 if pid:
487 494 try:
488 495 self.ui.debug('forked worker process (pid=%d)\n' % pid)
489 496 self._workerpids.add(pid)
490 497 h.newconnection()
491 498 finally:
492 499 conn.close() # release handle in parent process
493 500 else:
494 501 try:
495 502 self._runworker(conn)
496 503 conn.close()
497 504 os._exit(0)
498 505 except: # never return, hence no re-raises
499 506 try:
500 507 self.ui.traceback(force=True)
501 508 finally:
502 509 os._exit(255)
503 510
504 511 def _sigchldhandler(self, signal, frame):
505 512 self._reapworkers(os.WNOHANG)
506 513
507 514 def _reapworkers(self, options):
508 515 while self._workerpids:
509 516 try:
510 517 pid, _status = os.waitpid(-1, options)
511 518 except OSError as inst:
512 519 if inst.errno == errno.EINTR:
513 520 continue
514 521 if inst.errno != errno.ECHILD:
515 522 raise
516 523 # no child processes at all (reaped by other waitpid()?)
517 524 self._workerpids.clear()
518 525 return
519 526 if pid == 0:
520 527 # no waitable child processes
521 528 return
522 529 self.ui.debug('worker process exited (pid=%d)\n' % pid)
523 530 self._workerpids.discard(pid)
524 531
525 532 def _runworker(self, conn):
526 533 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
527 534 _initworkerprocess()
528 535 h = self._servicehandler
529 536 try:
530 537 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
531 538 finally:
532 539 gc.collect() # trigger __del__ since worker process uses os._exit
General Comments 0
You need to be logged in to leave comments. Login now