##// END OF EJS Templates
wireprotoserver: move SSH server operation to a standalone function...
Gregory Szorc -
r36232:3b3a987b default
parent child Browse files
Show More
@@ -1,458 +1,478 b''
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 #
3 #
4 # This software may be used and distributed according to the terms of the
4 # This software may be used and distributed according to the terms of the
5 # GNU General Public License version 2 or any later version.
5 # GNU General Public License version 2 or any later version.
6
6
7 from __future__ import absolute_import
7 from __future__ import absolute_import
8
8
9 import abc
9 import abc
10 import contextlib
10 import contextlib
11 import struct
11 import struct
12 import sys
12 import sys
13
13
14 from .i18n import _
14 from .i18n import _
15 from . import (
15 from . import (
16 encoding,
16 encoding,
17 error,
17 error,
18 hook,
18 hook,
19 pycompat,
19 pycompat,
20 util,
20 util,
21 wireproto,
21 wireproto,
22 wireprototypes,
22 wireprototypes,
23 )
23 )
24
24
25 stringio = util.stringio
25 stringio = util.stringio
26
26
27 urlerr = util.urlerr
27 urlerr = util.urlerr
28 urlreq = util.urlreq
28 urlreq = util.urlreq
29
29
30 HTTP_OK = 200
30 HTTP_OK = 200
31
31
32 HGTYPE = 'application/mercurial-0.1'
32 HGTYPE = 'application/mercurial-0.1'
33 HGTYPE2 = 'application/mercurial-0.2'
33 HGTYPE2 = 'application/mercurial-0.2'
34 HGERRTYPE = 'application/hg-error'
34 HGERRTYPE = 'application/hg-error'
35
35
36 # Names of the SSH protocol implementations.
36 # Names of the SSH protocol implementations.
37 SSHV1 = 'ssh-v1'
37 SSHV1 = 'ssh-v1'
38 # This is advertised over the wire. Incremental the counter at the end
38 # This is advertised over the wire. Incremental the counter at the end
39 # to reflect BC breakages.
39 # to reflect BC breakages.
40 SSHV2 = 'exp-ssh-v2-0001'
40 SSHV2 = 'exp-ssh-v2-0001'
41
41
42 class baseprotocolhandler(object):
42 class baseprotocolhandler(object):
43 """Abstract base class for wire protocol handlers.
43 """Abstract base class for wire protocol handlers.
44
44
45 A wire protocol handler serves as an interface between protocol command
45 A wire protocol handler serves as an interface between protocol command
46 handlers and the wire protocol transport layer. Protocol handlers provide
46 handlers and the wire protocol transport layer. Protocol handlers provide
47 methods to read command arguments, redirect stdio for the duration of
47 methods to read command arguments, redirect stdio for the duration of
48 the request, handle response types, etc.
48 the request, handle response types, etc.
49 """
49 """
50
50
51 __metaclass__ = abc.ABCMeta
51 __metaclass__ = abc.ABCMeta
52
52
53 @abc.abstractproperty
53 @abc.abstractproperty
54 def name(self):
54 def name(self):
55 """The name of the protocol implementation.
55 """The name of the protocol implementation.
56
56
57 Used for uniquely identifying the transport type.
57 Used for uniquely identifying the transport type.
58 """
58 """
59
59
60 @abc.abstractmethod
60 @abc.abstractmethod
61 def getargs(self, args):
61 def getargs(self, args):
62 """return the value for arguments in <args>
62 """return the value for arguments in <args>
63
63
64 returns a list of values (same order as <args>)"""
64 returns a list of values (same order as <args>)"""
65
65
66 @abc.abstractmethod
66 @abc.abstractmethod
67 def forwardpayload(self, fp):
67 def forwardpayload(self, fp):
68 """Read the raw payload and forward to a file.
68 """Read the raw payload and forward to a file.
69
69
70 The payload is read in full before the function returns.
70 The payload is read in full before the function returns.
71 """
71 """
72
72
73 @abc.abstractmethod
73 @abc.abstractmethod
74 def mayberedirectstdio(self):
74 def mayberedirectstdio(self):
75 """Context manager to possibly redirect stdio.
75 """Context manager to possibly redirect stdio.
76
76
77 The context manager yields a file-object like object that receives
77 The context manager yields a file-object like object that receives
78 stdout and stderr output when the context manager is active. Or it
78 stdout and stderr output when the context manager is active. Or it
79 yields ``None`` if no I/O redirection occurs.
79 yields ``None`` if no I/O redirection occurs.
80
80
81 The intent of this context manager is to capture stdio output
81 The intent of this context manager is to capture stdio output
82 so it may be sent in the response. Some transports support streaming
82 so it may be sent in the response. Some transports support streaming
83 stdio to the client in real time. For these transports, stdio output
83 stdio to the client in real time. For these transports, stdio output
84 won't be captured.
84 won't be captured.
85 """
85 """
86
86
87 @abc.abstractmethod
87 @abc.abstractmethod
88 def client(self):
88 def client(self):
89 """Returns a string representation of this client (as bytes)."""
89 """Returns a string representation of this client (as bytes)."""
90
90
91 def decodevaluefromheaders(req, headerprefix):
91 def decodevaluefromheaders(req, headerprefix):
92 """Decode a long value from multiple HTTP request headers.
92 """Decode a long value from multiple HTTP request headers.
93
93
94 Returns the value as a bytes, not a str.
94 Returns the value as a bytes, not a str.
95 """
95 """
96 chunks = []
96 chunks = []
97 i = 1
97 i = 1
98 prefix = headerprefix.upper().replace(r'-', r'_')
98 prefix = headerprefix.upper().replace(r'-', r'_')
99 while True:
99 while True:
100 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
100 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
101 if v is None:
101 if v is None:
102 break
102 break
103 chunks.append(pycompat.bytesurl(v))
103 chunks.append(pycompat.bytesurl(v))
104 i += 1
104 i += 1
105
105
106 return ''.join(chunks)
106 return ''.join(chunks)
107
107
108 class webproto(baseprotocolhandler):
108 class webproto(baseprotocolhandler):
109 def __init__(self, req, ui):
109 def __init__(self, req, ui):
110 self._req = req
110 self._req = req
111 self._ui = ui
111 self._ui = ui
112
112
113 @property
113 @property
114 def name(self):
114 def name(self):
115 return 'http'
115 return 'http'
116
116
117 def getargs(self, args):
117 def getargs(self, args):
118 knownargs = self._args()
118 knownargs = self._args()
119 data = {}
119 data = {}
120 keys = args.split()
120 keys = args.split()
121 for k in keys:
121 for k in keys:
122 if k == '*':
122 if k == '*':
123 star = {}
123 star = {}
124 for key in knownargs.keys():
124 for key in knownargs.keys():
125 if key != 'cmd' and key not in keys:
125 if key != 'cmd' and key not in keys:
126 star[key] = knownargs[key][0]
126 star[key] = knownargs[key][0]
127 data['*'] = star
127 data['*'] = star
128 else:
128 else:
129 data[k] = knownargs[k][0]
129 data[k] = knownargs[k][0]
130 return [data[k] for k in keys]
130 return [data[k] for k in keys]
131
131
132 def _args(self):
132 def _args(self):
133 args = util.rapply(pycompat.bytesurl, self._req.form.copy())
133 args = util.rapply(pycompat.bytesurl, self._req.form.copy())
134 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
134 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
135 if postlen:
135 if postlen:
136 args.update(urlreq.parseqs(
136 args.update(urlreq.parseqs(
137 self._req.read(postlen), keep_blank_values=True))
137 self._req.read(postlen), keep_blank_values=True))
138 return args
138 return args
139
139
140 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
140 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
141 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
141 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
142 return args
142 return args
143
143
144 def forwardpayload(self, fp):
144 def forwardpayload(self, fp):
145 length = int(self._req.env[r'CONTENT_LENGTH'])
145 length = int(self._req.env[r'CONTENT_LENGTH'])
146 # If httppostargs is used, we need to read Content-Length
146 # If httppostargs is used, we need to read Content-Length
147 # minus the amount that was consumed by args.
147 # minus the amount that was consumed by args.
148 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
148 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
149 for s in util.filechunkiter(self._req, limit=length):
149 for s in util.filechunkiter(self._req, limit=length):
150 fp.write(s)
150 fp.write(s)
151
151
152 @contextlib.contextmanager
152 @contextlib.contextmanager
153 def mayberedirectstdio(self):
153 def mayberedirectstdio(self):
154 oldout = self._ui.fout
154 oldout = self._ui.fout
155 olderr = self._ui.ferr
155 olderr = self._ui.ferr
156
156
157 out = util.stringio()
157 out = util.stringio()
158
158
159 try:
159 try:
160 self._ui.fout = out
160 self._ui.fout = out
161 self._ui.ferr = out
161 self._ui.ferr = out
162 yield out
162 yield out
163 finally:
163 finally:
164 self._ui.fout = oldout
164 self._ui.fout = oldout
165 self._ui.ferr = olderr
165 self._ui.ferr = olderr
166
166
167 def client(self):
167 def client(self):
168 return 'remote:%s:%s:%s' % (
168 return 'remote:%s:%s:%s' % (
169 self._req.env.get('wsgi.url_scheme') or 'http',
169 self._req.env.get('wsgi.url_scheme') or 'http',
170 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
170 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
171 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
171 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
172
172
173 def iscmd(cmd):
173 def iscmd(cmd):
174 return cmd in wireproto.commands
174 return cmd in wireproto.commands
175
175
176 def parsehttprequest(repo, req, query):
176 def parsehttprequest(repo, req, query):
177 """Parse the HTTP request for a wire protocol request.
177 """Parse the HTTP request for a wire protocol request.
178
178
179 If the current request appears to be a wire protocol request, this
179 If the current request appears to be a wire protocol request, this
180 function returns a dict with details about that request, including
180 function returns a dict with details about that request, including
181 an ``abstractprotocolserver`` instance suitable for handling the
181 an ``abstractprotocolserver`` instance suitable for handling the
182 request. Otherwise, ``None`` is returned.
182 request. Otherwise, ``None`` is returned.
183
183
184 ``req`` is a ``wsgirequest`` instance.
184 ``req`` is a ``wsgirequest`` instance.
185 """
185 """
186 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
186 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
187 # string parameter. If it isn't present, this isn't a wire protocol
187 # string parameter. If it isn't present, this isn't a wire protocol
188 # request.
188 # request.
189 if r'cmd' not in req.form:
189 if r'cmd' not in req.form:
190 return None
190 return None
191
191
192 cmd = pycompat.sysbytes(req.form[r'cmd'][0])
192 cmd = pycompat.sysbytes(req.form[r'cmd'][0])
193
193
194 # The "cmd" request parameter is used by both the wire protocol and hgweb.
194 # The "cmd" request parameter is used by both the wire protocol and hgweb.
195 # While not all wire protocol commands are available for all transports,
195 # While not all wire protocol commands are available for all transports,
196 # if we see a "cmd" value that resembles a known wire protocol command, we
196 # if we see a "cmd" value that resembles a known wire protocol command, we
197 # route it to a protocol handler. This is better than routing possible
197 # route it to a protocol handler. This is better than routing possible
198 # wire protocol requests to hgweb because it prevents hgweb from using
198 # wire protocol requests to hgweb because it prevents hgweb from using
199 # known wire protocol commands and it is less confusing for machine
199 # known wire protocol commands and it is less confusing for machine
200 # clients.
200 # clients.
201 if cmd not in wireproto.commands:
201 if cmd not in wireproto.commands:
202 return None
202 return None
203
203
204 proto = webproto(req, repo.ui)
204 proto = webproto(req, repo.ui)
205
205
206 return {
206 return {
207 'cmd': cmd,
207 'cmd': cmd,
208 'proto': proto,
208 'proto': proto,
209 'dispatch': lambda: _callhttp(repo, req, proto, cmd),
209 'dispatch': lambda: _callhttp(repo, req, proto, cmd),
210 'handleerror': lambda ex: _handlehttperror(ex, req, cmd),
210 'handleerror': lambda ex: _handlehttperror(ex, req, cmd),
211 }
211 }
212
212
213 def _httpresponsetype(ui, req, prefer_uncompressed):
213 def _httpresponsetype(ui, req, prefer_uncompressed):
214 """Determine the appropriate response type and compression settings.
214 """Determine the appropriate response type and compression settings.
215
215
216 Returns a tuple of (mediatype, compengine, engineopts).
216 Returns a tuple of (mediatype, compengine, engineopts).
217 """
217 """
218 # Determine the response media type and compression engine based
218 # Determine the response media type and compression engine based
219 # on the request parameters.
219 # on the request parameters.
220 protocaps = decodevaluefromheaders(req, r'X-HgProto').split(' ')
220 protocaps = decodevaluefromheaders(req, r'X-HgProto').split(' ')
221
221
222 if '0.2' in protocaps:
222 if '0.2' in protocaps:
223 # All clients are expected to support uncompressed data.
223 # All clients are expected to support uncompressed data.
224 if prefer_uncompressed:
224 if prefer_uncompressed:
225 return HGTYPE2, util._noopengine(), {}
225 return HGTYPE2, util._noopengine(), {}
226
226
227 # Default as defined by wire protocol spec.
227 # Default as defined by wire protocol spec.
228 compformats = ['zlib', 'none']
228 compformats = ['zlib', 'none']
229 for cap in protocaps:
229 for cap in protocaps:
230 if cap.startswith('comp='):
230 if cap.startswith('comp='):
231 compformats = cap[5:].split(',')
231 compformats = cap[5:].split(',')
232 break
232 break
233
233
234 # Now find an agreed upon compression format.
234 # Now find an agreed upon compression format.
235 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
235 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
236 if engine.wireprotosupport().name in compformats:
236 if engine.wireprotosupport().name in compformats:
237 opts = {}
237 opts = {}
238 level = ui.configint('server', '%slevel' % engine.name())
238 level = ui.configint('server', '%slevel' % engine.name())
239 if level is not None:
239 if level is not None:
240 opts['level'] = level
240 opts['level'] = level
241
241
242 return HGTYPE2, engine, opts
242 return HGTYPE2, engine, opts
243
243
244 # No mutually supported compression format. Fall back to the
244 # No mutually supported compression format. Fall back to the
245 # legacy protocol.
245 # legacy protocol.
246
246
247 # Don't allow untrusted settings because disabling compression or
247 # Don't allow untrusted settings because disabling compression or
248 # setting a very high compression level could lead to flooding
248 # setting a very high compression level could lead to flooding
249 # the server's network or CPU.
249 # the server's network or CPU.
250 opts = {'level': ui.configint('server', 'zliblevel')}
250 opts = {'level': ui.configint('server', 'zliblevel')}
251 return HGTYPE, util.compengines['zlib'], opts
251 return HGTYPE, util.compengines['zlib'], opts
252
252
253 def _callhttp(repo, req, proto, cmd):
253 def _callhttp(repo, req, proto, cmd):
254 def genversion2(gen, engine, engineopts):
254 def genversion2(gen, engine, engineopts):
255 # application/mercurial-0.2 always sends a payload header
255 # application/mercurial-0.2 always sends a payload header
256 # identifying the compression engine.
256 # identifying the compression engine.
257 name = engine.wireprotosupport().name
257 name = engine.wireprotosupport().name
258 assert 0 < len(name) < 256
258 assert 0 < len(name) < 256
259 yield struct.pack('B', len(name))
259 yield struct.pack('B', len(name))
260 yield name
260 yield name
261
261
262 for chunk in gen:
262 for chunk in gen:
263 yield chunk
263 yield chunk
264
264
265 rsp = wireproto.dispatch(repo, proto, cmd)
265 rsp = wireproto.dispatch(repo, proto, cmd)
266
266
267 if not wireproto.commands.commandavailable(cmd, proto):
267 if not wireproto.commands.commandavailable(cmd, proto):
268 req.respond(HTTP_OK, HGERRTYPE,
268 req.respond(HTTP_OK, HGERRTYPE,
269 body=_('requested wire protocol command is not available '
269 body=_('requested wire protocol command is not available '
270 'over HTTP'))
270 'over HTTP'))
271 return []
271 return []
272
272
273 if isinstance(rsp, bytes):
273 if isinstance(rsp, bytes):
274 req.respond(HTTP_OK, HGTYPE, body=rsp)
274 req.respond(HTTP_OK, HGTYPE, body=rsp)
275 return []
275 return []
276 elif isinstance(rsp, wireprototypes.bytesresponse):
276 elif isinstance(rsp, wireprototypes.bytesresponse):
277 req.respond(HTTP_OK, HGTYPE, body=rsp.data)
277 req.respond(HTTP_OK, HGTYPE, body=rsp.data)
278 return []
278 return []
279 elif isinstance(rsp, wireprototypes.streamreslegacy):
279 elif isinstance(rsp, wireprototypes.streamreslegacy):
280 gen = rsp.gen
280 gen = rsp.gen
281 req.respond(HTTP_OK, HGTYPE)
281 req.respond(HTTP_OK, HGTYPE)
282 return gen
282 return gen
283 elif isinstance(rsp, wireprototypes.streamres):
283 elif isinstance(rsp, wireprototypes.streamres):
284 gen = rsp.gen
284 gen = rsp.gen
285
285
286 # This code for compression should not be streamres specific. It
286 # This code for compression should not be streamres specific. It
287 # is here because we only compress streamres at the moment.
287 # is here because we only compress streamres at the moment.
288 mediatype, engine, engineopts = _httpresponsetype(
288 mediatype, engine, engineopts = _httpresponsetype(
289 repo.ui, req, rsp.prefer_uncompressed)
289 repo.ui, req, rsp.prefer_uncompressed)
290 gen = engine.compressstream(gen, engineopts)
290 gen = engine.compressstream(gen, engineopts)
291
291
292 if mediatype == HGTYPE2:
292 if mediatype == HGTYPE2:
293 gen = genversion2(gen, engine, engineopts)
293 gen = genversion2(gen, engine, engineopts)
294
294
295 req.respond(HTTP_OK, mediatype)
295 req.respond(HTTP_OK, mediatype)
296 return gen
296 return gen
297 elif isinstance(rsp, wireprototypes.pushres):
297 elif isinstance(rsp, wireprototypes.pushres):
298 rsp = '%d\n%s' % (rsp.res, rsp.output)
298 rsp = '%d\n%s' % (rsp.res, rsp.output)
299 req.respond(HTTP_OK, HGTYPE, body=rsp)
299 req.respond(HTTP_OK, HGTYPE, body=rsp)
300 return []
300 return []
301 elif isinstance(rsp, wireprototypes.pusherr):
301 elif isinstance(rsp, wireprototypes.pusherr):
302 # This is the httplib workaround documented in _handlehttperror().
302 # This is the httplib workaround documented in _handlehttperror().
303 req.drain()
303 req.drain()
304
304
305 rsp = '0\n%s\n' % rsp.res
305 rsp = '0\n%s\n' % rsp.res
306 req.respond(HTTP_OK, HGTYPE, body=rsp)
306 req.respond(HTTP_OK, HGTYPE, body=rsp)
307 return []
307 return []
308 elif isinstance(rsp, wireprototypes.ooberror):
308 elif isinstance(rsp, wireprototypes.ooberror):
309 rsp = rsp.message
309 rsp = rsp.message
310 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
310 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
311 return []
311 return []
312 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
312 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
313
313
314 def _handlehttperror(e, req, cmd):
314 def _handlehttperror(e, req, cmd):
315 """Called when an ErrorResponse is raised during HTTP request processing."""
315 """Called when an ErrorResponse is raised during HTTP request processing."""
316
316
317 # Clients using Python's httplib are stateful: the HTTP client
317 # Clients using Python's httplib are stateful: the HTTP client
318 # won't process an HTTP response until all request data is
318 # won't process an HTTP response until all request data is
319 # sent to the server. The intent of this code is to ensure
319 # sent to the server. The intent of this code is to ensure
320 # we always read HTTP request data from the client, thus
320 # we always read HTTP request data from the client, thus
321 # ensuring httplib transitions to a state that allows it to read
321 # ensuring httplib transitions to a state that allows it to read
322 # the HTTP response. In other words, it helps prevent deadlocks
322 # the HTTP response. In other words, it helps prevent deadlocks
323 # on clients using httplib.
323 # on clients using httplib.
324
324
325 if (req.env[r'REQUEST_METHOD'] == r'POST' and
325 if (req.env[r'REQUEST_METHOD'] == r'POST' and
326 # But not if Expect: 100-continue is being used.
326 # But not if Expect: 100-continue is being used.
327 (req.env.get('HTTP_EXPECT',
327 (req.env.get('HTTP_EXPECT',
328 '').lower() != '100-continue') or
328 '').lower() != '100-continue') or
329 # Or the non-httplib HTTP library is being advertised by
329 # Or the non-httplib HTTP library is being advertised by
330 # the client.
330 # the client.
331 req.env.get('X-HgHttp2', '')):
331 req.env.get('X-HgHttp2', '')):
332 req.drain()
332 req.drain()
333 else:
333 else:
334 req.headers.append((r'Connection', r'Close'))
334 req.headers.append((r'Connection', r'Close'))
335
335
336 # TODO This response body assumes the failed command was
336 # TODO This response body assumes the failed command was
337 # "unbundle." That assumption is not always valid.
337 # "unbundle." That assumption is not always valid.
338 req.respond(e, HGTYPE, body='0\n%s\n' % e)
338 req.respond(e, HGTYPE, body='0\n%s\n' % e)
339
339
340 return ''
340 return ''
341
341
342 def _sshv1respondbytes(fout, value):
342 def _sshv1respondbytes(fout, value):
343 """Send a bytes response for protocol version 1."""
343 """Send a bytes response for protocol version 1."""
344 fout.write('%d\n' % len(value))
344 fout.write('%d\n' % len(value))
345 fout.write(value)
345 fout.write(value)
346 fout.flush()
346 fout.flush()
347
347
348 def _sshv1respondstream(fout, source):
348 def _sshv1respondstream(fout, source):
349 write = fout.write
349 write = fout.write
350 for chunk in source.gen:
350 for chunk in source.gen:
351 write(chunk)
351 write(chunk)
352 fout.flush()
352 fout.flush()
353
353
354 def _sshv1respondooberror(fout, ferr, rsp):
354 def _sshv1respondooberror(fout, ferr, rsp):
355 ferr.write(b'%s\n-\n' % rsp)
355 ferr.write(b'%s\n-\n' % rsp)
356 ferr.flush()
356 ferr.flush()
357 fout.write(b'\n')
357 fout.write(b'\n')
358 fout.flush()
358 fout.flush()
359
359
360 class sshv1protocolhandler(baseprotocolhandler):
360 class sshv1protocolhandler(baseprotocolhandler):
361 """Handler for requests services via version 1 of SSH protocol."""
361 """Handler for requests services via version 1 of SSH protocol."""
362 def __init__(self, ui, fin, fout):
362 def __init__(self, ui, fin, fout):
363 self._ui = ui
363 self._ui = ui
364 self._fin = fin
364 self._fin = fin
365 self._fout = fout
365 self._fout = fout
366
366
367 @property
367 @property
368 def name(self):
368 def name(self):
369 return SSHV1
369 return SSHV1
370
370
371 def getargs(self, args):
371 def getargs(self, args):
372 data = {}
372 data = {}
373 keys = args.split()
373 keys = args.split()
374 for n in xrange(len(keys)):
374 for n in xrange(len(keys)):
375 argline = self._fin.readline()[:-1]
375 argline = self._fin.readline()[:-1]
376 arg, l = argline.split()
376 arg, l = argline.split()
377 if arg not in keys:
377 if arg not in keys:
378 raise error.Abort(_("unexpected parameter %r") % arg)
378 raise error.Abort(_("unexpected parameter %r") % arg)
379 if arg == '*':
379 if arg == '*':
380 star = {}
380 star = {}
381 for k in xrange(int(l)):
381 for k in xrange(int(l)):
382 argline = self._fin.readline()[:-1]
382 argline = self._fin.readline()[:-1]
383 arg, l = argline.split()
383 arg, l = argline.split()
384 val = self._fin.read(int(l))
384 val = self._fin.read(int(l))
385 star[arg] = val
385 star[arg] = val
386 data['*'] = star
386 data['*'] = star
387 else:
387 else:
388 val = self._fin.read(int(l))
388 val = self._fin.read(int(l))
389 data[arg] = val
389 data[arg] = val
390 return [data[k] for k in keys]
390 return [data[k] for k in keys]
391
391
392 def forwardpayload(self, fpout):
392 def forwardpayload(self, fpout):
393 # The file is in the form:
393 # The file is in the form:
394 #
394 #
395 # <chunk size>\n<chunk>
395 # <chunk size>\n<chunk>
396 # ...
396 # ...
397 # 0\n
397 # 0\n
398 _sshv1respondbytes(self._fout, b'')
398 _sshv1respondbytes(self._fout, b'')
399 count = int(self._fin.readline())
399 count = int(self._fin.readline())
400 while count:
400 while count:
401 fpout.write(self._fin.read(count))
401 fpout.write(self._fin.read(count))
402 count = int(self._fin.readline())
402 count = int(self._fin.readline())
403
403
404 @contextlib.contextmanager
404 @contextlib.contextmanager
405 def mayberedirectstdio(self):
405 def mayberedirectstdio(self):
406 yield None
406 yield None
407
407
408 def client(self):
408 def client(self):
409 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
409 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
410 return 'remote:ssh:' + client
410 return 'remote:ssh:' + client
411
411
412 def _runsshserver(ui, repo, fin, fout):
413 state = 'protov1-serving'
414 proto = sshv1protocolhandler(ui, fin, fout)
415
416 while True:
417 if state == 'protov1-serving':
418 # Commands are issued on new lines.
419 request = fin.readline()[:-1]
420
421 # Empty lines signal to terminate the connection.
422 if not request:
423 state = 'shutdown'
424 continue
425
426 available = wireproto.commands.commandavailable(request, proto)
427
428 # This command isn't available. Send an empty response and go
429 # back to waiting for a new command.
430 if not available:
431 _sshv1respondbytes(fout, b'')
432 continue
433
434 rsp = wireproto.dispatch(repo, proto, request)
435
436 if isinstance(rsp, bytes):
437 _sshv1respondbytes(fout, rsp)
438 elif isinstance(rsp, wireprototypes.bytesresponse):
439 _sshv1respondbytes(fout, rsp.data)
440 elif isinstance(rsp, wireprototypes.streamres):
441 _sshv1respondstream(fout, rsp)
442 elif isinstance(rsp, wireprototypes.streamreslegacy):
443 _sshv1respondstream(fout, rsp)
444 elif isinstance(rsp, wireprototypes.pushres):
445 _sshv1respondbytes(fout, b'')
446 _sshv1respondbytes(fout, b'%d' % rsp.res)
447 elif isinstance(rsp, wireprototypes.pusherr):
448 _sshv1respondbytes(fout, rsp.res)
449 elif isinstance(rsp, wireprototypes.ooberror):
450 _sshv1respondooberror(fout, ui.ferr, rsp.message)
451 else:
452 raise error.ProgrammingError('unhandled response type from '
453 'wire protocol command: %s' % rsp)
454
455 elif state == 'shutdown':
456 break
457
458 else:
459 raise error.ProgrammingError('unhandled ssh server state: %s' %
460 state)
461
412 class sshserver(object):
462 class sshserver(object):
413 def __init__(self, ui, repo):
463 def __init__(self, ui, repo):
414 self._ui = ui
464 self._ui = ui
415 self._repo = repo
465 self._repo = repo
416 self._fin = ui.fin
466 self._fin = ui.fin
417 self._fout = ui.fout
467 self._fout = ui.fout
418
468
419 hook.redirect(True)
469 hook.redirect(True)
420 ui.fout = repo.ui.fout = ui.ferr
470 ui.fout = repo.ui.fout = ui.ferr
421
471
422 # Prevent insertion/deletion of CRs
472 # Prevent insertion/deletion of CRs
423 util.setbinary(self._fin)
473 util.setbinary(self._fin)
424 util.setbinary(self._fout)
474 util.setbinary(self._fout)
425
475
426 self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
427
428 def serve_forever(self):
476 def serve_forever(self):
429 while self.serve_one():
477 _runsshserver(self._ui, self._repo, self._fin, self._fout)
430 pass
431 sys.exit(0)
478 sys.exit(0)
432
433 def serve_one(self):
434 cmd = self._fin.readline()[:-1]
435 if cmd and wireproto.commands.commandavailable(cmd, self._proto):
436 rsp = wireproto.dispatch(self._repo, self._proto, cmd)
437
438 if isinstance(rsp, bytes):
439 _sshv1respondbytes(self._fout, rsp)
440 elif isinstance(rsp, wireprototypes.bytesresponse):
441 _sshv1respondbytes(self._fout, rsp.data)
442 elif isinstance(rsp, wireprototypes.streamres):
443 _sshv1respondstream(self._fout, rsp)
444 elif isinstance(rsp, wireprototypes.streamreslegacy):
445 _sshv1respondstream(self._fout, rsp)
446 elif isinstance(rsp, wireprototypes.pushres):
447 _sshv1respondbytes(self._fout, b'')
448 _sshv1respondbytes(self._fout, b'%d' % rsp.res)
449 elif isinstance(rsp, wireprototypes.pusherr):
450 _sshv1respondbytes(self._fout, rsp.res)
451 elif isinstance(rsp, wireprototypes.ooberror):
452 _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
453 else:
454 raise error.ProgrammingError('unhandled response type from '
455 'wire protocol command: %s' % rsp)
456 elif cmd:
457 _sshv1respondbytes(self._fout, b'')
458 return cmd != ''
@@ -1,127 +1,131 b''
1 # sshprotoext.py - Extension to test behavior of SSH protocol
1 # sshprotoext.py - Extension to test behavior of SSH protocol
2 #
2 #
3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 # This extension replaces the SSH server started via `hg serve --stdio`.
8 # This extension replaces the SSH server started via `hg serve --stdio`.
9 # The server behaves differently depending on environment variables.
9 # The server behaves differently depending on environment variables.
10
10
11 from __future__ import absolute_import
11 from __future__ import absolute_import
12
12
13 from mercurial import (
13 from mercurial import (
14 error,
14 error,
15 extensions,
15 extensions,
16 registrar,
16 registrar,
17 sshpeer,
17 sshpeer,
18 wireproto,
18 wireproto,
19 wireprotoserver,
19 wireprotoserver,
20 )
20 )
21
21
22 configtable = {}
22 configtable = {}
23 configitem = registrar.configitem(configtable)
23 configitem = registrar.configitem(configtable)
24
24
25 configitem(b'sshpeer', b'mode', default=None)
25 configitem(b'sshpeer', b'mode', default=None)
26 configitem(b'sshpeer', b'handshake-mode', default=None)
26 configitem(b'sshpeer', b'handshake-mode', default=None)
27
27
28 class bannerserver(wireprotoserver.sshserver):
28 class bannerserver(wireprotoserver.sshserver):
29 """Server that sends a banner to stdout."""
29 """Server that sends a banner to stdout."""
30 def serve_forever(self):
30 def serve_forever(self):
31 for i in range(10):
31 for i in range(10):
32 self._fout.write(b'banner: line %d\n' % i)
32 self._fout.write(b'banner: line %d\n' % i)
33
33
34 super(bannerserver, self).serve_forever()
34 super(bannerserver, self).serve_forever()
35
35
36 class prehelloserver(wireprotoserver.sshserver):
36 class prehelloserver(wireprotoserver.sshserver):
37 """Tests behavior when connecting to <0.9.1 servers.
37 """Tests behavior when connecting to <0.9.1 servers.
38
38
39 The ``hello`` wire protocol command was introduced in Mercurial
39 The ``hello`` wire protocol command was introduced in Mercurial
40 0.9.1. Modern clients send the ``hello`` command when connecting
40 0.9.1. Modern clients send the ``hello`` command when connecting
41 to SSH servers. This mock server tests behavior of the handshake
41 to SSH servers. This mock server tests behavior of the handshake
42 when ``hello`` is not supported.
42 when ``hello`` is not supported.
43 """
43 """
44 def serve_forever(self):
44 def serve_forever(self):
45 l = self._fin.readline()
45 l = self._fin.readline()
46 assert l == b'hello\n'
46 assert l == b'hello\n'
47 # Respond to unknown commands with an empty reply.
47 # Respond to unknown commands with an empty reply.
48 wireprotoserver._sshv1respondbytes(self._fout, b'')
48 wireprotoserver._sshv1respondbytes(self._fout, b'')
49 l = self._fin.readline()
49 l = self._fin.readline()
50 assert l == b'between\n'
50 assert l == b'between\n'
51 rsp = wireproto.dispatch(self._repo, self._proto, b'between')
51 proto = wireprotoserver.sshv1protocolhandler(self._ui, self._fin,
52 self._fout)
53 rsp = wireproto.dispatch(self._repo, proto, b'between')
52 wireprotoserver._sshv1respondbytes(self._fout, rsp.data)
54 wireprotoserver._sshv1respondbytes(self._fout, rsp.data)
53
55
54 super(prehelloserver, self).serve_forever()
56 super(prehelloserver, self).serve_forever()
55
57
56 class upgradev2server(wireprotoserver.sshserver):
58 class upgradev2server(wireprotoserver.sshserver):
57 """Tests behavior for clients that issue upgrade to version 2."""
59 """Tests behavior for clients that issue upgrade to version 2."""
58 def serve_forever(self):
60 def serve_forever(self):
59 name = wireprotoserver.SSHV2
61 name = wireprotoserver.SSHV2
60 l = self._fin.readline()
62 l = self._fin.readline()
61 assert l.startswith(b'upgrade ')
63 assert l.startswith(b'upgrade ')
62 token, caps = l[:-1].split(b' ')[1:]
64 token, caps = l[:-1].split(b' ')[1:]
63 assert caps == b'proto=%s' % name
65 assert caps == b'proto=%s' % name
64
66
65 # Filter hello and between requests.
67 # Filter hello and between requests.
66 l = self._fin.readline()
68 l = self._fin.readline()
67 assert l == b'hello\n'
69 assert l == b'hello\n'
68 l = self._fin.readline()
70 l = self._fin.readline()
69 assert l == b'between\n'
71 assert l == b'between\n'
70 l = self._fin.readline()
72 l = self._fin.readline()
71 assert l == b'pairs 81\n'
73 assert l == b'pairs 81\n'
72 self._fin.read(81)
74 self._fin.read(81)
73
75
74 # Send the upgrade response.
76 # Send the upgrade response.
77 proto = wireprotoserver.sshv1protocolhandler(self._ui, self._fin,
78 self._fout)
75 self._fout.write(b'upgraded %s %s\n' % (token, name))
79 self._fout.write(b'upgraded %s %s\n' % (token, name))
76 servercaps = wireproto.capabilities(self._repo, self._proto)
80 servercaps = wireproto.capabilities(self._repo, proto)
77 rsp = b'capabilities: %s' % servercaps.data
81 rsp = b'capabilities: %s' % servercaps.data
78 self._fout.write(b'%d\n' % len(rsp))
82 self._fout.write(b'%d\n' % len(rsp))
79 self._fout.write(rsp)
83 self._fout.write(rsp)
80 self._fout.write(b'\n')
84 self._fout.write(b'\n')
81 self._fout.flush()
85 self._fout.flush()
82
86
83 super(upgradev2server, self).serve_forever()
87 super(upgradev2server, self).serve_forever()
84
88
85 def performhandshake(orig, ui, stdin, stdout, stderr):
89 def performhandshake(orig, ui, stdin, stdout, stderr):
86 """Wrapped version of sshpeer._performhandshake to send extra commands."""
90 """Wrapped version of sshpeer._performhandshake to send extra commands."""
87 mode = ui.config(b'sshpeer', b'handshake-mode')
91 mode = ui.config(b'sshpeer', b'handshake-mode')
88 if mode == b'pre-no-args':
92 if mode == b'pre-no-args':
89 ui.debug(b'sending no-args command\n')
93 ui.debug(b'sending no-args command\n')
90 stdin.write(b'no-args\n')
94 stdin.write(b'no-args\n')
91 stdin.flush()
95 stdin.flush()
92 return orig(ui, stdin, stdout, stderr)
96 return orig(ui, stdin, stdout, stderr)
93 elif mode == b'pre-multiple-no-args':
97 elif mode == b'pre-multiple-no-args':
94 ui.debug(b'sending unknown1 command\n')
98 ui.debug(b'sending unknown1 command\n')
95 stdin.write(b'unknown1\n')
99 stdin.write(b'unknown1\n')
96 ui.debug(b'sending unknown2 command\n')
100 ui.debug(b'sending unknown2 command\n')
97 stdin.write(b'unknown2\n')
101 stdin.write(b'unknown2\n')
98 ui.debug(b'sending unknown3 command\n')
102 ui.debug(b'sending unknown3 command\n')
99 stdin.write(b'unknown3\n')
103 stdin.write(b'unknown3\n')
100 stdin.flush()
104 stdin.flush()
101 return orig(ui, stdin, stdout, stderr)
105 return orig(ui, stdin, stdout, stderr)
102 else:
106 else:
103 raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
107 raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
104 mode)
108 mode)
105
109
106 def extsetup(ui):
110 def extsetup(ui):
107 # It's easier for tests to define the server behavior via environment
111 # It's easier for tests to define the server behavior via environment
108 # variables than config options. This is because `hg serve --stdio`
112 # variables than config options. This is because `hg serve --stdio`
109 # has to be invoked with a certain form for security reasons and
113 # has to be invoked with a certain form for security reasons and
110 # `dummyssh` can't just add `--config` flags to the command line.
114 # `dummyssh` can't just add `--config` flags to the command line.
111 servermode = ui.environ.get(b'SSHSERVERMODE')
115 servermode = ui.environ.get(b'SSHSERVERMODE')
112
116
113 if servermode == b'banner':
117 if servermode == b'banner':
114 wireprotoserver.sshserver = bannerserver
118 wireprotoserver.sshserver = bannerserver
115 elif servermode == b'no-hello':
119 elif servermode == b'no-hello':
116 wireprotoserver.sshserver = prehelloserver
120 wireprotoserver.sshserver = prehelloserver
117 elif servermode == b'upgradev2':
121 elif servermode == b'upgradev2':
118 wireprotoserver.sshserver = upgradev2server
122 wireprotoserver.sshserver = upgradev2server
119 elif servermode:
123 elif servermode:
120 raise error.ProgrammingError(b'unknown server mode: %s' % servermode)
124 raise error.ProgrammingError(b'unknown server mode: %s' % servermode)
121
125
122 peermode = ui.config(b'sshpeer', b'mode')
126 peermode = ui.config(b'sshpeer', b'mode')
123
127
124 if peermode == b'extra-handshake-commands':
128 if peermode == b'extra-handshake-commands':
125 extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake)
129 extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake)
126 elif peermode:
130 elif peermode:
127 raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)
131 raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)
@@ -1,47 +1,50 b''
1 from __future__ import absolute_import, print_function
1 from __future__ import absolute_import, print_function
2
2
3 import io
3 import io
4 import unittest
4 import unittest
5
5
6 import silenttestrunner
6 import silenttestrunner
7
7
8 from mercurial import (
8 from mercurial import (
9 util,
9 util,
10 wireproto,
10 wireproto,
11 wireprotoserver,
11 wireprotoserver,
12 )
12 )
13
13
14 class SSHServerGetArgsTests(unittest.TestCase):
14 class SSHServerGetArgsTests(unittest.TestCase):
15 def testparseknown(self):
15 def testparseknown(self):
16 tests = [
16 tests = [
17 (b'* 0\nnodes 0\n', [b'', {}]),
17 (b'* 0\nnodes 0\n', [b'', {}]),
18 (b'* 0\nnodes 40\n1111111111111111111111111111111111111111\n',
18 (b'* 0\nnodes 40\n1111111111111111111111111111111111111111\n',
19 [b'1111111111111111111111111111111111111111', {}]),
19 [b'1111111111111111111111111111111111111111', {}]),
20 ]
20 ]
21 for input, expected in tests:
21 for input, expected in tests:
22 self.assertparse(b'known', input, expected)
22 self.assertparse(b'known', input, expected)
23
23
24 def assertparse(self, cmd, input, expected):
24 def assertparse(self, cmd, input, expected):
25 server = mockserver(input)
25 server = mockserver(input)
26 proto = wireprotoserver.sshv1protocolhandler(server._ui,
27 server._fin,
28 server._fout)
26 _func, spec = wireproto.commands[cmd]
29 _func, spec = wireproto.commands[cmd]
27 self.assertEqual(server._proto.getargs(spec), expected)
30 self.assertEqual(proto.getargs(spec), expected)
28
31
29 def mockserver(inbytes):
32 def mockserver(inbytes):
30 ui = mockui(inbytes)
33 ui = mockui(inbytes)
31 repo = mockrepo(ui)
34 repo = mockrepo(ui)
32 return wireprotoserver.sshserver(ui, repo)
35 return wireprotoserver.sshserver(ui, repo)
33
36
34 class mockrepo(object):
37 class mockrepo(object):
35 def __init__(self, ui):
38 def __init__(self, ui):
36 self.ui = ui
39 self.ui = ui
37
40
38 class mockui(object):
41 class mockui(object):
39 def __init__(self, inbytes):
42 def __init__(self, inbytes):
40 self.fin = io.BytesIO(inbytes)
43 self.fin = io.BytesIO(inbytes)
41 self.fout = io.BytesIO()
44 self.fout = io.BytesIO()
42 self.ferr = io.BytesIO()
45 self.ferr = io.BytesIO()
43
46
44 if __name__ == '__main__':
47 if __name__ == '__main__':
45 # Don't call into msvcrt to set BytesIO to binary mode
48 # Don't call into msvcrt to set BytesIO to binary mode
46 util.setbinary = lambda fp: True
49 util.setbinary = lambda fp: True
47 silenttestrunner.main(__name__)
50 silenttestrunner.main(__name__)
General Comments 0
You need to be logged in to leave comments. Login now