##// END OF EJS Templates
commandserver: reset state of progress bar per command...
Yuya Nishihara -
r27566:5d6f984c default
parent child Browse files
Show More
@@ -1,377 +1,382
1 1 # commandserver.py - communicate with Mercurial's API over a pipe
2 2 #
3 3 # Copyright Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import SocketServer
11 11 import errno
12 12 import os
13 13 import struct
14 14 import sys
15 15 import traceback
16 16
17 17 from .i18n import _
18 18 from . import (
19 19 encoding,
20 20 error,
21 21 util,
22 22 )
23 23
24 24 logfile = None
25 25
26 26 def log(*args):
27 27 if not logfile:
28 28 return
29 29
30 30 for a in args:
31 31 logfile.write(str(a))
32 32
33 33 logfile.flush()
34 34
35 35 class channeledoutput(object):
36 36 """
37 37 Write data to out in the following format:
38 38
39 39 data length (unsigned int),
40 40 data
41 41 """
42 42 def __init__(self, out, channel):
43 43 self.out = out
44 44 self.channel = channel
45 45
46 46 @property
47 47 def name(self):
48 48 return '<%c-channel>' % self.channel
49 49
50 50 def write(self, data):
51 51 if not data:
52 52 return
53 53 self.out.write(struct.pack('>cI', self.channel, len(data)))
54 54 self.out.write(data)
55 55 self.out.flush()
56 56
57 57 def __getattr__(self, attr):
58 58 if attr in ('isatty', 'fileno'):
59 59 raise AttributeError(attr)
60 60 return getattr(self.out, attr)
61 61
62 62 class channeledinput(object):
63 63 """
64 64 Read data from in_.
65 65
66 66 Requests for input are written to out in the following format:
67 67 channel identifier - 'I' for plain input, 'L' line based (1 byte)
68 68 how many bytes to send at most (unsigned int),
69 69
70 70 The client replies with:
71 71 data length (unsigned int), 0 meaning EOF
72 72 data
73 73 """
74 74
75 75 maxchunksize = 4 * 1024
76 76
77 77 def __init__(self, in_, out, channel):
78 78 self.in_ = in_
79 79 self.out = out
80 80 self.channel = channel
81 81
82 82 @property
83 83 def name(self):
84 84 return '<%c-channel>' % self.channel
85 85
86 86 def read(self, size=-1):
87 87 if size < 0:
88 88 # if we need to consume all the clients input, ask for 4k chunks
89 89 # so the pipe doesn't fill up risking a deadlock
90 90 size = self.maxchunksize
91 91 s = self._read(size, self.channel)
92 92 buf = s
93 93 while s:
94 94 s = self._read(size, self.channel)
95 95 buf += s
96 96
97 97 return buf
98 98 else:
99 99 return self._read(size, self.channel)
100 100
101 101 def _read(self, size, channel):
102 102 if not size:
103 103 return ''
104 104 assert size > 0
105 105
106 106 # tell the client we need at most size bytes
107 107 self.out.write(struct.pack('>cI', channel, size))
108 108 self.out.flush()
109 109
110 110 length = self.in_.read(4)
111 111 length = struct.unpack('>I', length)[0]
112 112 if not length:
113 113 return ''
114 114 else:
115 115 return self.in_.read(length)
116 116
117 117 def readline(self, size=-1):
118 118 if size < 0:
119 119 size = self.maxchunksize
120 120 s = self._read(size, 'L')
121 121 buf = s
122 122 # keep asking for more until there's either no more or
123 123 # we got a full line
124 124 while s and s[-1] != '\n':
125 125 s = self._read(size, 'L')
126 126 buf += s
127 127
128 128 return buf
129 129 else:
130 130 return self._read(size, 'L')
131 131
132 132 def __iter__(self):
133 133 return self
134 134
135 135 def next(self):
136 136 l = self.readline()
137 137 if not l:
138 138 raise StopIteration
139 139 return l
140 140
141 141 def __getattr__(self, attr):
142 142 if attr in ('isatty', 'fileno'):
143 143 raise AttributeError(attr)
144 144 return getattr(self.in_, attr)
145 145
146 146 class server(object):
147 147 """
148 148 Listens for commands on fin, runs them and writes the output on a channel
149 149 based stream to fout.
150 150 """
151 151 def __init__(self, ui, repo, fin, fout):
152 152 self.cwd = os.getcwd()
153 153
154 154 # developer config: cmdserver.log
155 155 logpath = ui.config("cmdserver", "log", None)
156 156 if logpath:
157 157 global logfile
158 158 if logpath == '-':
159 159 # write log on a special 'd' (debug) channel
160 160 logfile = channeledoutput(fout, 'd')
161 161 else:
162 162 logfile = open(logpath, 'a')
163 163
164 164 if repo:
165 165 # the ui here is really the repo ui so take its baseui so we don't
166 166 # end up with its local configuration
167 167 self.ui = repo.baseui
168 168 self.repo = repo
169 169 self.repoui = repo.ui
170 170 else:
171 171 self.ui = ui
172 172 self.repo = self.repoui = None
173 173
174 174 self.cerr = channeledoutput(fout, 'e')
175 175 self.cout = channeledoutput(fout, 'o')
176 176 self.cin = channeledinput(fin, fout, 'I')
177 177 self.cresult = channeledoutput(fout, 'r')
178 178
179 179 self.client = fin
180 180
181 181 def _read(self, size):
182 182 if not size:
183 183 return ''
184 184
185 185 data = self.client.read(size)
186 186
187 187 # is the other end closed?
188 188 if not data:
189 189 raise EOFError
190 190
191 191 return data
192 192
193 193 def runcommand(self):
194 194 """ reads a list of \0 terminated arguments, executes
195 195 and writes the return code to the result channel """
196 196 from . import dispatch # avoid cycle
197 197
198 198 length = struct.unpack('>I', self._read(4))[0]
199 199 if not length:
200 200 args = []
201 201 else:
202 202 args = self._read(length).split('\0')
203 203
204 204 # copy the uis so changes (e.g. --config or --verbose) don't
205 205 # persist between requests
206 206 copiedui = self.ui.copy()
207 207 uis = [copiedui]
208 208 if self.repo:
209 209 self.repo.baseui = copiedui
210 210 # clone ui without using ui.copy because this is protected
211 211 repoui = self.repoui.__class__(self.repoui)
212 212 repoui.copy = copiedui.copy # redo copy protection
213 213 uis.append(repoui)
214 214 self.repo.ui = self.repo.dirstate._ui = repoui
215 215 self.repo.invalidateall()
216 216
217 # reset last-print time of progress bar per command
218 # (progbar is singleton, we don't have to do for all uis)
219 if copiedui._progbar:
220 copiedui._progbar.resetstate()
221
217 222 for ui in uis:
218 223 # any kind of interaction must use server channels, but chg may
219 224 # replace channels by fully functional tty files. so nontty is
220 225 # enforced only if cin is a channel.
221 226 if not util.safehasattr(self.cin, 'fileno'):
222 227 ui.setconfig('ui', 'nontty', 'true', 'commandserver')
223 228
224 229 req = dispatch.request(args[:], copiedui, self.repo, self.cin,
225 230 self.cout, self.cerr)
226 231
227 232 ret = (dispatch.dispatch(req) or 0) & 255 # might return None
228 233
229 234 # restore old cwd
230 235 if '--cwd' in args:
231 236 os.chdir(self.cwd)
232 237
233 238 self.cresult.write(struct.pack('>i', int(ret)))
234 239
235 240 def getencoding(self):
236 241 """ writes the current encoding to the result channel """
237 242 self.cresult.write(encoding.encoding)
238 243
239 244 def serveone(self):
240 245 cmd = self.client.readline()[:-1]
241 246 if cmd:
242 247 handler = self.capabilities.get(cmd)
243 248 if handler:
244 249 handler(self)
245 250 else:
246 251 # clients are expected to check what commands are supported by
247 252 # looking at the servers capabilities
248 253 raise error.Abort(_('unknown command %s') % cmd)
249 254
250 255 return cmd != ''
251 256
252 257 capabilities = {'runcommand' : runcommand,
253 258 'getencoding' : getencoding}
254 259
255 260 def serve(self):
256 261 hellomsg = 'capabilities: ' + ' '.join(sorted(self.capabilities))
257 262 hellomsg += '\n'
258 263 hellomsg += 'encoding: ' + encoding.encoding
259 264 hellomsg += '\n'
260 265 hellomsg += 'pid: %d' % os.getpid()
261 266
262 267 # write the hello msg in -one- chunk
263 268 self.cout.write(hellomsg)
264 269
265 270 try:
266 271 while self.serveone():
267 272 pass
268 273 except EOFError:
269 274 # we'll get here if the client disconnected while we were reading
270 275 # its request
271 276 return 1
272 277
273 278 return 0
274 279
275 280 def _protectio(ui):
276 281 """ duplicates streams and redirect original to null if ui uses stdio """
277 282 ui.flush()
278 283 newfiles = []
279 284 nullfd = os.open(os.devnull, os.O_RDWR)
280 285 for f, sysf, mode in [(ui.fin, sys.stdin, 'rb'),
281 286 (ui.fout, sys.stdout, 'wb')]:
282 287 if f is sysf:
283 288 newfd = os.dup(f.fileno())
284 289 os.dup2(nullfd, f.fileno())
285 290 f = os.fdopen(newfd, mode)
286 291 newfiles.append(f)
287 292 os.close(nullfd)
288 293 return tuple(newfiles)
289 294
290 295 def _restoreio(ui, fin, fout):
291 296 """ restores streams from duplicated ones """
292 297 ui.flush()
293 298 for f, uif in [(fin, ui.fin), (fout, ui.fout)]:
294 299 if f is not uif:
295 300 os.dup2(f.fileno(), uif.fileno())
296 301 f.close()
297 302
298 303 class pipeservice(object):
299 304 def __init__(self, ui, repo, opts):
300 305 self.ui = ui
301 306 self.repo = repo
302 307
303 308 def init(self):
304 309 pass
305 310
306 311 def run(self):
307 312 ui = self.ui
308 313 # redirect stdio to null device so that broken extensions or in-process
309 314 # hooks will never cause corruption of channel protocol.
310 315 fin, fout = _protectio(ui)
311 316 try:
312 317 sv = server(ui, self.repo, fin, fout)
313 318 return sv.serve()
314 319 finally:
315 320 _restoreio(ui, fin, fout)
316 321
317 322 class _requesthandler(SocketServer.StreamRequestHandler):
318 323 def handle(self):
319 324 ui = self.server.ui
320 325 repo = self.server.repo
321 326 sv = server(ui, repo, self.rfile, self.wfile)
322 327 try:
323 328 try:
324 329 sv.serve()
325 330 # handle exceptions that may be raised by command server. most of
326 331 # known exceptions are caught by dispatch.
327 332 except error.Abort as inst:
328 333 ui.warn(_('abort: %s\n') % inst)
329 334 except IOError as inst:
330 335 if inst.errno != errno.EPIPE:
331 336 raise
332 337 except KeyboardInterrupt:
333 338 pass
334 339 except: # re-raises
335 340 # also write traceback to error channel. otherwise client cannot
336 341 # see it because it is written to server's stderr by default.
337 342 traceback.print_exc(file=sv.cerr)
338 343 raise
339 344
340 345 class unixservice(object):
341 346 """
342 347 Listens on unix domain socket and forks server per connection
343 348 """
344 349 def __init__(self, ui, repo, opts):
345 350 self.ui = ui
346 351 self.repo = repo
347 352 self.address = opts['address']
348 353 if not util.safehasattr(SocketServer, 'UnixStreamServer'):
349 354 raise error.Abort(_('unsupported platform'))
350 355 if not self.address:
351 356 raise error.Abort(_('no socket path specified with --address'))
352 357
353 358 def init(self):
354 359 class cls(SocketServer.ForkingMixIn, SocketServer.UnixStreamServer):
355 360 ui = self.ui
356 361 repo = self.repo
357 362 self.server = cls(self.address, _requesthandler)
358 363 self.ui.status(_('listening at %s\n') % self.address)
359 364 self.ui.flush() # avoid buffering of status message
360 365
361 366 def run(self):
362 367 try:
363 368 self.server.serve_forever()
364 369 finally:
365 370 os.unlink(self.address)
366 371
367 372 _servicemap = {
368 373 'pipe': pipeservice,
369 374 'unix': unixservice,
370 375 }
371 376
372 377 def createservice(ui, repo, opts):
373 378 mode = opts['cmdserver']
374 379 try:
375 380 return _servicemap[mode](ui, repo, opts)
376 381 except KeyError:
377 382 raise error.Abort(_('unknown mode %s') % mode)
General Comments 0
You need to be logged in to leave comments. Login now