##// END OF EJS Templates
commandserver: add _readstr and _readlist...
Jun Wu -
r28156:75f586a1 default
parent child Browse files
Show More
@@ -1,382 +1,401 b''
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import SocketServer
11 11 import errno
12 12 import os
13 13 import struct
14 14 import sys
15 15 import traceback
16 16
17 17 from .i18n import _
18 18 from . import (
19 19 encoding,
20 20 error,
21 21 util,
22 22 )
23 23
24 24 logfile = None
25 25
26 26 def log(*args):
27 27 if not logfile:
28 28 return
29 29
30 30 for a in args:
31 31 logfile.write(str(a))
32 32
33 33 logfile.flush()
34 34
35 35 class channeledoutput(object):
36 36 """
37 37 Write data to out in the following format:
38 38
39 39 data length (unsigned int),
40 40 data
41 41 """
42 42 def __init__(self, out, channel):
43 43 self.out = out
44 44 self.channel = channel
45 45
46 46 @property
47 47 def name(self):
48 48 return '<%c-channel>' % self.channel
49 49
50 50 def write(self, data):
51 51 if not data:
52 52 return
53 53 self.out.write(struct.pack('>cI', self.channel, len(data)))
54 54 self.out.write(data)
55 55 self.out.flush()
56 56
57 57 def __getattr__(self, attr):
58 58 if attr in ('isatty', 'fileno', 'tell', 'seek'):
59 59 raise AttributeError(attr)
60 60 return getattr(self.out, attr)
61 61
62 62 class channeledinput(object):
63 63 """
64 64 Read data from in_.
65 65
66 66 Requests for input are written to out in the following format:
67 67 channel identifier - 'I' for plain input, 'L' line based (1 byte)
68 68 how many bytes to send at most (unsigned int),
69 69
70 70 The client replies with:
71 71 data length (unsigned int), 0 meaning EOF
72 72 data
73 73 """
74 74
75 75 maxchunksize = 4 * 1024
76 76
77 77 def __init__(self, in_, out, channel):
78 78 self.in_ = in_
79 79 self.out = out
80 80 self.channel = channel
81 81
82 82 @property
83 83 def name(self):
84 84 return '<%c-channel>' % self.channel
85 85
86 86 def read(self, size=-1):
87 87 if size < 0:
88 88 # if we need to consume all the clients input, ask for 4k chunks
89 89 # so the pipe doesn't fill up risking a deadlock
90 90 size = self.maxchunksize
91 91 s = self._read(size, self.channel)
92 92 buf = s
93 93 while s:
94 94 s = self._read(size, self.channel)
95 95 buf += s
96 96
97 97 return buf
98 98 else:
99 99 return self._read(size, self.channel)
100 100
101 101 def _read(self, size, channel):
102 102 if not size:
103 103 return ''
104 104 assert size > 0
105 105
106 106 # tell the client we need at most size bytes
107 107 self.out.write(struct.pack('>cI', channel, size))
108 108 self.out.flush()
109 109
110 110 length = self.in_.read(4)
111 111 length = struct.unpack('>I', length)[0]
112 112 if not length:
113 113 return ''
114 114 else:
115 115 return self.in_.read(length)
116 116
117 117 def readline(self, size=-1):
118 118 if size < 0:
119 119 size = self.maxchunksize
120 120 s = self._read(size, 'L')
121 121 buf = s
122 122 # keep asking for more until there's either no more or
123 123 # we got a full line
124 124 while s and s[-1] != '\n':
125 125 s = self._read(size, 'L')
126 126 buf += s
127 127
128 128 return buf
129 129 else:
130 130 return self._read(size, 'L')
131 131
132 132 def __iter__(self):
133 133 return self
134 134
135 135 def next(self):
136 136 l = self.readline()
137 137 if not l:
138 138 raise StopIteration
139 139 return l
140 140
141 141 def __getattr__(self, attr):
142 142 if attr in ('isatty', 'fileno', 'tell', 'seek'):
143 143 raise AttributeError(attr)
144 144 return getattr(self.in_, attr)
145 145
146 146 class server(object):
147 147 """
148 148 Listens for commands on fin, runs them and writes the output on a channel
149 149 based stream to fout.
150 150 """
151 151 def __init__(self, ui, repo, fin, fout):
152 152 self.cwd = os.getcwd()
153 153
154 154 # developer config: cmdserver.log
155 155 logpath = ui.config("cmdserver", "log", None)
156 156 if logpath:
157 157 global logfile
158 158 if logpath == '-':
159 159 # write log on a special 'd' (debug) channel
160 160 logfile = channeledoutput(fout, 'd')
161 161 else:
162 162 logfile = open(logpath, 'a')
163 163
164 164 if repo:
165 165 # the ui here is really the repo ui so take its baseui so we don't
166 166 # end up with its local configuration
167 167 self.ui = repo.baseui
168 168 self.repo = repo
169 169 self.repoui = repo.ui
170 170 else:
171 171 self.ui = ui
172 172 self.repo = self.repoui = None
173 173
174 174 self.cerr = channeledoutput(fout, 'e')
175 175 self.cout = channeledoutput(fout, 'o')
176 176 self.cin = channeledinput(fin, fout, 'I')
177 177 self.cresult = channeledoutput(fout, 'r')
178 178
179 179 self.client = fin
180 180
181 181 def _read(self, size):
182 182 if not size:
183 183 return ''
184 184
185 185 data = self.client.read(size)
186 186
187 187 # is the other end closed?
188 188 if not data:
189 189 raise EOFError
190 190
191 191 return data
192 192
193 def _readstr(self):
194 """read a string from the channel
195
196 format:
197 data length (uint32), data
198 """
199 length = struct.unpack('>I', self._read(4))[0]
200 if not length:
201 return ''
202 return self._read(length)
203
204 def _readlist(self):
205 """read a list of NULL separated strings from the channel"""
206 s = self._readstr()
207 if s:
208 return s.split('\0')
209 else:
210 return []
211
193 212 def runcommand(self):
194 213 """ reads a list of \0 terminated arguments, executes
195 214 and writes the return code to the result channel """
196 215 from . import dispatch # avoid cycle
197 216
198 217 length = struct.unpack('>I', self._read(4))[0]
199 218 if not length:
200 219 args = []
201 220 else:
202 221 args = self._read(length).split('\0')
203 222
204 223 # copy the uis so changes (e.g. --config or --verbose) don't
205 224 # persist between requests
206 225 copiedui = self.ui.copy()
207 226 uis = [copiedui]
208 227 if self.repo:
209 228 self.repo.baseui = copiedui
210 229 # clone ui without using ui.copy because this is protected
211 230 repoui = self.repoui.__class__(self.repoui)
212 231 repoui.copy = copiedui.copy # redo copy protection
213 232 uis.append(repoui)
214 233 self.repo.ui = self.repo.dirstate._ui = repoui
215 234 self.repo.invalidateall()
216 235
217 236 # reset last-print time of progress bar per command
218 237 # (progbar is singleton, we don't have to do for all uis)
219 238 if copiedui._progbar:
220 239 copiedui._progbar.resetstate()
221 240
222 241 for ui in uis:
223 242 # any kind of interaction must use server channels, but chg may
224 243 # replace channels by fully functional tty files. so nontty is
225 244 # enforced only if cin is a channel.
226 245 if not util.safehasattr(self.cin, 'fileno'):
227 246 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
228 247
229 248 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
230 249 self.cout, self.cerr)
231 250
232 251 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
233 252
234 253 # restore old cwd
235 254 if '--cwd' in args:
236 255 os.chdir(self.cwd)
237 256
238 257 self.cresult.write(struct.pack('>i', int(ret)))
239 258
240 259 def getencoding(self):
241 260 """ writes the current encoding to the result channel """
242 261 self.cresult.write(encoding.encoding)
243 262
244 263 def serveone(self):
245 264 cmd = self.client.readline()[:-1]
246 265 if cmd:
247 266 handler = self.capabilities.get(cmd)
248 267 if handler:
249 268 handler(self)
250 269 else:
251 270 # clients are expected to check what commands are supported by
252 271 # looking at the servers capabilities
253 272 raise error.Abort(_('unknown command %s') % cmd)
254 273
255 274 return cmd != ''
256 275
257 276 capabilities = {'runcommand' : runcommand,
258 277 'getencoding' : getencoding}
259 278
260 279 def serve(self):
261 280 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
262 281 hellomsg += '\n'
263 282 hellomsg += 'encoding: ' + encoding.encoding
264 283 hellomsg += '\n'
265 284 hellomsg += 'pid: %d' % util.getpid()
266 285
267 286 # write the hello msg in -one- chunk
268 287 self.cout.write(hellomsg)
269 288
270 289 try:
271 290 while self.serveone():
272 291 pass
273 292 except EOFError:
274 293 # we'll get here if the client disconnected while we were reading
275 294 # its request
276 295 return 1
277 296
278 297 return 0
279 298
280 299 def _protectio(ui):
281 300 """ duplicates streams and redirect original to null if ui uses stdio """
282 301 ui.flush()
283 302 newfiles = []
284 303 nullfd = os.open(os.devnull, os.O_RDWR)
285 304 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
286 305 (ui.fout, sys.stdout, 'wb')]:
287 306 if f is sysf:
288 307 newfd = os.dup(f.fileno())
289 308 os.dup2(nullfd, f.fileno())
290 309 f = os.fdopen(newfd, mode)
291 310 newfiles.append(f)
292 311 os.close(nullfd)
293 312 return tuple(newfiles)
294 313
295 314 def _restoreio(ui, fin, fout):
296 315 """ restores streams from duplicated ones """
297 316 ui.flush()
298 317 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
299 318 if f is not uif:
300 319 os.dup2(f.fileno(), uif.fileno())
301 320 f.close()
302 321
303 322 class pipeservice(object):
304 323 def __init__(self, ui, repo, opts):
305 324 self.ui = ui
306 325 self.repo = repo
307 326
308 327 def init(self):
309 328 pass
310 329
311 330 def run(self):
312 331 ui = self.ui
313 332 # redirect stdio to null device so that broken extensions or in-process
314 333 # hooks will never cause corruption of channel protocol.
315 334 fin, fout = _protectio(ui)
316 335 try:
317 336 sv = server(ui, self.repo, fin, fout)
318 337 return sv.serve()
319 338 finally:
320 339 _restoreio(ui, fin, fout)
321 340
322 341 class _requesthandler(SocketServer.StreamRequestHandler):
323 342 def handle(self):
324 343 ui = self.server.ui
325 344 repo = self.server.repo
326 345 sv = server(ui, repo, self.rfile, self.wfile)
327 346 try:
328 347 try:
329 348 sv.serve()
330 349 # handle exceptions that may be raised by command server. most of
331 350 # known exceptions are caught by dispatch.
332 351 except error.Abort as inst:
333 352 ui.warn(_('abort: %s\n') % inst)
334 353 except IOError as inst:
335 354 if inst.errno != errno.EPIPE:
336 355 raise
337 356 except KeyboardInterrupt:
338 357 pass
339 358 except: # re-raises
340 359 # also write traceback to error channel. otherwise client cannot
341 360 # see it because it is written to server's stderr by default.
342 361 traceback.print_exc(file=sv.cerr)
343 362 raise
344 363
345 364 class unixservice(object):
346 365 """
347 366 Listens on unix domain socket and forks server per connection
348 367 """
349 368 def __init__(self, ui, repo, opts):
350 369 self.ui = ui
351 370 self.repo = repo
352 371 self.address = opts['address']
353 372 if not util.safehasattr(SocketServer, 'UnixStreamServer'):
354 373 raise error.Abort(_('unsupported platform'))
355 374 if not self.address:
356 375 raise error.Abort(_('no socket path specified with --address'))
357 376
358 377 def init(self):
359 378 class cls(SocketServer.ForkingMixIn, SocketServer.UnixStreamServer):
360 379 ui = self.ui
361 380 repo = self.repo
362 381 self.server = cls(self.address, _requesthandler)
363 382 self.ui.status(_('listening at %s\n') % self.address)
364 383 self.ui.flush() # avoid buffering of status message
365 384
366 385 def run(self):
367 386 try:
368 387 self.server.serve_forever()
369 388 finally:
370 389 os.unlink(self.address)
371 390
372 391 _servicemap = {
373 392 'pipe': pipeservice,
374 393 'unix': unixservice,
375 394 }
376 395
377 396 def createservice(ui, repo, opts):
378 397 mode = opts['cmdserver']
379 398 try:
380 399 return _servicemap[mode](ui, repo, opts)
381 400 except KeyError:
382 401 raise error.Abort(_('unknown mode %s') % mode)
General Comments 0
You need to be logged in to leave comments. Login now