##// END OF EJS Templates
wireprotoserver: extract SSH response handling functions...
Gregory Szorc -
r36081:5767664d default
parent child Browse files
Show More
@@ -1,437 +1,438
1 1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 3 #
4 4 # This software may be used and distributed according to the terms of the
5 5 # GNU General Public License version 2 or any later version.
6 6
7 7 from __future__ import absolute_import
8 8
9 9 import abc
10 10 import cgi
11 11 import struct
12 12 import sys
13 13
14 14 from .i18n import _
15 15 from . import (
16 16 encoding,
17 17 error,
18 18 hook,
19 19 pycompat,
20 20 util,
21 21 wireproto,
22 22 )
23 23
24 24 stringio = util.stringio
25 25
26 26 urlerr = util.urlerr
27 27 urlreq = util.urlreq
28 28
29 29 HTTP_OK = 200
30 30
31 31 HGTYPE = 'application/mercurial-0.1'
32 32 HGTYPE2 = 'application/mercurial-0.2'
33 33 HGERRTYPE = 'application/hg-error'
34 34
35 35 # Names of the SSH protocol implementations.
36 36 SSHV1 = 'ssh-v1'
37 37 # This is advertised over the wire. Incremental the counter at the end
38 38 # to reflect BC breakages.
39 39 SSHV2 = 'exp-ssh-v2-0001'
40 40
41 41 class baseprotocolhandler(object):
42 42 """Abstract base class for wire protocol handlers.
43 43
44 44 A wire protocol handler serves as an interface between protocol command
45 45 handlers and the wire protocol transport layer. Protocol handlers provide
46 46 methods to read command arguments, redirect stdio for the duration of
47 47 the request, handle response types, etc.
48 48 """
49 49
50 50 __metaclass__ = abc.ABCMeta
51 51
52 52 @abc.abstractproperty
53 53 def name(self):
54 54 """The name of the protocol implementation.
55 55
56 56 Used for uniquely identifying the transport type.
57 57 """
58 58
59 59 @abc.abstractmethod
60 60 def getargs(self, args):
61 61 """return the value for arguments in <args>
62 62
63 63 returns a list of values (same order as <args>)"""
64 64
65 65 @abc.abstractmethod
66 66 def getfile(self, fp):
67 67 """write the whole content of a file into a file like object
68 68
69 69 The file is in the form::
70 70
71 71 (<chunk-size>\n<chunk>)+0\n
72 72
73 73 chunk size is the ascii version of the int.
74 74 """
75 75
76 76 @abc.abstractmethod
77 77 def redirect(self):
78 78 """may setup interception for stdout and stderr
79 79
80 80 See also the `restore` method."""
81 81
82 82 # If the `redirect` function does install interception, the `restore`
83 83 # function MUST be defined. If interception is not used, this function
84 84 # MUST NOT be defined.
85 85 #
86 86 # left commented here on purpose
87 87 #
88 88 #def restore(self):
89 89 # """reinstall previous stdout and stderr and return intercepted stdout
90 90 # """
91 91 # raise NotImplementedError()
92 92
93 93 def decodevaluefromheaders(req, headerprefix):
94 94 """Decode a long value from multiple HTTP request headers.
95 95
96 96 Returns the value as a bytes, not a str.
97 97 """
98 98 chunks = []
99 99 i = 1
100 100 prefix = headerprefix.upper().replace(r'-', r'_')
101 101 while True:
102 102 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
103 103 if v is None:
104 104 break
105 105 chunks.append(pycompat.bytesurl(v))
106 106 i += 1
107 107
108 108 return ''.join(chunks)
109 109
110 110 class webproto(baseprotocolhandler):
111 111 def __init__(self, req, ui):
112 112 self._req = req
113 113 self._ui = ui
114 114
115 115 @property
116 116 def name(self):
117 117 return 'http'
118 118
119 119 def getargs(self, args):
120 120 knownargs = self._args()
121 121 data = {}
122 122 keys = args.split()
123 123 for k in keys:
124 124 if k == '*':
125 125 star = {}
126 126 for key in knownargs.keys():
127 127 if key != 'cmd' and key not in keys:
128 128 star[key] = knownargs[key][0]
129 129 data['*'] = star
130 130 else:
131 131 data[k] = knownargs[k][0]
132 132 return [data[k] for k in keys]
133 133
134 134 def _args(self):
135 135 args = util.rapply(pycompat.bytesurl, self._req.form.copy())
136 136 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
137 137 if postlen:
138 138 args.update(cgi.parse_qs(
139 139 self._req.read(postlen), keep_blank_values=True))
140 140 return args
141 141
142 142 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
143 143 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
144 144 return args
145 145
146 146 def getfile(self, fp):
147 147 length = int(self._req.env[r'CONTENT_LENGTH'])
148 148 # If httppostargs is used, we need to read Content-Length
149 149 # minus the amount that was consumed by args.
150 150 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
151 151 for s in util.filechunkiter(self._req, limit=length):
152 152 fp.write(s)
153 153
154 154 def redirect(self):
155 155 self._oldio = self._ui.fout, self._ui.ferr
156 156 self._ui.ferr = self._ui.fout = stringio()
157 157
158 158 def restore(self):
159 159 val = self._ui.fout.getvalue()
160 160 self._ui.ferr, self._ui.fout = self._oldio
161 161 return val
162 162
163 163 def _client(self):
164 164 return 'remote:%s:%s:%s' % (
165 165 self._req.env.get('wsgi.url_scheme') or 'http',
166 166 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
167 167 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
168 168
169 169 def responsetype(self, prefer_uncompressed):
170 170 """Determine the appropriate response type and compression settings.
171 171
172 172 Returns a tuple of (mediatype, compengine, engineopts).
173 173 """
174 174 # Determine the response media type and compression engine based
175 175 # on the request parameters.
176 176 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
177 177
178 178 if '0.2' in protocaps:
179 179 # All clients are expected to support uncompressed data.
180 180 if prefer_uncompressed:
181 181 return HGTYPE2, util._noopengine(), {}
182 182
183 183 # Default as defined by wire protocol spec.
184 184 compformats = ['zlib', 'none']
185 185 for cap in protocaps:
186 186 if cap.startswith('comp='):
187 187 compformats = cap[5:].split(',')
188 188 break
189 189
190 190 # Now find an agreed upon compression format.
191 191 for engine in wireproto.supportedcompengines(self._ui, self,
192 192 util.SERVERROLE):
193 193 if engine.wireprotosupport().name in compformats:
194 194 opts = {}
195 195 level = self._ui.configint('server',
196 196 '%slevel' % engine.name())
197 197 if level is not None:
198 198 opts['level'] = level
199 199
200 200 return HGTYPE2, engine, opts
201 201
202 202 # No mutually supported compression format. Fall back to the
203 203 # legacy protocol.
204 204
205 205 # Don't allow untrusted settings because disabling compression or
206 206 # setting a very high compression level could lead to flooding
207 207 # the server's network or CPU.
208 208 opts = {'level': self._ui.configint('server', 'zliblevel')}
209 209 return HGTYPE, util.compengines['zlib'], opts
210 210
211 211 def iscmd(cmd):
212 212 return cmd in wireproto.commands
213 213
214 214 def parsehttprequest(repo, req, query):
215 215 """Parse the HTTP request for a wire protocol request.
216 216
217 217 If the current request appears to be a wire protocol request, this
218 218 function returns a dict with details about that request, including
219 219 an ``abstractprotocolserver`` instance suitable for handling the
220 220 request. Otherwise, ``None`` is returned.
221 221
222 222 ``req`` is a ``wsgirequest`` instance.
223 223 """
224 224 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
225 225 # string parameter. If it isn't present, this isn't a wire protocol
226 226 # request.
227 227 if r'cmd' not in req.form:
228 228 return None
229 229
230 230 cmd = pycompat.sysbytes(req.form[r'cmd'][0])
231 231
232 232 # The "cmd" request parameter is used by both the wire protocol and hgweb.
233 233 # While not all wire protocol commands are available for all transports,
234 234 # if we see a "cmd" value that resembles a known wire protocol command, we
235 235 # route it to a protocol handler. This is better than routing possible
236 236 # wire protocol requests to hgweb because it prevents hgweb from using
237 237 # known wire protocol commands and it is less confusing for machine
238 238 # clients.
239 239 if cmd not in wireproto.commands:
240 240 return None
241 241
242 242 proto = webproto(req, repo.ui)
243 243
244 244 return {
245 245 'cmd': cmd,
246 246 'proto': proto,
247 247 'dispatch': lambda: _callhttp(repo, req, proto, cmd),
248 248 'handleerror': lambda ex: _handlehttperror(ex, req, cmd),
249 249 }
250 250
251 251 def _callhttp(repo, req, proto, cmd):
252 252 def genversion2(gen, engine, engineopts):
253 253 # application/mercurial-0.2 always sends a payload header
254 254 # identifying the compression engine.
255 255 name = engine.wireprotosupport().name
256 256 assert 0 < len(name) < 256
257 257 yield struct.pack('B', len(name))
258 258 yield name
259 259
260 260 for chunk in gen:
261 261 yield chunk
262 262
263 263 rsp = wireproto.dispatch(repo, proto, cmd)
264 264
265 265 if not wireproto.commands.commandavailable(cmd, proto):
266 266 req.respond(HTTP_OK, HGERRTYPE,
267 267 body=_('requested wire protocol command is not available '
268 268 'over HTTP'))
269 269 return []
270 270
271 271 if isinstance(rsp, bytes):
272 272 req.respond(HTTP_OK, HGTYPE, body=rsp)
273 273 return []
274 274 elif isinstance(rsp, wireproto.streamres_legacy):
275 275 gen = rsp.gen
276 276 req.respond(HTTP_OK, HGTYPE)
277 277 return gen
278 278 elif isinstance(rsp, wireproto.streamres):
279 279 gen = rsp.gen
280 280
281 281 # This code for compression should not be streamres specific. It
282 282 # is here because we only compress streamres at the moment.
283 283 mediatype, engine, engineopts = proto.responsetype(
284 284 rsp.prefer_uncompressed)
285 285 gen = engine.compressstream(gen, engineopts)
286 286
287 287 if mediatype == HGTYPE2:
288 288 gen = genversion2(gen, engine, engineopts)
289 289
290 290 req.respond(HTTP_OK, mediatype)
291 291 return gen
292 292 elif isinstance(rsp, wireproto.pushres):
293 293 val = proto.restore()
294 294 rsp = '%d\n%s' % (rsp.res, val)
295 295 req.respond(HTTP_OK, HGTYPE, body=rsp)
296 296 return []
297 297 elif isinstance(rsp, wireproto.pusherr):
298 298 # This is the httplib workaround documented in _handlehttperror().
299 299 req.drain()
300 300
301 301 proto.restore()
302 302 rsp = '0\n%s\n' % rsp.res
303 303 req.respond(HTTP_OK, HGTYPE, body=rsp)
304 304 return []
305 305 elif isinstance(rsp, wireproto.ooberror):
306 306 rsp = rsp.message
307 307 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
308 308 return []
309 309 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
310 310
311 311 def _handlehttperror(e, req, cmd):
312 312 """Called when an ErrorResponse is raised during HTTP request processing."""
313 313
314 314 # Clients using Python's httplib are stateful: the HTTP client
315 315 # won't process an HTTP response until all request data is
316 316 # sent to the server. The intent of this code is to ensure
317 317 # we always read HTTP request data from the client, thus
318 318 # ensuring httplib transitions to a state that allows it to read
319 319 # the HTTP response. In other words, it helps prevent deadlocks
320 320 # on clients using httplib.
321 321
322 322 if (req.env[r'REQUEST_METHOD'] == r'POST' and
323 323 # But not if Expect: 100-continue is being used.
324 324 (req.env.get('HTTP_EXPECT',
325 325 '').lower() != '100-continue') or
326 326 # Or the non-httplib HTTP library is being advertised by
327 327 # the client.
328 328 req.env.get('X-HgHttp2', '')):
329 329 req.drain()
330 330 else:
331 331 req.headers.append((r'Connection', r'Close'))
332 332
333 333 # TODO This response body assumes the failed command was
334 334 # "unbundle." That assumption is not always valid.
335 335 req.respond(e, HGTYPE, body='0\n%s\n' % e)
336 336
337 337 return ''
338 338
339 def _sshv1respondbytes(fout, value):
340 """Send a bytes response for protocol version 1."""
341 fout.write('%d\n' % len(value))
342 fout.write(value)
343 fout.flush()
344
345 def _sshv1respondstream(fout, source):
346 write = fout.write
347 for chunk in source.gen:
348 write(chunk)
349 fout.flush()
350
351 def _sshv1respondooberror(fout, ferr, rsp):
352 ferr.write(b'%s\n-\n' % rsp)
353 ferr.flush()
354 fout.write(b'\n')
355 fout.flush()
356
339 357 class sshserver(baseprotocolhandler):
340 358 def __init__(self, ui, repo):
341 359 self._ui = ui
342 360 self._repo = repo
343 361 self._fin = ui.fin
344 362 self._fout = ui.fout
345 363
346 364 hook.redirect(True)
347 365 ui.fout = repo.ui.fout = ui.ferr
348 366
349 367 # Prevent insertion/deletion of CRs
350 368 util.setbinary(self._fin)
351 369 util.setbinary(self._fout)
352 370
353 371 @property
354 372 def name(self):
355 373 return 'ssh'
356 374
357 375 def getargs(self, args):
358 376 data = {}
359 377 keys = args.split()
360 378 for n in xrange(len(keys)):
361 379 argline = self._fin.readline()[:-1]
362 380 arg, l = argline.split()
363 381 if arg not in keys:
364 382 raise error.Abort(_("unexpected parameter %r") % arg)
365 383 if arg == '*':
366 384 star = {}
367 385 for k in xrange(int(l)):
368 386 argline = self._fin.readline()[:-1]
369 387 arg, l = argline.split()
370 388 val = self._fin.read(int(l))
371 389 star[arg] = val
372 390 data['*'] = star
373 391 else:
374 392 val = self._fin.read(int(l))
375 393 data[arg] = val
376 394 return [data[k] for k in keys]
377 395
378 396 def getfile(self, fpout):
379 self._sendresponse('')
397 _sshv1respondbytes(self._fout, b'')
380 398 count = int(self._fin.readline())
381 399 while count:
382 400 fpout.write(self._fin.read(count))
383 401 count = int(self._fin.readline())
384 402
385 403 def redirect(self):
386 404 pass
387 405
388 def _sendresponse(self, v):
389 self._fout.write("%d\n" % len(v))
390 self._fout.write(v)
391 self._fout.flush()
392
393 def _sendstream(self, source):
394 write = self._fout.write
395 for chunk in source.gen:
396 write(chunk)
397 self._fout.flush()
398
399 def _sendpushresponse(self, rsp):
400 self._sendresponse('')
401 self._sendresponse(str(rsp.res))
402
403 def _sendpusherror(self, rsp):
404 self._sendresponse(rsp.res)
405
406 def _sendooberror(self, rsp):
407 self._ui.ferr.write('%s\n-\n' % rsp.message)
408 self._ui.ferr.flush()
409 self._fout.write('\n')
410 self._fout.flush()
411
412 406 def serve_forever(self):
413 407 while self.serve_one():
414 408 pass
415 409 sys.exit(0)
416 410
417 _handlers = {
418 str: _sendresponse,
419 wireproto.streamres: _sendstream,
420 wireproto.streamres_legacy: _sendstream,
421 wireproto.pushres: _sendpushresponse,
422 wireproto.pusherr: _sendpusherror,
423 wireproto.ooberror: _sendooberror,
424 }
425
426 411 def serve_one(self):
427 412 cmd = self._fin.readline()[:-1]
428 413 if cmd and wireproto.commands.commandavailable(cmd, self):
429 414 rsp = wireproto.dispatch(self._repo, self, cmd)
430 self._handlers[rsp.__class__](self, rsp)
415
416 if isinstance(rsp, bytes):
417 _sshv1respondbytes(self._fout, rsp)
418 elif isinstance(rsp, wireproto.streamres):
419 _sshv1respondstream(self._fout, rsp)
420 elif isinstance(rsp, wireproto.streamres_legacy):
421 _sshv1respondstream(self._fout, rsp)
422 elif isinstance(rsp, wireproto.pushres):
423 _sshv1respondbytes(self._fout, b'')
424 _sshv1respondbytes(self._fout, bytes(rsp.res))
425 elif isinstance(rsp, wireproto.pusherr):
426 _sshv1respondbytes(self._fout, rsp.res)
427 elif isinstance(rsp, wireproto.ooberror):
428 _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
429 else:
430 raise error.ProgrammingError('unhandled response type from '
431 'wire protocol command: %s' % rsp)
431 432 elif cmd:
432 self._sendresponse("")
433 _sshv1respondbytes(self._fout, b'')
433 434 return cmd != ''
434 435
435 436 def _client(self):
436 437 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
437 438 return 'remote:ssh:' + client
@@ -1,127 +1,127
1 1 # sshprotoext.py - Extension to test behavior of SSH protocol
2 2 #
3 3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 # This extension replaces the SSH server started via `hg serve --stdio`.
9 9 # The server behaves differently depending on environment variables.
10 10
11 11 from __future__ import absolute_import
12 12
13 13 from mercurial import (
14 14 error,
15 15 extensions,
16 16 registrar,
17 17 sshpeer,
18 18 wireproto,
19 19 wireprotoserver,
20 20 )
21 21
22 22 configtable = {}
23 23 configitem = registrar.configitem(configtable)
24 24
25 25 configitem('sshpeer', 'mode', default=None)
26 26 configitem('sshpeer', 'handshake-mode', default=None)
27 27
28 28 class bannerserver(wireprotoserver.sshserver):
29 29 """Server that sends a banner to stdout."""
30 30 def serve_forever(self):
31 31 for i in range(10):
32 32 self._fout.write(b'banner: line %d\n' % i)
33 33
34 34 super(bannerserver, self).serve_forever()
35 35
36 36 class prehelloserver(wireprotoserver.sshserver):
37 37 """Tests behavior when connecting to <0.9.1 servers.
38 38
39 39 The ``hello`` wire protocol command was introduced in Mercurial
40 40 0.9.1. Modern clients send the ``hello`` command when connecting
41 41 to SSH servers. This mock server tests behavior of the handshake
42 42 when ``hello`` is not supported.
43 43 """
44 44 def serve_forever(self):
45 45 l = self._fin.readline()
46 46 assert l == b'hello\n'
47 47 # Respond to unknown commands with an empty reply.
48 self._sendresponse(b'')
48 wireprotoserver._sshv1respondbytes(self._fout, b'')
49 49 l = self._fin.readline()
50 50 assert l == b'between\n'
51 51 rsp = wireproto.dispatch(self._repo, self, b'between')
52 self._handlers[rsp.__class__](self, rsp)
52 wireprotoserver._sshv1respondbytes(self._fout, rsp)
53 53
54 54 super(prehelloserver, self).serve_forever()
55 55
56 56 class upgradev2server(wireprotoserver.sshserver):
57 57 """Tests behavior for clients that issue upgrade to version 2."""
58 58 def serve_forever(self):
59 59 name = wireprotoserver.SSHV2
60 60 l = self._fin.readline()
61 61 assert l.startswith(b'upgrade ')
62 62 token, caps = l[:-1].split(b' ')[1:]
63 63 assert caps == b'proto=%s' % name
64 64
65 65 # Filter hello and between requests.
66 66 l = self._fin.readline()
67 67 assert l == b'hello\n'
68 68 l = self._fin.readline()
69 69 assert l == b'between\n'
70 70 l = self._fin.readline()
71 71 assert l == 'pairs 81\n'
72 72 self._fin.read(81)
73 73
74 74 # Send the upgrade response.
75 75 self._fout.write(b'upgraded %s %s\n' % (token, name))
76 76 servercaps = wireproto.capabilities(self._repo, self)
77 77 rsp = b'capabilities: %s' % servercaps
78 78 self._fout.write(b'%d\n' % len(rsp))
79 79 self._fout.write(rsp)
80 80 self._fout.write(b'\n')
81 81 self._fout.flush()
82 82
83 83 super(upgradev2server, self).serve_forever()
84 84
85 85 def performhandshake(orig, ui, stdin, stdout, stderr):
86 86 """Wrapped version of sshpeer._performhandshake to send extra commands."""
87 87 mode = ui.config(b'sshpeer', b'handshake-mode')
88 88 if mode == b'pre-no-args':
89 89 ui.debug(b'sending no-args command\n')
90 90 stdin.write(b'no-args\n')
91 91 stdin.flush()
92 92 return orig(ui, stdin, stdout, stderr)
93 93 elif mode == b'pre-multiple-no-args':
94 94 ui.debug(b'sending unknown1 command\n')
95 95 stdin.write(b'unknown1\n')
96 96 ui.debug(b'sending unknown2 command\n')
97 97 stdin.write(b'unknown2\n')
98 98 ui.debug(b'sending unknown3 command\n')
99 99 stdin.write(b'unknown3\n')
100 100 stdin.flush()
101 101 return orig(ui, stdin, stdout, stderr)
102 102 else:
103 103 raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
104 104 mode)
105 105
106 106 def extsetup(ui):
107 107 # It's easier for tests to define the server behavior via environment
108 108 # variables than config options. This is because `hg serve --stdio`
109 109 # has to be invoked with a certain form for security reasons and
110 110 # `dummyssh` can't just add `--config` flags to the command line.
111 111 servermode = ui.environ.get(b'SSHSERVERMODE')
112 112
113 113 if servermode == b'banner':
114 114 wireprotoserver.sshserver = bannerserver
115 115 elif servermode == b'no-hello':
116 116 wireprotoserver.sshserver = prehelloserver
117 117 elif servermode == b'upgradev2':
118 118 wireprotoserver.sshserver = upgradev2server
119 119 elif servermode:
120 120 raise error.ProgrammingError(b'unknown server mode: %s' % servermode)
121 121
122 122 peermode = ui.config(b'sshpeer', b'mode')
123 123
124 124 if peermode == b'extra-handshake-commands':
125 125 extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake)
126 126 elif peermode:
127 127 raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)
General Comments 0
You need to be logged in to leave comments. Login now