##// END OF EJS Templates
commandserver: drop tell() and seek() from channels (issue5049)...
Yuya Nishihara -
r27915:5f2a308b stable
parent child Browse files
Show More
@@ -1,382 +1,382 b''
1 # commandserver.py - communicate with Mercurial's API over a pipe
1 # commandserver.py - communicate with Mercurial's API over a pipe
2 #
2 #
3 # Copyright Matt Mackall <mpm@selenic.com>
3 # Copyright Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import SocketServer
10 import SocketServer
11 import errno
11 import errno
12 import os
12 import os
13 import struct
13 import struct
14 import sys
14 import sys
15 import traceback
15 import traceback
16
16
17 from .i18n import _
17 from .i18n import _
18 from . import (
18 from . import (
19 encoding,
19 encoding,
20 error,
20 error,
21 util,
21 util,
22 )
22 )
23
23
24 logfile = None
24 logfile = None
25
25
26 def log(*args):
26 def log(*args):
27 if not logfile:
27 if not logfile:
28 return
28 return
29
29
30 for a in args:
30 for a in args:
31 logfile.write(str(a))
31 logfile.write(str(a))
32
32
33 logfile.flush()
33 logfile.flush()
34
34
35 class channeledoutput(object):
35 class channeledoutput(object):
36 """
36 """
37 Write data to out in the following format:
37 Write data to out in the following format:
38
38
39 data length (unsigned int),
39 data length (unsigned int),
40 data
40 data
41 """
41 """
42 def __init__(self, out, channel):
42 def __init__(self, out, channel):
43 self.out = out
43 self.out = out
44 self.channel = channel
44 self.channel = channel
45
45
46 @property
46 @property
47 def name(self):
47 def name(self):
48 return '<%c-channel>' % self.channel
48 return '<%c-channel>' % self.channel
49
49
50 def write(self, data):
50 def write(self, data):
51 if not data:
51 if not data:
52 return
52 return
53 self.out.write(struct.pack('>cI', self.channel, len(data)))
53 self.out.write(struct.pack('>cI', self.channel, len(data)))
54 self.out.write(data)
54 self.out.write(data)
55 self.out.flush()
55 self.out.flush()
56
56
57 def __getattr__(self, attr):
57 def __getattr__(self, attr):
58 if attr in ('isatty', 'fileno'):
58 if attr in ('isatty', 'fileno', 'tell', 'seek'):
59 raise AttributeError(attr)
59 raise AttributeError(attr)
60 return getattr(self.out, attr)
60 return getattr(self.out, attr)
61
61
62 class channeledinput(object):
62 class channeledinput(object):
63 """
63 """
64 Read data from in_.
64 Read data from in_.
65
65
66 Requests for input are written to out in the following format:
66 Requests for input are written to out in the following format:
67 channel identifier - 'I' for plain input, 'L' line based (1 byte)
67 channel identifier - 'I' for plain input, 'L' line based (1 byte)
68 how many bytes to send at most (unsigned int),
68 how many bytes to send at most (unsigned int),
69
69
70 The client replies with:
70 The client replies with:
71 data length (unsigned int), 0 meaning EOF
71 data length (unsigned int), 0 meaning EOF
72 data
72 data
73 """
73 """
74
74
75 maxchunksize = 4 * 1024
75 maxchunksize = 4 * 1024
76
76
77 def __init__(self, in_, out, channel):
77 def __init__(self, in_, out, channel):
78 self.in_ = in_
78 self.in_ = in_
79 self.out = out
79 self.out = out
80 self.channel = channel
80 self.channel = channel
81
81
82 @property
82 @property
83 def name(self):
83 def name(self):
84 return '<%c-channel>' % self.channel
84 return '<%c-channel>' % self.channel
85
85
86 def read(self, size=-1):
86 def read(self, size=-1):
87 if size < 0:
87 if size < 0:
88 # if we need to consume all the clients input, ask for 4k chunks
88 # if we need to consume all the clients input, ask for 4k chunks
89 # so the pipe doesn't fill up risking a deadlock
89 # so the pipe doesn't fill up risking a deadlock
90 size = self.maxchunksize
90 size = self.maxchunksize
91 s = self._read(size, self.channel)
91 s = self._read(size, self.channel)
92 buf = s
92 buf = s
93 while s:
93 while s:
94 s = self._read(size, self.channel)
94 s = self._read(size, self.channel)
95 buf += s
95 buf += s
96
96
97 return buf
97 return buf
98 else:
98 else:
99 return self._read(size, self.channel)
99 return self._read(size, self.channel)
100
100
101 def _read(self, size, channel):
101 def _read(self, size, channel):
102 if not size:
102 if not size:
103 return ''
103 return ''
104 assert size > 0
104 assert size > 0
105
105
106 # tell the client we need at most size bytes
106 # tell the client we need at most size bytes
107 self.out.write(struct.pack('>cI', channel, size))
107 self.out.write(struct.pack('>cI', channel, size))
108 self.out.flush()
108 self.out.flush()
109
109
110 length = self.in_.read(4)
110 length = self.in_.read(4)
111 length = struct.unpack('>I', length)[0]
111 length = struct.unpack('>I', length)[0]
112 if not length:
112 if not length:
113 return ''
113 return ''
114 else:
114 else:
115 return self.in_.read(length)
115 return self.in_.read(length)
116
116
117 def readline(self, size=-1):
117 def readline(self, size=-1):
118 if size < 0:
118 if size < 0:
119 size = self.maxchunksize
119 size = self.maxchunksize
120 s = self._read(size, 'L')
120 s = self._read(size, 'L')
121 buf = s
121 buf = s
122 # keep asking for more until there's either no more or
122 # keep asking for more until there's either no more or
123 # we got a full line
123 # we got a full line
124 while s and s[-1] != '\n':
124 while s and s[-1] != '\n':
125 s = self._read(size, 'L')
125 s = self._read(size, 'L')
126 buf += s
126 buf += s
127
127
128 return buf
128 return buf
129 else:
129 else:
130 return self._read(size, 'L')
130 return self._read(size, 'L')
131
131
132 def __iter__(self):
132 def __iter__(self):
133 return self
133 return self
134
134
135 def next(self):
135 def next(self):
136 l = self.readline()
136 l = self.readline()
137 if not l:
137 if not l:
138 raise StopIteration
138 raise StopIteration
139 return l
139 return l
140
140
141 def __getattr__(self, attr):
141 def __getattr__(self, attr):
142 if attr in ('isatty', 'fileno'):
142 if attr in ('isatty', 'fileno', 'tell', 'seek'):
143 raise AttributeError(attr)
143 raise AttributeError(attr)
144 return getattr(self.in_, attr)
144 return getattr(self.in_, attr)
145
145
146 class server(object):
146 class server(object):
147 """
147 """
148 Listens for commands on fin, runs them and writes the output on a channel
148 Listens for commands on fin, runs them and writes the output on a channel
149 based stream to fout.
149 based stream to fout.
150 """
150 """
151 def __init__(self, ui, repo, fin, fout):
151 def __init__(self, ui, repo, fin, fout):
152 self.cwd = os.getcwd()
152 self.cwd = os.getcwd()
153
153
154 # developer config: cmdserver.log
154 # developer config: cmdserver.log
155 logpath = ui.config("cmdserver", "log", None)
155 logpath = ui.config("cmdserver", "log", None)
156 if logpath:
156 if logpath:
157 global logfile
157 global logfile
158 if logpath == '-':
158 if logpath == '-':
159 # write log on a special 'd' (debug) channel
159 # write log on a special 'd' (debug) channel
160 logfile = channeledoutput(fout, 'd')
160 logfile = channeledoutput(fout, 'd')
161 else:
161 else:
162 logfile = open(logpath, 'a')
162 logfile = open(logpath, 'a')
163
163
164 if repo:
164 if repo:
165 # the ui here is really the repo ui so take its baseui so we don't
165 # the ui here is really the repo ui so take its baseui so we don't
166 # end up with its local configuration
166 # end up with its local configuration
167 self.ui = repo.baseui
167 self.ui = repo.baseui
168 self.repo = repo
168 self.repo = repo
169 self.repoui = repo.ui
169 self.repoui = repo.ui
170 else:
170 else:
171 self.ui = ui
171 self.ui = ui
172 self.repo = self.repoui = None
172 self.repo = self.repoui = None
173
173
174 self.cerr = channeledoutput(fout, 'e')
174 self.cerr = channeledoutput(fout, 'e')
175 self.cout = channeledoutput(fout, 'o')
175 self.cout = channeledoutput(fout, 'o')
176 self.cin = channeledinput(fin, fout, 'I')
176 self.cin = channeledinput(fin, fout, 'I')
177 self.cresult = channeledoutput(fout, 'r')
177 self.cresult = channeledoutput(fout, 'r')
178
178
179 self.client = fin
179 self.client = fin
180
180
181 def _read(self, size):
181 def _read(self, size):
182 if not size:
182 if not size:
183 return ''
183 return ''
184
184
185 data = self.client.read(size)
185 data = self.client.read(size)
186
186
187 # is the other end closed?
187 # is the other end closed?
188 if not data:
188 if not data:
189 raise EOFError
189 raise EOFError
190
190
191 return data
191 return data
192
192
193 def runcommand(self):
193 def runcommand(self):
194 """ reads a list of \0 terminated arguments, executes
194 """ reads a list of \0 terminated arguments, executes
195 and writes the return code to the result channel """
195 and writes the return code to the result channel """
196 from . import dispatch # avoid cycle
196 from . import dispatch # avoid cycle
197
197
198 length = struct.unpack('>I', self._read(4))[0]
198 length = struct.unpack('>I', self._read(4))[0]
199 if not length:
199 if not length:
200 args = []
200 args = []
201 else:
201 else:
202 args = self._read(length).split('\0')
202 args = self._read(length).split('\0')
203
203
204 # copy the uis so changes (e.g. --config or --verbose) don't
204 # copy the uis so changes (e.g. --config or --verbose) don't
205 # persist between requests
205 # persist between requests
206 copiedui = self.ui.copy()
206 copiedui = self.ui.copy()
207 uis = [copiedui]
207 uis = [copiedui]
208 if self.repo:
208 if self.repo:
209 self.repo.baseui = copiedui
209 self.repo.baseui = copiedui
210 # clone ui without using ui.copy because this is protected
210 # clone ui without using ui.copy because this is protected
211 repoui = self.repoui.__class__(self.repoui)
211 repoui = self.repoui.__class__(self.repoui)
212 repoui.copy = copiedui.copy # redo copy protection
212 repoui.copy = copiedui.copy # redo copy protection
213 uis.append(repoui)
213 uis.append(repoui)
214 self.repo.ui = self.repo.dirstate._ui = repoui
214 self.repo.ui = self.repo.dirstate._ui = repoui
215 self.repo.invalidateall()
215 self.repo.invalidateall()
216
216
217 # reset last-print time of progress bar per command
217 # reset last-print time of progress bar per command
218 # (progbar is singleton, we don't have to do for all uis)
218 # (progbar is singleton, we don't have to do for all uis)
219 if copiedui._progbar:
219 if copiedui._progbar:
220 copiedui._progbar.resetstate()
220 copiedui._progbar.resetstate()
221
221
222 for ui in uis:
222 for ui in uis:
223 # any kind of interaction must use server channels, but chg may
223 # any kind of interaction must use server channels, but chg may
224 # replace channels by fully functional tty files. so nontty is
224 # replace channels by fully functional tty files. so nontty is
225 # enforced only if cin is a channel.
225 # enforced only if cin is a channel.
226 if not util.safehasattr(self.cin, 'fileno'):
226 if not util.safehasattr(self.cin, 'fileno'):
227 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
227 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
228
228
229 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
229 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
230 self.cout, self.cerr)
230 self.cout, self.cerr)
231
231
232 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
232 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
233
233
234 # restore old cwd
234 # restore old cwd
235 if '--cwd' in args:
235 if '--cwd' in args:
236 os.chdir(self.cwd)
236 os.chdir(self.cwd)
237
237
238 self.cresult.write(struct.pack('>i', int(ret)))
238 self.cresult.write(struct.pack('>i', int(ret)))
239
239
240 def getencoding(self):
240 def getencoding(self):
241 """ writes the current encoding to the result channel """
241 """ writes the current encoding to the result channel """
242 self.cresult.write(encoding.encoding)
242 self.cresult.write(encoding.encoding)
243
243
244 def serveone(self):
244 def serveone(self):
245 cmd = self.client.readline()[:-1]
245 cmd = self.client.readline()[:-1]
246 if cmd:
246 if cmd:
247 handler = self.capabilities.get(cmd)
247 handler = self.capabilities.get(cmd)
248 if handler:
248 if handler:
249 handler(self)
249 handler(self)
250 else:
250 else:
251 # clients are expected to check what commands are supported by
251 # clients are expected to check what commands are supported by
252 # looking at the servers capabilities
252 # looking at the servers capabilities
253 raise error.Abort(_('unknown command %s') % cmd)
253 raise error.Abort(_('unknown command %s') % cmd)
254
254
255 return cmd != ''
255 return cmd != ''
256
256
257 capabilities = {'runcommand' : runcommand,
257 capabilities = {'runcommand' : runcommand,
258 'getencoding' : getencoding}
258 'getencoding' : getencoding}
259
259
260 def serve(self):
260 def serve(self):
261 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
261 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
262 hellomsg += '\n'
262 hellomsg += '\n'
263 hellomsg += 'encoding: ' + encoding.encoding
263 hellomsg += 'encoding: ' + encoding.encoding
264 hellomsg += '\n'
264 hellomsg += '\n'
265 hellomsg += 'pid: %d' % os.getpid()
265 hellomsg += 'pid: %d' % os.getpid()
266
266
267 # write the hello msg in -one- chunk
267 # write the hello msg in -one- chunk
268 self.cout.write(hellomsg)
268 self.cout.write(hellomsg)
269
269
270 try:
270 try:
271 while self.serveone():
271 while self.serveone():
272 pass
272 pass
273 except EOFError:
273 except EOFError:
274 # we'll get here if the client disconnected while we were reading
274 # we'll get here if the client disconnected while we were reading
275 # its request
275 # its request
276 return 1
276 return 1
277
277
278 return 0
278 return 0
279
279
280 def _protectio(ui):
280 def _protectio(ui):
281 """ duplicates streams and redirect original to null if ui uses stdio """
281 """ duplicates streams and redirect original to null if ui uses stdio """
282 ui.flush()
282 ui.flush()
283 newfiles = []
283 newfiles = []
284 nullfd = os.open(os.devnull, os.O_RDWR)
284 nullfd = os.open(os.devnull, os.O_RDWR)
285 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
285 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
286 (ui.fout, sys.stdout, 'wb')]:
286 (ui.fout, sys.stdout, 'wb')]:
287 if f is sysf:
287 if f is sysf:
288 newfd = os.dup(f.fileno())
288 newfd = os.dup(f.fileno())
289 os.dup2(nullfd, f.fileno())
289 os.dup2(nullfd, f.fileno())
290 f = os.fdopen(newfd, mode)
290 f = os.fdopen(newfd, mode)
291 newfiles.append(f)
291 newfiles.append(f)
292 os.close(nullfd)
292 os.close(nullfd)
293 return tuple(newfiles)
293 return tuple(newfiles)
294
294
295 def _restoreio(ui, fin, fout):
295 def _restoreio(ui, fin, fout):
296 """ restores streams from duplicated ones """
296 """ restores streams from duplicated ones """
297 ui.flush()
297 ui.flush()
298 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
298 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
299 if f is not uif:
299 if f is not uif:
300 os.dup2(f.fileno(), uif.fileno())
300 os.dup2(f.fileno(), uif.fileno())
301 f.close()
301 f.close()
302
302
303 class pipeservice(object):
303 class pipeservice(object):
304 def __init__(self, ui, repo, opts):
304 def __init__(self, ui, repo, opts):
305 self.ui = ui
305 self.ui = ui
306 self.repo = repo
306 self.repo = repo
307
307
308 def init(self):
308 def init(self):
309 pass
309 pass
310
310
311 def run(self):
311 def run(self):
312 ui = self.ui
312 ui = self.ui
313 # redirect stdio to null device so that broken extensions or in-process
313 # redirect stdio to null device so that broken extensions or in-process
314 # hooks will never cause corruption of channel protocol.
314 # hooks will never cause corruption of channel protocol.
315 fin, fout = _protectio(ui)
315 fin, fout = _protectio(ui)
316 try:
316 try:
317 sv = server(ui, self.repo, fin, fout)
317 sv = server(ui, self.repo, fin, fout)
318 return sv.serve()
318 return sv.serve()
319 finally:
319 finally:
320 _restoreio(ui, fin, fout)
320 _restoreio(ui, fin, fout)
321
321
322 class _requesthandler(SocketServer.StreamRequestHandler):
322 class _requesthandler(SocketServer.StreamRequestHandler):
323 def handle(self):
323 def handle(self):
324 ui = self.server.ui
324 ui = self.server.ui
325 repo = self.server.repo
325 repo = self.server.repo
326 sv = server(ui, repo, self.rfile, self.wfile)
326 sv = server(ui, repo, self.rfile, self.wfile)
327 try:
327 try:
328 try:
328 try:
329 sv.serve()
329 sv.serve()
330 # handle exceptions that may be raised by command server. most of
330 # handle exceptions that may be raised by command server. most of
331 # known exceptions are caught by dispatch.
331 # known exceptions are caught by dispatch.
332 except error.Abort as inst:
332 except error.Abort as inst:
333 ui.warn(_('abort: %s\n') % inst)
333 ui.warn(_('abort: %s\n') % inst)
334 except IOError as inst:
334 except IOError as inst:
335 if inst.errno != errno.EPIPE:
335 if inst.errno != errno.EPIPE:
336 raise
336 raise
337 except KeyboardInterrupt:
337 except KeyboardInterrupt:
338 pass
338 pass
339 except: # re-raises
339 except: # re-raises
340 # also write traceback to error channel. otherwise client cannot
340 # also write traceback to error channel. otherwise client cannot
341 # see it because it is written to server's stderr by default.
341 # see it because it is written to server's stderr by default.
342 traceback.print_exc(file=sv.cerr)
342 traceback.print_exc(file=sv.cerr)
343 raise
343 raise
344
344
345 class unixservice(object):
345 class unixservice(object):
346 """
346 """
347 Listens on unix domain socket and forks server per connection
347 Listens on unix domain socket and forks server per connection
348 """
348 """
349 def __init__(self, ui, repo, opts):
349 def __init__(self, ui, repo, opts):
350 self.ui = ui
350 self.ui = ui
351 self.repo = repo
351 self.repo = repo
352 self.address = opts['address']
352 self.address = opts['address']
353 if not util.safehasattr(SocketServer, 'UnixStreamServer'):
353 if not util.safehasattr(SocketServer, 'UnixStreamServer'):
354 raise error.Abort(_('unsupported platform'))
354 raise error.Abort(_('unsupported platform'))
355 if not self.address:
355 if not self.address:
356 raise error.Abort(_('no socket path specified with --address'))
356 raise error.Abort(_('no socket path specified with --address'))
357
357
358 def init(self):
358 def init(self):
359 class cls(SocketServer.ForkingMixIn, SocketServer.UnixStreamServer):
359 class cls(SocketServer.ForkingMixIn, SocketServer.UnixStreamServer):
360 ui = self.ui
360 ui = self.ui
361 repo = self.repo
361 repo = self.repo
362 self.server = cls(self.address, _requesthandler)
362 self.server = cls(self.address, _requesthandler)
363 self.ui.status(_('listening at %s\n') % self.address)
363 self.ui.status(_('listening at %s\n') % self.address)
364 self.ui.flush() # avoid buffering of status message
364 self.ui.flush() # avoid buffering of status message
365
365
366 def run(self):
366 def run(self):
367 try:
367 try:
368 self.server.serve_forever()
368 self.server.serve_forever()
369 finally:
369 finally:
370 os.unlink(self.address)
370 os.unlink(self.address)
371
371
372 _servicemap = {
372 _servicemap = {
373 'pipe': pipeservice,
373 'pipe': pipeservice,
374 'unix': unixservice,
374 'unix': unixservice,
375 }
375 }
376
376
377 def createservice(ui, repo, opts):
377 def createservice(ui, repo, opts):
378 mode = opts['cmdserver']
378 mode = opts['cmdserver']
379 try:
379 try:
380 return _servicemap[mode](ui, repo, opts)
380 return _servicemap[mode](ui, repo, opts)
381 except KeyError:
381 except KeyError:
382 raise error.Abort(_('unknown mode %s') % mode)
382 raise error.Abort(_('unknown mode %s') % mode)
General Comments 0
You need to be logged in to leave comments. Login now