##// END OF EJS Templates
wireprotoserver: move SSH server operation to a standalone function...
Gregory Szorc -
r36232:3b3a987b default
parent child Browse files
Show More
@@ -1,458 +1,478 b''
1 1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 3 #
4 4 # This software may be used and distributed according to the terms of the
5 5 # GNU General Public License version 2 or any later version.
6 6
7 7 from __future__ import absolute_import
8 8
9 9 import abc
10 10 import contextlib
11 11 import struct
12 12 import sys
13 13
14 14 from .i18n import _
15 15 from . import (
16 16 encoding,
17 17 error,
18 18 hook,
19 19 pycompat,
20 20 util,
21 21 wireproto,
22 22 wireprototypes,
23 23 )
24 24
25 25 stringio = util.stringio
26 26
27 27 urlerr = util.urlerr
28 28 urlreq = util.urlreq
29 29
30 30 HTTP_OK = 200
31 31
32 32 HGTYPE = 'application/mercurial-0.1'
33 33 HGTYPE2 = 'application/mercurial-0.2'
34 34 HGERRTYPE = 'application/hg-error'
35 35
36 36 # Names of the SSH protocol implementations.
37 37 SSHV1 = 'ssh-v1'
38 38 # This is advertised over the wire. Incremental the counter at the end
39 39 # to reflect BC breakages.
40 40 SSHV2 = 'exp-ssh-v2-0001'
41 41
42 42 class baseprotocolhandler(object):
43 43 """Abstract base class for wire protocol handlers.
44 44
45 45 A wire protocol handler serves as an interface between protocol command
46 46 handlers and the wire protocol transport layer. Protocol handlers provide
47 47 methods to read command arguments, redirect stdio for the duration of
48 48 the request, handle response types, etc.
49 49 """
50 50
51 51 __metaclass__ = abc.ABCMeta
52 52
53 53 @abc.abstractproperty
54 54 def name(self):
55 55 """The name of the protocol implementation.
56 56
57 57 Used for uniquely identifying the transport type.
58 58 """
59 59
60 60 @abc.abstractmethod
61 61 def getargs(self, args):
62 62 """return the value for arguments in <args>
63 63
64 64 returns a list of values (same order as <args>)"""
65 65
66 66 @abc.abstractmethod
67 67 def forwardpayload(self, fp):
68 68 """Read the raw payload and forward to a file.
69 69
70 70 The payload is read in full before the function returns.
71 71 """
72 72
73 73 @abc.abstractmethod
74 74 def mayberedirectstdio(self):
75 75 """Context manager to possibly redirect stdio.
76 76
77 77 The context manager yields a file-object like object that receives
78 78 stdout and stderr output when the context manager is active. Or it
79 79 yields ``None`` if no I/O redirection occurs.
80 80
81 81 The intent of this context manager is to capture stdio output
82 82 so it may be sent in the response. Some transports support streaming
83 83 stdio to the client in real time. For these transports, stdio output
84 84 won't be captured.
85 85 """
86 86
87 87 @abc.abstractmethod
88 88 def client(self):
89 89 """Returns a string representation of this client (as bytes)."""
90 90
91 91 def decodevaluefromheaders(req, headerprefix):
92 92 """Decode a long value from multiple HTTP request headers.
93 93
94 94 Returns the value as a bytes, not a str.
95 95 """
96 96 chunks = []
97 97 i = 1
98 98 prefix = headerprefix.upper().replace(r'-', r'_')
99 99 while True:
100 100 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
101 101 if v is None:
102 102 break
103 103 chunks.append(pycompat.bytesurl(v))
104 104 i += 1
105 105
106 106 return ''.join(chunks)
107 107
108 108 class webproto(baseprotocolhandler):
109 109 def __init__(self, req, ui):
110 110 self._req = req
111 111 self._ui = ui
112 112
113 113 @property
114 114 def name(self):
115 115 return 'http'
116 116
117 117 def getargs(self, args):
118 118 knownargs = self._args()
119 119 data = {}
120 120 keys = args.split()
121 121 for k in keys:
122 122 if k == '*':
123 123 star = {}
124 124 for key in knownargs.keys():
125 125 if key != 'cmd' and key not in keys:
126 126 star[key] = knownargs[key][0]
127 127 data['*'] = star
128 128 else:
129 129 data[k] = knownargs[k][0]
130 130 return [data[k] for k in keys]
131 131
132 132 def _args(self):
133 133 args = util.rapply(pycompat.bytesurl, self._req.form.copy())
134 134 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
135 135 if postlen:
136 136 args.update(urlreq.parseqs(
137 137 self._req.read(postlen), keep_blank_values=True))
138 138 return args
139 139
140 140 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
141 141 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
142 142 return args
143 143
144 144 def forwardpayload(self, fp):
145 145 length = int(self._req.env[r'CONTENT_LENGTH'])
146 146 # If httppostargs is used, we need to read Content-Length
147 147 # minus the amount that was consumed by args.
148 148 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
149 149 for s in util.filechunkiter(self._req, limit=length):
150 150 fp.write(s)
151 151
152 152 @contextlib.contextmanager
153 153 def mayberedirectstdio(self):
154 154 oldout = self._ui.fout
155 155 olderr = self._ui.ferr
156 156
157 157 out = util.stringio()
158 158
159 159 try:
160 160 self._ui.fout = out
161 161 self._ui.ferr = out
162 162 yield out
163 163 finally:
164 164 self._ui.fout = oldout
165 165 self._ui.ferr = olderr
166 166
167 167 def client(self):
168 168 return 'remote:%s:%s:%s' % (
169 169 self._req.env.get('wsgi.url_scheme') or 'http',
170 170 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
171 171 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
172 172
173 173 def iscmd(cmd):
174 174 return cmd in wireproto.commands
175 175
176 176 def parsehttprequest(repo, req, query):
177 177 """Parse the HTTP request for a wire protocol request.
178 178
179 179 If the current request appears to be a wire protocol request, this
180 180 function returns a dict with details about that request, including
181 181 an ``abstractprotocolserver`` instance suitable for handling the
182 182 request. Otherwise, ``None`` is returned.
183 183
184 184 ``req`` is a ``wsgirequest`` instance.
185 185 """
186 186 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
187 187 # string parameter. If it isn't present, this isn't a wire protocol
188 188 # request.
189 189 if r'cmd' not in req.form:
190 190 return None
191 191
192 192 cmd = pycompat.sysbytes(req.form[r'cmd'][0])
193 193
194 194 # The "cmd" request parameter is used by both the wire protocol and hgweb.
195 195 # While not all wire protocol commands are available for all transports,
196 196 # if we see a "cmd" value that resembles a known wire protocol command, we
197 197 # route it to a protocol handler. This is better than routing possible
198 198 # wire protocol requests to hgweb because it prevents hgweb from using
199 199 # known wire protocol commands and it is less confusing for machine
200 200 # clients.
201 201 if cmd not in wireproto.commands:
202 202 return None
203 203
204 204 proto = webproto(req, repo.ui)
205 205
206 206 return {
207 207 'cmd': cmd,
208 208 'proto': proto,
209 209 'dispatch': lambda: _callhttp(repo, req, proto, cmd),
210 210 'handleerror': lambda ex: _handlehttperror(ex, req, cmd),
211 211 }
212 212
213 213 def _httpresponsetype(ui, req, prefer_uncompressed):
214 214 """Determine the appropriate response type and compression settings.
215 215
216 216 Returns a tuple of (mediatype, compengine, engineopts).
217 217 """
218 218 # Determine the response media type and compression engine based
219 219 # on the request parameters.
220 220 protocaps = decodevaluefromheaders(req, r'X-HgProto').split(' ')
221 221
222 222 if '0.2' in protocaps:
223 223 # All clients are expected to support uncompressed data.
224 224 if prefer_uncompressed:
225 225 return HGTYPE2, util._noopengine(), {}
226 226
227 227 # Default as defined by wire protocol spec.
228 228 compformats = ['zlib', 'none']
229 229 for cap in protocaps:
230 230 if cap.startswith('comp='):
231 231 compformats = cap[5:].split(',')
232 232 break
233 233
234 234 # Now find an agreed upon compression format.
235 235 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
236 236 if engine.wireprotosupport().name in compformats:
237 237 opts = {}
238 238 level = ui.configint('server', '%slevel' % engine.name())
239 239 if level is not None:
240 240 opts['level'] = level
241 241
242 242 return HGTYPE2, engine, opts
243 243
244 244 # No mutually supported compression format. Fall back to the
245 245 # legacy protocol.
246 246
247 247 # Don't allow untrusted settings because disabling compression or
248 248 # setting a very high compression level could lead to flooding
249 249 # the server's network or CPU.
250 250 opts = {'level': ui.configint('server', 'zliblevel')}
251 251 return HGTYPE, util.compengines['zlib'], opts
252 252
253 253 def _callhttp(repo, req, proto, cmd):
254 254 def genversion2(gen, engine, engineopts):
255 255 # application/mercurial-0.2 always sends a payload header
256 256 # identifying the compression engine.
257 257 name = engine.wireprotosupport().name
258 258 assert 0 < len(name) < 256
259 259 yield struct.pack('B', len(name))
260 260 yield name
261 261
262 262 for chunk in gen:
263 263 yield chunk
264 264
265 265 rsp = wireproto.dispatch(repo, proto, cmd)
266 266
267 267 if not wireproto.commands.commandavailable(cmd, proto):
268 268 req.respond(HTTP_OK, HGERRTYPE,
269 269 body=_('requested wire protocol command is not available '
270 270 'over HTTP'))
271 271 return []
272 272
273 273 if isinstance(rsp, bytes):
274 274 req.respond(HTTP_OK, HGTYPE, body=rsp)
275 275 return []
276 276 elif isinstance(rsp, wireprototypes.bytesresponse):
277 277 req.respond(HTTP_OK, HGTYPE, body=rsp.data)
278 278 return []
279 279 elif isinstance(rsp, wireprototypes.streamreslegacy):
280 280 gen = rsp.gen
281 281 req.respond(HTTP_OK, HGTYPE)
282 282 return gen
283 283 elif isinstance(rsp, wireprototypes.streamres):
284 284 gen = rsp.gen
285 285
286 286 # This code for compression should not be streamres specific. It
287 287 # is here because we only compress streamres at the moment.
288 288 mediatype, engine, engineopts = _httpresponsetype(
289 289 repo.ui, req, rsp.prefer_uncompressed)
290 290 gen = engine.compressstream(gen, engineopts)
291 291
292 292 if mediatype == HGTYPE2:
293 293 gen = genversion2(gen, engine, engineopts)
294 294
295 295 req.respond(HTTP_OK, mediatype)
296 296 return gen
297 297 elif isinstance(rsp, wireprototypes.pushres):
298 298 rsp = '%d\n%s' % (rsp.res, rsp.output)
299 299 req.respond(HTTP_OK, HGTYPE, body=rsp)
300 300 return []
301 301 elif isinstance(rsp, wireprototypes.pusherr):
302 302 # This is the httplib workaround documented in _handlehttperror().
303 303 req.drain()
304 304
305 305 rsp = '0\n%s\n' % rsp.res
306 306 req.respond(HTTP_OK, HGTYPE, body=rsp)
307 307 return []
308 308 elif isinstance(rsp, wireprototypes.ooberror):
309 309 rsp = rsp.message
310 310 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
311 311 return []
312 312 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
313 313
314 314 def _handlehttperror(e, req, cmd):
315 315 """Called when an ErrorResponse is raised during HTTP request processing."""
316 316
317 317 # Clients using Python's httplib are stateful: the HTTP client
318 318 # won't process an HTTP response until all request data is
319 319 # sent to the server. The intent of this code is to ensure
320 320 # we always read HTTP request data from the client, thus
321 321 # ensuring httplib transitions to a state that allows it to read
322 322 # the HTTP response. In other words, it helps prevent deadlocks
323 323 # on clients using httplib.
324 324
325 325 if (req.env[r'REQUEST_METHOD'] == r'POST' and
326 326 # But not if Expect: 100-continue is being used.
327 327 (req.env.get('HTTP_EXPECT',
328 328 '').lower() != '100-continue') or
329 329 # Or the non-httplib HTTP library is being advertised by
330 330 # the client.
331 331 req.env.get('X-HgHttp2', '')):
332 332 req.drain()
333 333 else:
334 334 req.headers.append((r'Connection', r'Close'))
335 335
336 336 # TODO This response body assumes the failed command was
337 337 # "unbundle." That assumption is not always valid.
338 338 req.respond(e, HGTYPE, body='0\n%s\n' % e)
339 339
340 340 return ''
341 341
342 342 def _sshv1respondbytes(fout, value):
343 343 """Send a bytes response for protocol version 1."""
344 344 fout.write('%d\n' % len(value))
345 345 fout.write(value)
346 346 fout.flush()
347 347
348 348 def _sshv1respondstream(fout, source):
349 349 write = fout.write
350 350 for chunk in source.gen:
351 351 write(chunk)
352 352 fout.flush()
353 353
354 354 def _sshv1respondooberror(fout, ferr, rsp):
355 355 ferr.write(b'%s\n-\n' % rsp)
356 356 ferr.flush()
357 357 fout.write(b'\n')
358 358 fout.flush()
359 359
360 360 class sshv1protocolhandler(baseprotocolhandler):
361 361 """Handler for requests services via version 1 of SSH protocol."""
362 362 def __init__(self, ui, fin, fout):
363 363 self._ui = ui
364 364 self._fin = fin
365 365 self._fout = fout
366 366
367 367 @property
368 368 def name(self):
369 369 return SSHV1
370 370
371 371 def getargs(self, args):
372 372 data = {}
373 373 keys = args.split()
374 374 for n in xrange(len(keys)):
375 375 argline = self._fin.readline()[:-1]
376 376 arg, l = argline.split()
377 377 if arg not in keys:
378 378 raise error.Abort(_("unexpected parameter %r") % arg)
379 379 if arg == '*':
380 380 star = {}
381 381 for k in xrange(int(l)):
382 382 argline = self._fin.readline()[:-1]
383 383 arg, l = argline.split()
384 384 val = self._fin.read(int(l))
385 385 star[arg] = val
386 386 data['*'] = star
387 387 else:
388 388 val = self._fin.read(int(l))
389 389 data[arg] = val
390 390 return [data[k] for k in keys]
391 391
392 392 def forwardpayload(self, fpout):
393 393 # The file is in the form:
394 394 #
395 395 # <chunk size>\n<chunk>
396 396 # ...
397 397 # 0\n
398 398 _sshv1respondbytes(self._fout, b'')
399 399 count = int(self._fin.readline())
400 400 while count:
401 401 fpout.write(self._fin.read(count))
402 402 count = int(self._fin.readline())
403 403
404 404 @contextlib.contextmanager
405 405 def mayberedirectstdio(self):
406 406 yield None
407 407
408 408 def client(self):
409 409 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
410 410 return 'remote:ssh:' + client
411 411
412 def _runsshserver(ui, repo, fin, fout):
413 state = 'protov1-serving'
414 proto = sshv1protocolhandler(ui, fin, fout)
415
416 while True:
417 if state == 'protov1-serving':
418 # Commands are issued on new lines.
419 request = fin.readline()[:-1]
420
421 # Empty lines signal to terminate the connection.
422 if not request:
423 state = 'shutdown'
424 continue
425
426 available = wireproto.commands.commandavailable(request, proto)
427
428 # This command isn't available. Send an empty response and go
429 # back to waiting for a new command.
430 if not available:
431 _sshv1respondbytes(fout, b'')
432 continue
433
434 rsp = wireproto.dispatch(repo, proto, request)
435
436 if isinstance(rsp, bytes):
437 _sshv1respondbytes(fout, rsp)
438 elif isinstance(rsp, wireprototypes.bytesresponse):
439 _sshv1respondbytes(fout, rsp.data)
440 elif isinstance(rsp, wireprototypes.streamres):
441 _sshv1respondstream(fout, rsp)
442 elif isinstance(rsp, wireprototypes.streamreslegacy):
443 _sshv1respondstream(fout, rsp)
444 elif isinstance(rsp, wireprototypes.pushres):
445 _sshv1respondbytes(fout, b'')
446 _sshv1respondbytes(fout, b'%d' % rsp.res)
447 elif isinstance(rsp, wireprototypes.pusherr):
448 _sshv1respondbytes(fout, rsp.res)
449 elif isinstance(rsp, wireprototypes.ooberror):
450 _sshv1respondooberror(fout, ui.ferr, rsp.message)
451 else:
452 raise error.ProgrammingError('unhandled response type from '
453 'wire protocol command: %s' % rsp)
454
455 elif state == 'shutdown':
456 break
457
458 else:
459 raise error.ProgrammingError('unhandled ssh server state: %s' %
460 state)
461
412 462 class sshserver(object):
413 463 def __init__(self, ui, repo):
414 464 self._ui = ui
415 465 self._repo = repo
416 466 self._fin = ui.fin
417 467 self._fout = ui.fout
418 468
419 469 hook.redirect(True)
420 470 ui.fout = repo.ui.fout = ui.ferr
421 471
422 472 # Prevent insertion/deletion of CRs
423 473 util.setbinary(self._fin)
424 474 util.setbinary(self._fout)
425 475
426 self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
427
428 476 def serve_forever(self):
429 while self.serve_one():
430 pass
477 _runsshserver(self._ui, self._repo, self._fin, self._fout)
431 478 sys.exit(0)
432
433 def serve_one(self):
434 cmd = self._fin.readline()[:-1]
435 if cmd and wireproto.commands.commandavailable(cmd, self._proto):
436 rsp = wireproto.dispatch(self._repo, self._proto, cmd)
437
438 if isinstance(rsp, bytes):
439 _sshv1respondbytes(self._fout, rsp)
440 elif isinstance(rsp, wireprototypes.bytesresponse):
441 _sshv1respondbytes(self._fout, rsp.data)
442 elif isinstance(rsp, wireprototypes.streamres):
443 _sshv1respondstream(self._fout, rsp)
444 elif isinstance(rsp, wireprototypes.streamreslegacy):
445 _sshv1respondstream(self._fout, rsp)
446 elif isinstance(rsp, wireprototypes.pushres):
447 _sshv1respondbytes(self._fout, b'')
448 _sshv1respondbytes(self._fout, b'%d' % rsp.res)
449 elif isinstance(rsp, wireprototypes.pusherr):
450 _sshv1respondbytes(self._fout, rsp.res)
451 elif isinstance(rsp, wireprototypes.ooberror):
452 _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
453 else:
454 raise error.ProgrammingError('unhandled response type from '
455 'wire protocol command: %s' % rsp)
456 elif cmd:
457 _sshv1respondbytes(self._fout, b'')
458 return cmd != ''
@@ -1,127 +1,131 b''
1 1 # sshprotoext.py - Extension to test behavior of SSH protocol
2 2 #
3 3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 # This extension replaces the SSH server started via `hg serve --stdio`.
9 9 # The server behaves differently depending on environment variables.
10 10
11 11 from __future__ import absolute_import
12 12
13 13 from mercurial import (
14 14 error,
15 15 extensions,
16 16 registrar,
17 17 sshpeer,
18 18 wireproto,
19 19 wireprotoserver,
20 20 )
21 21
22 22 configtable = {}
23 23 configitem = registrar.configitem(configtable)
24 24
25 25 configitem(b'sshpeer', b'mode', default=None)
26 26 configitem(b'sshpeer', b'handshake-mode', default=None)
27 27
28 28 class bannerserver(wireprotoserver.sshserver):
29 29 """Server that sends a banner to stdout."""
30 30 def serve_forever(self):
31 31 for i in range(10):
32 32 self._fout.write(b'banner: line %d\n' % i)
33 33
34 34 super(bannerserver, self).serve_forever()
35 35
36 36 class prehelloserver(wireprotoserver.sshserver):
37 37 """Tests behavior when connecting to <0.9.1 servers.
38 38
39 39 The ``hello`` wire protocol command was introduced in Mercurial
40 40 0.9.1. Modern clients send the ``hello`` command when connecting
41 41 to SSH servers. This mock server tests behavior of the handshake
42 42 when ``hello`` is not supported.
43 43 """
44 44 def serve_forever(self):
45 45 l = self._fin.readline()
46 46 assert l == b'hello\n'
47 47 # Respond to unknown commands with an empty reply.
48 48 wireprotoserver._sshv1respondbytes(self._fout, b'')
49 49 l = self._fin.readline()
50 50 assert l == b'between\n'
51 rsp = wireproto.dispatch(self._repo, self._proto, b'between')
51 proto = wireprotoserver.sshv1protocolhandler(self._ui, self._fin,
52 self._fout)
53 rsp = wireproto.dispatch(self._repo, proto, b'between')
52 54 wireprotoserver._sshv1respondbytes(self._fout, rsp.data)
53 55
54 56 super(prehelloserver, self).serve_forever()
55 57
56 58 class upgradev2server(wireprotoserver.sshserver):
57 59 """Tests behavior for clients that issue upgrade to version 2."""
58 60 def serve_forever(self):
59 61 name = wireprotoserver.SSHV2
60 62 l = self._fin.readline()
61 63 assert l.startswith(b'upgrade ')
62 64 token, caps = l[:-1].split(b' ')[1:]
63 65 assert caps == b'proto=%s' % name
64 66
65 67 # Filter hello and between requests.
66 68 l = self._fin.readline()
67 69 assert l == b'hello\n'
68 70 l = self._fin.readline()
69 71 assert l == b'between\n'
70 72 l = self._fin.readline()
71 73 assert l == b'pairs 81\n'
72 74 self._fin.read(81)
73 75
74 76 # Send the upgrade response.
77 proto = wireprotoserver.sshv1protocolhandler(self._ui, self._fin,
78 self._fout)
75 79 self._fout.write(b'upgraded %s %s\n' % (token, name))
76 servercaps = wireproto.capabilities(self._repo, self._proto)
80 servercaps = wireproto.capabilities(self._repo, proto)
77 81 rsp = b'capabilities: %s' % servercaps.data
78 82 self._fout.write(b'%d\n' % len(rsp))
79 83 self._fout.write(rsp)
80 84 self._fout.write(b'\n')
81 85 self._fout.flush()
82 86
83 87 super(upgradev2server, self).serve_forever()
84 88
85 89 def performhandshake(orig, ui, stdin, stdout, stderr):
86 90 """Wrapped version of sshpeer._performhandshake to send extra commands."""
87 91 mode = ui.config(b'sshpeer', b'handshake-mode')
88 92 if mode == b'pre-no-args':
89 93 ui.debug(b'sending no-args command\n')
90 94 stdin.write(b'no-args\n')
91 95 stdin.flush()
92 96 return orig(ui, stdin, stdout, stderr)
93 97 elif mode == b'pre-multiple-no-args':
94 98 ui.debug(b'sending unknown1 command\n')
95 99 stdin.write(b'unknown1\n')
96 100 ui.debug(b'sending unknown2 command\n')
97 101 stdin.write(b'unknown2\n')
98 102 ui.debug(b'sending unknown3 command\n')
99 103 stdin.write(b'unknown3\n')
100 104 stdin.flush()
101 105 return orig(ui, stdin, stdout, stderr)
102 106 else:
103 107 raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
104 108 mode)
105 109
106 110 def extsetup(ui):
107 111 # It's easier for tests to define the server behavior via environment
108 112 # variables than config options. This is because `hg serve --stdio`
109 113 # has to be invoked with a certain form for security reasons and
110 114 # `dummyssh` can't just add `--config` flags to the command line.
111 115 servermode = ui.environ.get(b'SSHSERVERMODE')
112 116
113 117 if servermode == b'banner':
114 118 wireprotoserver.sshserver = bannerserver
115 119 elif servermode == b'no-hello':
116 120 wireprotoserver.sshserver = prehelloserver
117 121 elif servermode == b'upgradev2':
118 122 wireprotoserver.sshserver = upgradev2server
119 123 elif servermode:
120 124 raise error.ProgrammingError(b'unknown server mode: %s' % servermode)
121 125
122 126 peermode = ui.config(b'sshpeer', b'mode')
123 127
124 128 if peermode == b'extra-handshake-commands':
125 129 extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake)
126 130 elif peermode:
127 131 raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)
@@ -1,47 +1,50 b''
1 1 from __future__ import absolute_import, print_function
2 2
3 3 import io
4 4 import unittest
5 5
6 6 import silenttestrunner
7 7
8 8 from mercurial import (
9 9 util,
10 10 wireproto,
11 11 wireprotoserver,
12 12 )
13 13
14 14 class SSHServerGetArgsTests(unittest.TestCase):
15 15 def testparseknown(self):
16 16 tests = [
17 17 (b'* 0\nnodes 0\n', [b'', {}]),
18 18 (b'* 0\nnodes 40\n1111111111111111111111111111111111111111\n',
19 19 [b'1111111111111111111111111111111111111111', {}]),
20 20 ]
21 21 for input, expected in tests:
22 22 self.assertparse(b'known', input, expected)
23 23
24 24 def assertparse(self, cmd, input, expected):
25 25 server = mockserver(input)
26 proto = wireprotoserver.sshv1protocolhandler(server._ui,
27 server._fin,
28 server._fout)
26 29 _func, spec = wireproto.commands[cmd]
27 self.assertEqual(server._proto.getargs(spec), expected)
30 self.assertEqual(proto.getargs(spec), expected)
28 31
29 32 def mockserver(inbytes):
30 33 ui = mockui(inbytes)
31 34 repo = mockrepo(ui)
32 35 return wireprotoserver.sshserver(ui, repo)
33 36
34 37 class mockrepo(object):
35 38 def __init__(self, ui):
36 39 self.ui = ui
37 40
38 41 class mockui(object):
39 42 def __init__(self, inbytes):
40 43 self.fin = io.BytesIO(inbytes)
41 44 self.fout = io.BytesIO()
42 45 self.ferr = io.BytesIO()
43 46
44 47 if __name__ == '__main__':
45 48 # Don't call into msvcrt to set BytesIO to binary mode
46 49 util.setbinary = lambda fp: True
47 50 silenttestrunner.main(__name__)
General Comments 0
You need to be logged in to leave comments. Login now