##// END OF EJS Templates
commandserver: add _readstr and _readlist...
Jun Wu -
r28156:75f586a1 default
parent child Browse files
Show More
@@ -1,382 +1,401 b''
1 # commandserver.py - communicate with Mercurial's API over a pipe
1 # commandserver.py - communicate with Mercurial's API over a pipe
2 #
2 #
3 # Copyright Matt Mackall <mpm@selenic.com>
3 # Copyright Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import SocketServer
10 import SocketServer
11 import errno
11 import errno
12 import os
12 import os
13 import struct
13 import struct
14 import sys
14 import sys
15 import traceback
15 import traceback
16
16
17 from .i18n import _
17 from .i18n import _
18 from . import (
18 from . import (
19 encoding,
19 encoding,
20 error,
20 error,
21 util,
21 util,
22 )
22 )
23
23
24 logfile = None
24 logfile = None
25
25
26 def log(*args):
26 def log(*args):
27 if not logfile:
27 if not logfile:
28 return
28 return
29
29
30 for a in args:
30 for a in args:
31 logfile.write(str(a))
31 logfile.write(str(a))
32
32
33 logfile.flush()
33 logfile.flush()
34
34
35 class channeledoutput(object):
35 class channeledoutput(object):
36 """
36 """
37 Write data to out in the following format:
37 Write data to out in the following format:
38
38
39 data length (unsigned int),
39 data length (unsigned int),
40 data
40 data
41 """
41 """
42 def __init__(self, out, channel):
42 def __init__(self, out, channel):
43 self.out = out
43 self.out = out
44 self.channel = channel
44 self.channel = channel
45
45
46 @property
46 @property
47 def name(self):
47 def name(self):
48 return '<%c-channel>' % self.channel
48 return '<%c-channel>' % self.channel
49
49
50 def write(self, data):
50 def write(self, data):
51 if not data:
51 if not data:
52 return
52 return
53 self.out.write(struct.pack('>cI', self.channel, len(data)))
53 self.out.write(struct.pack('>cI', self.channel, len(data)))
54 self.out.write(data)
54 self.out.write(data)
55 self.out.flush()
55 self.out.flush()
56
56
57 def __getattr__(self, attr):
57 def __getattr__(self, attr):
58 if attr in ('isatty', 'fileno', 'tell', 'seek'):
58 if attr in ('isatty', 'fileno', 'tell', 'seek'):
59 raise AttributeError(attr)
59 raise AttributeError(attr)
60 return getattr(self.out, attr)
60 return getattr(self.out, attr)
61
61
62 class channeledinput(object):
62 class channeledinput(object):
63 """
63 """
64 Read data from in_.
64 Read data from in_.
65
65
66 Requests for input are written to out in the following format:
66 Requests for input are written to out in the following format:
67 channel identifier - 'I' for plain input, 'L' line based (1 byte)
67 channel identifier - 'I' for plain input, 'L' line based (1 byte)
68 how many bytes to send at most (unsigned int),
68 how many bytes to send at most (unsigned int),
69
69
70 The client replies with:
70 The client replies with:
71 data length (unsigned int), 0 meaning EOF
71 data length (unsigned int), 0 meaning EOF
72 data
72 data
73 """
73 """
74
74
75 maxchunksize = 4 * 1024
75 maxchunksize = 4 * 1024
76
76
77 def __init__(self, in_, out, channel):
77 def __init__(self, in_, out, channel):
78 self.in_ = in_
78 self.in_ = in_
79 self.out = out
79 self.out = out
80 self.channel = channel
80 self.channel = channel
81
81
82 @property
82 @property
83 def name(self):
83 def name(self):
84 return '<%c-channel>' % self.channel
84 return '<%c-channel>' % self.channel
85
85
86 def read(self, size=-1):
86 def read(self, size=-1):
87 if size < 0:
87 if size < 0:
88 # if we need to consume all the clients input, ask for 4k chunks
88 # if we need to consume all the clients input, ask for 4k chunks
89 # so the pipe doesn't fill up risking a deadlock
89 # so the pipe doesn't fill up risking a deadlock
90 size = self.maxchunksize
90 size = self.maxchunksize
91 s = self._read(size, self.channel)
91 s = self._read(size, self.channel)
92 buf = s
92 buf = s
93 while s:
93 while s:
94 s = self._read(size, self.channel)
94 s = self._read(size, self.channel)
95 buf += s
95 buf += s
96
96
97 return buf
97 return buf
98 else:
98 else:
99 return self._read(size, self.channel)
99 return self._read(size, self.channel)
100
100
101 def _read(self, size, channel):
101 def _read(self, size, channel):
102 if not size:
102 if not size:
103 return ''
103 return ''
104 assert size > 0
104 assert size > 0
105
105
106 # tell the client we need at most size bytes
106 # tell the client we need at most size bytes
107 self.out.write(struct.pack('>cI', channel, size))
107 self.out.write(struct.pack('>cI', channel, size))
108 self.out.flush()
108 self.out.flush()
109
109
110 length = self.in_.read(4)
110 length = self.in_.read(4)
111 length = struct.unpack('>I', length)[0]
111 length = struct.unpack('>I', length)[0]
112 if not length:
112 if not length:
113 return ''
113 return ''
114 else:
114 else:
115 return self.in_.read(length)
115 return self.in_.read(length)
116
116
117 def readline(self, size=-1):
117 def readline(self, size=-1):
118 if size < 0:
118 if size < 0:
119 size = self.maxchunksize
119 size = self.maxchunksize
120 s = self._read(size, 'L')
120 s = self._read(size, 'L')
121 buf = s
121 buf = s
122 # keep asking for more until there's either no more or
122 # keep asking for more until there's either no more or
123 # we got a full line
123 # we got a full line
124 while s and s[-1] != '\n':
124 while s and s[-1] != '\n':
125 s = self._read(size, 'L')
125 s = self._read(size, 'L')
126 buf += s
126 buf += s
127
127
128 return buf
128 return buf
129 else:
129 else:
130 return self._read(size, 'L')
130 return self._read(size, 'L')
131
131
132 def __iter__(self):
132 def __iter__(self):
133 return self
133 return self
134
134
135 def next(self):
135 def next(self):
136 l = self.readline()
136 l = self.readline()
137 if not l:
137 if not l:
138 raise StopIteration
138 raise StopIteration
139 return l
139 return l
140
140
141 def __getattr__(self, attr):
141 def __getattr__(self, attr):
142 if attr in ('isatty', 'fileno', 'tell', 'seek'):
142 if attr in ('isatty', 'fileno', 'tell', 'seek'):
143 raise AttributeError(attr)
143 raise AttributeError(attr)
144 return getattr(self.in_, attr)
144 return getattr(self.in_, attr)
145
145
146 class server(object):
146 class server(object):
147 """
147 """
148 Listens for commands on fin, runs them and writes the output on a channel
148 Listens for commands on fin, runs them and writes the output on a channel
149 based stream to fout.
149 based stream to fout.
150 """
150 """
151 def __init__(self, ui, repo, fin, fout):
151 def __init__(self, ui, repo, fin, fout):
152 self.cwd = os.getcwd()
152 self.cwd = os.getcwd()
153
153
154 # developer config: cmdserver.log
154 # developer config: cmdserver.log
155 logpath = ui.config("cmdserver", "log", None)
155 logpath = ui.config("cmdserver", "log", None)
156 if logpath:
156 if logpath:
157 global logfile
157 global logfile
158 if logpath == '-':
158 if logpath == '-':
159 # write log on a special 'd' (debug) channel
159 # write log on a special 'd' (debug) channel
160 logfile = channeledoutput(fout, 'd')
160 logfile = channeledoutput(fout, 'd')
161 else:
161 else:
162 logfile = open(logpath, 'a')
162 logfile = open(logpath, 'a')
163
163
164 if repo:
164 if repo:
165 # the ui here is really the repo ui so take its baseui so we don't
165 # the ui here is really the repo ui so take its baseui so we don't
166 # end up with its local configuration
166 # end up with its local configuration
167 self.ui = repo.baseui
167 self.ui = repo.baseui
168 self.repo = repo
168 self.repo = repo
169 self.repoui = repo.ui
169 self.repoui = repo.ui
170 else:
170 else:
171 self.ui = ui
171 self.ui = ui
172 self.repo = self.repoui = None
172 self.repo = self.repoui = None
173
173
174 self.cerr = channeledoutput(fout, 'e')
174 self.cerr = channeledoutput(fout, 'e')
175 self.cout = channeledoutput(fout, 'o')
175 self.cout = channeledoutput(fout, 'o')
176 self.cin = channeledinput(fin, fout, 'I')
176 self.cin = channeledinput(fin, fout, 'I')
177 self.cresult = channeledoutput(fout, 'r')
177 self.cresult = channeledoutput(fout, 'r')
178
178
179 self.client = fin
179 self.client = fin
180
180
181 def _read(self, size):
181 def _read(self, size):
182 if not size:
182 if not size:
183 return ''
183 return ''
184
184
185 data = self.client.read(size)
185 data = self.client.read(size)
186
186
187 # is the other end closed?
187 # is the other end closed?
188 if not data:
188 if not data:
189 raise EOFError
189 raise EOFError
190
190
191 return data
191 return data
192
192
193 def _readstr(self):
194 """read a string from the channel
195
196 format:
197 data length (uint32), data
198 """
199 length = struct.unpack('>I', self._read(4))[0]
200 if not length:
201 return ''
202 return self._read(length)
203
204 def _readlist(self):
205 """read a list of NULL separated strings from the channel"""
206 s = self._readstr()
207 if s:
208 return s.split('\0')
209 else:
210 return []
211
193 def runcommand(self):
212 def runcommand(self):
194 """ reads a list of \0 terminated arguments, executes
213 """ reads a list of \0 terminated arguments, executes
195 and writes the return code to the result channel """
214 and writes the return code to the result channel """
196 from . import dispatch # avoid cycle
215 from . import dispatch # avoid cycle
197
216
198 length = struct.unpack('>I', self._read(4))[0]
217 length = struct.unpack('>I', self._read(4))[0]
199 if not length:
218 if not length:
200 args = []
219 args = []
201 else:
220 else:
202 args = self._read(length).split('\0')
221 args = self._read(length).split('\0')
203
222
204 # copy the uis so changes (e.g. --config or --verbose) don't
223 # copy the uis so changes (e.g. --config or --verbose) don't
205 # persist between requests
224 # persist between requests
206 copiedui = self.ui.copy()
225 copiedui = self.ui.copy()
207 uis = [copiedui]
226 uis = [copiedui]
208 if self.repo:
227 if self.repo:
209 self.repo.baseui = copiedui
228 self.repo.baseui = copiedui
210 # clone ui without using ui.copy because this is protected
229 # clone ui without using ui.copy because this is protected
211 repoui = self.repoui.__class__(self.repoui)
230 repoui = self.repoui.__class__(self.repoui)
212 repoui.copy = copiedui.copy # redo copy protection
231 repoui.copy = copiedui.copy # redo copy protection
213 uis.append(repoui)
232 uis.append(repoui)
214 self.repo.ui = self.repo.dirstate._ui = repoui
233 self.repo.ui = self.repo.dirstate._ui = repoui
215 self.repo.invalidateall()
234 self.repo.invalidateall()
216
235
217 # reset last-print time of progress bar per command
236 # reset last-print time of progress bar per command
218 # (progbar is singleton, we don't have to do for all uis)
237 # (progbar is singleton, we don't have to do for all uis)
219 if copiedui._progbar:
238 if copiedui._progbar:
220 copiedui._progbar.resetstate()
239 copiedui._progbar.resetstate()
221
240
222 for ui in uis:
241 for ui in uis:
223 # any kind of interaction must use server channels, but chg may
242 # any kind of interaction must use server channels, but chg may
224 # replace channels by fully functional tty files. so nontty is
243 # replace channels by fully functional tty files. so nontty is
225 # enforced only if cin is a channel.
244 # enforced only if cin is a channel.
226 if not util.safehasattr(self.cin, 'fileno'):
245 if not util.safehasattr(self.cin, 'fileno'):
227 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
246 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
228
247
229 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
248 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
230 self.cout, self.cerr)
249 self.cout, self.cerr)
231
250
232 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
251 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
233
252
234 # restore old cwd
253 # restore old cwd
235 if '--cwd' in args:
254 if '--cwd' in args:
236 os.chdir(self.cwd)
255 os.chdir(self.cwd)
237
256
238 self.cresult.write(struct.pack('>i', int(ret)))
257 self.cresult.write(struct.pack('>i', int(ret)))
239
258
240 def getencoding(self):
259 def getencoding(self):
241 """ writes the current encoding to the result channel """
260 """ writes the current encoding to the result channel """
242 self.cresult.write(encoding.encoding)
261 self.cresult.write(encoding.encoding)
243
262
244 def serveone(self):
263 def serveone(self):
245 cmd = self.client.readline()[:-1]
264 cmd = self.client.readline()[:-1]
246 if cmd:
265 if cmd:
247 handler = self.capabilities.get(cmd)
266 handler = self.capabilities.get(cmd)
248 if handler:
267 if handler:
249 handler(self)
268 handler(self)
250 else:
269 else:
251 # clients are expected to check what commands are supported by
270 # clients are expected to check what commands are supported by
252 # looking at the servers capabilities
271 # looking at the servers capabilities
253 raise error.Abort(_('unknown command %s') % cmd)
272 raise error.Abort(_('unknown command %s') % cmd)
254
273
255 return cmd != ''
274 return cmd != ''
256
275
257 capabilities = {'runcommand' : runcommand,
276 capabilities = {'runcommand' : runcommand,
258 'getencoding' : getencoding}
277 'getencoding' : getencoding}
259
278
260 def serve(self):
279 def serve(self):
261 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
280 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
262 hellomsg += '\n'
281 hellomsg += '\n'
263 hellomsg += 'encoding: ' + encoding.encoding
282 hellomsg += 'encoding: ' + encoding.encoding
264 hellomsg += '\n'
283 hellomsg += '\n'
265 hellomsg += 'pid: %d' % util.getpid()
284 hellomsg += 'pid: %d' % util.getpid()
266
285
267 # write the hello msg in -one- chunk
286 # write the hello msg in -one- chunk
268 self.cout.write(hellomsg)
287 self.cout.write(hellomsg)
269
288
270 try:
289 try:
271 while self.serveone():
290 while self.serveone():
272 pass
291 pass
273 except EOFError:
292 except EOFError:
274 # we'll get here if the client disconnected while we were reading
293 # we'll get here if the client disconnected while we were reading
275 # its request
294 # its request
276 return 1
295 return 1
277
296
278 return 0
297 return 0
279
298
280 def _protectio(ui):
299 def _protectio(ui):
281 """ duplicates streams and redirect original to null if ui uses stdio """
300 """ duplicates streams and redirect original to null if ui uses stdio """
282 ui.flush()
301 ui.flush()
283 newfiles = []
302 newfiles = []
284 nullfd = os.open(os.devnull, os.O_RDWR)
303 nullfd = os.open(os.devnull, os.O_RDWR)
285 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
304 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
286 (ui.fout, sys.stdout, 'wb')]:
305 (ui.fout, sys.stdout, 'wb')]:
287 if f is sysf:
306 if f is sysf:
288 newfd = os.dup(f.fileno())
307 newfd = os.dup(f.fileno())
289 os.dup2(nullfd, f.fileno())
308 os.dup2(nullfd, f.fileno())
290 f = os.fdopen(newfd, mode)
309 f = os.fdopen(newfd, mode)
291 newfiles.append(f)
310 newfiles.append(f)
292 os.close(nullfd)
311 os.close(nullfd)
293 return tuple(newfiles)
312 return tuple(newfiles)
294
313
295 def _restoreio(ui, fin, fout):
314 def _restoreio(ui, fin, fout):
296 """ restores streams from duplicated ones """
315 """ restores streams from duplicated ones """
297 ui.flush()
316 ui.flush()
298 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
317 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
299 if f is not uif:
318 if f is not uif:
300 os.dup2(f.fileno(), uif.fileno())
319 os.dup2(f.fileno(), uif.fileno())
301 f.close()
320 f.close()
302
321
303 class pipeservice(object):
322 class pipeservice(object):
304 def __init__(self, ui, repo, opts):
323 def __init__(self, ui, repo, opts):
305 self.ui = ui
324 self.ui = ui
306 self.repo = repo
325 self.repo = repo
307
326
308 def init(self):
327 def init(self):
309 pass
328 pass
310
329
311 def run(self):
330 def run(self):
312 ui = self.ui
331 ui = self.ui
313 # redirect stdio to null device so that broken extensions or in-process
332 # redirect stdio to null device so that broken extensions or in-process
314 # hooks will never cause corruption of channel protocol.
333 # hooks will never cause corruption of channel protocol.
315 fin, fout = _protectio(ui)
334 fin, fout = _protectio(ui)
316 try:
335 try:
317 sv = server(ui, self.repo, fin, fout)
336 sv = server(ui, self.repo, fin, fout)
318 return sv.serve()
337 return sv.serve()
319 finally:
338 finally:
320 _restoreio(ui, fin, fout)
339 _restoreio(ui, fin, fout)
321
340
322 class _requesthandler(SocketServer.StreamRequestHandler):
341 class _requesthandler(SocketServer.StreamRequestHandler):
323 def handle(self):
342 def handle(self):
324 ui = self.server.ui
343 ui = self.server.ui
325 repo = self.server.repo
344 repo = self.server.repo
326 sv = server(ui, repo, self.rfile, self.wfile)
345 sv = server(ui, repo, self.rfile, self.wfile)
327 try:
346 try:
328 try:
347 try:
329 sv.serve()
348 sv.serve()
330 # handle exceptions that may be raised by command server. most of
349 # handle exceptions that may be raised by command server. most of
331 # known exceptions are caught by dispatch.
350 # known exceptions are caught by dispatch.
332 except error.Abort as inst:
351 except error.Abort as inst:
333 ui.warn(_('abort: %s\n') % inst)
352 ui.warn(_('abort: %s\n') % inst)
334 except IOError as inst:
353 except IOError as inst:
335 if inst.errno != errno.EPIPE:
354 if inst.errno != errno.EPIPE:
336 raise
355 raise
337 except KeyboardInterrupt:
356 except KeyboardInterrupt:
338 pass
357 pass
339 except: # re-raises
358 except: # re-raises
340 # also write traceback to error channel. otherwise client cannot
359 # also write traceback to error channel. otherwise client cannot
341 # see it because it is written to server's stderr by default.
360 # see it because it is written to server's stderr by default.
342 traceback.print_exc(file=sv.cerr)
361 traceback.print_exc(file=sv.cerr)
343 raise
362 raise
344
363
345 class unixservice(object):
364 class unixservice(object):
346 """
365 """
347 Listens on unix domain socket and forks server per connection
366 Listens on unix domain socket and forks server per connection
348 """
367 """
349 def __init__(self, ui, repo, opts):
368 def __init__(self, ui, repo, opts):
350 self.ui = ui
369 self.ui = ui
351 self.repo = repo
370 self.repo = repo
352 self.address = opts['address']
371 self.address = opts['address']
353 if not util.safehasattr(SocketServer, 'UnixStreamServer'):
372 if not util.safehasattr(SocketServer, 'UnixStreamServer'):
354 raise error.Abort(_('unsupported platform'))
373 raise error.Abort(_('unsupported platform'))
355 if not self.address:
374 if not self.address:
356 raise error.Abort(_('no socket path specified with --address'))
375 raise error.Abort(_('no socket path specified with --address'))
357
376
358 def init(self):
377 def init(self):
359 class cls(SocketServer.ForkingMixIn, SocketServer.UnixStreamServer):
378 class cls(SocketServer.ForkingMixIn, SocketServer.UnixStreamServer):
360 ui = self.ui
379 ui = self.ui
361 repo = self.repo
380 repo = self.repo
362 self.server = cls(self.address, _requesthandler)
381 self.server = cls(self.address, _requesthandler)
363 self.ui.status(_('listening at %s\n') % self.address)
382 self.ui.status(_('listening at %s\n') % self.address)
364 self.ui.flush() # avoid buffering of status message
383 self.ui.flush() # avoid buffering of status message
365
384
366 def run(self):
385 def run(self):
367 try:
386 try:
368 self.server.serve_forever()
387 self.server.serve_forever()
369 finally:
388 finally:
370 os.unlink(self.address)
389 os.unlink(self.address)
371
390
372 _servicemap = {
391 _servicemap = {
373 'pipe': pipeservice,
392 'pipe': pipeservice,
374 'unix': unixservice,
393 'unix': unixservice,
375 }
394 }
376
395
377 def createservice(ui, repo, opts):
396 def createservice(ui, repo, opts):
378 mode = opts['cmdserver']
397 mode = opts['cmdserver']
379 try:
398 try:
380 return _servicemap[mode](ui, repo, opts)
399 return _servicemap[mode](ui, repo, opts)
381 except KeyError:
400 except KeyError:
382 raise error.Abort(_('unknown mode %s') % mode)
401 raise error.Abort(_('unknown mode %s') % mode)
General Comments 0
You need to be logged in to leave comments. Login now