##// END OF EJS Templates
commandserver: separate initialization and cleanup of forked process...
Yuya Nishihara -
r29586:42cdba9c default
parent child Browse files
Show More
@@ -1,533 +1,536 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import errno
11 11 import gc
12 12 import os
13 13 import random
14 14 import select
15 15 import signal
16 16 import socket
17 17 import struct
18 18 import sys
19 19 import traceback
20 20
21 21 from .i18n import _
22 22 from . import (
23 23 encoding,
24 24 error,
25 25 util,
26 26 )
27 27
28 28 logfile = None
29 29
30 30 def log(*args):
31 31 if not logfile:
32 32 return
33 33
34 34 for a in args:
35 35 logfile.write(str(a))
36 36
37 37 logfile.flush()
38 38
39 39 class channeledoutput(object):
40 40 """
41 41 Write data to out in the following format:
42 42
43 43 data length (unsigned int),
44 44 data
45 45 """
46 46 def __init__(self, out, channel):
47 47 self.out = out
48 48 self.channel = channel
49 49
50 50 @property
51 51 def name(self):
52 52 return '<%c-channel>' % self.channel
53 53
54 54 def write(self, data):
55 55 if not data:
56 56 return
57 57 self.out.write(struct.pack('>cI', self.channel, len(data)))
58 58 self.out.write(data)
59 59 self.out.flush()
60 60
61 61 def __getattr__(self, attr):
62 62 if attr in ('isatty', 'fileno', 'tell', 'seek'):
63 63 raise AttributeError(attr)
64 64 return getattr(self.out, attr)
65 65
66 66 class channeledinput(object):
67 67 """
68 68 Read data from in_.
69 69
70 70 Requests for input are written to out in the following format:
71 71 channel identifier - 'I' for plain input, 'L' line based (1 byte)
72 72 how many bytes to send at most (unsigned int),
73 73
74 74 The client replies with:
75 75 data length (unsigned int), 0 meaning EOF
76 76 data
77 77 """
78 78
79 79 maxchunksize = 4 * 1024
80 80
81 81 def __init__(self, in_, out, channel):
82 82 self.in_ = in_
83 83 self.out = out
84 84 self.channel = channel
85 85
86 86 @property
87 87 def name(self):
88 88 return '<%c-channel>' % self.channel
89 89
90 90 def read(self, size=-1):
91 91 if size < 0:
92 92 # if we need to consume all the clients input, ask for 4k chunks
93 93 # so the pipe doesn't fill up risking a deadlock
94 94 size = self.maxchunksize
95 95 s = self._read(size, self.channel)
96 96 buf = s
97 97 while s:
98 98 s = self._read(size, self.channel)
99 99 buf += s
100 100
101 101 return buf
102 102 else:
103 103 return self._read(size, self.channel)
104 104
105 105 def _read(self, size, channel):
106 106 if not size:
107 107 return ''
108 108 assert size > 0
109 109
110 110 # tell the client we need at most size bytes
111 111 self.out.write(struct.pack('>cI', channel, size))
112 112 self.out.flush()
113 113
114 114 length = self.in_.read(4)
115 115 length = struct.unpack('>I', length)[0]
116 116 if not length:
117 117 return ''
118 118 else:
119 119 return self.in_.read(length)
120 120
121 121 def readline(self, size=-1):
122 122 if size < 0:
123 123 size = self.maxchunksize
124 124 s = self._read(size, 'L')
125 125 buf = s
126 126 # keep asking for more until there's either no more or
127 127 # we got a full line
128 128 while s and s[-1] != '\n':
129 129 s = self._read(size, 'L')
130 130 buf += s
131 131
132 132 return buf
133 133 else:
134 134 return self._read(size, 'L')
135 135
136 136 def __iter__(self):
137 137 return self
138 138
139 139 def next(self):
140 140 l = self.readline()
141 141 if not l:
142 142 raise StopIteration
143 143 return l
144 144
145 145 def __getattr__(self, attr):
146 146 if attr in ('isatty', 'fileno', 'tell', 'seek'):
147 147 raise AttributeError(attr)
148 148 return getattr(self.in_, attr)
149 149
150 150 class server(object):
151 151 """
152 152 Listens for commands on fin, runs them and writes the output on a channel
153 153 based stream to fout.
154 154 """
155 155 def __init__(self, ui, repo, fin, fout):
156 156 self.cwd = os.getcwd()
157 157
158 158 # developer config: cmdserver.log
159 159 logpath = ui.config("cmdserver", "log", None)
160 160 if logpath:
161 161 global logfile
162 162 if logpath == '-':
163 163 # write log on a special 'd' (debug) channel
164 164 logfile = channeledoutput(fout, 'd')
165 165 else:
166 166 logfile = open(logpath, 'a')
167 167
168 168 if repo:
169 169 # the ui here is really the repo ui so take its baseui so we don't
170 170 # end up with its local configuration
171 171 self.ui = repo.baseui
172 172 self.repo = repo
173 173 self.repoui = repo.ui
174 174 else:
175 175 self.ui = ui
176 176 self.repo = self.repoui = None
177 177
178 178 self.cerr = channeledoutput(fout, 'e')
179 179 self.cout = channeledoutput(fout, 'o')
180 180 self.cin = channeledinput(fin, fout, 'I')
181 181 self.cresult = channeledoutput(fout, 'r')
182 182
183 183 self.client = fin
184 184
185 185 def cleanup(self):
186 186 """release and restore resources taken during server session"""
187 187 pass
188 188
189 189 def _read(self, size):
190 190 if not size:
191 191 return ''
192 192
193 193 data = self.client.read(size)
194 194
195 195 # is the other end closed?
196 196 if not data:
197 197 raise EOFError
198 198
199 199 return data
200 200
201 201 def _readstr(self):
202 202 """read a string from the channel
203 203
204 204 format:
205 205 data length (uint32), data
206 206 """
207 207 length = struct.unpack('>I', self._read(4))[0]
208 208 if not length:
209 209 return ''
210 210 return self._read(length)
211 211
212 212 def _readlist(self):
213 213 """read a list of NULL separated strings from the channel"""
214 214 s = self._readstr()
215 215 if s:
216 216 return s.split('\0')
217 217 else:
218 218 return []
219 219
220 220 def runcommand(self):
221 221 """ reads a list of \0 terminated arguments, executes
222 222 and writes the return code to the result channel """
223 223 from . import dispatch # avoid cycle
224 224
225 225 args = self._readlist()
226 226
227 227 # copy the uis so changes (e.g. --config or --verbose) don't
228 228 # persist between requests
229 229 copiedui = self.ui.copy()
230 230 uis = [copiedui]
231 231 if self.repo:
232 232 self.repo.baseui = copiedui
233 233 # clone ui without using ui.copy because this is protected
234 234 repoui = self.repoui.__class__(self.repoui)
235 235 repoui.copy = copiedui.copy # redo copy protection
236 236 uis.append(repoui)
237 237 self.repo.ui = self.repo.dirstate._ui = repoui
238 238 self.repo.invalidateall()
239 239
240 240 for ui in uis:
241 241 ui.resetstate()
242 242 # any kind of interaction must use server channels, but chg may
243 243 # replace channels by fully functional tty files. so nontty is
244 244 # enforced only if cin is a channel.
245 245 if not util.safehasattr(self.cin, 'fileno'):
246 246 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
247 247
248 248 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
249 249 self.cout, self.cerr)
250 250
251 251 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
252 252
253 253 # restore old cwd
254 254 if '--cwd' in args:
255 255 os.chdir(self.cwd)
256 256
257 257 self.cresult.write(struct.pack('>i', int(ret)))
258 258
259 259 def getencoding(self):
260 260 """ writes the current encoding to the result channel """
261 261 self.cresult.write(encoding.encoding)
262 262
263 263 def serveone(self):
264 264 cmd = self.client.readline()[:-1]
265 265 if cmd:
266 266 handler = self.capabilities.get(cmd)
267 267 if handler:
268 268 handler(self)
269 269 else:
270 270 # clients are expected to check what commands are supported by
271 271 # looking at the servers capabilities
272 272 raise error.Abort(_('unknown command %s') % cmd)
273 273
274 274 return cmd != ''
275 275
276 276 capabilities = {'runcommand' : runcommand,
277 277 'getencoding' : getencoding}
278 278
279 279 def serve(self):
280 280 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
281 281 hellomsg += '\n'
282 282 hellomsg += 'encoding: ' + encoding.encoding
283 283 hellomsg += '\n'
284 284 hellomsg += 'pid: %d' % util.getpid()
285 285 if util.safehasattr(os, 'getpgid'):
286 286 hellomsg += '\n'
287 287 hellomsg += 'pgid: %d' % os.getpgid(0)
288 288
289 289 # write the hello msg in -one- chunk
290 290 self.cout.write(hellomsg)
291 291
292 292 try:
293 293 while self.serveone():
294 294 pass
295 295 except EOFError:
296 296 # we'll get here if the client disconnected while we were reading
297 297 # its request
298 298 return 1
299 299
300 300 return 0
301 301
302 302 def _protectio(ui):
303 303 """ duplicates streams and redirect original to null if ui uses stdio """
304 304 ui.flush()
305 305 newfiles = []
306 306 nullfd = os.open(os.devnull, os.O_RDWR)
307 307 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
308 308 (ui.fout, sys.stdout, 'wb')]:
309 309 if f is sysf:
310 310 newfd = os.dup(f.fileno())
311 311 os.dup2(nullfd, f.fileno())
312 312 f = os.fdopen(newfd, mode)
313 313 newfiles.append(f)
314 314 os.close(nullfd)
315 315 return tuple(newfiles)
316 316
317 317 def _restoreio(ui, fin, fout):
318 318 """ restores streams from duplicated ones """
319 319 ui.flush()
320 320 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
321 321 if f is not uif:
322 322 os.dup2(f.fileno(), uif.fileno())
323 323 f.close()
324 324
325 325 class pipeservice(object):
326 326 def __init__(self, ui, repo, opts):
327 327 self.ui = ui
328 328 self.repo = repo
329 329
330 330 def init(self):
331 331 pass
332 332
333 333 def run(self):
334 334 ui = self.ui
335 335 # redirect stdio to null device so that broken extensions or in-process
336 336 # hooks will never cause corruption of channel protocol.
337 337 fin, fout = _protectio(ui)
338 338 try:
339 339 sv = server(ui, self.repo, fin, fout)
340 340 return sv.serve()
341 341 finally:
342 342 sv.cleanup()
343 343 _restoreio(ui, fin, fout)
344 344
345 def _serverequest(ui, repo, conn, createcmdserver):
345 def _initworkerprocess():
346 346 # use a different process group from the master process, making this
347 347 # process pass kernel "is_current_pgrp_orphaned" check so signals like
348 348 # SIGTSTP, SIGTTIN, SIGTTOU are not ignored.
349 349 os.setpgid(0, 0)
350 350 # change random state otherwise forked request handlers would have a
351 351 # same state inherited from parent.
352 352 random.seed()
353 353
354 def _serverequest(ui, repo, conn, createcmdserver):
354 355 fin = conn.makefile('rb')
355 356 fout = conn.makefile('wb')
356 357 sv = None
357 358 try:
358 359 sv = createcmdserver(repo, conn, fin, fout)
359 360 try:
360 361 sv.serve()
361 362 # handle exceptions that may be raised by command server. most of
362 363 # known exceptions are caught by dispatch.
363 364 except error.Abort as inst:
364 365 ui.warn(_('abort: %s\n') % inst)
365 366 except IOError as inst:
366 367 if inst.errno != errno.EPIPE:
367 368 raise
368 369 except KeyboardInterrupt:
369 370 pass
370 371 finally:
371 372 sv.cleanup()
372 373 except: # re-raises
373 374 # also write traceback to error channel. otherwise client cannot
374 375 # see it because it is written to server's stderr by default.
375 376 if sv:
376 377 cerr = sv.cerr
377 378 else:
378 379 cerr = channeledoutput(fout, 'e')
379 380 traceback.print_exc(file=cerr)
380 381 raise
381 382 finally:
382 383 fin.close()
383 384 try:
384 385 fout.close() # implicit flush() may cause another EPIPE
385 386 except IOError as inst:
386 387 if inst.errno != errno.EPIPE:
387 388 raise
388 # trigger __del__ since ForkingMixIn uses os._exit
389 gc.collect()
390 389
391 390 class unixservicehandler(object):
392 391 """Set of pluggable operations for unix-mode services
393 392
394 393 Almost all methods except for createcmdserver() are called in the main
395 394 process. You can't pass mutable resource back from createcmdserver().
396 395 """
397 396
398 397 pollinterval = None
399 398
400 399 def __init__(self, ui):
401 400 self.ui = ui
402 401
403 402 def bindsocket(self, sock, address):
404 403 util.bindunixsocket(sock, address)
405 404
406 405 def unlinksocket(self, address):
407 406 os.unlink(address)
408 407
409 408 def printbanner(self, address):
410 409 self.ui.status(_('listening at %s\n') % address)
411 410 self.ui.flush() # avoid buffering of status message
412 411
413 412 def shouldexit(self):
414 413 """True if server should shut down; checked per pollinterval"""
415 414 return False
416 415
417 416 def newconnection(self):
418 417 """Called when main process notices new connection"""
419 418 pass
420 419
421 420 def createcmdserver(self, repo, conn, fin, fout):
422 421 """Create new command server instance; called in the process that
423 422 serves for the current connection"""
424 423 return server(self.ui, repo, fin, fout)
425 424
426 425 class unixforkingservice(object):
427 426 """
428 427 Listens on unix domain socket and forks server per connection
429 428 """
430 429
431 430 def __init__(self, ui, repo, opts, handler=None):
432 431 self.ui = ui
433 432 self.repo = repo
434 433 self.address = opts['address']
435 434 if not util.safehasattr(socket, 'AF_UNIX'):
436 435 raise error.Abort(_('unsupported platform'))
437 436 if not self.address:
438 437 raise error.Abort(_('no socket path specified with --address'))
439 438 self._servicehandler = handler or unixservicehandler(ui)
440 439 self._sock = None
441 440 self._oldsigchldhandler = None
442 441 self._workerpids = set() # updated by signal handler; do not iterate
443 442
444 443 def init(self):
445 444 self._sock = socket.socket(socket.AF_UNIX)
446 445 self._servicehandler.bindsocket(self._sock, self.address)
447 446 self._sock.listen(5)
448 447 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
449 448 self._oldsigchldhandler = o
450 449 self._servicehandler.printbanner(self.address)
451 450
452 451 def _cleanup(self):
453 452 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
454 453 self._sock.close()
455 454 self._servicehandler.unlinksocket(self.address)
456 455 # don't kill child processes as they have active clients, just wait
457 456 self._reapworkers(0)
458 457
459 458 def run(self):
460 459 try:
461 460 self._mainloop()
462 461 finally:
463 462 self._cleanup()
464 463
465 464 def _mainloop(self):
466 465 h = self._servicehandler
467 466 while not h.shouldexit():
468 467 try:
469 468 ready = select.select([self._sock], [], [], h.pollinterval)[0]
470 469 if not ready:
471 470 continue
472 471 conn, _addr = self._sock.accept()
473 472 except (select.error, socket.error) as inst:
474 473 if inst.args[0] == errno.EINTR:
475 474 continue
476 475 raise
477 476
478 477 pid = os.fork()
479 478 if pid:
480 479 try:
481 480 self.ui.debug('forked worker process (pid=%d)\n' % pid)
482 481 self._workerpids.add(pid)
483 482 h.newconnection()
484 483 finally:
485 484 conn.close() # release handle in parent process
486 485 else:
487 486 try:
488 487 self._serveworker(conn)
489 488 conn.close()
490 489 os._exit(0)
491 490 except: # never return, hence no re-raises
492 491 try:
493 492 self.ui.traceback(force=True)
494 493 finally:
495 494 os._exit(255)
496 495
497 496 def _sigchldhandler(self, signal, frame):
498 497 self._reapworkers(os.WNOHANG)
499 498
500 499 def _reapworkers(self, options):
501 500 while self._workerpids:
502 501 try:
503 502 pid, _status = os.waitpid(-1, options)
504 503 except OSError as inst:
505 504 if inst.errno == errno.EINTR:
506 505 continue
507 506 if inst.errno != errno.ECHILD:
508 507 raise
509 508 # no child processes at all (reaped by other waitpid()?)
510 509 self._workerpids.clear()
511 510 return
512 511 if pid == 0:
513 512 # no waitable child processes
514 513 return
515 514 self.ui.debug('worker process exited (pid=%d)\n' % pid)
516 515 self._workerpids.discard(pid)
517 516
518 517 def _serveworker(self, conn):
519 518 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
519 _initworkerprocess()
520 520 h = self._servicehandler
521 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
521 try:
522 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
523 finally:
524 gc.collect() # trigger __del__ since worker process uses os._exit
522 525
523 526 _servicemap = {
524 527 'pipe': pipeservice,
525 528 'unix': unixforkingservice,
526 529 }
527 530
528 531 def createservice(ui, repo, opts):
529 532 mode = opts['cmdserver']
530 533 try:
531 534 return _servicemap[mode](ui, repo, opts)
532 535 except KeyError:
533 536 raise error.Abort(_('unknown mode %s') % mode)
General Comments 0
You need to be logged in to leave comments. Login now