##// END OF EJS Templates
commandserver: add new forking server implemented without using SocketServer...
Yuya Nishihara -
r29544:024e8f82 default
parent child Browse files
Show More
@@ -1,437 +1,562 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import errno
11 11 import gc
12 12 import os
13 13 import random
14 import select
15 import signal
16 import socket
14 17 import struct
15 18 import sys
16 19 import traceback
17 20
18 21 from .i18n import _
19 22 from . import (
20 23 encoding,
21 24 error,
22 25 util,
23 26 )
24 27
25 28 socketserver = util.socketserver
26 29
27 30 logfile = None
28 31
29 32 def log(*args):
30 33 if not logfile:
31 34 return
32 35
33 36 for a in args:
34 37 logfile.write(str(a))
35 38
36 39 logfile.flush()
37 40
38 41 class channeledoutput(object):
39 42 """
40 43 Write data to out in the following format:
41 44
42 45 data length (unsigned int),
43 46 data
44 47 """
45 48 def __init__(self, out, channel):
46 49 self.out = out
47 50 self.channel = channel
48 51
49 52 @property
50 53 def name(self):
51 54 return '<%c-channel>' % self.channel
52 55
53 56 def write(self, data):
54 57 if not data:
55 58 return
56 59 self.out.write(struct.pack('>cI', self.channel, len(data)))
57 60 self.out.write(data)
58 61 self.out.flush()
59 62
60 63 def __getattr__(self, attr):
61 64 if attr in ('isatty', 'fileno', 'tell', 'seek'):
62 65 raise AttributeError(attr)
63 66 return getattr(self.out, attr)
64 67
65 68 class channeledinput(object):
66 69 """
67 70 Read data from in_.
68 71
69 72 Requests for input are written to out in the following format:
70 73 channel identifier - 'I' for plain input, 'L' line based (1 byte)
71 74 how many bytes to send at most (unsigned int),
72 75
73 76 The client replies with:
74 77 data length (unsigned int), 0 meaning EOF
75 78 data
76 79 """
77 80
78 81 maxchunksize = 4 * 1024
79 82
80 83 def __init__(self, in_, out, channel):
81 84 self.in_ = in_
82 85 self.out = out
83 86 self.channel = channel
84 87
85 88 @property
86 89 def name(self):
87 90 return '<%c-channel>' % self.channel
88 91
89 92 def read(self, size=-1):
90 93 if size < 0:
91 94 # if we need to consume all the clients input, ask for 4k chunks
92 95 # so the pipe doesn't fill up risking a deadlock
93 96 size = self.maxchunksize
94 97 s = self._read(size, self.channel)
95 98 buf = s
96 99 while s:
97 100 s = self._read(size, self.channel)
98 101 buf += s
99 102
100 103 return buf
101 104 else:
102 105 return self._read(size, self.channel)
103 106
104 107 def _read(self, size, channel):
105 108 if not size:
106 109 return ''
107 110 assert size > 0
108 111
109 112 # tell the client we need at most size bytes
110 113 self.out.write(struct.pack('>cI', channel, size))
111 114 self.out.flush()
112 115
113 116 length = self.in_.read(4)
114 117 length = struct.unpack('>I', length)[0]
115 118 if not length:
116 119 return ''
117 120 else:
118 121 return self.in_.read(length)
119 122
120 123 def readline(self, size=-1):
121 124 if size < 0:
122 125 size = self.maxchunksize
123 126 s = self._read(size, 'L')
124 127 buf = s
125 128 # keep asking for more until there's either no more or
126 129 # we got a full line
127 130 while s and s[-1] != '\n':
128 131 s = self._read(size, 'L')
129 132 buf += s
130 133
131 134 return buf
132 135 else:
133 136 return self._read(size, 'L')
134 137
135 138 def __iter__(self):
136 139 return self
137 140
138 141 def next(self):
139 142 l = self.readline()
140 143 if not l:
141 144 raise StopIteration
142 145 return l
143 146
144 147 def __getattr__(self, attr):
145 148 if attr in ('isatty', 'fileno', 'tell', 'seek'):
146 149 raise AttributeError(attr)
147 150 return getattr(self.in_, attr)
148 151
149 152 class server(object):
150 153 """
151 154 Listens for commands on fin, runs them and writes the output on a channel
152 155 based stream to fout.
153 156 """
154 157 def __init__(self, ui, repo, fin, fout):
155 158 self.cwd = os.getcwd()
156 159
157 160 # developer config: cmdserver.log
158 161 logpath = ui.config("cmdserver", "log", None)
159 162 if logpath:
160 163 global logfile
161 164 if logpath == '-':
162 165 # write log on a special 'd' (debug) channel
163 166 logfile = channeledoutput(fout, 'd')
164 167 else:
165 168 logfile = open(logpath, 'a')
166 169
167 170 if repo:
168 171 # the ui here is really the repo ui so take its baseui so we don't
169 172 # end up with its local configuration
170 173 self.ui = repo.baseui
171 174 self.repo = repo
172 175 self.repoui = repo.ui
173 176 else:
174 177 self.ui = ui
175 178 self.repo = self.repoui = None
176 179
177 180 self.cerr = channeledoutput(fout, 'e')
178 181 self.cout = channeledoutput(fout, 'o')
179 182 self.cin = channeledinput(fin, fout, 'I')
180 183 self.cresult = channeledoutput(fout, 'r')
181 184
182 185 self.client = fin
183 186
184 187 def cleanup(self):
185 188 """release and restore resources taken during server session"""
186 189 pass
187 190
188 191 def _read(self, size):
189 192 if not size:
190 193 return ''
191 194
192 195 data = self.client.read(size)
193 196
194 197 # is the other end closed?
195 198 if not data:
196 199 raise EOFError
197 200
198 201 return data
199 202
200 203 def _readstr(self):
201 204 """read a string from the channel
202 205
203 206 format:
204 207 data length (uint32), data
205 208 """
206 209 length = struct.unpack('>I', self._read(4))[0]
207 210 if not length:
208 211 return ''
209 212 return self._read(length)
210 213
211 214 def _readlist(self):
212 215 """read a list of NULL separated strings from the channel"""
213 216 s = self._readstr()
214 217 if s:
215 218 return s.split('\0')
216 219 else:
217 220 return []
218 221
219 222 def runcommand(self):
220 223 """ reads a list of \0 terminated arguments, executes
221 224 and writes the return code to the result channel """
222 225 from . import dispatch # avoid cycle
223 226
224 227 args = self._readlist()
225 228
226 229 # copy the uis so changes (e.g. --config or --verbose) don't
227 230 # persist between requests
228 231 copiedui = self.ui.copy()
229 232 uis = [copiedui]
230 233 if self.repo:
231 234 self.repo.baseui = copiedui
232 235 # clone ui without using ui.copy because this is protected
233 236 repoui = self.repoui.__class__(self.repoui)
234 237 repoui.copy = copiedui.copy # redo copy protection
235 238 uis.append(repoui)
236 239 self.repo.ui = self.repo.dirstate._ui = repoui
237 240 self.repo.invalidateall()
238 241
239 242 for ui in uis:
240 243 ui.resetstate()
241 244 # any kind of interaction must use server channels, but chg may
242 245 # replace channels by fully functional tty files. so nontty is
243 246 # enforced only if cin is a channel.
244 247 if not util.safehasattr(self.cin, 'fileno'):
245 248 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
246 249
247 250 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
248 251 self.cout, self.cerr)
249 252
250 253 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
251 254
252 255 # restore old cwd
253 256 if '--cwd' in args:
254 257 os.chdir(self.cwd)
255 258
256 259 self.cresult.write(struct.pack('>i', int(ret)))
257 260
258 261 def getencoding(self):
259 262 """ writes the current encoding to the result channel """
260 263 self.cresult.write(encoding.encoding)
261 264
262 265 def serveone(self):
263 266 cmd = self.client.readline()[:-1]
264 267 if cmd:
265 268 handler = self.capabilities.get(cmd)
266 269 if handler:
267 270 handler(self)
268 271 else:
269 272 # clients are expected to check what commands are supported by
270 273 # looking at the servers capabilities
271 274 raise error.Abort(_('unknown command %s') % cmd)
272 275
273 276 return cmd != ''
274 277
275 278 capabilities = {'runcommand' : runcommand,
276 279 'getencoding' : getencoding}
277 280
278 281 def serve(self):
279 282 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
280 283 hellomsg += '\n'
281 284 hellomsg += 'encoding: ' + encoding.encoding
282 285 hellomsg += '\n'
283 286 hellomsg += 'pid: %d' % util.getpid()
284 287
285 288 # write the hello msg in -one- chunk
286 289 self.cout.write(hellomsg)
287 290
288 291 try:
289 292 while self.serveone():
290 293 pass
291 294 except EOFError:
292 295 # we'll get here if the client disconnected while we were reading
293 296 # its request
294 297 return 1
295 298
296 299 return 0
297 300
298 301 def _protectio(ui):
299 302 """ duplicates streams and redirect original to null if ui uses stdio """
300 303 ui.flush()
301 304 newfiles = []
302 305 nullfd = os.open(os.devnull, os.O_RDWR)
303 306 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
304 307 (ui.fout, sys.stdout, 'wb')]:
305 308 if f is sysf:
306 309 newfd = os.dup(f.fileno())
307 310 os.dup2(nullfd, f.fileno())
308 311 f = os.fdopen(newfd, mode)
309 312 newfiles.append(f)
310 313 os.close(nullfd)
311 314 return tuple(newfiles)
312 315
313 316 def _restoreio(ui, fin, fout):
314 317 """ restores streams from duplicated ones """
315 318 ui.flush()
316 319 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
317 320 if f is not uif:
318 321 os.dup2(f.fileno(), uif.fileno())
319 322 f.close()
320 323
321 324 class pipeservice(object):
322 325 def __init__(self, ui, repo, opts):
323 326 self.ui = ui
324 327 self.repo = repo
325 328
326 329 def init(self):
327 330 pass
328 331
329 332 def run(self):
330 333 ui = self.ui
331 334 # redirect stdio to null device so that broken extensions or in-process
332 335 # hooks will never cause corruption of channel protocol.
333 336 fin, fout = _protectio(ui)
334 337 try:
335 338 sv = server(ui, self.repo, fin, fout)
336 339 return sv.serve()
337 340 finally:
338 341 sv.cleanup()
339 342 _restoreio(ui, fin, fout)
340 343
341 344 def _serverequest(ui, repo, conn, createcmdserver):
342 345 if True: # TODO: unindent
343 346 # use a different process group from the master process, making this
344 347 # process pass kernel "is_current_pgrp_orphaned" check so signals like
345 348 # SIGTSTP, SIGTTIN, SIGTTOU are not ignored.
346 349 os.setpgid(0, 0)
347 350 # change random state otherwise forked request handlers would have a
348 351 # same state inherited from parent.
349 352 random.seed()
350 353
351 354 fin = conn.makefile('rb')
352 355 fout = conn.makefile('wb')
353 356 sv = None
354 357 try:
355 358 sv = createcmdserver(repo, conn, fin, fout)
356 359 try:
357 360 sv.serve()
358 361 # handle exceptions that may be raised by command server. most of
359 362 # known exceptions are caught by dispatch.
360 363 except error.Abort as inst:
361 364 ui.warn(_('abort: %s\n') % inst)
362 365 except IOError as inst:
363 366 if inst.errno != errno.EPIPE:
364 367 raise
365 368 except KeyboardInterrupt:
366 369 pass
367 370 finally:
368 371 sv.cleanup()
369 372 except: # re-raises
370 373 # also write traceback to error channel. otherwise client cannot
371 374 # see it because it is written to server's stderr by default.
372 375 if sv:
373 376 cerr = sv.cerr
374 377 else:
375 378 cerr = channeledoutput(fout, 'e')
376 379 traceback.print_exc(file=cerr)
377 380 raise
378 381 finally:
379 382 fin.close()
380 383 try:
381 384 fout.close() # implicit flush() may cause another EPIPE
382 385 except IOError as inst:
383 386 if inst.errno != errno.EPIPE:
384 387 raise
385 388 # trigger __del__ since ForkingMixIn uses os._exit
386 389 gc.collect()
387 390
391 class unixservicehandler(object):
392 """Set of pluggable operations for unix-mode services
393
394 Almost all methods except for createcmdserver() are called in the main
395 process. You can't pass mutable resource back from createcmdserver().
396 """
397
398 pollinterval = None
399
400 def __init__(self, ui):
401 self.ui = ui
402
403 def bindsocket(self, sock, address):
404 util.bindunixsocket(sock, address)
405
406 def unlinksocket(self, address):
407 os.unlink(address)
408
409 def printbanner(self, address):
410 self.ui.status(_('listening at %s\n') % address)
411 self.ui.flush() # avoid buffering of status message
412
413 def shouldexit(self):
414 """True if server should shut down; checked per pollinterval"""
415 return False
416
417 def newconnection(self):
418 """Called when main process notices new connection"""
419 pass
420
421 def createcmdserver(self, repo, conn, fin, fout):
422 """Create new command server instance; called in the process that
423 serves for the current connection"""
424 return server(self.ui, repo, fin, fout)
425
388 426 class _requesthandler(socketserver.BaseRequestHandler):
389 427 def handle(self):
390 428 _serverequest(self.server.ui, self.server.repo, self.request,
391 429 self._createcmdserver)
392 430
393 431 def _createcmdserver(self, repo, conn, fin, fout):
394 432 ui = self.server.ui
395 433 return server(ui, repo, fin, fout)
396 434
397 435 class unixservice(object):
398 436 """
399 437 Listens on unix domain socket and forks server per connection
400 438 """
401 439 def __init__(self, ui, repo, opts):
402 440 self.ui = ui
403 441 self.repo = repo
404 442 self.address = opts['address']
405 443 if not util.safehasattr(socketserver, 'UnixStreamServer'):
406 444 raise error.Abort(_('unsupported platform'))
407 445 if not self.address:
408 446 raise error.Abort(_('no socket path specified with --address'))
409 447
410 448 def init(self):
411 449 class cls(socketserver.ForkingMixIn, socketserver.UnixStreamServer):
412 450 ui = self.ui
413 451 repo = self.repo
414 452 self.server = cls(self.address, _requesthandler)
415 453 self.ui.status(_('listening at %s\n') % self.address)
416 454 self.ui.flush() # avoid buffering of status message
417 455
418 456 def _cleanup(self):
419 457 os.unlink(self.address)
420 458
421 459 def run(self):
422 460 try:
423 461 self.server.serve_forever()
424 462 finally:
425 463 self._cleanup()
426 464
465 class unixforkingservice(unixservice):
466 def __init__(self, ui, repo, opts, handler=None):
467 super(unixforkingservice, self).__init__(ui, repo, opts)
468 self._servicehandler = handler or unixservicehandler(ui)
469 self._sock = None
470 self._oldsigchldhandler = None
471 self._workerpids = set() # updated by signal handler; do not iterate
472
473 def init(self):
474 self._sock = socket.socket(socket.AF_UNIX)
475 self._servicehandler.bindsocket(self._sock, self.address)
476 self._sock.listen(5)
477 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
478 self._oldsigchldhandler = o
479 self._servicehandler.printbanner(self.address)
480
481 def _cleanup(self):
482 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
483 self._sock.close()
484 self._servicehandler.unlinksocket(self.address)
485 # don't kill child processes as they have active clients, just wait
486 self._reapworkers(0)
487
488 def run(self):
489 try:
490 self._mainloop()
491 finally:
492 self._cleanup()
493
494 def _mainloop(self):
495 h = self._servicehandler
496 while not h.shouldexit():
497 try:
498 ready = select.select([self._sock], [], [], h.pollinterval)[0]
499 if not ready:
500 continue
501 conn, _addr = self._sock.accept()
502 except (select.error, socket.error) as inst:
503 if inst.args[0] == errno.EINTR:
504 continue
505 raise
506
507 pid = os.fork()
508 if pid:
509 try:
510 self.ui.debug('forked worker process (pid=%d)\n' % pid)
511 self._workerpids.add(pid)
512 h.newconnection()
513 finally:
514 conn.close() # release handle in parent process
515 else:
516 try:
517 self._serveworker(conn)
518 conn.close()
519 os._exit(0)
520 except: # never return, hence no re-raises
521 try:
522 self.ui.traceback(force=True)
523 finally:
524 os._exit(255)
525
526 def _sigchldhandler(self, signal, frame):
527 self._reapworkers(os.WNOHANG)
528
529 def _reapworkers(self, options):
530 while self._workerpids:
531 try:
532 pid, _status = os.waitpid(-1, options)
533 except OSError as inst:
534 if inst.errno == errno.EINTR:
535 continue
536 if inst.errno != errno.ECHILD:
537 raise
538 # no child processes at all (reaped by other waitpid()?)
539 self._workerpids.clear()
540 return
541 if pid == 0:
542 # no waitable child processes
543 return
544 self.ui.debug('worker process exited (pid=%d)\n' % pid)
545 self._workerpids.discard(pid)
546
547 def _serveworker(self, conn):
548 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
549 h = self._servicehandler
550 _serverequest(self.ui, self.repo, conn, h.createcmdserver)
551
427 552 _servicemap = {
428 553 'pipe': pipeservice,
429 'unix': unixservice,
554 'unix': unixforkingservice,
430 555 }
431 556
432 557 def createservice(ui, repo, opts):
433 558 mode = opts['cmdserver']
434 559 try:
435 560 return _servicemap[mode](ui, repo, opts)
436 561 except KeyError:
437 562 raise error.Abort(_('unknown mode %s') % mode)
General Comments 0
You need to be logged in to leave comments. Login now