##// END OF EJS Templates
wireprotoserver: split ssh protocol handler and server...
Gregory Szorc -
r36082:bf676267 default
parent child Browse files
Show More
@@ -1,438 +1,447
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 #
3 #
4 # This software may be used and distributed according to the terms of the
4 # This software may be used and distributed according to the terms of the
5 # GNU General Public License version 2 or any later version.
5 # GNU General Public License version 2 or any later version.
6
6
7 from __future__ import absolute_import
7 from __future__ import absolute_import
8
8
9 import abc
9 import abc
10 import cgi
10 import cgi
11 import struct
11 import struct
12 import sys
12 import sys
13
13
14 from .i18n import _
14 from .i18n import _
15 from . import (
15 from . import (
16 encoding,
16 encoding,
17 error,
17 error,
18 hook,
18 hook,
19 pycompat,
19 pycompat,
20 util,
20 util,
21 wireproto,
21 wireproto,
22 )
22 )
23
23
24 stringio = util.stringio
24 stringio = util.stringio
25
25
26 urlerr = util.urlerr
26 urlerr = util.urlerr
27 urlreq = util.urlreq
27 urlreq = util.urlreq
28
28
29 HTTP_OK = 200
29 HTTP_OK = 200
30
30
31 HGTYPE = 'application/mercurial-0.1'
31 HGTYPE = 'application/mercurial-0.1'
32 HGTYPE2 = 'application/mercurial-0.2'
32 HGTYPE2 = 'application/mercurial-0.2'
33 HGERRTYPE = 'application/hg-error'
33 HGERRTYPE = 'application/hg-error'
34
34
35 # Names of the SSH protocol implementations.
35 # Names of the SSH protocol implementations.
36 SSHV1 = 'ssh-v1'
36 SSHV1 = 'ssh-v1'
37 # This is advertised over the wire. Incremental the counter at the end
37 # This is advertised over the wire. Incremental the counter at the end
38 # to reflect BC breakages.
38 # to reflect BC breakages.
39 SSHV2 = 'exp-ssh-v2-0001'
39 SSHV2 = 'exp-ssh-v2-0001'
40
40
41 class baseprotocolhandler(object):
41 class baseprotocolhandler(object):
42 """Abstract base class for wire protocol handlers.
42 """Abstract base class for wire protocol handlers.
43
43
44 A wire protocol handler serves as an interface between protocol command
44 A wire protocol handler serves as an interface between protocol command
45 handlers and the wire protocol transport layer. Protocol handlers provide
45 handlers and the wire protocol transport layer. Protocol handlers provide
46 methods to read command arguments, redirect stdio for the duration of
46 methods to read command arguments, redirect stdio for the duration of
47 the request, handle response types, etc.
47 the request, handle response types, etc.
48 """
48 """
49
49
50 __metaclass__ = abc.ABCMeta
50 __metaclass__ = abc.ABCMeta
51
51
52 @abc.abstractproperty
52 @abc.abstractproperty
53 def name(self):
53 def name(self):
54 """The name of the protocol implementation.
54 """The name of the protocol implementation.
55
55
56 Used for uniquely identifying the transport type.
56 Used for uniquely identifying the transport type.
57 """
57 """
58
58
59 @abc.abstractmethod
59 @abc.abstractmethod
60 def getargs(self, args):
60 def getargs(self, args):
61 """return the value for arguments in <args>
61 """return the value for arguments in <args>
62
62
63 returns a list of values (same order as <args>)"""
63 returns a list of values (same order as <args>)"""
64
64
65 @abc.abstractmethod
65 @abc.abstractmethod
66 def getfile(self, fp):
66 def getfile(self, fp):
67 """write the whole content of a file into a file like object
67 """write the whole content of a file into a file like object
68
68
69 The file is in the form::
69 The file is in the form::
70
70
71 (<chunk-size>\n<chunk>)+0\n
71 (<chunk-size>\n<chunk>)+0\n
72
72
73 chunk size is the ascii version of the int.
73 chunk size is the ascii version of the int.
74 """
74 """
75
75
76 @abc.abstractmethod
76 @abc.abstractmethod
77 def redirect(self):
77 def redirect(self):
78 """may setup interception for stdout and stderr
78 """may setup interception for stdout and stderr
79
79
80 See also the `restore` method."""
80 See also the `restore` method."""
81
81
82 # If the `redirect` function does install interception, the `restore`
82 # If the `redirect` function does install interception, the `restore`
83 # function MUST be defined. If interception is not used, this function
83 # function MUST be defined. If interception is not used, this function
84 # MUST NOT be defined.
84 # MUST NOT be defined.
85 #
85 #
86 # left commented here on purpose
86 # left commented here on purpose
87 #
87 #
88 #def restore(self):
88 #def restore(self):
89 # """reinstall previous stdout and stderr and return intercepted stdout
89 # """reinstall previous stdout and stderr and return intercepted stdout
90 # """
90 # """
91 # raise NotImplementedError()
91 # raise NotImplementedError()
92
92
93 def decodevaluefromheaders(req, headerprefix):
93 def decodevaluefromheaders(req, headerprefix):
94 """Decode a long value from multiple HTTP request headers.
94 """Decode a long value from multiple HTTP request headers.
95
95
96 Returns the value as a bytes, not a str.
96 Returns the value as a bytes, not a str.
97 """
97 """
98 chunks = []
98 chunks = []
99 i = 1
99 i = 1
100 prefix = headerprefix.upper().replace(r'-', r'_')
100 prefix = headerprefix.upper().replace(r'-', r'_')
101 while True:
101 while True:
102 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
102 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
103 if v is None:
103 if v is None:
104 break
104 break
105 chunks.append(pycompat.bytesurl(v))
105 chunks.append(pycompat.bytesurl(v))
106 i += 1
106 i += 1
107
107
108 return ''.join(chunks)
108 return ''.join(chunks)
109
109
110 class webproto(baseprotocolhandler):
110 class webproto(baseprotocolhandler):
111 def __init__(self, req, ui):
111 def __init__(self, req, ui):
112 self._req = req
112 self._req = req
113 self._ui = ui
113 self._ui = ui
114
114
115 @property
115 @property
116 def name(self):
116 def name(self):
117 return 'http'
117 return 'http'
118
118
119 def getargs(self, args):
119 def getargs(self, args):
120 knownargs = self._args()
120 knownargs = self._args()
121 data = {}
121 data = {}
122 keys = args.split()
122 keys = args.split()
123 for k in keys:
123 for k in keys:
124 if k == '*':
124 if k == '*':
125 star = {}
125 star = {}
126 for key in knownargs.keys():
126 for key in knownargs.keys():
127 if key != 'cmd' and key not in keys:
127 if key != 'cmd' and key not in keys:
128 star[key] = knownargs[key][0]
128 star[key] = knownargs[key][0]
129 data['*'] = star
129 data['*'] = star
130 else:
130 else:
131 data[k] = knownargs[k][0]
131 data[k] = knownargs[k][0]
132 return [data[k] for k in keys]
132 return [data[k] for k in keys]
133
133
134 def _args(self):
134 def _args(self):
135 args = util.rapply(pycompat.bytesurl, self._req.form.copy())
135 args = util.rapply(pycompat.bytesurl, self._req.form.copy())
136 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
136 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
137 if postlen:
137 if postlen:
138 args.update(cgi.parse_qs(
138 args.update(cgi.parse_qs(
139 self._req.read(postlen), keep_blank_values=True))
139 self._req.read(postlen), keep_blank_values=True))
140 return args
140 return args
141
141
142 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
142 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
143 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
143 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
144 return args
144 return args
145
145
146 def getfile(self, fp):
146 def getfile(self, fp):
147 length = int(self._req.env[r'CONTENT_LENGTH'])
147 length = int(self._req.env[r'CONTENT_LENGTH'])
148 # If httppostargs is used, we need to read Content-Length
148 # If httppostargs is used, we need to read Content-Length
149 # minus the amount that was consumed by args.
149 # minus the amount that was consumed by args.
150 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
150 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
151 for s in util.filechunkiter(self._req, limit=length):
151 for s in util.filechunkiter(self._req, limit=length):
152 fp.write(s)
152 fp.write(s)
153
153
154 def redirect(self):
154 def redirect(self):
155 self._oldio = self._ui.fout, self._ui.ferr
155 self._oldio = self._ui.fout, self._ui.ferr
156 self._ui.ferr = self._ui.fout = stringio()
156 self._ui.ferr = self._ui.fout = stringio()
157
157
158 def restore(self):
158 def restore(self):
159 val = self._ui.fout.getvalue()
159 val = self._ui.fout.getvalue()
160 self._ui.ferr, self._ui.fout = self._oldio
160 self._ui.ferr, self._ui.fout = self._oldio
161 return val
161 return val
162
162
163 def _client(self):
163 def _client(self):
164 return 'remote:%s:%s:%s' % (
164 return 'remote:%s:%s:%s' % (
165 self._req.env.get('wsgi.url_scheme') or 'http',
165 self._req.env.get('wsgi.url_scheme') or 'http',
166 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
166 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
167 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
167 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
168
168
169 def responsetype(self, prefer_uncompressed):
169 def responsetype(self, prefer_uncompressed):
170 """Determine the appropriate response type and compression settings.
170 """Determine the appropriate response type and compression settings.
171
171
172 Returns a tuple of (mediatype, compengine, engineopts).
172 Returns a tuple of (mediatype, compengine, engineopts).
173 """
173 """
174 # Determine the response media type and compression engine based
174 # Determine the response media type and compression engine based
175 # on the request parameters.
175 # on the request parameters.
176 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
176 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
177
177
178 if '0.2' in protocaps:
178 if '0.2' in protocaps:
179 # All clients are expected to support uncompressed data.
179 # All clients are expected to support uncompressed data.
180 if prefer_uncompressed:
180 if prefer_uncompressed:
181 return HGTYPE2, util._noopengine(), {}
181 return HGTYPE2, util._noopengine(), {}
182
182
183 # Default as defined by wire protocol spec.
183 # Default as defined by wire protocol spec.
184 compformats = ['zlib', 'none']
184 compformats = ['zlib', 'none']
185 for cap in protocaps:
185 for cap in protocaps:
186 if cap.startswith('comp='):
186 if cap.startswith('comp='):
187 compformats = cap[5:].split(',')
187 compformats = cap[5:].split(',')
188 break
188 break
189
189
190 # Now find an agreed upon compression format.
190 # Now find an agreed upon compression format.
191 for engine in wireproto.supportedcompengines(self._ui, self,
191 for engine in wireproto.supportedcompengines(self._ui, self,
192 util.SERVERROLE):
192 util.SERVERROLE):
193 if engine.wireprotosupport().name in compformats:
193 if engine.wireprotosupport().name in compformats:
194 opts = {}
194 opts = {}
195 level = self._ui.configint('server',
195 level = self._ui.configint('server',
196 '%slevel' % engine.name())
196 '%slevel' % engine.name())
197 if level is not None:
197 if level is not None:
198 opts['level'] = level
198 opts['level'] = level
199
199
200 return HGTYPE2, engine, opts
200 return HGTYPE2, engine, opts
201
201
202 # No mutually supported compression format. Fall back to the
202 # No mutually supported compression format. Fall back to the
203 # legacy protocol.
203 # legacy protocol.
204
204
205 # Don't allow untrusted settings because disabling compression or
205 # Don't allow untrusted settings because disabling compression or
206 # setting a very high compression level could lead to flooding
206 # setting a very high compression level could lead to flooding
207 # the server's network or CPU.
207 # the server's network or CPU.
208 opts = {'level': self._ui.configint('server', 'zliblevel')}
208 opts = {'level': self._ui.configint('server', 'zliblevel')}
209 return HGTYPE, util.compengines['zlib'], opts
209 return HGTYPE, util.compengines['zlib'], opts
210
210
211 def iscmd(cmd):
211 def iscmd(cmd):
212 return cmd in wireproto.commands
212 return cmd in wireproto.commands
213
213
214 def parsehttprequest(repo, req, query):
214 def parsehttprequest(repo, req, query):
215 """Parse the HTTP request for a wire protocol request.
215 """Parse the HTTP request for a wire protocol request.
216
216
217 If the current request appears to be a wire protocol request, this
217 If the current request appears to be a wire protocol request, this
218 function returns a dict with details about that request, including
218 function returns a dict with details about that request, including
219 an ``abstractprotocolserver`` instance suitable for handling the
219 an ``abstractprotocolserver`` instance suitable for handling the
220 request. Otherwise, ``None`` is returned.
220 request. Otherwise, ``None`` is returned.
221
221
222 ``req`` is a ``wsgirequest`` instance.
222 ``req`` is a ``wsgirequest`` instance.
223 """
223 """
224 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
224 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
225 # string parameter. If it isn't present, this isn't a wire protocol
225 # string parameter. If it isn't present, this isn't a wire protocol
226 # request.
226 # request.
227 if r'cmd' not in req.form:
227 if r'cmd' not in req.form:
228 return None
228 return None
229
229
230 cmd = pycompat.sysbytes(req.form[r'cmd'][0])
230 cmd = pycompat.sysbytes(req.form[r'cmd'][0])
231
231
232 # The "cmd" request parameter is used by both the wire protocol and hgweb.
232 # The "cmd" request parameter is used by both the wire protocol and hgweb.
233 # While not all wire protocol commands are available for all transports,
233 # While not all wire protocol commands are available for all transports,
234 # if we see a "cmd" value that resembles a known wire protocol command, we
234 # if we see a "cmd" value that resembles a known wire protocol command, we
235 # route it to a protocol handler. This is better than routing possible
235 # route it to a protocol handler. This is better than routing possible
236 # wire protocol requests to hgweb because it prevents hgweb from using
236 # wire protocol requests to hgweb because it prevents hgweb from using
237 # known wire protocol commands and it is less confusing for machine
237 # known wire protocol commands and it is less confusing for machine
238 # clients.
238 # clients.
239 if cmd not in wireproto.commands:
239 if cmd not in wireproto.commands:
240 return None
240 return None
241
241
242 proto = webproto(req, repo.ui)
242 proto = webproto(req, repo.ui)
243
243
244 return {
244 return {
245 'cmd': cmd,
245 'cmd': cmd,
246 'proto': proto,
246 'proto': proto,
247 'dispatch': lambda: _callhttp(repo, req, proto, cmd),
247 'dispatch': lambda: _callhttp(repo, req, proto, cmd),
248 'handleerror': lambda ex: _handlehttperror(ex, req, cmd),
248 'handleerror': lambda ex: _handlehttperror(ex, req, cmd),
249 }
249 }
250
250
251 def _callhttp(repo, req, proto, cmd):
251 def _callhttp(repo, req, proto, cmd):
252 def genversion2(gen, engine, engineopts):
252 def genversion2(gen, engine, engineopts):
253 # application/mercurial-0.2 always sends a payload header
253 # application/mercurial-0.2 always sends a payload header
254 # identifying the compression engine.
254 # identifying the compression engine.
255 name = engine.wireprotosupport().name
255 name = engine.wireprotosupport().name
256 assert 0 < len(name) < 256
256 assert 0 < len(name) < 256
257 yield struct.pack('B', len(name))
257 yield struct.pack('B', len(name))
258 yield name
258 yield name
259
259
260 for chunk in gen:
260 for chunk in gen:
261 yield chunk
261 yield chunk
262
262
263 rsp = wireproto.dispatch(repo, proto, cmd)
263 rsp = wireproto.dispatch(repo, proto, cmd)
264
264
265 if not wireproto.commands.commandavailable(cmd, proto):
265 if not wireproto.commands.commandavailable(cmd, proto):
266 req.respond(HTTP_OK, HGERRTYPE,
266 req.respond(HTTP_OK, HGERRTYPE,
267 body=_('requested wire protocol command is not available '
267 body=_('requested wire protocol command is not available '
268 'over HTTP'))
268 'over HTTP'))
269 return []
269 return []
270
270
271 if isinstance(rsp, bytes):
271 if isinstance(rsp, bytes):
272 req.respond(HTTP_OK, HGTYPE, body=rsp)
272 req.respond(HTTP_OK, HGTYPE, body=rsp)
273 return []
273 return []
274 elif isinstance(rsp, wireproto.streamres_legacy):
274 elif isinstance(rsp, wireproto.streamres_legacy):
275 gen = rsp.gen
275 gen = rsp.gen
276 req.respond(HTTP_OK, HGTYPE)
276 req.respond(HTTP_OK, HGTYPE)
277 return gen
277 return gen
278 elif isinstance(rsp, wireproto.streamres):
278 elif isinstance(rsp, wireproto.streamres):
279 gen = rsp.gen
279 gen = rsp.gen
280
280
281 # This code for compression should not be streamres specific. It
281 # This code for compression should not be streamres specific. It
282 # is here because we only compress streamres at the moment.
282 # is here because we only compress streamres at the moment.
283 mediatype, engine, engineopts = proto.responsetype(
283 mediatype, engine, engineopts = proto.responsetype(
284 rsp.prefer_uncompressed)
284 rsp.prefer_uncompressed)
285 gen = engine.compressstream(gen, engineopts)
285 gen = engine.compressstream(gen, engineopts)
286
286
287 if mediatype == HGTYPE2:
287 if mediatype == HGTYPE2:
288 gen = genversion2(gen, engine, engineopts)
288 gen = genversion2(gen, engine, engineopts)
289
289
290 req.respond(HTTP_OK, mediatype)
290 req.respond(HTTP_OK, mediatype)
291 return gen
291 return gen
292 elif isinstance(rsp, wireproto.pushres):
292 elif isinstance(rsp, wireproto.pushres):
293 val = proto.restore()
293 val = proto.restore()
294 rsp = '%d\n%s' % (rsp.res, val)
294 rsp = '%d\n%s' % (rsp.res, val)
295 req.respond(HTTP_OK, HGTYPE, body=rsp)
295 req.respond(HTTP_OK, HGTYPE, body=rsp)
296 return []
296 return []
297 elif isinstance(rsp, wireproto.pusherr):
297 elif isinstance(rsp, wireproto.pusherr):
298 # This is the httplib workaround documented in _handlehttperror().
298 # This is the httplib workaround documented in _handlehttperror().
299 req.drain()
299 req.drain()
300
300
301 proto.restore()
301 proto.restore()
302 rsp = '0\n%s\n' % rsp.res
302 rsp = '0\n%s\n' % rsp.res
303 req.respond(HTTP_OK, HGTYPE, body=rsp)
303 req.respond(HTTP_OK, HGTYPE, body=rsp)
304 return []
304 return []
305 elif isinstance(rsp, wireproto.ooberror):
305 elif isinstance(rsp, wireproto.ooberror):
306 rsp = rsp.message
306 rsp = rsp.message
307 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
307 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
308 return []
308 return []
309 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
309 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
310
310
311 def _handlehttperror(e, req, cmd):
311 def _handlehttperror(e, req, cmd):
312 """Called when an ErrorResponse is raised during HTTP request processing."""
312 """Called when an ErrorResponse is raised during HTTP request processing."""
313
313
314 # Clients using Python's httplib are stateful: the HTTP client
314 # Clients using Python's httplib are stateful: the HTTP client
315 # won't process an HTTP response until all request data is
315 # won't process an HTTP response until all request data is
316 # sent to the server. The intent of this code is to ensure
316 # sent to the server. The intent of this code is to ensure
317 # we always read HTTP request data from the client, thus
317 # we always read HTTP request data from the client, thus
318 # ensuring httplib transitions to a state that allows it to read
318 # ensuring httplib transitions to a state that allows it to read
319 # the HTTP response. In other words, it helps prevent deadlocks
319 # the HTTP response. In other words, it helps prevent deadlocks
320 # on clients using httplib.
320 # on clients using httplib.
321
321
322 if (req.env[r'REQUEST_METHOD'] == r'POST' and
322 if (req.env[r'REQUEST_METHOD'] == r'POST' and
323 # But not if Expect: 100-continue is being used.
323 # But not if Expect: 100-continue is being used.
324 (req.env.get('HTTP_EXPECT',
324 (req.env.get('HTTP_EXPECT',
325 '').lower() != '100-continue') or
325 '').lower() != '100-continue') or
326 # Or the non-httplib HTTP library is being advertised by
326 # Or the non-httplib HTTP library is being advertised by
327 # the client.
327 # the client.
328 req.env.get('X-HgHttp2', '')):
328 req.env.get('X-HgHttp2', '')):
329 req.drain()
329 req.drain()
330 else:
330 else:
331 req.headers.append((r'Connection', r'Close'))
331 req.headers.append((r'Connection', r'Close'))
332
332
333 # TODO This response body assumes the failed command was
333 # TODO This response body assumes the failed command was
334 # "unbundle." That assumption is not always valid.
334 # "unbundle." That assumption is not always valid.
335 req.respond(e, HGTYPE, body='0\n%s\n' % e)
335 req.respond(e, HGTYPE, body='0\n%s\n' % e)
336
336
337 return ''
337 return ''
338
338
339 def _sshv1respondbytes(fout, value):
339 def _sshv1respondbytes(fout, value):
340 """Send a bytes response for protocol version 1."""
340 """Send a bytes response for protocol version 1."""
341 fout.write('%d\n' % len(value))
341 fout.write('%d\n' % len(value))
342 fout.write(value)
342 fout.write(value)
343 fout.flush()
343 fout.flush()
344
344
345 def _sshv1respondstream(fout, source):
345 def _sshv1respondstream(fout, source):
346 write = fout.write
346 write = fout.write
347 for chunk in source.gen:
347 for chunk in source.gen:
348 write(chunk)
348 write(chunk)
349 fout.flush()
349 fout.flush()
350
350
351 def _sshv1respondooberror(fout, ferr, rsp):
351 def _sshv1respondooberror(fout, ferr, rsp):
352 ferr.write(b'%s\n-\n' % rsp)
352 ferr.write(b'%s\n-\n' % rsp)
353 ferr.flush()
353 ferr.flush()
354 fout.write(b'\n')
354 fout.write(b'\n')
355 fout.flush()
355 fout.flush()
356
356
357 class sshserver(baseprotocolhandler):
357 class sshv1protocolhandler(baseprotocolhandler):
358 def __init__(self, ui, repo):
358 """Handler for requests services via version 1 of SSH protocol."""
359 def __init__(self, ui, fin, fout):
359 self._ui = ui
360 self._ui = ui
360 self._repo = repo
361 self._fin = fin
361 self._fin = ui.fin
362 self._fout = fout
362 self._fout = ui.fout
363
364 hook.redirect(True)
365 ui.fout = repo.ui.fout = ui.ferr
366
367 # Prevent insertion/deletion of CRs
368 util.setbinary(self._fin)
369 util.setbinary(self._fout)
370
363
371 @property
364 @property
372 def name(self):
365 def name(self):
373 return 'ssh'
366 return 'ssh'
374
367
375 def getargs(self, args):
368 def getargs(self, args):
376 data = {}
369 data = {}
377 keys = args.split()
370 keys = args.split()
378 for n in xrange(len(keys)):
371 for n in xrange(len(keys)):
379 argline = self._fin.readline()[:-1]
372 argline = self._fin.readline()[:-1]
380 arg, l = argline.split()
373 arg, l = argline.split()
381 if arg not in keys:
374 if arg not in keys:
382 raise error.Abort(_("unexpected parameter %r") % arg)
375 raise error.Abort(_("unexpected parameter %r") % arg)
383 if arg == '*':
376 if arg == '*':
384 star = {}
377 star = {}
385 for k in xrange(int(l)):
378 for k in xrange(int(l)):
386 argline = self._fin.readline()[:-1]
379 argline = self._fin.readline()[:-1]
387 arg, l = argline.split()
380 arg, l = argline.split()
388 val = self._fin.read(int(l))
381 val = self._fin.read(int(l))
389 star[arg] = val
382 star[arg] = val
390 data['*'] = star
383 data['*'] = star
391 else:
384 else:
392 val = self._fin.read(int(l))
385 val = self._fin.read(int(l))
393 data[arg] = val
386 data[arg] = val
394 return [data[k] for k in keys]
387 return [data[k] for k in keys]
395
388
396 def getfile(self, fpout):
389 def getfile(self, fpout):
397 _sshv1respondbytes(self._fout, b'')
390 _sshv1respondbytes(self._fout, b'')
398 count = int(self._fin.readline())
391 count = int(self._fin.readline())
399 while count:
392 while count:
400 fpout.write(self._fin.read(count))
393 fpout.write(self._fin.read(count))
401 count = int(self._fin.readline())
394 count = int(self._fin.readline())
402
395
403 def redirect(self):
396 def redirect(self):
404 pass
397 pass
405
398
399 def _client(self):
400 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
401 return 'remote:ssh:' + client
402
403 class sshserver(object):
404 def __init__(self, ui, repo):
405 self._ui = ui
406 self._repo = repo
407 self._fin = ui.fin
408 self._fout = ui.fout
409
410 hook.redirect(True)
411 ui.fout = repo.ui.fout = ui.ferr
412
413 # Prevent insertion/deletion of CRs
414 util.setbinary(self._fin)
415 util.setbinary(self._fout)
416
417 self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
418
406 def serve_forever(self):
419 def serve_forever(self):
407 while self.serve_one():
420 while self.serve_one():
408 pass
421 pass
409 sys.exit(0)
422 sys.exit(0)
410
423
411 def serve_one(self):
424 def serve_one(self):
412 cmd = self._fin.readline()[:-1]
425 cmd = self._fin.readline()[:-1]
413 if cmd and wireproto.commands.commandavailable(cmd, self):
426 if cmd and wireproto.commands.commandavailable(cmd, self._proto):
414 rsp = wireproto.dispatch(self._repo, self, cmd)
427 rsp = wireproto.dispatch(self._repo, self._proto, cmd)
415
428
416 if isinstance(rsp, bytes):
429 if isinstance(rsp, bytes):
417 _sshv1respondbytes(self._fout, rsp)
430 _sshv1respondbytes(self._fout, rsp)
418 elif isinstance(rsp, wireproto.streamres):
431 elif isinstance(rsp, wireproto.streamres):
419 _sshv1respondstream(self._fout, rsp)
432 _sshv1respondstream(self._fout, rsp)
420 elif isinstance(rsp, wireproto.streamres_legacy):
433 elif isinstance(rsp, wireproto.streamres_legacy):
421 _sshv1respondstream(self._fout, rsp)
434 _sshv1respondstream(self._fout, rsp)
422 elif isinstance(rsp, wireproto.pushres):
435 elif isinstance(rsp, wireproto.pushres):
423 _sshv1respondbytes(self._fout, b'')
436 _sshv1respondbytes(self._fout, b'')
424 _sshv1respondbytes(self._fout, bytes(rsp.res))
437 _sshv1respondbytes(self._fout, bytes(rsp.res))
425 elif isinstance(rsp, wireproto.pusherr):
438 elif isinstance(rsp, wireproto.pusherr):
426 _sshv1respondbytes(self._fout, rsp.res)
439 _sshv1respondbytes(self._fout, rsp.res)
427 elif isinstance(rsp, wireproto.ooberror):
440 elif isinstance(rsp, wireproto.ooberror):
428 _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
441 _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
429 else:
442 else:
430 raise error.ProgrammingError('unhandled response type from '
443 raise error.ProgrammingError('unhandled response type from '
431 'wire protocol command: %s' % rsp)
444 'wire protocol command: %s' % rsp)
432 elif cmd:
445 elif cmd:
433 _sshv1respondbytes(self._fout, b'')
446 _sshv1respondbytes(self._fout, b'')
434 return cmd != ''
447 return cmd != ''
435
436 def _client(self):
437 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
438 return 'remote:ssh:' + client
@@ -1,127 +1,127
1 # sshprotoext.py - Extension to test behavior of SSH protocol
1 # sshprotoext.py - Extension to test behavior of SSH protocol
2 #
2 #
3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 # This extension replaces the SSH server started via `hg serve --stdio`.
8 # This extension replaces the SSH server started via `hg serve --stdio`.
9 # The server behaves differently depending on environment variables.
9 # The server behaves differently depending on environment variables.
10
10
11 from __future__ import absolute_import
11 from __future__ import absolute_import
12
12
13 from mercurial import (
13 from mercurial import (
14 error,
14 error,
15 extensions,
15 extensions,
16 registrar,
16 registrar,
17 sshpeer,
17 sshpeer,
18 wireproto,
18 wireproto,
19 wireprotoserver,
19 wireprotoserver,
20 )
20 )
21
21
22 configtable = {}
22 configtable = {}
23 configitem = registrar.configitem(configtable)
23 configitem = registrar.configitem(configtable)
24
24
25 configitem('sshpeer', 'mode', default=None)
25 configitem('sshpeer', 'mode', default=None)
26 configitem('sshpeer', 'handshake-mode', default=None)
26 configitem('sshpeer', 'handshake-mode', default=None)
27
27
28 class bannerserver(wireprotoserver.sshserver):
28 class bannerserver(wireprotoserver.sshserver):
29 """Server that sends a banner to stdout."""
29 """Server that sends a banner to stdout."""
30 def serve_forever(self):
30 def serve_forever(self):
31 for i in range(10):
31 for i in range(10):
32 self._fout.write(b'banner: line %d\n' % i)
32 self._fout.write(b'banner: line %d\n' % i)
33
33
34 super(bannerserver, self).serve_forever()
34 super(bannerserver, self).serve_forever()
35
35
36 class prehelloserver(wireprotoserver.sshserver):
36 class prehelloserver(wireprotoserver.sshserver):
37 """Tests behavior when connecting to <0.9.1 servers.
37 """Tests behavior when connecting to <0.9.1 servers.
38
38
39 The ``hello`` wire protocol command was introduced in Mercurial
39 The ``hello`` wire protocol command was introduced in Mercurial
40 0.9.1. Modern clients send the ``hello`` command when connecting
40 0.9.1. Modern clients send the ``hello`` command when connecting
41 to SSH servers. This mock server tests behavior of the handshake
41 to SSH servers. This mock server tests behavior of the handshake
42 when ``hello`` is not supported.
42 when ``hello`` is not supported.
43 """
43 """
44 def serve_forever(self):
44 def serve_forever(self):
45 l = self._fin.readline()
45 l = self._fin.readline()
46 assert l == b'hello\n'
46 assert l == b'hello\n'
47 # Respond to unknown commands with an empty reply.
47 # Respond to unknown commands with an empty reply.
48 wireprotoserver._sshv1respondbytes(self._fout, b'')
48 wireprotoserver._sshv1respondbytes(self._fout, b'')
49 l = self._fin.readline()
49 l = self._fin.readline()
50 assert l == b'between\n'
50 assert l == b'between\n'
51 rsp = wireproto.dispatch(self._repo, self, b'between')
51 rsp = wireproto.dispatch(self._repo, self._proto, b'between')
52 wireprotoserver._sshv1respondbytes(self._fout, rsp)
52 wireprotoserver._sshv1respondbytes(self._fout, rsp)
53
53
54 super(prehelloserver, self).serve_forever()
54 super(prehelloserver, self).serve_forever()
55
55
56 class upgradev2server(wireprotoserver.sshserver):
56 class upgradev2server(wireprotoserver.sshserver):
57 """Tests behavior for clients that issue upgrade to version 2."""
57 """Tests behavior for clients that issue upgrade to version 2."""
58 def serve_forever(self):
58 def serve_forever(self):
59 name = wireprotoserver.SSHV2
59 name = wireprotoserver.SSHV2
60 l = self._fin.readline()
60 l = self._fin.readline()
61 assert l.startswith(b'upgrade ')
61 assert l.startswith(b'upgrade ')
62 token, caps = l[:-1].split(b' ')[1:]
62 token, caps = l[:-1].split(b' ')[1:]
63 assert caps == b'proto=%s' % name
63 assert caps == b'proto=%s' % name
64
64
65 # Filter hello and between requests.
65 # Filter hello and between requests.
66 l = self._fin.readline()
66 l = self._fin.readline()
67 assert l == b'hello\n'
67 assert l == b'hello\n'
68 l = self._fin.readline()
68 l = self._fin.readline()
69 assert l == b'between\n'
69 assert l == b'between\n'
70 l = self._fin.readline()
70 l = self._fin.readline()
71 assert l == 'pairs 81\n'
71 assert l == 'pairs 81\n'
72 self._fin.read(81)
72 self._fin.read(81)
73
73
74 # Send the upgrade response.
74 # Send the upgrade response.
75 self._fout.write(b'upgraded %s %s\n' % (token, name))
75 self._fout.write(b'upgraded %s %s\n' % (token, name))
76 servercaps = wireproto.capabilities(self._repo, self)
76 servercaps = wireproto.capabilities(self._repo, self._proto)
77 rsp = b'capabilities: %s' % servercaps
77 rsp = b'capabilities: %s' % servercaps
78 self._fout.write(b'%d\n' % len(rsp))
78 self._fout.write(b'%d\n' % len(rsp))
79 self._fout.write(rsp)
79 self._fout.write(rsp)
80 self._fout.write(b'\n')
80 self._fout.write(b'\n')
81 self._fout.flush()
81 self._fout.flush()
82
82
83 super(upgradev2server, self).serve_forever()
83 super(upgradev2server, self).serve_forever()
84
84
85 def performhandshake(orig, ui, stdin, stdout, stderr):
85 def performhandshake(orig, ui, stdin, stdout, stderr):
86 """Wrapped version of sshpeer._performhandshake to send extra commands."""
86 """Wrapped version of sshpeer._performhandshake to send extra commands."""
87 mode = ui.config(b'sshpeer', b'handshake-mode')
87 mode = ui.config(b'sshpeer', b'handshake-mode')
88 if mode == b'pre-no-args':
88 if mode == b'pre-no-args':
89 ui.debug(b'sending no-args command\n')
89 ui.debug(b'sending no-args command\n')
90 stdin.write(b'no-args\n')
90 stdin.write(b'no-args\n')
91 stdin.flush()
91 stdin.flush()
92 return orig(ui, stdin, stdout, stderr)
92 return orig(ui, stdin, stdout, stderr)
93 elif mode == b'pre-multiple-no-args':
93 elif mode == b'pre-multiple-no-args':
94 ui.debug(b'sending unknown1 command\n')
94 ui.debug(b'sending unknown1 command\n')
95 stdin.write(b'unknown1\n')
95 stdin.write(b'unknown1\n')
96 ui.debug(b'sending unknown2 command\n')
96 ui.debug(b'sending unknown2 command\n')
97 stdin.write(b'unknown2\n')
97 stdin.write(b'unknown2\n')
98 ui.debug(b'sending unknown3 command\n')
98 ui.debug(b'sending unknown3 command\n')
99 stdin.write(b'unknown3\n')
99 stdin.write(b'unknown3\n')
100 stdin.flush()
100 stdin.flush()
101 return orig(ui, stdin, stdout, stderr)
101 return orig(ui, stdin, stdout, stderr)
102 else:
102 else:
103 raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
103 raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
104 mode)
104 mode)
105
105
106 def extsetup(ui):
106 def extsetup(ui):
107 # It's easier for tests to define the server behavior via environment
107 # It's easier for tests to define the server behavior via environment
108 # variables than config options. This is because `hg serve --stdio`
108 # variables than config options. This is because `hg serve --stdio`
109 # has to be invoked with a certain form for security reasons and
109 # has to be invoked with a certain form for security reasons and
110 # `dummyssh` can't just add `--config` flags to the command line.
110 # `dummyssh` can't just add `--config` flags to the command line.
111 servermode = ui.environ.get(b'SSHSERVERMODE')
111 servermode = ui.environ.get(b'SSHSERVERMODE')
112
112
113 if servermode == b'banner':
113 if servermode == b'banner':
114 wireprotoserver.sshserver = bannerserver
114 wireprotoserver.sshserver = bannerserver
115 elif servermode == b'no-hello':
115 elif servermode == b'no-hello':
116 wireprotoserver.sshserver = prehelloserver
116 wireprotoserver.sshserver = prehelloserver
117 elif servermode == b'upgradev2':
117 elif servermode == b'upgradev2':
118 wireprotoserver.sshserver = upgradev2server
118 wireprotoserver.sshserver = upgradev2server
119 elif servermode:
119 elif servermode:
120 raise error.ProgrammingError(b'unknown server mode: %s' % servermode)
120 raise error.ProgrammingError(b'unknown server mode: %s' % servermode)
121
121
122 peermode = ui.config(b'sshpeer', b'mode')
122 peermode = ui.config(b'sshpeer', b'mode')
123
123
124 if peermode == b'extra-handshake-commands':
124 if peermode == b'extra-handshake-commands':
125 extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake)
125 extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake)
126 elif peermode:
126 elif peermode:
127 raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)
127 raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)
@@ -1,47 +1,47
1 from __future__ import absolute_import, print_function
1 from __future__ import absolute_import, print_function
2
2
3 import io
3 import io
4 import unittest
4 import unittest
5
5
6 import silenttestrunner
6 import silenttestrunner
7
7
8 from mercurial import (
8 from mercurial import (
9 util,
9 util,
10 wireproto,
10 wireproto,
11 wireprotoserver,
11 wireprotoserver,
12 )
12 )
13
13
14 class SSHServerGetArgsTests(unittest.TestCase):
14 class SSHServerGetArgsTests(unittest.TestCase):
15 def testparseknown(self):
15 def testparseknown(self):
16 tests = [
16 tests = [
17 ('* 0\nnodes 0\n', ['', {}]),
17 ('* 0\nnodes 0\n', ['', {}]),
18 ('* 0\nnodes 40\n1111111111111111111111111111111111111111\n',
18 ('* 0\nnodes 40\n1111111111111111111111111111111111111111\n',
19 ['1111111111111111111111111111111111111111', {}]),
19 ['1111111111111111111111111111111111111111', {}]),
20 ]
20 ]
21 for input, expected in tests:
21 for input, expected in tests:
22 self.assertparse('known', input, expected)
22 self.assertparse('known', input, expected)
23
23
24 def assertparse(self, cmd, input, expected):
24 def assertparse(self, cmd, input, expected):
25 server = mockserver(input)
25 server = mockserver(input)
26 _func, spec = wireproto.commands[cmd]
26 _func, spec = wireproto.commands[cmd]
27 self.assertEqual(server.getargs(spec), expected)
27 self.assertEqual(server._proto.getargs(spec), expected)
28
28
29 def mockserver(inbytes):
29 def mockserver(inbytes):
30 ui = mockui(inbytes)
30 ui = mockui(inbytes)
31 repo = mockrepo(ui)
31 repo = mockrepo(ui)
32 return wireprotoserver.sshserver(ui, repo)
32 return wireprotoserver.sshserver(ui, repo)
33
33
34 class mockrepo(object):
34 class mockrepo(object):
35 def __init__(self, ui):
35 def __init__(self, ui):
36 self.ui = ui
36 self.ui = ui
37
37
38 class mockui(object):
38 class mockui(object):
39 def __init__(self, inbytes):
39 def __init__(self, inbytes):
40 self.fin = io.BytesIO(inbytes)
40 self.fin = io.BytesIO(inbytes)
41 self.fout = io.BytesIO()
41 self.fout = io.BytesIO()
42 self.ferr = io.BytesIO()
42 self.ferr = io.BytesIO()
43
43
44 if __name__ == '__main__':
44 if __name__ == '__main__':
45 # Don't call into msvcrt to set BytesIO to binary mode
45 # Don't call into msvcrt to set BytesIO to binary mode
46 util.setbinary = lambda fp: True
46 util.setbinary = lambda fp: True
47 silenttestrunner.main(__name__)
47 silenttestrunner.main(__name__)
General Comments 0
You need to be logged in to leave comments. Login now