##// END OF EJS Templates
commandserver: do not set nontty flag if channel is replaced by a real file...
Yuya Nishihara -
r27565:e7937438 default
parent child Browse files
Show More
@@ -1,374 +1,377 b''
1 # commandserver.py - communicate with Mercurial's API over a pipe
1 # commandserver.py - communicate with Mercurial's API over a pipe
2 #
2 #
3 # Copyright Matt Mackall <mpm@selenic.com>
3 # Copyright Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import SocketServer
10 import SocketServer
11 import errno
11 import errno
12 import os
12 import os
13 import struct
13 import struct
14 import sys
14 import sys
15 import traceback
15 import traceback
16
16
17 from .i18n import _
17 from .i18n import _
18 from . import (
18 from . import (
19 encoding,
19 encoding,
20 error,
20 error,
21 util,
21 util,
22 )
22 )
23
23
24 logfile = None
24 logfile = None
25
25
26 def log(*args):
26 def log(*args):
27 if not logfile:
27 if not logfile:
28 return
28 return
29
29
30 for a in args:
30 for a in args:
31 logfile.write(str(a))
31 logfile.write(str(a))
32
32
33 logfile.flush()
33 logfile.flush()
34
34
35 class channeledoutput(object):
35 class channeledoutput(object):
36 """
36 """
37 Write data to out in the following format:
37 Write data to out in the following format:
38
38
39 data length (unsigned int),
39 data length (unsigned int),
40 data
40 data
41 """
41 """
42 def __init__(self, out, channel):
42 def __init__(self, out, channel):
43 self.out = out
43 self.out = out
44 self.channel = channel
44 self.channel = channel
45
45
46 @property
46 @property
47 def name(self):
47 def name(self):
48 return '<%c-channel>' % self.channel
48 return '<%c-channel>' % self.channel
49
49
50 def write(self, data):
50 def write(self, data):
51 if not data:
51 if not data:
52 return
52 return
53 self.out.write(struct.pack('>cI', self.channel, len(data)))
53 self.out.write(struct.pack('>cI', self.channel, len(data)))
54 self.out.write(data)
54 self.out.write(data)
55 self.out.flush()
55 self.out.flush()
56
56
57 def __getattr__(self, attr):
57 def __getattr__(self, attr):
58 if attr in ('isatty', 'fileno'):
58 if attr in ('isatty', 'fileno'):
59 raise AttributeError(attr)
59 raise AttributeError(attr)
60 return getattr(self.out, attr)
60 return getattr(self.out, attr)
61
61
62 class channeledinput(object):
62 class channeledinput(object):
63 """
63 """
64 Read data from in_.
64 Read data from in_.
65
65
66 Requests for input are written to out in the following format:
66 Requests for input are written to out in the following format:
67 channel identifier - 'I' for plain input, 'L' line based (1 byte)
67 channel identifier - 'I' for plain input, 'L' line based (1 byte)
68 how many bytes to send at most (unsigned int),
68 how many bytes to send at most (unsigned int),
69
69
70 The client replies with:
70 The client replies with:
71 data length (unsigned int), 0 meaning EOF
71 data length (unsigned int), 0 meaning EOF
72 data
72 data
73 """
73 """
74
74
75 maxchunksize = 4 * 1024
75 maxchunksize = 4 * 1024
76
76
77 def __init__(self, in_, out, channel):
77 def __init__(self, in_, out, channel):
78 self.in_ = in_
78 self.in_ = in_
79 self.out = out
79 self.out = out
80 self.channel = channel
80 self.channel = channel
81
81
82 @property
82 @property
83 def name(self):
83 def name(self):
84 return '<%c-channel>' % self.channel
84 return '<%c-channel>' % self.channel
85
85
86 def read(self, size=-1):
86 def read(self, size=-1):
87 if size < 0:
87 if size < 0:
88 # if we need to consume all the clients input, ask for 4k chunks
88 # if we need to consume all the clients input, ask for 4k chunks
89 # so the pipe doesn't fill up risking a deadlock
89 # so the pipe doesn't fill up risking a deadlock
90 size = self.maxchunksize
90 size = self.maxchunksize
91 s = self._read(size, self.channel)
91 s = self._read(size, self.channel)
92 buf = s
92 buf = s
93 while s:
93 while s:
94 s = self._read(size, self.channel)
94 s = self._read(size, self.channel)
95 buf += s
95 buf += s
96
96
97 return buf
97 return buf
98 else:
98 else:
99 return self._read(size, self.channel)
99 return self._read(size, self.channel)
100
100
101 def _read(self, size, channel):
101 def _read(self, size, channel):
102 if not size:
102 if not size:
103 return ''
103 return ''
104 assert size > 0
104 assert size > 0
105
105
106 # tell the client we need at most size bytes
106 # tell the client we need at most size bytes
107 self.out.write(struct.pack('>cI', channel, size))
107 self.out.write(struct.pack('>cI', channel, size))
108 self.out.flush()
108 self.out.flush()
109
109
110 length = self.in_.read(4)
110 length = self.in_.read(4)
111 length = struct.unpack('>I', length)[0]
111 length = struct.unpack('>I', length)[0]
112 if not length:
112 if not length:
113 return ''
113 return ''
114 else:
114 else:
115 return self.in_.read(length)
115 return self.in_.read(length)
116
116
117 def readline(self, size=-1):
117 def readline(self, size=-1):
118 if size < 0:
118 if size < 0:
119 size = self.maxchunksize
119 size = self.maxchunksize
120 s = self._read(size, 'L')
120 s = self._read(size, 'L')
121 buf = s
121 buf = s
122 # keep asking for more until there's either no more or
122 # keep asking for more until there's either no more or
123 # we got a full line
123 # we got a full line
124 while s and s[-1] != '\n':
124 while s and s[-1] != '\n':
125 s = self._read(size, 'L')
125 s = self._read(size, 'L')
126 buf += s
126 buf += s
127
127
128 return buf
128 return buf
129 else:
129 else:
130 return self._read(size, 'L')
130 return self._read(size, 'L')
131
131
132 def __iter__(self):
132 def __iter__(self):
133 return self
133 return self
134
134
135 def next(self):
135 def next(self):
136 l = self.readline()
136 l = self.readline()
137 if not l:
137 if not l:
138 raise StopIteration
138 raise StopIteration
139 return l
139 return l
140
140
141 def __getattr__(self, attr):
141 def __getattr__(self, attr):
142 if attr in ('isatty', 'fileno'):
142 if attr in ('isatty', 'fileno'):
143 raise AttributeError(attr)
143 raise AttributeError(attr)
144 return getattr(self.in_, attr)
144 return getattr(self.in_, attr)
145
145
146 class server(object):
146 class server(object):
147 """
147 """
148 Listens for commands on fin, runs them and writes the output on a channel
148 Listens for commands on fin, runs them and writes the output on a channel
149 based stream to fout.
149 based stream to fout.
150 """
150 """
151 def __init__(self, ui, repo, fin, fout):
151 def __init__(self, ui, repo, fin, fout):
152 self.cwd = os.getcwd()
152 self.cwd = os.getcwd()
153
153
154 # developer config: cmdserver.log
154 # developer config: cmdserver.log
155 logpath = ui.config("cmdserver", "log", None)
155 logpath = ui.config("cmdserver", "log", None)
156 if logpath:
156 if logpath:
157 global logfile
157 global logfile
158 if logpath == '-':
158 if logpath == '-':
159 # write log on a special 'd' (debug) channel
159 # write log on a special 'd' (debug) channel
160 logfile = channeledoutput(fout, 'd')
160 logfile = channeledoutput(fout, 'd')
161 else:
161 else:
162 logfile = open(logpath, 'a')
162 logfile = open(logpath, 'a')
163
163
164 if repo:
164 if repo:
165 # the ui here is really the repo ui so take its baseui so we don't
165 # the ui here is really the repo ui so take its baseui so we don't
166 # end up with its local configuration
166 # end up with its local configuration
167 self.ui = repo.baseui
167 self.ui = repo.baseui
168 self.repo = repo
168 self.repo = repo
169 self.repoui = repo.ui
169 self.repoui = repo.ui
170 else:
170 else:
171 self.ui = ui
171 self.ui = ui
172 self.repo = self.repoui = None
172 self.repo = self.repoui = None
173
173
174 self.cerr = channeledoutput(fout, 'e')
174 self.cerr = channeledoutput(fout, 'e')
175 self.cout = channeledoutput(fout, 'o')
175 self.cout = channeledoutput(fout, 'o')
176 self.cin = channeledinput(fin, fout, 'I')
176 self.cin = channeledinput(fin, fout, 'I')
177 self.cresult = channeledoutput(fout, 'r')
177 self.cresult = channeledoutput(fout, 'r')
178
178
179 self.client = fin
179 self.client = fin
180
180
181 def _read(self, size):
181 def _read(self, size):
182 if not size:
182 if not size:
183 return ''
183 return ''
184
184
185 data = self.client.read(size)
185 data = self.client.read(size)
186
186
187 # is the other end closed?
187 # is the other end closed?
188 if not data:
188 if not data:
189 raise EOFError
189 raise EOFError
190
190
191 return data
191 return data
192
192
193 def runcommand(self):
193 def runcommand(self):
194 """ reads a list of \0 terminated arguments, executes
194 """ reads a list of \0 terminated arguments, executes
195 and writes the return code to the result channel """
195 and writes the return code to the result channel """
196 from . import dispatch # avoid cycle
196 from . import dispatch # avoid cycle
197
197
198 length = struct.unpack('>I', self._read(4))[0]
198 length = struct.unpack('>I', self._read(4))[0]
199 if not length:
199 if not length:
200 args = []
200 args = []
201 else:
201 else:
202 args = self._read(length).split('\0')
202 args = self._read(length).split('\0')
203
203
204 # copy the uis so changes (e.g. --config or --verbose) don't
204 # copy the uis so changes (e.g. --config or --verbose) don't
205 # persist between requests
205 # persist between requests
206 copiedui = self.ui.copy()
206 copiedui = self.ui.copy()
207 uis = [copiedui]
207 uis = [copiedui]
208 if self.repo:
208 if self.repo:
209 self.repo.baseui = copiedui
209 self.repo.baseui = copiedui
210 # clone ui without using ui.copy because this is protected
210 # clone ui without using ui.copy because this is protected
211 repoui = self.repoui.__class__(self.repoui)
211 repoui = self.repoui.__class__(self.repoui)
212 repoui.copy = copiedui.copy # redo copy protection
212 repoui.copy = copiedui.copy # redo copy protection
213 uis.append(repoui)
213 uis.append(repoui)
214 self.repo.ui = self.repo.dirstate._ui = repoui
214 self.repo.ui = self.repo.dirstate._ui = repoui
215 self.repo.invalidateall()
215 self.repo.invalidateall()
216
216
217 for ui in uis:
217 for ui in uis:
218 # any kind of interaction must use server channels
218 # any kind of interaction must use server channels, but chg may
219 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
219 # replace channels by fully functional tty files. so nontty is
220 # enforced only if cin is a channel.
221 if not util.safehasattr(self.cin, 'fileno'):
222 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
220
223
221 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
224 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
222 self.cout, self.cerr)
225 self.cout, self.cerr)
223
226
224 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
227 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
225
228
226 # restore old cwd
229 # restore old cwd
227 if '--cwd' in args:
230 if '--cwd' in args:
228 os.chdir(self.cwd)
231 os.chdir(self.cwd)
229
232
230 self.cresult.write(struct.pack('>i', int(ret)))
233 self.cresult.write(struct.pack('>i', int(ret)))
231
234
232 def getencoding(self):
235 def getencoding(self):
233 """ writes the current encoding to the result channel """
236 """ writes the current encoding to the result channel """
234 self.cresult.write(encoding.encoding)
237 self.cresult.write(encoding.encoding)
235
238
236 def serveone(self):
239 def serveone(self):
237 cmd = self.client.readline()[:-1]
240 cmd = self.client.readline()[:-1]
238 if cmd:
241 if cmd:
239 handler = self.capabilities.get(cmd)
242 handler = self.capabilities.get(cmd)
240 if handler:
243 if handler:
241 handler(self)
244 handler(self)
242 else:
245 else:
243 # clients are expected to check what commands are supported by
246 # clients are expected to check what commands are supported by
244 # looking at the servers capabilities
247 # looking at the servers capabilities
245 raise error.Abort(_('unknown command %s') % cmd)
248 raise error.Abort(_('unknown command %s') % cmd)
246
249
247 return cmd != ''
250 return cmd != ''
248
251
249 capabilities = {'runcommand' : runcommand,
252 capabilities = {'runcommand' : runcommand,
250 'getencoding' : getencoding}
253 'getencoding' : getencoding}
251
254
252 def serve(self):
255 def serve(self):
253 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
256 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
254 hellomsg += '\n'
257 hellomsg += '\n'
255 hellomsg += 'encoding: ' + encoding.encoding
258 hellomsg += 'encoding: ' + encoding.encoding
256 hellomsg += '\n'
259 hellomsg += '\n'
257 hellomsg += 'pid: %d' % os.getpid()
260 hellomsg += 'pid: %d' % os.getpid()
258
261
259 # write the hello msg in -one- chunk
262 # write the hello msg in -one- chunk
260 self.cout.write(hellomsg)
263 self.cout.write(hellomsg)
261
264
262 try:
265 try:
263 while self.serveone():
266 while self.serveone():
264 pass
267 pass
265 except EOFError:
268 except EOFError:
266 # we'll get here if the client disconnected while we were reading
269 # we'll get here if the client disconnected while we were reading
267 # its request
270 # its request
268 return 1
271 return 1
269
272
270 return 0
273 return 0
271
274
272 def _protectio(ui):
275 def _protectio(ui):
273 """ duplicates streams and redirect original to null if ui uses stdio """
276 """ duplicates streams and redirect original to null if ui uses stdio """
274 ui.flush()
277 ui.flush()
275 newfiles = []
278 newfiles = []
276 nullfd = os.open(os.devnull, os.O_RDWR)
279 nullfd = os.open(os.devnull, os.O_RDWR)
277 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
280 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
278 (ui.fout, sys.stdout, 'wb')]:
281 (ui.fout, sys.stdout, 'wb')]:
279 if f is sysf:
282 if f is sysf:
280 newfd = os.dup(f.fileno())
283 newfd = os.dup(f.fileno())
281 os.dup2(nullfd, f.fileno())
284 os.dup2(nullfd, f.fileno())
282 f = os.fdopen(newfd, mode)
285 f = os.fdopen(newfd, mode)
283 newfiles.append(f)
286 newfiles.append(f)
284 os.close(nullfd)
287 os.close(nullfd)
285 return tuple(newfiles)
288 return tuple(newfiles)
286
289
287 def _restoreio(ui, fin, fout):
290 def _restoreio(ui, fin, fout):
288 """ restores streams from duplicated ones """
291 """ restores streams from duplicated ones """
289 ui.flush()
292 ui.flush()
290 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
293 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
291 if f is not uif:
294 if f is not uif:
292 os.dup2(f.fileno(), uif.fileno())
295 os.dup2(f.fileno(), uif.fileno())
293 f.close()
296 f.close()
294
297
295 class pipeservice(object):
298 class pipeservice(object):
296 def __init__(self, ui, repo, opts):
299 def __init__(self, ui, repo, opts):
297 self.ui = ui
300 self.ui = ui
298 self.repo = repo
301 self.repo = repo
299
302
300 def init(self):
303 def init(self):
301 pass
304 pass
302
305
303 def run(self):
306 def run(self):
304 ui = self.ui
307 ui = self.ui
305 # redirect stdio to null device so that broken extensions or in-process
308 # redirect stdio to null device so that broken extensions or in-process
306 # hooks will never cause corruption of channel protocol.
309 # hooks will never cause corruption of channel protocol.
307 fin, fout = _protectio(ui)
310 fin, fout = _protectio(ui)
308 try:
311 try:
309 sv = server(ui, self.repo, fin, fout)
312 sv = server(ui, self.repo, fin, fout)
310 return sv.serve()
313 return sv.serve()
311 finally:
314 finally:
312 _restoreio(ui, fin, fout)
315 _restoreio(ui, fin, fout)
313
316
314 class _requesthandler(SocketServer.StreamRequestHandler):
317 class _requesthandler(SocketServer.StreamRequestHandler):
315 def handle(self):
318 def handle(self):
316 ui = self.server.ui
319 ui = self.server.ui
317 repo = self.server.repo
320 repo = self.server.repo
318 sv = server(ui, repo, self.rfile, self.wfile)
321 sv = server(ui, repo, self.rfile, self.wfile)
319 try:
322 try:
320 try:
323 try:
321 sv.serve()
324 sv.serve()
322 # handle exceptions that may be raised by command server. most of
325 # handle exceptions that may be raised by command server. most of
323 # known exceptions are caught by dispatch.
326 # known exceptions are caught by dispatch.
324 except error.Abort as inst:
327 except error.Abort as inst:
325 ui.warn(_('abort: %s\n') % inst)
328 ui.warn(_('abort: %s\n') % inst)
326 except IOError as inst:
329 except IOError as inst:
327 if inst.errno != errno.EPIPE:
330 if inst.errno != errno.EPIPE:
328 raise
331 raise
329 except KeyboardInterrupt:
332 except KeyboardInterrupt:
330 pass
333 pass
331 except: # re-raises
334 except: # re-raises
332 # also write traceback to error channel. otherwise client cannot
335 # also write traceback to error channel. otherwise client cannot
333 # see it because it is written to server's stderr by default.
336 # see it because it is written to server's stderr by default.
334 traceback.print_exc(file=sv.cerr)
337 traceback.print_exc(file=sv.cerr)
335 raise
338 raise
336
339
337 class unixservice(object):
340 class unixservice(object):
338 """
341 """
339 Listens on unix domain socket and forks server per connection
342 Listens on unix domain socket and forks server per connection
340 """
343 """
341 def __init__(self, ui, repo, opts):
344 def __init__(self, ui, repo, opts):
342 self.ui = ui
345 self.ui = ui
343 self.repo = repo
346 self.repo = repo
344 self.address = opts['address']
347 self.address = opts['address']
345 if not util.safehasattr(SocketServer, 'UnixStreamServer'):
348 if not util.safehasattr(SocketServer, 'UnixStreamServer'):
346 raise error.Abort(_('unsupported platform'))
349 raise error.Abort(_('unsupported platform'))
347 if not self.address:
350 if not self.address:
348 raise error.Abort(_('no socket path specified with --address'))
351 raise error.Abort(_('no socket path specified with --address'))
349
352
350 def init(self):
353 def init(self):
351 class cls(SocketServer.ForkingMixIn, SocketServer.UnixStreamServer):
354 class cls(SocketServer.ForkingMixIn, SocketServer.UnixStreamServer):
352 ui = self.ui
355 ui = self.ui
353 repo = self.repo
356 repo = self.repo
354 self.server = cls(self.address, _requesthandler)
357 self.server = cls(self.address, _requesthandler)
355 self.ui.status(_('listening at %s\n') % self.address)
358 self.ui.status(_('listening at %s\n') % self.address)
356 self.ui.flush() # avoid buffering of status message
359 self.ui.flush() # avoid buffering of status message
357
360
358 def run(self):
361 def run(self):
359 try:
362 try:
360 self.server.serve_forever()
363 self.server.serve_forever()
361 finally:
364 finally:
362 os.unlink(self.address)
365 os.unlink(self.address)
363
366
364 _servicemap = {
367 _servicemap = {
365 'pipe': pipeservice,
368 'pipe': pipeservice,
366 'unix': unixservice,
369 'unix': unixservice,
367 }
370 }
368
371
369 def createservice(ui, repo, opts):
372 def createservice(ui, repo, opts):
370 mode = opts['cmdserver']
373 mode = opts['cmdserver']
371 try:
374 try:
372 return _servicemap[mode](ui, repo, opts)
375 return _servicemap[mode](ui, repo, opts)
373 except KeyError:
376 except KeyError:
374 raise error.Abort(_('unknown mode %s') % mode)
377 raise error.Abort(_('unknown mode %s') % mode)
General Comments 0
You need to be logged in to leave comments. Login now