##// END OF EJS Templates
commandserver: loop over selector events...
Yuya Nishihara -
r40914:2525faf4 default
parent child Browse files
Show More
@@ -1,631 +1,633
1 # commandserver.py - communicate with Mercurial's API over a pipe
1 # commandserver.py - communicate with Mercurial's API over a pipe
2 #
2 #
3 # Copyright Matt Mackall <mpm@selenic.com>
3 # Copyright Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import errno
10 import errno
11 import gc
11 import gc
12 import os
12 import os
13 import random
13 import random
14 import signal
14 import signal
15 import socket
15 import socket
16 import struct
16 import struct
17 import traceback
17 import traceback
18
18
19 try:
19 try:
20 import selectors
20 import selectors
21 selectors.BaseSelector
21 selectors.BaseSelector
22 except ImportError:
22 except ImportError:
23 from .thirdparty import selectors2 as selectors
23 from .thirdparty import selectors2 as selectors
24
24
25 from .i18n import _
25 from .i18n import _
26 from . import (
26 from . import (
27 encoding,
27 encoding,
28 error,
28 error,
29 loggingutil,
29 loggingutil,
30 pycompat,
30 pycompat,
31 util,
31 util,
32 vfs as vfsmod,
32 vfs as vfsmod,
33 )
33 )
34 from .utils import (
34 from .utils import (
35 cborutil,
35 cborutil,
36 procutil,
36 procutil,
37 )
37 )
38
38
39 class channeledoutput(object):
39 class channeledoutput(object):
40 """
40 """
41 Write data to out in the following format:
41 Write data to out in the following format:
42
42
43 data length (unsigned int),
43 data length (unsigned int),
44 data
44 data
45 """
45 """
46 def __init__(self, out, channel):
46 def __init__(self, out, channel):
47 self.out = out
47 self.out = out
48 self.channel = channel
48 self.channel = channel
49
49
50 @property
50 @property
51 def name(self):
51 def name(self):
52 return '<%c-channel>' % self.channel
52 return '<%c-channel>' % self.channel
53
53
54 def write(self, data):
54 def write(self, data):
55 if not data:
55 if not data:
56 return
56 return
57 # single write() to guarantee the same atomicity as the underlying file
57 # single write() to guarantee the same atomicity as the underlying file
58 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
58 self.out.write(struct.pack('>cI', self.channel, len(data)) + data)
59 self.out.flush()
59 self.out.flush()
60
60
61 def __getattr__(self, attr):
61 def __getattr__(self, attr):
62 if attr in (r'isatty', r'fileno', r'tell', r'seek'):
62 if attr in (r'isatty', r'fileno', r'tell', r'seek'):
63 raise AttributeError(attr)
63 raise AttributeError(attr)
64 return getattr(self.out, attr)
64 return getattr(self.out, attr)
65
65
66 class channeledmessage(object):
66 class channeledmessage(object):
67 """
67 """
68 Write encoded message and metadata to out in the following format:
68 Write encoded message and metadata to out in the following format:
69
69
70 data length (unsigned int),
70 data length (unsigned int),
71 encoded message and metadata, as a flat key-value dict.
71 encoded message and metadata, as a flat key-value dict.
72
72
73 Each message should have 'type' attribute. Messages of unknown type
73 Each message should have 'type' attribute. Messages of unknown type
74 should be ignored.
74 should be ignored.
75 """
75 """
76
76
77 # teach ui that write() can take **opts
77 # teach ui that write() can take **opts
78 structured = True
78 structured = True
79
79
80 def __init__(self, out, channel, encodename, encodefn):
80 def __init__(self, out, channel, encodename, encodefn):
81 self._cout = channeledoutput(out, channel)
81 self._cout = channeledoutput(out, channel)
82 self.encoding = encodename
82 self.encoding = encodename
83 self._encodefn = encodefn
83 self._encodefn = encodefn
84
84
85 def write(self, data, **opts):
85 def write(self, data, **opts):
86 opts = pycompat.byteskwargs(opts)
86 opts = pycompat.byteskwargs(opts)
87 if data is not None:
87 if data is not None:
88 opts[b'data'] = data
88 opts[b'data'] = data
89 self._cout.write(self._encodefn(opts))
89 self._cout.write(self._encodefn(opts))
90
90
91 def __getattr__(self, attr):
91 def __getattr__(self, attr):
92 return getattr(self._cout, attr)
92 return getattr(self._cout, attr)
93
93
94 class channeledinput(object):
94 class channeledinput(object):
95 """
95 """
96 Read data from in_.
96 Read data from in_.
97
97
98 Requests for input are written to out in the following format:
98 Requests for input are written to out in the following format:
99 channel identifier - 'I' for plain input, 'L' line based (1 byte)
99 channel identifier - 'I' for plain input, 'L' line based (1 byte)
100 how many bytes to send at most (unsigned int),
100 how many bytes to send at most (unsigned int),
101
101
102 The client replies with:
102 The client replies with:
103 data length (unsigned int), 0 meaning EOF
103 data length (unsigned int), 0 meaning EOF
104 data
104 data
105 """
105 """
106
106
107 maxchunksize = 4 * 1024
107 maxchunksize = 4 * 1024
108
108
109 def __init__(self, in_, out, channel):
109 def __init__(self, in_, out, channel):
110 self.in_ = in_
110 self.in_ = in_
111 self.out = out
111 self.out = out
112 self.channel = channel
112 self.channel = channel
113
113
114 @property
114 @property
115 def name(self):
115 def name(self):
116 return '<%c-channel>' % self.channel
116 return '<%c-channel>' % self.channel
117
117
118 def read(self, size=-1):
118 def read(self, size=-1):
119 if size < 0:
119 if size < 0:
120 # if we need to consume all the clients input, ask for 4k chunks
120 # if we need to consume all the clients input, ask for 4k chunks
121 # so the pipe doesn't fill up risking a deadlock
121 # so the pipe doesn't fill up risking a deadlock
122 size = self.maxchunksize
122 size = self.maxchunksize
123 s = self._read(size, self.channel)
123 s = self._read(size, self.channel)
124 buf = s
124 buf = s
125 while s:
125 while s:
126 s = self._read(size, self.channel)
126 s = self._read(size, self.channel)
127 buf += s
127 buf += s
128
128
129 return buf
129 return buf
130 else:
130 else:
131 return self._read(size, self.channel)
131 return self._read(size, self.channel)
132
132
133 def _read(self, size, channel):
133 def _read(self, size, channel):
134 if not size:
134 if not size:
135 return ''
135 return ''
136 assert size > 0
136 assert size > 0
137
137
138 # tell the client we need at most size bytes
138 # tell the client we need at most size bytes
139 self.out.write(struct.pack('>cI', channel, size))
139 self.out.write(struct.pack('>cI', channel, size))
140 self.out.flush()
140 self.out.flush()
141
141
142 length = self.in_.read(4)
142 length = self.in_.read(4)
143 length = struct.unpack('>I', length)[0]
143 length = struct.unpack('>I', length)[0]
144 if not length:
144 if not length:
145 return ''
145 return ''
146 else:
146 else:
147 return self.in_.read(length)
147 return self.in_.read(length)
148
148
149 def readline(self, size=-1):
149 def readline(self, size=-1):
150 if size < 0:
150 if size < 0:
151 size = self.maxchunksize
151 size = self.maxchunksize
152 s = self._read(size, 'L')
152 s = self._read(size, 'L')
153 buf = s
153 buf = s
154 # keep asking for more until there's either no more or
154 # keep asking for more until there's either no more or
155 # we got a full line
155 # we got a full line
156 while s and s[-1] != '\n':
156 while s and s[-1] != '\n':
157 s = self._read(size, 'L')
157 s = self._read(size, 'L')
158 buf += s
158 buf += s
159
159
160 return buf
160 return buf
161 else:
161 else:
162 return self._read(size, 'L')
162 return self._read(size, 'L')
163
163
164 def __iter__(self):
164 def __iter__(self):
165 return self
165 return self
166
166
167 def next(self):
167 def next(self):
168 l = self.readline()
168 l = self.readline()
169 if not l:
169 if not l:
170 raise StopIteration
170 raise StopIteration
171 return l
171 return l
172
172
173 __next__ = next
173 __next__ = next
174
174
175 def __getattr__(self, attr):
175 def __getattr__(self, attr):
176 if attr in (r'isatty', r'fileno', r'tell', r'seek'):
176 if attr in (r'isatty', r'fileno', r'tell', r'seek'):
177 raise AttributeError(attr)
177 raise AttributeError(attr)
178 return getattr(self.in_, attr)
178 return getattr(self.in_, attr)
179
179
180 _messageencoders = {
180 _messageencoders = {
181 b'cbor': lambda v: b''.join(cborutil.streamencode(v)),
181 b'cbor': lambda v: b''.join(cborutil.streamencode(v)),
182 }
182 }
183
183
184 def _selectmessageencoder(ui):
184 def _selectmessageencoder(ui):
185 # experimental config: cmdserver.message-encodings
185 # experimental config: cmdserver.message-encodings
186 encnames = ui.configlist(b'cmdserver', b'message-encodings')
186 encnames = ui.configlist(b'cmdserver', b'message-encodings')
187 for n in encnames:
187 for n in encnames:
188 f = _messageencoders.get(n)
188 f = _messageencoders.get(n)
189 if f:
189 if f:
190 return n, f
190 return n, f
191 raise error.Abort(b'no supported message encodings: %s'
191 raise error.Abort(b'no supported message encodings: %s'
192 % b' '.join(encnames))
192 % b' '.join(encnames))
193
193
194 class server(object):
194 class server(object):
195 """
195 """
196 Listens for commands on fin, runs them and writes the output on a channel
196 Listens for commands on fin, runs them and writes the output on a channel
197 based stream to fout.
197 based stream to fout.
198 """
198 """
199 def __init__(self, ui, repo, fin, fout, prereposetups=None):
199 def __init__(self, ui, repo, fin, fout, prereposetups=None):
200 self.cwd = encoding.getcwd()
200 self.cwd = encoding.getcwd()
201
201
202 if repo:
202 if repo:
203 # the ui here is really the repo ui so take its baseui so we don't
203 # the ui here is really the repo ui so take its baseui so we don't
204 # end up with its local configuration
204 # end up with its local configuration
205 self.ui = repo.baseui
205 self.ui = repo.baseui
206 self.repo = repo
206 self.repo = repo
207 self.repoui = repo.ui
207 self.repoui = repo.ui
208 else:
208 else:
209 self.ui = ui
209 self.ui = ui
210 self.repo = self.repoui = None
210 self.repo = self.repoui = None
211 self._prereposetups = prereposetups
211 self._prereposetups = prereposetups
212
212
213 self.cdebug = channeledoutput(fout, 'd')
213 self.cdebug = channeledoutput(fout, 'd')
214 self.cerr = channeledoutput(fout, 'e')
214 self.cerr = channeledoutput(fout, 'e')
215 self.cout = channeledoutput(fout, 'o')
215 self.cout = channeledoutput(fout, 'o')
216 self.cin = channeledinput(fin, fout, 'I')
216 self.cin = channeledinput(fin, fout, 'I')
217 self.cresult = channeledoutput(fout, 'r')
217 self.cresult = channeledoutput(fout, 'r')
218
218
219 if self.ui.config(b'cmdserver', b'log') == b'-':
219 if self.ui.config(b'cmdserver', b'log') == b'-':
220 # switch log stream of server's ui to the 'd' (debug) channel
220 # switch log stream of server's ui to the 'd' (debug) channel
221 # (don't touch repo.ui as its lifetime is longer than the server)
221 # (don't touch repo.ui as its lifetime is longer than the server)
222 self.ui = self.ui.copy()
222 self.ui = self.ui.copy()
223 setuplogging(self.ui, repo=None, fp=self.cdebug)
223 setuplogging(self.ui, repo=None, fp=self.cdebug)
224
224
225 # TODO: add this to help/config.txt when stabilized
225 # TODO: add this to help/config.txt when stabilized
226 # ``channel``
226 # ``channel``
227 # Use separate channel for structured output. (Command-server only)
227 # Use separate channel for structured output. (Command-server only)
228 self.cmsg = None
228 self.cmsg = None
229 if ui.config(b'ui', b'message-output') == b'channel':
229 if ui.config(b'ui', b'message-output') == b'channel':
230 encname, encfn = _selectmessageencoder(ui)
230 encname, encfn = _selectmessageencoder(ui)
231 self.cmsg = channeledmessage(fout, b'm', encname, encfn)
231 self.cmsg = channeledmessage(fout, b'm', encname, encfn)
232
232
233 self.client = fin
233 self.client = fin
234
234
235 def cleanup(self):
235 def cleanup(self):
236 """release and restore resources taken during server session"""
236 """release and restore resources taken during server session"""
237
237
238 def _read(self, size):
238 def _read(self, size):
239 if not size:
239 if not size:
240 return ''
240 return ''
241
241
242 data = self.client.read(size)
242 data = self.client.read(size)
243
243
244 # is the other end closed?
244 # is the other end closed?
245 if not data:
245 if not data:
246 raise EOFError
246 raise EOFError
247
247
248 return data
248 return data
249
249
250 def _readstr(self):
250 def _readstr(self):
251 """read a string from the channel
251 """read a string from the channel
252
252
253 format:
253 format:
254 data length (uint32), data
254 data length (uint32), data
255 """
255 """
256 length = struct.unpack('>I', self._read(4))[0]
256 length = struct.unpack('>I', self._read(4))[0]
257 if not length:
257 if not length:
258 return ''
258 return ''
259 return self._read(length)
259 return self._read(length)
260
260
261 def _readlist(self):
261 def _readlist(self):
262 """read a list of NULL separated strings from the channel"""
262 """read a list of NULL separated strings from the channel"""
263 s = self._readstr()
263 s = self._readstr()
264 if s:
264 if s:
265 return s.split('\0')
265 return s.split('\0')
266 else:
266 else:
267 return []
267 return []
268
268
269 def runcommand(self):
269 def runcommand(self):
270 """ reads a list of \0 terminated arguments, executes
270 """ reads a list of \0 terminated arguments, executes
271 and writes the return code to the result channel """
271 and writes the return code to the result channel """
272 from . import dispatch # avoid cycle
272 from . import dispatch # avoid cycle
273
273
274 args = self._readlist()
274 args = self._readlist()
275
275
276 # copy the uis so changes (e.g. --config or --verbose) don't
276 # copy the uis so changes (e.g. --config or --verbose) don't
277 # persist between requests
277 # persist between requests
278 copiedui = self.ui.copy()
278 copiedui = self.ui.copy()
279 uis = [copiedui]
279 uis = [copiedui]
280 if self.repo:
280 if self.repo:
281 self.repo.baseui = copiedui
281 self.repo.baseui = copiedui
282 # clone ui without using ui.copy because this is protected
282 # clone ui without using ui.copy because this is protected
283 repoui = self.repoui.__class__(self.repoui)
283 repoui = self.repoui.__class__(self.repoui)
284 repoui.copy = copiedui.copy # redo copy protection
284 repoui.copy = copiedui.copy # redo copy protection
285 uis.append(repoui)
285 uis.append(repoui)
286 self.repo.ui = self.repo.dirstate._ui = repoui
286 self.repo.ui = self.repo.dirstate._ui = repoui
287 self.repo.invalidateall()
287 self.repo.invalidateall()
288
288
289 for ui in uis:
289 for ui in uis:
290 ui.resetstate()
290 ui.resetstate()
291 # any kind of interaction must use server channels, but chg may
291 # any kind of interaction must use server channels, but chg may
292 # replace channels by fully functional tty files. so nontty is
292 # replace channels by fully functional tty files. so nontty is
293 # enforced only if cin is a channel.
293 # enforced only if cin is a channel.
294 if not util.safehasattr(self.cin, 'fileno'):
294 if not util.safehasattr(self.cin, 'fileno'):
295 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
295 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
296
296
297 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
297 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
298 self.cout, self.cerr, self.cmsg,
298 self.cout, self.cerr, self.cmsg,
299 prereposetups=self._prereposetups)
299 prereposetups=self._prereposetups)
300
300
301 try:
301 try:
302 ret = dispatch.dispatch(req) & 255
302 ret = dispatch.dispatch(req) & 255
303 self.cresult.write(struct.pack('>i', int(ret)))
303 self.cresult.write(struct.pack('>i', int(ret)))
304 finally:
304 finally:
305 # restore old cwd
305 # restore old cwd
306 if '--cwd' in args:
306 if '--cwd' in args:
307 os.chdir(self.cwd)
307 os.chdir(self.cwd)
308
308
309 def getencoding(self):
309 def getencoding(self):
310 """ writes the current encoding to the result channel """
310 """ writes the current encoding to the result channel """
311 self.cresult.write(encoding.encoding)
311 self.cresult.write(encoding.encoding)
312
312
313 def serveone(self):
313 def serveone(self):
314 cmd = self.client.readline()[:-1]
314 cmd = self.client.readline()[:-1]
315 if cmd:
315 if cmd:
316 handler = self.capabilities.get(cmd)
316 handler = self.capabilities.get(cmd)
317 if handler:
317 if handler:
318 handler(self)
318 handler(self)
319 else:
319 else:
320 # clients are expected to check what commands are supported by
320 # clients are expected to check what commands are supported by
321 # looking at the servers capabilities
321 # looking at the servers capabilities
322 raise error.Abort(_('unknown command %s') % cmd)
322 raise error.Abort(_('unknown command %s') % cmd)
323
323
324 return cmd != ''
324 return cmd != ''
325
325
326 capabilities = {'runcommand': runcommand,
326 capabilities = {'runcommand': runcommand,
327 'getencoding': getencoding}
327 'getencoding': getencoding}
328
328
329 def serve(self):
329 def serve(self):
330 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
330 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
331 hellomsg += '\n'
331 hellomsg += '\n'
332 hellomsg += 'encoding: ' + encoding.encoding
332 hellomsg += 'encoding: ' + encoding.encoding
333 hellomsg += '\n'
333 hellomsg += '\n'
334 if self.cmsg:
334 if self.cmsg:
335 hellomsg += 'message-encoding: %s\n' % self.cmsg.encoding
335 hellomsg += 'message-encoding: %s\n' % self.cmsg.encoding
336 hellomsg += 'pid: %d' % procutil.getpid()
336 hellomsg += 'pid: %d' % procutil.getpid()
337 if util.safehasattr(os, 'getpgid'):
337 if util.safehasattr(os, 'getpgid'):
338 hellomsg += '\n'
338 hellomsg += '\n'
339 hellomsg += 'pgid: %d' % os.getpgid(0)
339 hellomsg += 'pgid: %d' % os.getpgid(0)
340
340
341 # write the hello msg in -one- chunk
341 # write the hello msg in -one- chunk
342 self.cout.write(hellomsg)
342 self.cout.write(hellomsg)
343
343
344 try:
344 try:
345 while self.serveone():
345 while self.serveone():
346 pass
346 pass
347 except EOFError:
347 except EOFError:
348 # we'll get here if the client disconnected while we were reading
348 # we'll get here if the client disconnected while we were reading
349 # its request
349 # its request
350 return 1
350 return 1
351
351
352 return 0
352 return 0
353
353
354 def setuplogging(ui, repo=None, fp=None):
354 def setuplogging(ui, repo=None, fp=None):
355 """Set up server logging facility
355 """Set up server logging facility
356
356
357 If cmdserver.log is '-', log messages will be sent to the given fp.
357 If cmdserver.log is '-', log messages will be sent to the given fp.
358 It should be the 'd' channel while a client is connected, and otherwise
358 It should be the 'd' channel while a client is connected, and otherwise
359 is the stderr of the server process.
359 is the stderr of the server process.
360 """
360 """
361 # developer config: cmdserver.log
361 # developer config: cmdserver.log
362 logpath = ui.config(b'cmdserver', b'log')
362 logpath = ui.config(b'cmdserver', b'log')
363 if not logpath:
363 if not logpath:
364 return
364 return
365 # developer config: cmdserver.track-log
365 # developer config: cmdserver.track-log
366 tracked = set(ui.configlist(b'cmdserver', b'track-log'))
366 tracked = set(ui.configlist(b'cmdserver', b'track-log'))
367
367
368 if logpath == b'-' and fp:
368 if logpath == b'-' and fp:
369 logger = loggingutil.fileobjectlogger(fp, tracked)
369 logger = loggingutil.fileobjectlogger(fp, tracked)
370 elif logpath == b'-':
370 elif logpath == b'-':
371 logger = loggingutil.fileobjectlogger(ui.ferr, tracked)
371 logger = loggingutil.fileobjectlogger(ui.ferr, tracked)
372 else:
372 else:
373 logpath = os.path.abspath(util.expandpath(logpath))
373 logpath = os.path.abspath(util.expandpath(logpath))
374 # developer config: cmdserver.max-log-files
374 # developer config: cmdserver.max-log-files
375 maxfiles = ui.configint(b'cmdserver', b'max-log-files')
375 maxfiles = ui.configint(b'cmdserver', b'max-log-files')
376 # developer config: cmdserver.max-log-size
376 # developer config: cmdserver.max-log-size
377 maxsize = ui.configbytes(b'cmdserver', b'max-log-size')
377 maxsize = ui.configbytes(b'cmdserver', b'max-log-size')
378 vfs = vfsmod.vfs(os.path.dirname(logpath))
378 vfs = vfsmod.vfs(os.path.dirname(logpath))
379 logger = loggingutil.filelogger(vfs, os.path.basename(logpath), tracked,
379 logger = loggingutil.filelogger(vfs, os.path.basename(logpath), tracked,
380 maxfiles=maxfiles, maxsize=maxsize)
380 maxfiles=maxfiles, maxsize=maxsize)
381
381
382 targetuis = {ui}
382 targetuis = {ui}
383 if repo:
383 if repo:
384 targetuis.add(repo.baseui)
384 targetuis.add(repo.baseui)
385 targetuis.add(repo.ui)
385 targetuis.add(repo.ui)
386 for u in targetuis:
386 for u in targetuis:
387 u.setlogger(b'cmdserver', logger)
387 u.setlogger(b'cmdserver', logger)
388
388
389 class pipeservice(object):
389 class pipeservice(object):
390 def __init__(self, ui, repo, opts):
390 def __init__(self, ui, repo, opts):
391 self.ui = ui
391 self.ui = ui
392 self.repo = repo
392 self.repo = repo
393
393
394 def init(self):
394 def init(self):
395 pass
395 pass
396
396
397 def run(self):
397 def run(self):
398 ui = self.ui
398 ui = self.ui
399 # redirect stdio to null device so that broken extensions or in-process
399 # redirect stdio to null device so that broken extensions or in-process
400 # hooks will never cause corruption of channel protocol.
400 # hooks will never cause corruption of channel protocol.
401 with procutil.protectedstdio(ui.fin, ui.fout) as (fin, fout):
401 with procutil.protectedstdio(ui.fin, ui.fout) as (fin, fout):
402 sv = server(ui, self.repo, fin, fout)
402 sv = server(ui, self.repo, fin, fout)
403 try:
403 try:
404 return sv.serve()
404 return sv.serve()
405 finally:
405 finally:
406 sv.cleanup()
406 sv.cleanup()
407
407
408 def _initworkerprocess():
408 def _initworkerprocess():
409 # use a different process group from the master process, in order to:
409 # use a different process group from the master process, in order to:
410 # 1. make the current process group no longer "orphaned" (because the
410 # 1. make the current process group no longer "orphaned" (because the
411 # parent of this process is in a different process group while
411 # parent of this process is in a different process group while
412 # remains in a same session)
412 # remains in a same session)
413 # according to POSIX 2.2.2.52, orphaned process group will ignore
413 # according to POSIX 2.2.2.52, orphaned process group will ignore
414 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
414 # terminal-generated stop signals like SIGTSTP (Ctrl+Z), which will
415 # cause trouble for things like ncurses.
415 # cause trouble for things like ncurses.
416 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
416 # 2. the client can use kill(-pgid, sig) to simulate terminal-generated
417 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
417 # SIGINT (Ctrl+C) and process-exit-generated SIGHUP. our child
418 # processes like ssh will be killed properly, without affecting
418 # processes like ssh will be killed properly, without affecting
419 # unrelated processes.
419 # unrelated processes.
420 os.setpgid(0, 0)
420 os.setpgid(0, 0)
421 # change random state otherwise forked request handlers would have a
421 # change random state otherwise forked request handlers would have a
422 # same state inherited from parent.
422 # same state inherited from parent.
423 random.seed()
423 random.seed()
424
424
425 def _serverequest(ui, repo, conn, createcmdserver, prereposetups):
425 def _serverequest(ui, repo, conn, createcmdserver, prereposetups):
426 fin = conn.makefile(r'rb')
426 fin = conn.makefile(r'rb')
427 fout = conn.makefile(r'wb')
427 fout = conn.makefile(r'wb')
428 sv = None
428 sv = None
429 try:
429 try:
430 sv = createcmdserver(repo, conn, fin, fout, prereposetups)
430 sv = createcmdserver(repo, conn, fin, fout, prereposetups)
431 try:
431 try:
432 sv.serve()
432 sv.serve()
433 # handle exceptions that may be raised by command server. most of
433 # handle exceptions that may be raised by command server. most of
434 # known exceptions are caught by dispatch.
434 # known exceptions are caught by dispatch.
435 except error.Abort as inst:
435 except error.Abort as inst:
436 ui.error(_('abort: %s\n') % inst)
436 ui.error(_('abort: %s\n') % inst)
437 except IOError as inst:
437 except IOError as inst:
438 if inst.errno != errno.EPIPE:
438 if inst.errno != errno.EPIPE:
439 raise
439 raise
440 except KeyboardInterrupt:
440 except KeyboardInterrupt:
441 pass
441 pass
442 finally:
442 finally:
443 sv.cleanup()
443 sv.cleanup()
444 except: # re-raises
444 except: # re-raises
445 # also write traceback to error channel. otherwise client cannot
445 # also write traceback to error channel. otherwise client cannot
446 # see it because it is written to server's stderr by default.
446 # see it because it is written to server's stderr by default.
447 if sv:
447 if sv:
448 cerr = sv.cerr
448 cerr = sv.cerr
449 else:
449 else:
450 cerr = channeledoutput(fout, 'e')
450 cerr = channeledoutput(fout, 'e')
451 cerr.write(encoding.strtolocal(traceback.format_exc()))
451 cerr.write(encoding.strtolocal(traceback.format_exc()))
452 raise
452 raise
453 finally:
453 finally:
454 fin.close()
454 fin.close()
455 try:
455 try:
456 fout.close() # implicit flush() may cause another EPIPE
456 fout.close() # implicit flush() may cause another EPIPE
457 except IOError as inst:
457 except IOError as inst:
458 if inst.errno != errno.EPIPE:
458 if inst.errno != errno.EPIPE:
459 raise
459 raise
460
460
461 class unixservicehandler(object):
461 class unixservicehandler(object):
462 """Set of pluggable operations for unix-mode services
462 """Set of pluggable operations for unix-mode services
463
463
464 Almost all methods except for createcmdserver() are called in the main
464 Almost all methods except for createcmdserver() are called in the main
465 process. You can't pass mutable resource back from createcmdserver().
465 process. You can't pass mutable resource back from createcmdserver().
466 """
466 """
467
467
468 pollinterval = None
468 pollinterval = None
469
469
470 def __init__(self, ui):
470 def __init__(self, ui):
471 self.ui = ui
471 self.ui = ui
472
472
473 def bindsocket(self, sock, address):
473 def bindsocket(self, sock, address):
474 util.bindunixsocket(sock, address)
474 util.bindunixsocket(sock, address)
475 sock.listen(socket.SOMAXCONN)
475 sock.listen(socket.SOMAXCONN)
476 self.ui.status(_('listening at %s\n') % address)
476 self.ui.status(_('listening at %s\n') % address)
477 self.ui.flush() # avoid buffering of status message
477 self.ui.flush() # avoid buffering of status message
478
478
479 def unlinksocket(self, address):
479 def unlinksocket(self, address):
480 os.unlink(address)
480 os.unlink(address)
481
481
482 def shouldexit(self):
482 def shouldexit(self):
483 """True if server should shut down; checked per pollinterval"""
483 """True if server should shut down; checked per pollinterval"""
484 return False
484 return False
485
485
486 def newconnection(self):
486 def newconnection(self):
487 """Called when main process notices new connection"""
487 """Called when main process notices new connection"""
488
488
489 def createcmdserver(self, repo, conn, fin, fout, prereposetups):
489 def createcmdserver(self, repo, conn, fin, fout, prereposetups):
490 """Create new command server instance; called in the process that
490 """Create new command server instance; called in the process that
491 serves for the current connection"""
491 serves for the current connection"""
492 return server(self.ui, repo, fin, fout, prereposetups)
492 return server(self.ui, repo, fin, fout, prereposetups)
493
493
494 class unixforkingservice(object):
494 class unixforkingservice(object):
495 """
495 """
496 Listens on unix domain socket and forks server per connection
496 Listens on unix domain socket and forks server per connection
497 """
497 """
498
498
499 def __init__(self, ui, repo, opts, handler=None):
499 def __init__(self, ui, repo, opts, handler=None):
500 self.ui = ui
500 self.ui = ui
501 self.repo = repo
501 self.repo = repo
502 self.address = opts['address']
502 self.address = opts['address']
503 if not util.safehasattr(socket, 'AF_UNIX'):
503 if not util.safehasattr(socket, 'AF_UNIX'):
504 raise error.Abort(_('unsupported platform'))
504 raise error.Abort(_('unsupported platform'))
505 if not self.address:
505 if not self.address:
506 raise error.Abort(_('no socket path specified with --address'))
506 raise error.Abort(_('no socket path specified with --address'))
507 self._servicehandler = handler or unixservicehandler(ui)
507 self._servicehandler = handler or unixservicehandler(ui)
508 self._sock = None
508 self._sock = None
509 self._oldsigchldhandler = None
509 self._oldsigchldhandler = None
510 self._workerpids = set() # updated by signal handler; do not iterate
510 self._workerpids = set() # updated by signal handler; do not iterate
511 self._socketunlinked = None
511 self._socketunlinked = None
512
512
513 def init(self):
513 def init(self):
514 self._sock = socket.socket(socket.AF_UNIX)
514 self._sock = socket.socket(socket.AF_UNIX)
515 self._servicehandler.bindsocket(self._sock, self.address)
515 self._servicehandler.bindsocket(self._sock, self.address)
516 if util.safehasattr(procutil, 'unblocksignal'):
516 if util.safehasattr(procutil, 'unblocksignal'):
517 procutil.unblocksignal(signal.SIGCHLD)
517 procutil.unblocksignal(signal.SIGCHLD)
518 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
518 o = signal.signal(signal.SIGCHLD, self._sigchldhandler)
519 self._oldsigchldhandler = o
519 self._oldsigchldhandler = o
520 self._socketunlinked = False
520 self._socketunlinked = False
521
521
522 def _unlinksocket(self):
522 def _unlinksocket(self):
523 if not self._socketunlinked:
523 if not self._socketunlinked:
524 self._servicehandler.unlinksocket(self.address)
524 self._servicehandler.unlinksocket(self.address)
525 self._socketunlinked = True
525 self._socketunlinked = True
526
526
527 def _cleanup(self):
527 def _cleanup(self):
528 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
528 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
529 self._sock.close()
529 self._sock.close()
530 self._unlinksocket()
530 self._unlinksocket()
531 # don't kill child processes as they have active clients, just wait
531 # don't kill child processes as they have active clients, just wait
532 self._reapworkers(0)
532 self._reapworkers(0)
533
533
534 def run(self):
534 def run(self):
535 try:
535 try:
536 self._mainloop()
536 self._mainloop()
537 finally:
537 finally:
538 self._cleanup()
538 self._cleanup()
539
539
540 def _mainloop(self):
540 def _mainloop(self):
541 exiting = False
541 exiting = False
542 h = self._servicehandler
542 h = self._servicehandler
543 selector = selectors.DefaultSelector()
543 selector = selectors.DefaultSelector()
544 selector.register(self._sock, selectors.EVENT_READ)
544 selector.register(self._sock, selectors.EVENT_READ,
545 self._acceptnewconnection)
545 while True:
546 while True:
546 if not exiting and h.shouldexit():
547 if not exiting and h.shouldexit():
547 # clients can no longer connect() to the domain socket, so
548 # clients can no longer connect() to the domain socket, so
548 # we stop queuing new requests.
549 # we stop queuing new requests.
549 # for requests that are queued (connect()-ed, but haven't been
550 # for requests that are queued (connect()-ed, but haven't been
550 # accept()-ed), handle them before exit. otherwise, clients
551 # accept()-ed), handle them before exit. otherwise, clients
551 # waiting for recv() will receive ECONNRESET.
552 # waiting for recv() will receive ECONNRESET.
552 self._unlinksocket()
553 self._unlinksocket()
553 exiting = True
554 exiting = True
554 try:
555 try:
555 ready = selector.select(timeout=h.pollinterval)
556 events = selector.select(timeout=h.pollinterval)
556 except OSError as inst:
557 except OSError as inst:
557 # selectors2 raises ETIMEDOUT if timeout exceeded while
558 # selectors2 raises ETIMEDOUT if timeout exceeded while
558 # handling signal interrupt. That's probably wrong, but
559 # handling signal interrupt. That's probably wrong, but
559 # we can easily get around it.
560 # we can easily get around it.
560 if inst.errno != errno.ETIMEDOUT:
561 if inst.errno != errno.ETIMEDOUT:
561 raise
562 raise
562 ready = []
563 events = []
563 if not ready:
564 if not events:
564 # only exit if we completed all queued requests
565 # only exit if we completed all queued requests
565 if exiting:
566 if exiting:
566 break
567 break
567 continue
568 continue
568 self._acceptnewconnection(self._sock, selector)
569 for key, _mask in events:
570 key.data(key.fileobj, selector)
569 selector.close()
571 selector.close()
570
572
571 def _acceptnewconnection(self, sock, selector):
573 def _acceptnewconnection(self, sock, selector):
572 h = self._servicehandler
574 h = self._servicehandler
573 try:
575 try:
574 conn, _addr = sock.accept()
576 conn, _addr = sock.accept()
575 except socket.error as inst:
577 except socket.error as inst:
576 if inst.args[0] == errno.EINTR:
578 if inst.args[0] == errno.EINTR:
577 return
579 return
578 raise
580 raise
579
581
580 pid = os.fork()
582 pid = os.fork()
581 if pid:
583 if pid:
582 try:
584 try:
583 self.ui.log(b'cmdserver', b'forked worker process (pid=%d)\n',
585 self.ui.log(b'cmdserver', b'forked worker process (pid=%d)\n',
584 pid)
586 pid)
585 self._workerpids.add(pid)
587 self._workerpids.add(pid)
586 h.newconnection()
588 h.newconnection()
587 finally:
589 finally:
588 conn.close() # release handle in parent process
590 conn.close() # release handle in parent process
589 else:
591 else:
590 try:
592 try:
591 selector.close()
593 selector.close()
592 sock.close()
594 sock.close()
593 self._runworker(conn)
595 self._runworker(conn)
594 conn.close()
596 conn.close()
595 os._exit(0)
597 os._exit(0)
596 except: # never return, hence no re-raises
598 except: # never return, hence no re-raises
597 try:
599 try:
598 self.ui.traceback(force=True)
600 self.ui.traceback(force=True)
599 finally:
601 finally:
600 os._exit(255)
602 os._exit(255)
601
603
602 def _sigchldhandler(self, signal, frame):
604 def _sigchldhandler(self, signal, frame):
603 self._reapworkers(os.WNOHANG)
605 self._reapworkers(os.WNOHANG)
604
606
605 def _reapworkers(self, options):
607 def _reapworkers(self, options):
606 while self._workerpids:
608 while self._workerpids:
607 try:
609 try:
608 pid, _status = os.waitpid(-1, options)
610 pid, _status = os.waitpid(-1, options)
609 except OSError as inst:
611 except OSError as inst:
610 if inst.errno == errno.EINTR:
612 if inst.errno == errno.EINTR:
611 continue
613 continue
612 if inst.errno != errno.ECHILD:
614 if inst.errno != errno.ECHILD:
613 raise
615 raise
614 # no child processes at all (reaped by other waitpid()?)
616 # no child processes at all (reaped by other waitpid()?)
615 self._workerpids.clear()
617 self._workerpids.clear()
616 return
618 return
617 if pid == 0:
619 if pid == 0:
618 # no waitable child processes
620 # no waitable child processes
619 return
621 return
620 self.ui.log(b'cmdserver', b'worker process exited (pid=%d)\n', pid)
622 self.ui.log(b'cmdserver', b'worker process exited (pid=%d)\n', pid)
621 self._workerpids.discard(pid)
623 self._workerpids.discard(pid)
622
624
623 def _runworker(self, conn):
625 def _runworker(self, conn):
624 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
626 signal.signal(signal.SIGCHLD, self._oldsigchldhandler)
625 _initworkerprocess()
627 _initworkerprocess()
626 h = self._servicehandler
628 h = self._servicehandler
627 try:
629 try:
628 _serverequest(self.ui, self.repo, conn, h.createcmdserver,
630 _serverequest(self.ui, self.repo, conn, h.createcmdserver,
629 prereposetups=None) # TODO: pass in hook functions
631 prereposetups=None) # TODO: pass in hook functions
630 finally:
632 finally:
631 gc.collect() # trigger __del__ since worker process uses os._exit
633 gc.collect() # trigger __del__ since worker process uses os._exit
General Comments 0
You need to be logged in to leave comments. Login now