##// END OF EJS Templates
commandserver: drop tell() and seek() from channels (issue5049)...
Yuya Nishihara -
r27915:5f2a308b stable
parent child Browse files
Show More
@@ -1,382 +1,382 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import SocketServer
11 11 import errno
12 12 import os
13 13 import struct
14 14 import sys
15 15 import traceback
16 16
17 17 from .i18n import _
18 18 from . import (
19 19 encoding,
20 20 error,
21 21 util,
22 22 )
23 23
24 24 logfile = None
25 25
26 26 def log(*args):
27 27 if not logfile:
28 28 return
29 29
30 30 for a in args:
31 31 logfile.write(str(a))
32 32
33 33 logfile.flush()
34 34
35 35 class channeledoutput(object):
36 36 """
37 37 Write data to out in the following format:
38 38
39 39 data length (unsigned int),
40 40 data
41 41 """
42 42 def __init__(self, out, channel):
43 43 self.out = out
44 44 self.channel = channel
45 45
46 46 @property
47 47 def name(self):
48 48 return '<%c-channel>' % self.channel
49 49
50 50 def write(self, data):
51 51 if not data:
52 52 return
53 53 self.out.write(struct.pack('>cI', self.channel, len(data)))
54 54 self.out.write(data)
55 55 self.out.flush()
56 56
57 57 def __getattr__(self, attr):
58 if attr in ('isatty', 'fileno'):
58 if attr in ('isatty', 'fileno', 'tell', 'seek'):
59 59 raise AttributeError(attr)
60 60 return getattr(self.out, attr)
61 61
62 62 class channeledinput(object):
63 63 """
64 64 Read data from in_.
65 65
66 66 Requests for input are written to out in the following format:
67 67 channel identifier - 'I' for plain input, 'L' line based (1 byte)
68 68 how many bytes to send at most (unsigned int),
69 69
70 70 The client replies with:
71 71 data length (unsigned int), 0 meaning EOF
72 72 data
73 73 """
74 74
75 75 maxchunksize = 4 * 1024
76 76
77 77 def __init__(self, in_, out, channel):
78 78 self.in_ = in_
79 79 self.out = out
80 80 self.channel = channel
81 81
82 82 @property
83 83 def name(self):
84 84 return '<%c-channel>' % self.channel
85 85
86 86 def read(self, size=-1):
87 87 if size < 0:
88 88 # if we need to consume all the clients input, ask for 4k chunks
89 89 # so the pipe doesn't fill up risking a deadlock
90 90 size = self.maxchunksize
91 91 s = self._read(size, self.channel)
92 92 buf = s
93 93 while s:
94 94 s = self._read(size, self.channel)
95 95 buf += s
96 96
97 97 return buf
98 98 else:
99 99 return self._read(size, self.channel)
100 100
101 101 def _read(self, size, channel):
102 102 if not size:
103 103 return ''
104 104 assert size > 0
105 105
106 106 # tell the client we need at most size bytes
107 107 self.out.write(struct.pack('>cI', channel, size))
108 108 self.out.flush()
109 109
110 110 length = self.in_.read(4)
111 111 length = struct.unpack('>I', length)[0]
112 112 if not length:
113 113 return ''
114 114 else:
115 115 return self.in_.read(length)
116 116
117 117 def readline(self, size=-1):
118 118 if size < 0:
119 119 size = self.maxchunksize
120 120 s = self._read(size, 'L')
121 121 buf = s
122 122 # keep asking for more until there's either no more or
123 123 # we got a full line
124 124 while s and s[-1] != '\n':
125 125 s = self._read(size, 'L')
126 126 buf += s
127 127
128 128 return buf
129 129 else:
130 130 return self._read(size, 'L')
131 131
132 132 def __iter__(self):
133 133 return self
134 134
135 135 def next(self):
136 136 l = self.readline()
137 137 if not l:
138 138 raise StopIteration
139 139 return l
140 140
141 141 def __getattr__(self, attr):
142 if attr in ('isatty', 'fileno'):
142 if attr in ('isatty', 'fileno', 'tell', 'seek'):
143 143 raise AttributeError(attr)
144 144 return getattr(self.in_, attr)
145 145
146 146 class server(object):
147 147 """
148 148 Listens for commands on fin, runs them and writes the output on a channel
149 149 based stream to fout.
150 150 """
151 151 def __init__(self, ui, repo, fin, fout):
152 152 self.cwd = os.getcwd()
153 153
154 154 # developer config: cmdserver.log
155 155 logpath = ui.config("cmdserver", "log", None)
156 156 if logpath:
157 157 global logfile
158 158 if logpath == '-':
159 159 # write log on a special 'd' (debug) channel
160 160 logfile = channeledoutput(fout, 'd')
161 161 else:
162 162 logfile = open(logpath, 'a')
163 163
164 164 if repo:
165 165 # the ui here is really the repo ui so take its baseui so we don't
166 166 # end up with its local configuration
167 167 self.ui = repo.baseui
168 168 self.repo = repo
169 169 self.repoui = repo.ui
170 170 else:
171 171 self.ui = ui
172 172 self.repo = self.repoui = None
173 173
174 174 self.cerr = channeledoutput(fout, 'e')
175 175 self.cout = channeledoutput(fout, 'o')
176 176 self.cin = channeledinput(fin, fout, 'I')
177 177 self.cresult = channeledoutput(fout, 'r')
178 178
179 179 self.client = fin
180 180
181 181 def _read(self, size):
182 182 if not size:
183 183 return ''
184 184
185 185 data = self.client.read(size)
186 186
187 187 # is the other end closed?
188 188 if not data:
189 189 raise EOFError
190 190
191 191 return data
192 192
193 193 def runcommand(self):
194 194 """ reads a list of \0 terminated arguments, executes
195 195 and writes the return code to the result channel """
196 196 from . import dispatch # avoid cycle
197 197
198 198 length = struct.unpack('>I', self._read(4))[0]
199 199 if not length:
200 200 args = []
201 201 else:
202 202 args = self._read(length).split('\0')
203 203
204 204 # copy the uis so changes (e.g. --config or --verbose) don't
205 205 # persist between requests
206 206 copiedui = self.ui.copy()
207 207 uis = [copiedui]
208 208 if self.repo:
209 209 self.repo.baseui = copiedui
210 210 # clone ui without using ui.copy because this is protected
211 211 repoui = self.repoui.__class__(self.repoui)
212 212 repoui.copy = copiedui.copy # redo copy protection
213 213 uis.append(repoui)
214 214 self.repo.ui = self.repo.dirstate._ui = repoui
215 215 self.repo.invalidateall()
216 216
217 217 # reset last-print time of progress bar per command
218 218 # (progbar is singleton, we don't have to do for all uis)
219 219 if copiedui._progbar:
220 220 copiedui._progbar.resetstate()
221 221
222 222 for ui in uis:
223 223 # any kind of interaction must use server channels, but chg may
224 224 # replace channels by fully functional tty files. so nontty is
225 225 # enforced only if cin is a channel.
226 226 if not util.safehasattr(self.cin, 'fileno'):
227 227 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
228 228
229 229 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
230 230 self.cout, self.cerr)
231 231
232 232 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
233 233
234 234 # restore old cwd
235 235 if '--cwd' in args:
236 236 os.chdir(self.cwd)
237 237
238 238 self.cresult.write(struct.pack('>i', int(ret)))
239 239
240 240 def getencoding(self):
241 241 """ writes the current encoding to the result channel """
242 242 self.cresult.write(encoding.encoding)
243 243
244 244 def serveone(self):
245 245 cmd = self.client.readline()[:-1]
246 246 if cmd:
247 247 handler = self.capabilities.get(cmd)
248 248 if handler:
249 249 handler(self)
250 250 else:
251 251 # clients are expected to check what commands are supported by
252 252 # looking at the servers capabilities
253 253 raise error.Abort(_('unknown command %s') % cmd)
254 254
255 255 return cmd != ''
256 256
257 257 capabilities = {'runcommand' : runcommand,
258 258 'getencoding' : getencoding}
259 259
260 260 def serve(self):
261 261 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
262 262 hellomsg += '\n'
263 263 hellomsg += 'encoding: ' + encoding.encoding
264 264 hellomsg += '\n'
265 265 hellomsg += 'pid: %d' % os.getpid()
266 266
267 267 # write the hello msg in -one- chunk
268 268 self.cout.write(hellomsg)
269 269
270 270 try:
271 271 while self.serveone():
272 272 pass
273 273 except EOFError:
274 274 # we'll get here if the client disconnected while we were reading
275 275 # its request
276 276 return 1
277 277
278 278 return 0
279 279
280 280 def _protectio(ui):
281 281 """ duplicates streams and redirect original to null if ui uses stdio """
282 282 ui.flush()
283 283 newfiles = []
284 284 nullfd = os.open(os.devnull, os.O_RDWR)
285 285 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
286 286 (ui.fout, sys.stdout, 'wb')]:
287 287 if f is sysf:
288 288 newfd = os.dup(f.fileno())
289 289 os.dup2(nullfd, f.fileno())
290 290 f = os.fdopen(newfd, mode)
291 291 newfiles.append(f)
292 292 os.close(nullfd)
293 293 return tuple(newfiles)
294 294
295 295 def _restoreio(ui, fin, fout):
296 296 """ restores streams from duplicated ones """
297 297 ui.flush()
298 298 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
299 299 if f is not uif:
300 300 os.dup2(f.fileno(), uif.fileno())
301 301 f.close()
302 302
303 303 class pipeservice(object):
304 304 def __init__(self, ui, repo, opts):
305 305 self.ui = ui
306 306 self.repo = repo
307 307
308 308 def init(self):
309 309 pass
310 310
311 311 def run(self):
312 312 ui = self.ui
313 313 # redirect stdio to null device so that broken extensions or in-process
314 314 # hooks will never cause corruption of channel protocol.
315 315 fin, fout = _protectio(ui)
316 316 try:
317 317 sv = server(ui, self.repo, fin, fout)
318 318 return sv.serve()
319 319 finally:
320 320 _restoreio(ui, fin, fout)
321 321
322 322 class _requesthandler(SocketServer.StreamRequestHandler):
323 323 def handle(self):
324 324 ui = self.server.ui
325 325 repo = self.server.repo
326 326 sv = server(ui, repo, self.rfile, self.wfile)
327 327 try:
328 328 try:
329 329 sv.serve()
330 330 # handle exceptions that may be raised by command server. most of
331 331 # known exceptions are caught by dispatch.
332 332 except error.Abort as inst:
333 333 ui.warn(_('abort: %s\n') % inst)
334 334 except IOError as inst:
335 335 if inst.errno != errno.EPIPE:
336 336 raise
337 337 except KeyboardInterrupt:
338 338 pass
339 339 except: # re-raises
340 340 # also write traceback to error channel. otherwise client cannot
341 341 # see it because it is written to server's stderr by default.
342 342 traceback.print_exc(file=sv.cerr)
343 343 raise
344 344
345 345 class unixservice(object):
346 346 """
347 347 Listens on unix domain socket and forks server per connection
348 348 """
349 349 def __init__(self, ui, repo, opts):
350 350 self.ui = ui
351 351 self.repo = repo
352 352 self.address = opts['address']
353 353 if not util.safehasattr(SocketServer, 'UnixStreamServer'):
354 354 raise error.Abort(_('unsupported platform'))
355 355 if not self.address:
356 356 raise error.Abort(_('no socket path specified with --address'))
357 357
358 358 def init(self):
359 359 class cls(SocketServer.ForkingMixIn, SocketServer.UnixStreamServer):
360 360 ui = self.ui
361 361 repo = self.repo
362 362 self.server = cls(self.address, _requesthandler)
363 363 self.ui.status(_('listening at %s\n') % self.address)
364 364 self.ui.flush() # avoid buffering of status message
365 365
366 366 def run(self):
367 367 try:
368 368 self.server.serve_forever()
369 369 finally:
370 370 os.unlink(self.address)
371 371
372 372 _servicemap = {
373 373 'pipe': pipeservice,
374 374 'unix': unixservice,
375 375 }
376 376
377 377 def createservice(ui, repo, opts):
378 378 mode = opts['cmdserver']
379 379 try:
380 380 return _servicemap[mode](ui, repo, opts)
381 381 except KeyError:
382 382 raise error.Abort(_('unknown mode %s') % mode)
General Comments 0
You need to be logged in to leave comments. Login now