##// END OF EJS Templates
wireprotoserver: move responsetype() out of http handler...
Gregory Szorc -
r36089:341c886e default
parent child Browse files
Show More
@@ -1,455 +1,453
1 1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 3 #
4 4 # This software may be used and distributed according to the terms of the
5 5 # GNU General Public License version 2 or any later version.
6 6
7 7 from __future__ import absolute_import
8 8
9 9 import abc
10 10 import cgi
11 11 import contextlib
12 12 import struct
13 13 import sys
14 14
15 15 from .i18n import _
16 16 from . import (
17 17 encoding,
18 18 error,
19 19 hook,
20 20 pycompat,
21 21 util,
22 22 wireproto,
23 23 )
24 24
25 25 stringio = util.stringio
26 26
27 27 urlerr = util.urlerr
28 28 urlreq = util.urlreq
29 29
30 30 HTTP_OK = 200
31 31
32 32 HGTYPE = 'application/mercurial-0.1'
33 33 HGTYPE2 = 'application/mercurial-0.2'
34 34 HGERRTYPE = 'application/hg-error'
35 35
36 36 # Names of the SSH protocol implementations.
37 37 SSHV1 = 'ssh-v1'
38 38 # This is advertised over the wire. Incremental the counter at the end
39 39 # to reflect BC breakages.
40 40 SSHV2 = 'exp-ssh-v2-0001'
41 41
42 42 class baseprotocolhandler(object):
43 43 """Abstract base class for wire protocol handlers.
44 44
45 45 A wire protocol handler serves as an interface between protocol command
46 46 handlers and the wire protocol transport layer. Protocol handlers provide
47 47 methods to read command arguments, redirect stdio for the duration of
48 48 the request, handle response types, etc.
49 49 """
50 50
51 51 __metaclass__ = abc.ABCMeta
52 52
53 53 @abc.abstractproperty
54 54 def name(self):
55 55 """The name of the protocol implementation.
56 56
57 57 Used for uniquely identifying the transport type.
58 58 """
59 59
60 60 @abc.abstractmethod
61 61 def getargs(self, args):
62 62 """return the value for arguments in <args>
63 63
64 64 returns a list of values (same order as <args>)"""
65 65
66 66 @abc.abstractmethod
67 67 def forwardpayload(self, fp):
68 68 """Read the raw payload and forward to a file.
69 69
70 70 The payload is read in full before the function returns.
71 71 """
72 72
73 73 @abc.abstractmethod
74 74 def mayberedirectstdio(self):
75 75 """Context manager to possibly redirect stdio.
76 76
77 77 The context manager yields a file-object like object that receives
78 78 stdout and stderr output when the context manager is active. Or it
79 79 yields ``None`` if no I/O redirection occurs.
80 80
81 81 The intent of this context manager is to capture stdio output
82 82 so it may be sent in the response. Some transports support streaming
83 83 stdio to the client in real time. For these transports, stdio output
84 84 won't be captured.
85 85 """
86 86
87 87 @abc.abstractmethod
88 88 def client(self):
89 89 """Returns a string representation of this client (as bytes)."""
90 90
91 91 def decodevaluefromheaders(req, headerprefix):
92 92 """Decode a long value from multiple HTTP request headers.
93 93
94 94 Returns the value as a bytes, not a str.
95 95 """
96 96 chunks = []
97 97 i = 1
98 98 prefix = headerprefix.upper().replace(r'-', r'_')
99 99 while True:
100 100 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
101 101 if v is None:
102 102 break
103 103 chunks.append(pycompat.bytesurl(v))
104 104 i += 1
105 105
106 106 return ''.join(chunks)
107 107
108 108 class webproto(baseprotocolhandler):
109 109 def __init__(self, req, ui):
110 110 self._req = req
111 111 self._ui = ui
112 112
113 113 @property
114 114 def name(self):
115 115 return 'http'
116 116
117 117 def getargs(self, args):
118 118 knownargs = self._args()
119 119 data = {}
120 120 keys = args.split()
121 121 for k in keys:
122 122 if k == '*':
123 123 star = {}
124 124 for key in knownargs.keys():
125 125 if key != 'cmd' and key not in keys:
126 126 star[key] = knownargs[key][0]
127 127 data['*'] = star
128 128 else:
129 129 data[k] = knownargs[k][0]
130 130 return [data[k] for k in keys]
131 131
132 132 def _args(self):
133 133 args = util.rapply(pycompat.bytesurl, self._req.form.copy())
134 134 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
135 135 if postlen:
136 136 args.update(cgi.parse_qs(
137 137 self._req.read(postlen), keep_blank_values=True))
138 138 return args
139 139
140 140 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
141 141 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
142 142 return args
143 143
144 144 def forwardpayload(self, fp):
145 145 length = int(self._req.env[r'CONTENT_LENGTH'])
146 146 # If httppostargs is used, we need to read Content-Length
147 147 # minus the amount that was consumed by args.
148 148 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
149 149 for s in util.filechunkiter(self._req, limit=length):
150 150 fp.write(s)
151 151
152 152 @contextlib.contextmanager
153 153 def mayberedirectstdio(self):
154 154 oldout = self._ui.fout
155 155 olderr = self._ui.ferr
156 156
157 157 out = util.stringio()
158 158
159 159 try:
160 160 self._ui.fout = out
161 161 self._ui.ferr = out
162 162 yield out
163 163 finally:
164 164 self._ui.fout = oldout
165 165 self._ui.ferr = olderr
166 166
167 167 def client(self):
168 168 return 'remote:%s:%s:%s' % (
169 169 self._req.env.get('wsgi.url_scheme') or 'http',
170 170 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
171 171 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
172 172
173 def responsetype(self, prefer_uncompressed):
174 """Determine the appropriate response type and compression settings.
175
176 Returns a tuple of (mediatype, compengine, engineopts).
177 """
178 # Determine the response media type and compression engine based
179 # on the request parameters.
180 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
181
182 if '0.2' in protocaps:
183 # All clients are expected to support uncompressed data.
184 if prefer_uncompressed:
185 return HGTYPE2, util._noopengine(), {}
186
187 # Default as defined by wire protocol spec.
188 compformats = ['zlib', 'none']
189 for cap in protocaps:
190 if cap.startswith('comp='):
191 compformats = cap[5:].split(',')
192 break
193
194 # Now find an agreed upon compression format.
195 for engine in wireproto.supportedcompengines(self._ui,
196 util.SERVERROLE):
197 if engine.wireprotosupport().name in compformats:
198 opts = {}
199 level = self._ui.configint('server',
200 '%slevel' % engine.name())
201 if level is not None:
202 opts['level'] = level
203
204 return HGTYPE2, engine, opts
205
206 # No mutually supported compression format. Fall back to the
207 # legacy protocol.
208
209 # Don't allow untrusted settings because disabling compression or
210 # setting a very high compression level could lead to flooding
211 # the server's network or CPU.
212 opts = {'level': self._ui.configint('server', 'zliblevel')}
213 return HGTYPE, util.compengines['zlib'], opts
214
215 173 def iscmd(cmd):
216 174 return cmd in wireproto.commands
217 175
218 176 def parsehttprequest(repo, req, query):
219 177 """Parse the HTTP request for a wire protocol request.
220 178
221 179 If the current request appears to be a wire protocol request, this
222 180 function returns a dict with details about that request, including
223 181 an ``abstractprotocolserver`` instance suitable for handling the
224 182 request. Otherwise, ``None`` is returned.
225 183
226 184 ``req`` is a ``wsgirequest`` instance.
227 185 """
228 186 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
229 187 # string parameter. If it isn't present, this isn't a wire protocol
230 188 # request.
231 189 if r'cmd' not in req.form:
232 190 return None
233 191
234 192 cmd = pycompat.sysbytes(req.form[r'cmd'][0])
235 193
236 194 # The "cmd" request parameter is used by both the wire protocol and hgweb.
237 195 # While not all wire protocol commands are available for all transports,
238 196 # if we see a "cmd" value that resembles a known wire protocol command, we
239 197 # route it to a protocol handler. This is better than routing possible
240 198 # wire protocol requests to hgweb because it prevents hgweb from using
241 199 # known wire protocol commands and it is less confusing for machine
242 200 # clients.
243 201 if cmd not in wireproto.commands:
244 202 return None
245 203
246 204 proto = webproto(req, repo.ui)
247 205
248 206 return {
249 207 'cmd': cmd,
250 208 'proto': proto,
251 209 'dispatch': lambda: _callhttp(repo, req, proto, cmd),
252 210 'handleerror': lambda ex: _handlehttperror(ex, req, cmd),
253 211 }
254 212
213 def _httpresponsetype(ui, req, prefer_uncompressed):
214 """Determine the appropriate response type and compression settings.
215
216 Returns a tuple of (mediatype, compengine, engineopts).
217 """
218 # Determine the response media type and compression engine based
219 # on the request parameters.
220 protocaps = decodevaluefromheaders(req, r'X-HgProto').split(' ')
221
222 if '0.2' in protocaps:
223 # All clients are expected to support uncompressed data.
224 if prefer_uncompressed:
225 return HGTYPE2, util._noopengine(), {}
226
227 # Default as defined by wire protocol spec.
228 compformats = ['zlib', 'none']
229 for cap in protocaps:
230 if cap.startswith('comp='):
231 compformats = cap[5:].split(',')
232 break
233
234 # Now find an agreed upon compression format.
235 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
236 if engine.wireprotosupport().name in compformats:
237 opts = {}
238 level = ui.configint('server', '%slevel' % engine.name())
239 if level is not None:
240 opts['level'] = level
241
242 return HGTYPE2, engine, opts
243
244 # No mutually supported compression format. Fall back to the
245 # legacy protocol.
246
247 # Don't allow untrusted settings because disabling compression or
248 # setting a very high compression level could lead to flooding
249 # the server's network or CPU.
250 opts = {'level': ui.configint('server', 'zliblevel')}
251 return HGTYPE, util.compengines['zlib'], opts
252
255 253 def _callhttp(repo, req, proto, cmd):
256 254 def genversion2(gen, engine, engineopts):
257 255 # application/mercurial-0.2 always sends a payload header
258 256 # identifying the compression engine.
259 257 name = engine.wireprotosupport().name
260 258 assert 0 < len(name) < 256
261 259 yield struct.pack('B', len(name))
262 260 yield name
263 261
264 262 for chunk in gen:
265 263 yield chunk
266 264
267 265 rsp = wireproto.dispatch(repo, proto, cmd)
268 266
269 267 if not wireproto.commands.commandavailable(cmd, proto):
270 268 req.respond(HTTP_OK, HGERRTYPE,
271 269 body=_('requested wire protocol command is not available '
272 270 'over HTTP'))
273 271 return []
274 272
275 273 if isinstance(rsp, bytes):
276 274 req.respond(HTTP_OK, HGTYPE, body=rsp)
277 275 return []
278 276 elif isinstance(rsp, wireproto.streamres_legacy):
279 277 gen = rsp.gen
280 278 req.respond(HTTP_OK, HGTYPE)
281 279 return gen
282 280 elif isinstance(rsp, wireproto.streamres):
283 281 gen = rsp.gen
284 282
285 283 # This code for compression should not be streamres specific. It
286 284 # is here because we only compress streamres at the moment.
287 mediatype, engine, engineopts = proto.responsetype(
288 rsp.prefer_uncompressed)
285 mediatype, engine, engineopts = _httpresponsetype(
286 repo.ui, req, rsp.prefer_uncompressed)
289 287 gen = engine.compressstream(gen, engineopts)
290 288
291 289 if mediatype == HGTYPE2:
292 290 gen = genversion2(gen, engine, engineopts)
293 291
294 292 req.respond(HTTP_OK, mediatype)
295 293 return gen
296 294 elif isinstance(rsp, wireproto.pushres):
297 295 rsp = '%d\n%s' % (rsp.res, rsp.output)
298 296 req.respond(HTTP_OK, HGTYPE, body=rsp)
299 297 return []
300 298 elif isinstance(rsp, wireproto.pusherr):
301 299 # This is the httplib workaround documented in _handlehttperror().
302 300 req.drain()
303 301
304 302 rsp = '0\n%s\n' % rsp.res
305 303 req.respond(HTTP_OK, HGTYPE, body=rsp)
306 304 return []
307 305 elif isinstance(rsp, wireproto.ooberror):
308 306 rsp = rsp.message
309 307 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
310 308 return []
311 309 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
312 310
313 311 def _handlehttperror(e, req, cmd):
314 312 """Called when an ErrorResponse is raised during HTTP request processing."""
315 313
316 314 # Clients using Python's httplib are stateful: the HTTP client
317 315 # won't process an HTTP response until all request data is
318 316 # sent to the server. The intent of this code is to ensure
319 317 # we always read HTTP request data from the client, thus
320 318 # ensuring httplib transitions to a state that allows it to read
321 319 # the HTTP response. In other words, it helps prevent deadlocks
322 320 # on clients using httplib.
323 321
324 322 if (req.env[r'REQUEST_METHOD'] == r'POST' and
325 323 # But not if Expect: 100-continue is being used.
326 324 (req.env.get('HTTP_EXPECT',
327 325 '').lower() != '100-continue') or
328 326 # Or the non-httplib HTTP library is being advertised by
329 327 # the client.
330 328 req.env.get('X-HgHttp2', '')):
331 329 req.drain()
332 330 else:
333 331 req.headers.append((r'Connection', r'Close'))
334 332
335 333 # TODO This response body assumes the failed command was
336 334 # "unbundle." That assumption is not always valid.
337 335 req.respond(e, HGTYPE, body='0\n%s\n' % e)
338 336
339 337 return ''
340 338
341 339 def _sshv1respondbytes(fout, value):
342 340 """Send a bytes response for protocol version 1."""
343 341 fout.write('%d\n' % len(value))
344 342 fout.write(value)
345 343 fout.flush()
346 344
347 345 def _sshv1respondstream(fout, source):
348 346 write = fout.write
349 347 for chunk in source.gen:
350 348 write(chunk)
351 349 fout.flush()
352 350
353 351 def _sshv1respondooberror(fout, ferr, rsp):
354 352 ferr.write(b'%s\n-\n' % rsp)
355 353 ferr.flush()
356 354 fout.write(b'\n')
357 355 fout.flush()
358 356
359 357 class sshv1protocolhandler(baseprotocolhandler):
360 358 """Handler for requests services via version 1 of SSH protocol."""
361 359 def __init__(self, ui, fin, fout):
362 360 self._ui = ui
363 361 self._fin = fin
364 362 self._fout = fout
365 363
366 364 @property
367 365 def name(self):
368 366 return 'ssh'
369 367
370 368 def getargs(self, args):
371 369 data = {}
372 370 keys = args.split()
373 371 for n in xrange(len(keys)):
374 372 argline = self._fin.readline()[:-1]
375 373 arg, l = argline.split()
376 374 if arg not in keys:
377 375 raise error.Abort(_("unexpected parameter %r") % arg)
378 376 if arg == '*':
379 377 star = {}
380 378 for k in xrange(int(l)):
381 379 argline = self._fin.readline()[:-1]
382 380 arg, l = argline.split()
383 381 val = self._fin.read(int(l))
384 382 star[arg] = val
385 383 data['*'] = star
386 384 else:
387 385 val = self._fin.read(int(l))
388 386 data[arg] = val
389 387 return [data[k] for k in keys]
390 388
391 389 def forwardpayload(self, fpout):
392 390 # The file is in the form:
393 391 #
394 392 # <chunk size>\n<chunk>
395 393 # ...
396 394 # 0\n
397 395 _sshv1respondbytes(self._fout, b'')
398 396 count = int(self._fin.readline())
399 397 while count:
400 398 fpout.write(self._fin.read(count))
401 399 count = int(self._fin.readline())
402 400
403 401 @contextlib.contextmanager
404 402 def mayberedirectstdio(self):
405 403 yield None
406 404
407 405 def client(self):
408 406 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
409 407 return 'remote:ssh:' + client
410 408
411 409 class sshserver(object):
412 410 def __init__(self, ui, repo):
413 411 self._ui = ui
414 412 self._repo = repo
415 413 self._fin = ui.fin
416 414 self._fout = ui.fout
417 415
418 416 hook.redirect(True)
419 417 ui.fout = repo.ui.fout = ui.ferr
420 418
421 419 # Prevent insertion/deletion of CRs
422 420 util.setbinary(self._fin)
423 421 util.setbinary(self._fout)
424 422
425 423 self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
426 424
427 425 def serve_forever(self):
428 426 while self.serve_one():
429 427 pass
430 428 sys.exit(0)
431 429
432 430 def serve_one(self):
433 431 cmd = self._fin.readline()[:-1]
434 432 if cmd and wireproto.commands.commandavailable(cmd, self._proto):
435 433 rsp = wireproto.dispatch(self._repo, self._proto, cmd)
436 434
437 435 if isinstance(rsp, bytes):
438 436 _sshv1respondbytes(self._fout, rsp)
439 437 elif isinstance(rsp, wireproto.streamres):
440 438 _sshv1respondstream(self._fout, rsp)
441 439 elif isinstance(rsp, wireproto.streamres_legacy):
442 440 _sshv1respondstream(self._fout, rsp)
443 441 elif isinstance(rsp, wireproto.pushres):
444 442 _sshv1respondbytes(self._fout, b'')
445 443 _sshv1respondbytes(self._fout, bytes(rsp.res))
446 444 elif isinstance(rsp, wireproto.pusherr):
447 445 _sshv1respondbytes(self._fout, rsp.res)
448 446 elif isinstance(rsp, wireproto.ooberror):
449 447 _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
450 448 else:
451 449 raise error.ProgrammingError('unhandled response type from '
452 450 'wire protocol command: %s' % rsp)
453 451 elif cmd:
454 452 _sshv1respondbytes(self._fout, b'')
455 453 return cmd != ''
General Comments 0
You need to be logged in to leave comments. Login now