##// END OF EJS Templates
hgweb: handle CONTENT_LENGTH...
Gregory Szorc -
r36863:ed0456fd default
parent child Browse files
Show More
@@ -1,315 +1,322 b''
1 # hgweb/request.py - An http request from either CGI or the standalone server.
1 # hgweb/request.py - An http request from either CGI or the standalone server.
2 #
2 #
3 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
3 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
4 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
4 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 from __future__ import absolute_import
9 from __future__ import absolute_import
10
10
11 import cgi
11 import cgi
12 import errno
12 import errno
13 import socket
13 import socket
14 import wsgiref.headers as wsgiheaders
14 import wsgiref.headers as wsgiheaders
15 #import wsgiref.validate
15 #import wsgiref.validate
16
16
17 from .common import (
17 from .common import (
18 ErrorResponse,
18 ErrorResponse,
19 HTTP_NOT_MODIFIED,
19 HTTP_NOT_MODIFIED,
20 statusmessage,
20 statusmessage,
21 )
21 )
22
22
23 from ..thirdparty import (
23 from ..thirdparty import (
24 attr,
24 attr,
25 )
25 )
26 from .. import (
26 from .. import (
27 pycompat,
27 pycompat,
28 util,
28 util,
29 )
29 )
30
30
31 shortcuts = {
31 shortcuts = {
32 'cl': [('cmd', ['changelog']), ('rev', None)],
32 'cl': [('cmd', ['changelog']), ('rev', None)],
33 'sl': [('cmd', ['shortlog']), ('rev', None)],
33 'sl': [('cmd', ['shortlog']), ('rev', None)],
34 'cs': [('cmd', ['changeset']), ('node', None)],
34 'cs': [('cmd', ['changeset']), ('node', None)],
35 'f': [('cmd', ['file']), ('filenode', None)],
35 'f': [('cmd', ['file']), ('filenode', None)],
36 'fl': [('cmd', ['filelog']), ('filenode', None)],
36 'fl': [('cmd', ['filelog']), ('filenode', None)],
37 'fd': [('cmd', ['filediff']), ('node', None)],
37 'fd': [('cmd', ['filediff']), ('node', None)],
38 'fa': [('cmd', ['annotate']), ('filenode', None)],
38 'fa': [('cmd', ['annotate']), ('filenode', None)],
39 'mf': [('cmd', ['manifest']), ('manifest', None)],
39 'mf': [('cmd', ['manifest']), ('manifest', None)],
40 'ca': [('cmd', ['archive']), ('node', None)],
40 'ca': [('cmd', ['archive']), ('node', None)],
41 'tags': [('cmd', ['tags'])],
41 'tags': [('cmd', ['tags'])],
42 'tip': [('cmd', ['changeset']), ('node', ['tip'])],
42 'tip': [('cmd', ['changeset']), ('node', ['tip'])],
43 'static': [('cmd', ['static']), ('file', None)]
43 'static': [('cmd', ['static']), ('file', None)]
44 }
44 }
45
45
46 def normalize(form):
46 def normalize(form):
47 # first expand the shortcuts
47 # first expand the shortcuts
48 for k in shortcuts:
48 for k in shortcuts:
49 if k in form:
49 if k in form:
50 for name, value in shortcuts[k]:
50 for name, value in shortcuts[k]:
51 if value is None:
51 if value is None:
52 value = form[k]
52 value = form[k]
53 form[name] = value
53 form[name] = value
54 del form[k]
54 del form[k]
55 # And strip the values
55 # And strip the values
56 bytesform = {}
56 bytesform = {}
57 for k, v in form.iteritems():
57 for k, v in form.iteritems():
58 bytesform[pycompat.bytesurl(k)] = [
58 bytesform[pycompat.bytesurl(k)] = [
59 pycompat.bytesurl(i.strip()) for i in v]
59 pycompat.bytesurl(i.strip()) for i in v]
60 return bytesform
60 return bytesform
61
61
62 @attr.s(frozen=True)
62 @attr.s(frozen=True)
63 class parsedrequest(object):
63 class parsedrequest(object):
64 """Represents a parsed WSGI request / static HTTP request parameters."""
64 """Represents a parsed WSGI request / static HTTP request parameters."""
65
65
66 # Full URL for this request.
66 # Full URL for this request.
67 url = attr.ib()
67 url = attr.ib()
68 # URL without any path components. Just <proto>://<host><port>.
68 # URL without any path components. Just <proto>://<host><port>.
69 baseurl = attr.ib()
69 baseurl = attr.ib()
70 # Advertised URL. Like ``url`` and ``baseurl`` but uses SERVER_NAME instead
70 # Advertised URL. Like ``url`` and ``baseurl`` but uses SERVER_NAME instead
71 # of HTTP: Host header for hostname. This is likely what clients used.
71 # of HTTP: Host header for hostname. This is likely what clients used.
72 advertisedurl = attr.ib()
72 advertisedurl = attr.ib()
73 advertisedbaseurl = attr.ib()
73 advertisedbaseurl = attr.ib()
74 # WSGI application path.
74 # WSGI application path.
75 apppath = attr.ib()
75 apppath = attr.ib()
76 # List of path parts to be used for dispatch.
76 # List of path parts to be used for dispatch.
77 dispatchparts = attr.ib()
77 dispatchparts = attr.ib()
78 # URL path component (no query string) used for dispatch.
78 # URL path component (no query string) used for dispatch.
79 dispatchpath = attr.ib()
79 dispatchpath = attr.ib()
80 # Whether there is a path component to this request. This can be true
80 # Whether there is a path component to this request. This can be true
81 # when ``dispatchpath`` is empty due to REPO_NAME muckery.
81 # when ``dispatchpath`` is empty due to REPO_NAME muckery.
82 havepathinfo = attr.ib()
82 havepathinfo = attr.ib()
83 # Raw query string (part after "?" in URL).
83 # Raw query string (part after "?" in URL).
84 querystring = attr.ib()
84 querystring = attr.ib()
85 # List of 2-tuples of query string arguments.
85 # List of 2-tuples of query string arguments.
86 querystringlist = attr.ib()
86 querystringlist = attr.ib()
87 # Dict of query string arguments. Values are lists with at least 1 item.
87 # Dict of query string arguments. Values are lists with at least 1 item.
88 querystringdict = attr.ib()
88 querystringdict = attr.ib()
89 # wsgiref.headers.Headers instance. Operates like a dict with case
89 # wsgiref.headers.Headers instance. Operates like a dict with case
90 # insensitive keys.
90 # insensitive keys.
91 headers = attr.ib()
91 headers = attr.ib()
92
92
93 def parserequestfromenv(env):
93 def parserequestfromenv(env):
94 """Parse URL components from environment variables.
94 """Parse URL components from environment variables.
95
95
96 WSGI defines request attributes via environment variables. This function
96 WSGI defines request attributes via environment variables. This function
97 parses the environment variables into a data structure.
97 parses the environment variables into a data structure.
98 """
98 """
99 # PEP-0333 defines the WSGI spec and is a useful reference for this code.
99 # PEP-0333 defines the WSGI spec and is a useful reference for this code.
100
100
101 # We first validate that the incoming object conforms with the WSGI spec.
101 # We first validate that the incoming object conforms with the WSGI spec.
102 # We only want to be dealing with spec-conforming WSGI implementations.
102 # We only want to be dealing with spec-conforming WSGI implementations.
103 # TODO enable this once we fix internal violations.
103 # TODO enable this once we fix internal violations.
104 #wsgiref.validate.check_environ(env)
104 #wsgiref.validate.check_environ(env)
105
105
106 # PEP-0333 states that environment keys and values are native strings
106 # PEP-0333 states that environment keys and values are native strings
107 # (bytes on Python 2 and str on Python 3). The code points for the Unicode
107 # (bytes on Python 2 and str on Python 3). The code points for the Unicode
108 # strings on Python 3 must be between \00000-\000FF. We deal with bytes
108 # strings on Python 3 must be between \00000-\000FF. We deal with bytes
109 # in Mercurial, so mass convert string keys and values to bytes.
109 # in Mercurial, so mass convert string keys and values to bytes.
110 if pycompat.ispy3:
110 if pycompat.ispy3:
111 env = {k.encode('latin-1'): v for k, v in env.iteritems()}
111 env = {k.encode('latin-1'): v for k, v in env.iteritems()}
112 env = {k: v.encode('latin-1') if isinstance(v, str) else v
112 env = {k: v.encode('latin-1') if isinstance(v, str) else v
113 for k, v in env.iteritems()}
113 for k, v in env.iteritems()}
114
114
115 # https://www.python.org/dev/peps/pep-0333/#environ-variables defines
115 # https://www.python.org/dev/peps/pep-0333/#environ-variables defines
116 # the environment variables.
116 # the environment variables.
117 # https://www.python.org/dev/peps/pep-0333/#url-reconstruction defines
117 # https://www.python.org/dev/peps/pep-0333/#url-reconstruction defines
118 # how URLs are reconstructed.
118 # how URLs are reconstructed.
119 fullurl = env['wsgi.url_scheme'] + '://'
119 fullurl = env['wsgi.url_scheme'] + '://'
120 advertisedfullurl = fullurl
120 advertisedfullurl = fullurl
121
121
122 def addport(s):
122 def addport(s):
123 if env['wsgi.url_scheme'] == 'https':
123 if env['wsgi.url_scheme'] == 'https':
124 if env['SERVER_PORT'] != '443':
124 if env['SERVER_PORT'] != '443':
125 s += ':' + env['SERVER_PORT']
125 s += ':' + env['SERVER_PORT']
126 else:
126 else:
127 if env['SERVER_PORT'] != '80':
127 if env['SERVER_PORT'] != '80':
128 s += ':' + env['SERVER_PORT']
128 s += ':' + env['SERVER_PORT']
129
129
130 return s
130 return s
131
131
132 if env.get('HTTP_HOST'):
132 if env.get('HTTP_HOST'):
133 fullurl += env['HTTP_HOST']
133 fullurl += env['HTTP_HOST']
134 else:
134 else:
135 fullurl += env['SERVER_NAME']
135 fullurl += env['SERVER_NAME']
136 fullurl = addport(fullurl)
136 fullurl = addport(fullurl)
137
137
138 advertisedfullurl += env['SERVER_NAME']
138 advertisedfullurl += env['SERVER_NAME']
139 advertisedfullurl = addport(advertisedfullurl)
139 advertisedfullurl = addport(advertisedfullurl)
140
140
141 baseurl = fullurl
141 baseurl = fullurl
142 advertisedbaseurl = advertisedfullurl
142 advertisedbaseurl = advertisedfullurl
143
143
144 fullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
144 fullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
145 advertisedfullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
145 advertisedfullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
146 fullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
146 fullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
147 advertisedfullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
147 advertisedfullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
148
148
149 if env.get('QUERY_STRING'):
149 if env.get('QUERY_STRING'):
150 fullurl += '?' + env['QUERY_STRING']
150 fullurl += '?' + env['QUERY_STRING']
151 advertisedfullurl += '?' + env['QUERY_STRING']
151 advertisedfullurl += '?' + env['QUERY_STRING']
152
152
153 # When dispatching requests, we look at the URL components (PATH_INFO
153 # When dispatching requests, we look at the URL components (PATH_INFO
154 # and QUERY_STRING) after the application root (SCRIPT_NAME). But hgwebdir
154 # and QUERY_STRING) after the application root (SCRIPT_NAME). But hgwebdir
155 # has the concept of "virtual" repositories. This is defined via REPO_NAME.
155 # has the concept of "virtual" repositories. This is defined via REPO_NAME.
156 # If REPO_NAME is defined, we append it to SCRIPT_NAME to form a new app
156 # If REPO_NAME is defined, we append it to SCRIPT_NAME to form a new app
157 # root. We also exclude its path components from PATH_INFO when resolving
157 # root. We also exclude its path components from PATH_INFO when resolving
158 # the dispatch path.
158 # the dispatch path.
159
159
160 apppath = env['SCRIPT_NAME']
160 apppath = env['SCRIPT_NAME']
161
161
162 if env.get('REPO_NAME'):
162 if env.get('REPO_NAME'):
163 if not apppath.endswith('/'):
163 if not apppath.endswith('/'):
164 apppath += '/'
164 apppath += '/'
165
165
166 apppath += env.get('REPO_NAME')
166 apppath += env.get('REPO_NAME')
167
167
168 if 'PATH_INFO' in env:
168 if 'PATH_INFO' in env:
169 dispatchparts = env['PATH_INFO'].strip('/').split('/')
169 dispatchparts = env['PATH_INFO'].strip('/').split('/')
170
170
171 # Strip out repo parts.
171 # Strip out repo parts.
172 repoparts = env.get('REPO_NAME', '').split('/')
172 repoparts = env.get('REPO_NAME', '').split('/')
173 if dispatchparts[:len(repoparts)] == repoparts:
173 if dispatchparts[:len(repoparts)] == repoparts:
174 dispatchparts = dispatchparts[len(repoparts):]
174 dispatchparts = dispatchparts[len(repoparts):]
175 else:
175 else:
176 dispatchparts = []
176 dispatchparts = []
177
177
178 dispatchpath = '/'.join(dispatchparts)
178 dispatchpath = '/'.join(dispatchparts)
179
179
180 querystring = env.get('QUERY_STRING', '')
180 querystring = env.get('QUERY_STRING', '')
181
181
182 # We store as a list so we have ordering information. We also store as
182 # We store as a list so we have ordering information. We also store as
183 # a dict to facilitate fast lookup.
183 # a dict to facilitate fast lookup.
184 querystringlist = util.urlreq.parseqsl(querystring, keep_blank_values=True)
184 querystringlist = util.urlreq.parseqsl(querystring, keep_blank_values=True)
185
185
186 querystringdict = {}
186 querystringdict = {}
187 for k, v in querystringlist:
187 for k, v in querystringlist:
188 if k in querystringdict:
188 if k in querystringdict:
189 querystringdict[k].append(v)
189 querystringdict[k].append(v)
190 else:
190 else:
191 querystringdict[k] = [v]
191 querystringdict[k] = [v]
192
192
193 # HTTP_* keys contain HTTP request headers. The Headers structure should
193 # HTTP_* keys contain HTTP request headers. The Headers structure should
194 # perform case normalization for us. We just rewrite underscore to dash
194 # perform case normalization for us. We just rewrite underscore to dash
195 # so keys match what likely went over the wire.
195 # so keys match what likely went over the wire.
196 headers = []
196 headers = []
197 for k, v in env.iteritems():
197 for k, v in env.iteritems():
198 if k.startswith('HTTP_'):
198 if k.startswith('HTTP_'):
199 headers.append((k[len('HTTP_'):].replace('_', '-'), v))
199 headers.append((k[len('HTTP_'):].replace('_', '-'), v))
200
200
201 headers = wsgiheaders.Headers(headers)
201 headers = wsgiheaders.Headers(headers)
202
202
203 # This is kind of a lie because the HTTP header wasn't explicitly
204 # sent. But for all intents and purposes it should be OK to lie about
205 # this, since a consumer will either either value to determine how many
206 # bytes are available to read.
207 if 'CONTENT_LENGTH' in env and 'HTTP_CONTENT_LENGTH' not in env:
208 headers['Content-Length'] = env['CONTENT_LENGTH']
209
203 return parsedrequest(url=fullurl, baseurl=baseurl,
210 return parsedrequest(url=fullurl, baseurl=baseurl,
204 advertisedurl=advertisedfullurl,
211 advertisedurl=advertisedfullurl,
205 advertisedbaseurl=advertisedbaseurl,
212 advertisedbaseurl=advertisedbaseurl,
206 apppath=apppath,
213 apppath=apppath,
207 dispatchparts=dispatchparts, dispatchpath=dispatchpath,
214 dispatchparts=dispatchparts, dispatchpath=dispatchpath,
208 havepathinfo='PATH_INFO' in env,
215 havepathinfo='PATH_INFO' in env,
209 querystring=querystring,
216 querystring=querystring,
210 querystringlist=querystringlist,
217 querystringlist=querystringlist,
211 querystringdict=querystringdict,
218 querystringdict=querystringdict,
212 headers=headers)
219 headers=headers)
213
220
214 class wsgirequest(object):
221 class wsgirequest(object):
215 """Higher-level API for a WSGI request.
222 """Higher-level API for a WSGI request.
216
223
217 WSGI applications are invoked with 2 arguments. They are used to
224 WSGI applications are invoked with 2 arguments. They are used to
218 instantiate instances of this class, which provides higher-level APIs
225 instantiate instances of this class, which provides higher-level APIs
219 for obtaining request parameters, writing HTTP output, etc.
226 for obtaining request parameters, writing HTTP output, etc.
220 """
227 """
221 def __init__(self, wsgienv, start_response):
228 def __init__(self, wsgienv, start_response):
222 version = wsgienv[r'wsgi.version']
229 version = wsgienv[r'wsgi.version']
223 if (version < (1, 0)) or (version >= (2, 0)):
230 if (version < (1, 0)) or (version >= (2, 0)):
224 raise RuntimeError("Unknown and unsupported WSGI version %d.%d"
231 raise RuntimeError("Unknown and unsupported WSGI version %d.%d"
225 % version)
232 % version)
226 self.inp = wsgienv[r'wsgi.input']
233 self.inp = wsgienv[r'wsgi.input']
227 self.err = wsgienv[r'wsgi.errors']
234 self.err = wsgienv[r'wsgi.errors']
228 self.threaded = wsgienv[r'wsgi.multithread']
235 self.threaded = wsgienv[r'wsgi.multithread']
229 self.multiprocess = wsgienv[r'wsgi.multiprocess']
236 self.multiprocess = wsgienv[r'wsgi.multiprocess']
230 self.run_once = wsgienv[r'wsgi.run_once']
237 self.run_once = wsgienv[r'wsgi.run_once']
231 self.env = wsgienv
238 self.env = wsgienv
232 self.form = normalize(cgi.parse(self.inp,
239 self.form = normalize(cgi.parse(self.inp,
233 self.env,
240 self.env,
234 keep_blank_values=1))
241 keep_blank_values=1))
235 self._start_response = start_response
242 self._start_response = start_response
236 self.server_write = None
243 self.server_write = None
237 self.headers = []
244 self.headers = []
238
245
239 def __iter__(self):
246 def __iter__(self):
240 return iter([])
247 return iter([])
241
248
242 def read(self, count=-1):
249 def read(self, count=-1):
243 return self.inp.read(count)
250 return self.inp.read(count)
244
251
245 def drain(self):
252 def drain(self):
246 '''need to read all data from request, httplib is half-duplex'''
253 '''need to read all data from request, httplib is half-duplex'''
247 length = int(self.env.get('CONTENT_LENGTH') or 0)
254 length = int(self.env.get('CONTENT_LENGTH') or 0)
248 for s in util.filechunkiter(self.inp, limit=length):
255 for s in util.filechunkiter(self.inp, limit=length):
249 pass
256 pass
250
257
251 def respond(self, status, type, filename=None, body=None):
258 def respond(self, status, type, filename=None, body=None):
252 if not isinstance(type, str):
259 if not isinstance(type, str):
253 type = pycompat.sysstr(type)
260 type = pycompat.sysstr(type)
254 if self._start_response is not None:
261 if self._start_response is not None:
255 self.headers.append((r'Content-Type', type))
262 self.headers.append((r'Content-Type', type))
256 if filename:
263 if filename:
257 filename = (filename.rpartition('/')[-1]
264 filename = (filename.rpartition('/')[-1]
258 .replace('\\', '\\\\').replace('"', '\\"'))
265 .replace('\\', '\\\\').replace('"', '\\"'))
259 self.headers.append(('Content-Disposition',
266 self.headers.append(('Content-Disposition',
260 'inline; filename="%s"' % filename))
267 'inline; filename="%s"' % filename))
261 if body is not None:
268 if body is not None:
262 self.headers.append((r'Content-Length', str(len(body))))
269 self.headers.append((r'Content-Length', str(len(body))))
263
270
264 for k, v in self.headers:
271 for k, v in self.headers:
265 if not isinstance(v, str):
272 if not isinstance(v, str):
266 raise TypeError('header value must be string: %r' % (v,))
273 raise TypeError('header value must be string: %r' % (v,))
267
274
268 if isinstance(status, ErrorResponse):
275 if isinstance(status, ErrorResponse):
269 self.headers.extend(status.headers)
276 self.headers.extend(status.headers)
270 if status.code == HTTP_NOT_MODIFIED:
277 if status.code == HTTP_NOT_MODIFIED:
271 # RFC 2616 Section 10.3.5: 304 Not Modified has cases where
278 # RFC 2616 Section 10.3.5: 304 Not Modified has cases where
272 # it MUST NOT include any headers other than these and no
279 # it MUST NOT include any headers other than these and no
273 # body
280 # body
274 self.headers = [(k, v) for (k, v) in self.headers if
281 self.headers = [(k, v) for (k, v) in self.headers if
275 k in ('Date', 'ETag', 'Expires',
282 k in ('Date', 'ETag', 'Expires',
276 'Cache-Control', 'Vary')]
283 'Cache-Control', 'Vary')]
277 status = statusmessage(status.code, pycompat.bytestr(status))
284 status = statusmessage(status.code, pycompat.bytestr(status))
278 elif status == 200:
285 elif status == 200:
279 status = '200 Script output follows'
286 status = '200 Script output follows'
280 elif isinstance(status, int):
287 elif isinstance(status, int):
281 status = statusmessage(status)
288 status = statusmessage(status)
282
289
283 self.server_write = self._start_response(
290 self.server_write = self._start_response(
284 pycompat.sysstr(status), self.headers)
291 pycompat.sysstr(status), self.headers)
285 self._start_response = None
292 self._start_response = None
286 self.headers = []
293 self.headers = []
287 if body is not None:
294 if body is not None:
288 self.write(body)
295 self.write(body)
289 self.server_write = None
296 self.server_write = None
290
297
291 def write(self, thing):
298 def write(self, thing):
292 if thing:
299 if thing:
293 try:
300 try:
294 self.server_write(thing)
301 self.server_write(thing)
295 except socket.error as inst:
302 except socket.error as inst:
296 if inst[0] != errno.ECONNRESET:
303 if inst[0] != errno.ECONNRESET:
297 raise
304 raise
298
305
299 def writelines(self, lines):
306 def writelines(self, lines):
300 for line in lines:
307 for line in lines:
301 self.write(line)
308 self.write(line)
302
309
303 def flush(self):
310 def flush(self):
304 return None
311 return None
305
312
306 def close(self):
313 def close(self):
307 return None
314 return None
308
315
309 def wsgiapplication(app_maker):
316 def wsgiapplication(app_maker):
310 '''For compatibility with old CGI scripts. A plain hgweb() or hgwebdir()
317 '''For compatibility with old CGI scripts. A plain hgweb() or hgwebdir()
311 can and should now be used as a WSGI application.'''
318 can and should now be used as a WSGI application.'''
312 application = app_maker()
319 application = app_maker()
313 def run_wsgi(env, respond):
320 def run_wsgi(env, respond):
314 return application(env, respond)
321 return application(env, respond)
315 return run_wsgi
322 return run_wsgi
@@ -1,668 +1,667 b''
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 #
3 #
4 # This software may be used and distributed according to the terms of the
4 # This software may be used and distributed according to the terms of the
5 # GNU General Public License version 2 or any later version.
5 # GNU General Public License version 2 or any later version.
6
6
7 from __future__ import absolute_import
7 from __future__ import absolute_import
8
8
9 import contextlib
9 import contextlib
10 import struct
10 import struct
11 import sys
11 import sys
12 import threading
12 import threading
13
13
14 from .i18n import _
14 from .i18n import _
15 from . import (
15 from . import (
16 encoding,
16 encoding,
17 error,
17 error,
18 hook,
18 hook,
19 pycompat,
19 pycompat,
20 util,
20 util,
21 wireproto,
21 wireproto,
22 wireprototypes,
22 wireprototypes,
23 )
23 )
24
24
25 stringio = util.stringio
25 stringio = util.stringio
26
26
27 urlerr = util.urlerr
27 urlerr = util.urlerr
28 urlreq = util.urlreq
28 urlreq = util.urlreq
29
29
30 HTTP_OK = 200
30 HTTP_OK = 200
31
31
32 HGTYPE = 'application/mercurial-0.1'
32 HGTYPE = 'application/mercurial-0.1'
33 HGTYPE2 = 'application/mercurial-0.2'
33 HGTYPE2 = 'application/mercurial-0.2'
34 HGERRTYPE = 'application/hg-error'
34 HGERRTYPE = 'application/hg-error'
35
35
36 SSHV1 = wireprototypes.SSHV1
36 SSHV1 = wireprototypes.SSHV1
37 SSHV2 = wireprototypes.SSHV2
37 SSHV2 = wireprototypes.SSHV2
38
38
39 def decodevaluefromheaders(req, headerprefix):
39 def decodevaluefromheaders(req, headerprefix):
40 """Decode a long value from multiple HTTP request headers.
40 """Decode a long value from multiple HTTP request headers.
41
41
42 Returns the value as a bytes, not a str.
42 Returns the value as a bytes, not a str.
43 """
43 """
44 chunks = []
44 chunks = []
45 i = 1
45 i = 1
46 while True:
46 while True:
47 v = req.headers.get(b'%s-%d' % (headerprefix, i))
47 v = req.headers.get(b'%s-%d' % (headerprefix, i))
48 if v is None:
48 if v is None:
49 break
49 break
50 chunks.append(pycompat.bytesurl(v))
50 chunks.append(pycompat.bytesurl(v))
51 i += 1
51 i += 1
52
52
53 return ''.join(chunks)
53 return ''.join(chunks)
54
54
55 class httpv1protocolhandler(wireprototypes.baseprotocolhandler):
55 class httpv1protocolhandler(wireprototypes.baseprotocolhandler):
56 def __init__(self, wsgireq, req, ui, checkperm):
56 def __init__(self, wsgireq, req, ui, checkperm):
57 self._wsgireq = wsgireq
57 self._wsgireq = wsgireq
58 self._req = req
58 self._req = req
59 self._ui = ui
59 self._ui = ui
60 self._checkperm = checkperm
60 self._checkperm = checkperm
61
61
62 @property
62 @property
63 def name(self):
63 def name(self):
64 return 'http-v1'
64 return 'http-v1'
65
65
66 def getargs(self, args):
66 def getargs(self, args):
67 knownargs = self._args()
67 knownargs = self._args()
68 data = {}
68 data = {}
69 keys = args.split()
69 keys = args.split()
70 for k in keys:
70 for k in keys:
71 if k == '*':
71 if k == '*':
72 star = {}
72 star = {}
73 for key in knownargs.keys():
73 for key in knownargs.keys():
74 if key != 'cmd' and key not in keys:
74 if key != 'cmd' and key not in keys:
75 star[key] = knownargs[key][0]
75 star[key] = knownargs[key][0]
76 data['*'] = star
76 data['*'] = star
77 else:
77 else:
78 data[k] = knownargs[k][0]
78 data[k] = knownargs[k][0]
79 return [data[k] for k in keys]
79 return [data[k] for k in keys]
80
80
81 def _args(self):
81 def _args(self):
82 args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
82 args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
83 postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
83 postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
84 if postlen:
84 if postlen:
85 args.update(urlreq.parseqs(
85 args.update(urlreq.parseqs(
86 self._wsgireq.read(postlen), keep_blank_values=True))
86 self._wsgireq.read(postlen), keep_blank_values=True))
87 return args
87 return args
88
88
89 argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
89 argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
90 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
90 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
91 return args
91 return args
92
92
93 def forwardpayload(self, fp):
93 def forwardpayload(self, fp):
94 if b'Content-Length' in self._req.headers:
94 # Existing clients *always* send Content-Length.
95 length = int(self._req.headers[b'Content-Length'])
95 length = int(self._req.headers[b'Content-Length'])
96 else:
96
97 length = int(self._wsgireq.env[r'CONTENT_LENGTH'])
98 # If httppostargs is used, we need to read Content-Length
97 # If httppostargs is used, we need to read Content-Length
99 # minus the amount that was consumed by args.
98 # minus the amount that was consumed by args.
100 length -= int(self._req.headers.get(b'X-HgArgs-Post', 0))
99 length -= int(self._req.headers.get(b'X-HgArgs-Post', 0))
101 for s in util.filechunkiter(self._wsgireq, limit=length):
100 for s in util.filechunkiter(self._wsgireq, limit=length):
102 fp.write(s)
101 fp.write(s)
103
102
104 @contextlib.contextmanager
103 @contextlib.contextmanager
105 def mayberedirectstdio(self):
104 def mayberedirectstdio(self):
106 oldout = self._ui.fout
105 oldout = self._ui.fout
107 olderr = self._ui.ferr
106 olderr = self._ui.ferr
108
107
109 out = util.stringio()
108 out = util.stringio()
110
109
111 try:
110 try:
112 self._ui.fout = out
111 self._ui.fout = out
113 self._ui.ferr = out
112 self._ui.ferr = out
114 yield out
113 yield out
115 finally:
114 finally:
116 self._ui.fout = oldout
115 self._ui.fout = oldout
117 self._ui.ferr = olderr
116 self._ui.ferr = olderr
118
117
119 def client(self):
118 def client(self):
120 return 'remote:%s:%s:%s' % (
119 return 'remote:%s:%s:%s' % (
121 self._wsgireq.env.get('wsgi.url_scheme') or 'http',
120 self._wsgireq.env.get('wsgi.url_scheme') or 'http',
122 urlreq.quote(self._wsgireq.env.get('REMOTE_HOST', '')),
121 urlreq.quote(self._wsgireq.env.get('REMOTE_HOST', '')),
123 urlreq.quote(self._wsgireq.env.get('REMOTE_USER', '')))
122 urlreq.quote(self._wsgireq.env.get('REMOTE_USER', '')))
124
123
125 def addcapabilities(self, repo, caps):
124 def addcapabilities(self, repo, caps):
126 caps.append('httpheader=%d' %
125 caps.append('httpheader=%d' %
127 repo.ui.configint('server', 'maxhttpheaderlen'))
126 repo.ui.configint('server', 'maxhttpheaderlen'))
128 if repo.ui.configbool('experimental', 'httppostargs'):
127 if repo.ui.configbool('experimental', 'httppostargs'):
129 caps.append('httppostargs')
128 caps.append('httppostargs')
130
129
131 # FUTURE advertise 0.2rx once support is implemented
130 # FUTURE advertise 0.2rx once support is implemented
132 # FUTURE advertise minrx and mintx after consulting config option
131 # FUTURE advertise minrx and mintx after consulting config option
133 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
132 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
134
133
135 compengines = wireproto.supportedcompengines(repo.ui, util.SERVERROLE)
134 compengines = wireproto.supportedcompengines(repo.ui, util.SERVERROLE)
136 if compengines:
135 if compengines:
137 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
136 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
138 for e in compengines)
137 for e in compengines)
139 caps.append('compression=%s' % comptypes)
138 caps.append('compression=%s' % comptypes)
140
139
141 return caps
140 return caps
142
141
143 def checkperm(self, perm):
142 def checkperm(self, perm):
144 return self._checkperm(perm)
143 return self._checkperm(perm)
145
144
146 # This method exists mostly so that extensions like remotefilelog can
145 # This method exists mostly so that extensions like remotefilelog can
147 # disable a kludgey legacy method only over http. As of early 2018,
146 # disable a kludgey legacy method only over http. As of early 2018,
148 # there are no other known users, so with any luck we can discard this
147 # there are no other known users, so with any luck we can discard this
149 # hook if remotefilelog becomes a first-party extension.
148 # hook if remotefilelog becomes a first-party extension.
150 def iscmd(cmd):
149 def iscmd(cmd):
151 return cmd in wireproto.commands
150 return cmd in wireproto.commands
152
151
153 def handlewsgirequest(rctx, wsgireq, req, checkperm):
152 def handlewsgirequest(rctx, wsgireq, req, checkperm):
154 """Possibly process a wire protocol request.
153 """Possibly process a wire protocol request.
155
154
156 If the current request is a wire protocol request, the request is
155 If the current request is a wire protocol request, the request is
157 processed by this function.
156 processed by this function.
158
157
159 ``wsgireq`` is a ``wsgirequest`` instance.
158 ``wsgireq`` is a ``wsgirequest`` instance.
160 ``req`` is a ``parsedrequest`` instance.
159 ``req`` is a ``parsedrequest`` instance.
161
160
162 Returns a 2-tuple of (bool, response) where the 1st element indicates
161 Returns a 2-tuple of (bool, response) where the 1st element indicates
163 whether the request was handled and the 2nd element is a return
162 whether the request was handled and the 2nd element is a return
164 value for a WSGI application (often a generator of bytes).
163 value for a WSGI application (often a generator of bytes).
165 """
164 """
166 # Avoid cycle involving hg module.
165 # Avoid cycle involving hg module.
167 from .hgweb import common as hgwebcommon
166 from .hgweb import common as hgwebcommon
168
167
169 repo = rctx.repo
168 repo = rctx.repo
170
169
171 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
170 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
172 # string parameter. If it isn't present, this isn't a wire protocol
171 # string parameter. If it isn't present, this isn't a wire protocol
173 # request.
172 # request.
174 if 'cmd' not in req.querystringdict:
173 if 'cmd' not in req.querystringdict:
175 return False, None
174 return False, None
176
175
177 cmd = req.querystringdict['cmd'][0]
176 cmd = req.querystringdict['cmd'][0]
178
177
179 # The "cmd" request parameter is used by both the wire protocol and hgweb.
178 # The "cmd" request parameter is used by both the wire protocol and hgweb.
180 # While not all wire protocol commands are available for all transports,
179 # While not all wire protocol commands are available for all transports,
181 # if we see a "cmd" value that resembles a known wire protocol command, we
180 # if we see a "cmd" value that resembles a known wire protocol command, we
182 # route it to a protocol handler. This is better than routing possible
181 # route it to a protocol handler. This is better than routing possible
183 # wire protocol requests to hgweb because it prevents hgweb from using
182 # wire protocol requests to hgweb because it prevents hgweb from using
184 # known wire protocol commands and it is less confusing for machine
183 # known wire protocol commands and it is less confusing for machine
185 # clients.
184 # clients.
186 if not iscmd(cmd):
185 if not iscmd(cmd):
187 return False, None
186 return False, None
188
187
189 # The "cmd" query string argument is only valid on the root path of the
188 # The "cmd" query string argument is only valid on the root path of the
190 # repo. e.g. ``/?cmd=foo``, ``/repo?cmd=foo``. URL paths within the repo
189 # repo. e.g. ``/?cmd=foo``, ``/repo?cmd=foo``. URL paths within the repo
191 # like ``/blah?cmd=foo`` are not allowed. So don't recognize the request
190 # like ``/blah?cmd=foo`` are not allowed. So don't recognize the request
192 # in this case. We send an HTTP 404 for backwards compatibility reasons.
191 # in this case. We send an HTTP 404 for backwards compatibility reasons.
193 if req.dispatchpath:
192 if req.dispatchpath:
194 res = _handlehttperror(
193 res = _handlehttperror(
195 hgwebcommon.ErrorResponse(hgwebcommon.HTTP_NOT_FOUND), wsgireq,
194 hgwebcommon.ErrorResponse(hgwebcommon.HTTP_NOT_FOUND), wsgireq,
196 req, cmd)
195 req, cmd)
197
196
198 return True, res
197 return True, res
199
198
200 proto = httpv1protocolhandler(wsgireq, req, repo.ui,
199 proto = httpv1protocolhandler(wsgireq, req, repo.ui,
201 lambda perm: checkperm(rctx, wsgireq, perm))
200 lambda perm: checkperm(rctx, wsgireq, perm))
202
201
203 # The permissions checker should be the only thing that can raise an
202 # The permissions checker should be the only thing that can raise an
204 # ErrorResponse. It is kind of a layer violation to catch an hgweb
203 # ErrorResponse. It is kind of a layer violation to catch an hgweb
205 # exception here. So consider refactoring into a exception type that
204 # exception here. So consider refactoring into a exception type that
206 # is associated with the wire protocol.
205 # is associated with the wire protocol.
207 try:
206 try:
208 res = _callhttp(repo, wsgireq, req, proto, cmd)
207 res = _callhttp(repo, wsgireq, req, proto, cmd)
209 except hgwebcommon.ErrorResponse as e:
208 except hgwebcommon.ErrorResponse as e:
210 res = _handlehttperror(e, wsgireq, req, cmd)
209 res = _handlehttperror(e, wsgireq, req, cmd)
211
210
212 return True, res
211 return True, res
213
212
214 def _httpresponsetype(ui, req, prefer_uncompressed):
213 def _httpresponsetype(ui, req, prefer_uncompressed):
215 """Determine the appropriate response type and compression settings.
214 """Determine the appropriate response type and compression settings.
216
215
217 Returns a tuple of (mediatype, compengine, engineopts).
216 Returns a tuple of (mediatype, compengine, engineopts).
218 """
217 """
219 # Determine the response media type and compression engine based
218 # Determine the response media type and compression engine based
220 # on the request parameters.
219 # on the request parameters.
221 protocaps = decodevaluefromheaders(req, 'X-HgProto').split(' ')
220 protocaps = decodevaluefromheaders(req, 'X-HgProto').split(' ')
222
221
223 if '0.2' in protocaps:
222 if '0.2' in protocaps:
224 # All clients are expected to support uncompressed data.
223 # All clients are expected to support uncompressed data.
225 if prefer_uncompressed:
224 if prefer_uncompressed:
226 return HGTYPE2, util._noopengine(), {}
225 return HGTYPE2, util._noopengine(), {}
227
226
228 # Default as defined by wire protocol spec.
227 # Default as defined by wire protocol spec.
229 compformats = ['zlib', 'none']
228 compformats = ['zlib', 'none']
230 for cap in protocaps:
229 for cap in protocaps:
231 if cap.startswith('comp='):
230 if cap.startswith('comp='):
232 compformats = cap[5:].split(',')
231 compformats = cap[5:].split(',')
233 break
232 break
234
233
235 # Now find an agreed upon compression format.
234 # Now find an agreed upon compression format.
236 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
235 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
237 if engine.wireprotosupport().name in compformats:
236 if engine.wireprotosupport().name in compformats:
238 opts = {}
237 opts = {}
239 level = ui.configint('server', '%slevel' % engine.name())
238 level = ui.configint('server', '%slevel' % engine.name())
240 if level is not None:
239 if level is not None:
241 opts['level'] = level
240 opts['level'] = level
242
241
243 return HGTYPE2, engine, opts
242 return HGTYPE2, engine, opts
244
243
245 # No mutually supported compression format. Fall back to the
244 # No mutually supported compression format. Fall back to the
246 # legacy protocol.
245 # legacy protocol.
247
246
248 # Don't allow untrusted settings because disabling compression or
247 # Don't allow untrusted settings because disabling compression or
249 # setting a very high compression level could lead to flooding
248 # setting a very high compression level could lead to flooding
250 # the server's network or CPU.
249 # the server's network or CPU.
251 opts = {'level': ui.configint('server', 'zliblevel')}
250 opts = {'level': ui.configint('server', 'zliblevel')}
252 return HGTYPE, util.compengines['zlib'], opts
251 return HGTYPE, util.compengines['zlib'], opts
253
252
254 def _callhttp(repo, wsgireq, req, proto, cmd):
253 def _callhttp(repo, wsgireq, req, proto, cmd):
255 def genversion2(gen, engine, engineopts):
254 def genversion2(gen, engine, engineopts):
256 # application/mercurial-0.2 always sends a payload header
255 # application/mercurial-0.2 always sends a payload header
257 # identifying the compression engine.
256 # identifying the compression engine.
258 name = engine.wireprotosupport().name
257 name = engine.wireprotosupport().name
259 assert 0 < len(name) < 256
258 assert 0 < len(name) < 256
260 yield struct.pack('B', len(name))
259 yield struct.pack('B', len(name))
261 yield name
260 yield name
262
261
263 for chunk in gen:
262 for chunk in gen:
264 yield chunk
263 yield chunk
265
264
266 if not wireproto.commands.commandavailable(cmd, proto):
265 if not wireproto.commands.commandavailable(cmd, proto):
267 wsgireq.respond(HTTP_OK, HGERRTYPE,
266 wsgireq.respond(HTTP_OK, HGERRTYPE,
268 body=_('requested wire protocol command is not '
267 body=_('requested wire protocol command is not '
269 'available over HTTP'))
268 'available over HTTP'))
270 return []
269 return []
271
270
272 proto.checkperm(wireproto.commands[cmd].permission)
271 proto.checkperm(wireproto.commands[cmd].permission)
273
272
274 rsp = wireproto.dispatch(repo, proto, cmd)
273 rsp = wireproto.dispatch(repo, proto, cmd)
275
274
276 if isinstance(rsp, bytes):
275 if isinstance(rsp, bytes):
277 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
276 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
278 return []
277 return []
279 elif isinstance(rsp, wireprototypes.bytesresponse):
278 elif isinstance(rsp, wireprototypes.bytesresponse):
280 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp.data)
279 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp.data)
281 return []
280 return []
282 elif isinstance(rsp, wireprototypes.streamreslegacy):
281 elif isinstance(rsp, wireprototypes.streamreslegacy):
283 gen = rsp.gen
282 gen = rsp.gen
284 wsgireq.respond(HTTP_OK, HGTYPE)
283 wsgireq.respond(HTTP_OK, HGTYPE)
285 return gen
284 return gen
286 elif isinstance(rsp, wireprototypes.streamres):
285 elif isinstance(rsp, wireprototypes.streamres):
287 gen = rsp.gen
286 gen = rsp.gen
288
287
289 # This code for compression should not be streamres specific. It
288 # This code for compression should not be streamres specific. It
290 # is here because we only compress streamres at the moment.
289 # is here because we only compress streamres at the moment.
291 mediatype, engine, engineopts = _httpresponsetype(
290 mediatype, engine, engineopts = _httpresponsetype(
292 repo.ui, req, rsp.prefer_uncompressed)
291 repo.ui, req, rsp.prefer_uncompressed)
293 gen = engine.compressstream(gen, engineopts)
292 gen = engine.compressstream(gen, engineopts)
294
293
295 if mediatype == HGTYPE2:
294 if mediatype == HGTYPE2:
296 gen = genversion2(gen, engine, engineopts)
295 gen = genversion2(gen, engine, engineopts)
297
296
298 wsgireq.respond(HTTP_OK, mediatype)
297 wsgireq.respond(HTTP_OK, mediatype)
299 return gen
298 return gen
300 elif isinstance(rsp, wireprototypes.pushres):
299 elif isinstance(rsp, wireprototypes.pushres):
301 rsp = '%d\n%s' % (rsp.res, rsp.output)
300 rsp = '%d\n%s' % (rsp.res, rsp.output)
302 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
301 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
303 return []
302 return []
304 elif isinstance(rsp, wireprototypes.pusherr):
303 elif isinstance(rsp, wireprototypes.pusherr):
305 # This is the httplib workaround documented in _handlehttperror().
304 # This is the httplib workaround documented in _handlehttperror().
306 wsgireq.drain()
305 wsgireq.drain()
307
306
308 rsp = '0\n%s\n' % rsp.res
307 rsp = '0\n%s\n' % rsp.res
309 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
308 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
310 return []
309 return []
311 elif isinstance(rsp, wireprototypes.ooberror):
310 elif isinstance(rsp, wireprototypes.ooberror):
312 rsp = rsp.message
311 rsp = rsp.message
313 wsgireq.respond(HTTP_OK, HGERRTYPE, body=rsp)
312 wsgireq.respond(HTTP_OK, HGERRTYPE, body=rsp)
314 return []
313 return []
315 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
314 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
316
315
317 def _handlehttperror(e, wsgireq, req, cmd):
316 def _handlehttperror(e, wsgireq, req, cmd):
318 """Called when an ErrorResponse is raised during HTTP request processing."""
317 """Called when an ErrorResponse is raised during HTTP request processing."""
319
318
320 # Clients using Python's httplib are stateful: the HTTP client
319 # Clients using Python's httplib are stateful: the HTTP client
321 # won't process an HTTP response until all request data is
320 # won't process an HTTP response until all request data is
322 # sent to the server. The intent of this code is to ensure
321 # sent to the server. The intent of this code is to ensure
323 # we always read HTTP request data from the client, thus
322 # we always read HTTP request data from the client, thus
324 # ensuring httplib transitions to a state that allows it to read
323 # ensuring httplib transitions to a state that allows it to read
325 # the HTTP response. In other words, it helps prevent deadlocks
324 # the HTTP response. In other words, it helps prevent deadlocks
326 # on clients using httplib.
325 # on clients using httplib.
327
326
328 if (wsgireq.env[r'REQUEST_METHOD'] == r'POST' and
327 if (wsgireq.env[r'REQUEST_METHOD'] == r'POST' and
329 # But not if Expect: 100-continue is being used.
328 # But not if Expect: 100-continue is being used.
330 (req.headers.get('Expect', '').lower() != '100-continue')):
329 (req.headers.get('Expect', '').lower() != '100-continue')):
331 wsgireq.drain()
330 wsgireq.drain()
332 else:
331 else:
333 wsgireq.headers.append((r'Connection', r'Close'))
332 wsgireq.headers.append((r'Connection', r'Close'))
334
333
335 # TODO This response body assumes the failed command was
334 # TODO This response body assumes the failed command was
336 # "unbundle." That assumption is not always valid.
335 # "unbundle." That assumption is not always valid.
337 wsgireq.respond(e, HGTYPE, body='0\n%s\n' % pycompat.bytestr(e))
336 wsgireq.respond(e, HGTYPE, body='0\n%s\n' % pycompat.bytestr(e))
338
337
339 return ''
338 return ''
340
339
341 def _sshv1respondbytes(fout, value):
340 def _sshv1respondbytes(fout, value):
342 """Send a bytes response for protocol version 1."""
341 """Send a bytes response for protocol version 1."""
343 fout.write('%d\n' % len(value))
342 fout.write('%d\n' % len(value))
344 fout.write(value)
343 fout.write(value)
345 fout.flush()
344 fout.flush()
346
345
347 def _sshv1respondstream(fout, source):
346 def _sshv1respondstream(fout, source):
348 write = fout.write
347 write = fout.write
349 for chunk in source.gen:
348 for chunk in source.gen:
350 write(chunk)
349 write(chunk)
351 fout.flush()
350 fout.flush()
352
351
353 def _sshv1respondooberror(fout, ferr, rsp):
352 def _sshv1respondooberror(fout, ferr, rsp):
354 ferr.write(b'%s\n-\n' % rsp)
353 ferr.write(b'%s\n-\n' % rsp)
355 ferr.flush()
354 ferr.flush()
356 fout.write(b'\n')
355 fout.write(b'\n')
357 fout.flush()
356 fout.flush()
358
357
359 class sshv1protocolhandler(wireprototypes.baseprotocolhandler):
358 class sshv1protocolhandler(wireprototypes.baseprotocolhandler):
360 """Handler for requests services via version 1 of SSH protocol."""
359 """Handler for requests services via version 1 of SSH protocol."""
361 def __init__(self, ui, fin, fout):
360 def __init__(self, ui, fin, fout):
362 self._ui = ui
361 self._ui = ui
363 self._fin = fin
362 self._fin = fin
364 self._fout = fout
363 self._fout = fout
365
364
366 @property
365 @property
367 def name(self):
366 def name(self):
368 return wireprototypes.SSHV1
367 return wireprototypes.SSHV1
369
368
370 def getargs(self, args):
369 def getargs(self, args):
371 data = {}
370 data = {}
372 keys = args.split()
371 keys = args.split()
373 for n in xrange(len(keys)):
372 for n in xrange(len(keys)):
374 argline = self._fin.readline()[:-1]
373 argline = self._fin.readline()[:-1]
375 arg, l = argline.split()
374 arg, l = argline.split()
376 if arg not in keys:
375 if arg not in keys:
377 raise error.Abort(_("unexpected parameter %r") % arg)
376 raise error.Abort(_("unexpected parameter %r") % arg)
378 if arg == '*':
377 if arg == '*':
379 star = {}
378 star = {}
380 for k in xrange(int(l)):
379 for k in xrange(int(l)):
381 argline = self._fin.readline()[:-1]
380 argline = self._fin.readline()[:-1]
382 arg, l = argline.split()
381 arg, l = argline.split()
383 val = self._fin.read(int(l))
382 val = self._fin.read(int(l))
384 star[arg] = val
383 star[arg] = val
385 data['*'] = star
384 data['*'] = star
386 else:
385 else:
387 val = self._fin.read(int(l))
386 val = self._fin.read(int(l))
388 data[arg] = val
387 data[arg] = val
389 return [data[k] for k in keys]
388 return [data[k] for k in keys]
390
389
391 def forwardpayload(self, fpout):
390 def forwardpayload(self, fpout):
392 # We initially send an empty response. This tells the client it is
391 # We initially send an empty response. This tells the client it is
393 # OK to start sending data. If a client sees any other response, it
392 # OK to start sending data. If a client sees any other response, it
394 # interprets it as an error.
393 # interprets it as an error.
395 _sshv1respondbytes(self._fout, b'')
394 _sshv1respondbytes(self._fout, b'')
396
395
397 # The file is in the form:
396 # The file is in the form:
398 #
397 #
399 # <chunk size>\n<chunk>
398 # <chunk size>\n<chunk>
400 # ...
399 # ...
401 # 0\n
400 # 0\n
402 count = int(self._fin.readline())
401 count = int(self._fin.readline())
403 while count:
402 while count:
404 fpout.write(self._fin.read(count))
403 fpout.write(self._fin.read(count))
405 count = int(self._fin.readline())
404 count = int(self._fin.readline())
406
405
407 @contextlib.contextmanager
406 @contextlib.contextmanager
408 def mayberedirectstdio(self):
407 def mayberedirectstdio(self):
409 yield None
408 yield None
410
409
411 def client(self):
410 def client(self):
412 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
411 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
413 return 'remote:ssh:' + client
412 return 'remote:ssh:' + client
414
413
415 def addcapabilities(self, repo, caps):
414 def addcapabilities(self, repo, caps):
416 return caps
415 return caps
417
416
418 def checkperm(self, perm):
417 def checkperm(self, perm):
419 pass
418 pass
420
419
421 class sshv2protocolhandler(sshv1protocolhandler):
420 class sshv2protocolhandler(sshv1protocolhandler):
422 """Protocol handler for version 2 of the SSH protocol."""
421 """Protocol handler for version 2 of the SSH protocol."""
423
422
424 @property
423 @property
425 def name(self):
424 def name(self):
426 return wireprototypes.SSHV2
425 return wireprototypes.SSHV2
427
426
428 def _runsshserver(ui, repo, fin, fout, ev):
427 def _runsshserver(ui, repo, fin, fout, ev):
429 # This function operates like a state machine of sorts. The following
428 # This function operates like a state machine of sorts. The following
430 # states are defined:
429 # states are defined:
431 #
430 #
432 # protov1-serving
431 # protov1-serving
433 # Server is in protocol version 1 serving mode. Commands arrive on
432 # Server is in protocol version 1 serving mode. Commands arrive on
434 # new lines. These commands are processed in this state, one command
433 # new lines. These commands are processed in this state, one command
435 # after the other.
434 # after the other.
436 #
435 #
437 # protov2-serving
436 # protov2-serving
438 # Server is in protocol version 2 serving mode.
437 # Server is in protocol version 2 serving mode.
439 #
438 #
440 # upgrade-initial
439 # upgrade-initial
441 # The server is going to process an upgrade request.
440 # The server is going to process an upgrade request.
442 #
441 #
443 # upgrade-v2-filter-legacy-handshake
442 # upgrade-v2-filter-legacy-handshake
444 # The protocol is being upgraded to version 2. The server is expecting
443 # The protocol is being upgraded to version 2. The server is expecting
445 # the legacy handshake from version 1.
444 # the legacy handshake from version 1.
446 #
445 #
447 # upgrade-v2-finish
446 # upgrade-v2-finish
448 # The upgrade to version 2 of the protocol is imminent.
447 # The upgrade to version 2 of the protocol is imminent.
449 #
448 #
450 # shutdown
449 # shutdown
451 # The server is shutting down, possibly in reaction to a client event.
450 # The server is shutting down, possibly in reaction to a client event.
452 #
451 #
453 # And here are their transitions:
452 # And here are their transitions:
454 #
453 #
455 # protov1-serving -> shutdown
454 # protov1-serving -> shutdown
456 # When server receives an empty request or encounters another
455 # When server receives an empty request or encounters another
457 # error.
456 # error.
458 #
457 #
459 # protov1-serving -> upgrade-initial
458 # protov1-serving -> upgrade-initial
460 # An upgrade request line was seen.
459 # An upgrade request line was seen.
461 #
460 #
462 # upgrade-initial -> upgrade-v2-filter-legacy-handshake
461 # upgrade-initial -> upgrade-v2-filter-legacy-handshake
463 # Upgrade to version 2 in progress. Server is expecting to
462 # Upgrade to version 2 in progress. Server is expecting to
464 # process a legacy handshake.
463 # process a legacy handshake.
465 #
464 #
466 # upgrade-v2-filter-legacy-handshake -> shutdown
465 # upgrade-v2-filter-legacy-handshake -> shutdown
467 # Client did not fulfill upgrade handshake requirements.
466 # Client did not fulfill upgrade handshake requirements.
468 #
467 #
469 # upgrade-v2-filter-legacy-handshake -> upgrade-v2-finish
468 # upgrade-v2-filter-legacy-handshake -> upgrade-v2-finish
470 # Client fulfilled version 2 upgrade requirements. Finishing that
469 # Client fulfilled version 2 upgrade requirements. Finishing that
471 # upgrade.
470 # upgrade.
472 #
471 #
473 # upgrade-v2-finish -> protov2-serving
472 # upgrade-v2-finish -> protov2-serving
474 # Protocol upgrade to version 2 complete. Server can now speak protocol
473 # Protocol upgrade to version 2 complete. Server can now speak protocol
475 # version 2.
474 # version 2.
476 #
475 #
477 # protov2-serving -> protov1-serving
476 # protov2-serving -> protov1-serving
478 # Ths happens by default since protocol version 2 is the same as
477 # Ths happens by default since protocol version 2 is the same as
479 # version 1 except for the handshake.
478 # version 1 except for the handshake.
480
479
481 state = 'protov1-serving'
480 state = 'protov1-serving'
482 proto = sshv1protocolhandler(ui, fin, fout)
481 proto = sshv1protocolhandler(ui, fin, fout)
483 protoswitched = False
482 protoswitched = False
484
483
485 while not ev.is_set():
484 while not ev.is_set():
486 if state == 'protov1-serving':
485 if state == 'protov1-serving':
487 # Commands are issued on new lines.
486 # Commands are issued on new lines.
488 request = fin.readline()[:-1]
487 request = fin.readline()[:-1]
489
488
490 # Empty lines signal to terminate the connection.
489 # Empty lines signal to terminate the connection.
491 if not request:
490 if not request:
492 state = 'shutdown'
491 state = 'shutdown'
493 continue
492 continue
494
493
495 # It looks like a protocol upgrade request. Transition state to
494 # It looks like a protocol upgrade request. Transition state to
496 # handle it.
495 # handle it.
497 if request.startswith(b'upgrade '):
496 if request.startswith(b'upgrade '):
498 if protoswitched:
497 if protoswitched:
499 _sshv1respondooberror(fout, ui.ferr,
498 _sshv1respondooberror(fout, ui.ferr,
500 b'cannot upgrade protocols multiple '
499 b'cannot upgrade protocols multiple '
501 b'times')
500 b'times')
502 state = 'shutdown'
501 state = 'shutdown'
503 continue
502 continue
504
503
505 state = 'upgrade-initial'
504 state = 'upgrade-initial'
506 continue
505 continue
507
506
508 available = wireproto.commands.commandavailable(request, proto)
507 available = wireproto.commands.commandavailable(request, proto)
509
508
510 # This command isn't available. Send an empty response and go
509 # This command isn't available. Send an empty response and go
511 # back to waiting for a new command.
510 # back to waiting for a new command.
512 if not available:
511 if not available:
513 _sshv1respondbytes(fout, b'')
512 _sshv1respondbytes(fout, b'')
514 continue
513 continue
515
514
516 rsp = wireproto.dispatch(repo, proto, request)
515 rsp = wireproto.dispatch(repo, proto, request)
517
516
518 if isinstance(rsp, bytes):
517 if isinstance(rsp, bytes):
519 _sshv1respondbytes(fout, rsp)
518 _sshv1respondbytes(fout, rsp)
520 elif isinstance(rsp, wireprototypes.bytesresponse):
519 elif isinstance(rsp, wireprototypes.bytesresponse):
521 _sshv1respondbytes(fout, rsp.data)
520 _sshv1respondbytes(fout, rsp.data)
522 elif isinstance(rsp, wireprototypes.streamres):
521 elif isinstance(rsp, wireprototypes.streamres):
523 _sshv1respondstream(fout, rsp)
522 _sshv1respondstream(fout, rsp)
524 elif isinstance(rsp, wireprototypes.streamreslegacy):
523 elif isinstance(rsp, wireprototypes.streamreslegacy):
525 _sshv1respondstream(fout, rsp)
524 _sshv1respondstream(fout, rsp)
526 elif isinstance(rsp, wireprototypes.pushres):
525 elif isinstance(rsp, wireprototypes.pushres):
527 _sshv1respondbytes(fout, b'')
526 _sshv1respondbytes(fout, b'')
528 _sshv1respondbytes(fout, b'%d' % rsp.res)
527 _sshv1respondbytes(fout, b'%d' % rsp.res)
529 elif isinstance(rsp, wireprototypes.pusherr):
528 elif isinstance(rsp, wireprototypes.pusherr):
530 _sshv1respondbytes(fout, rsp.res)
529 _sshv1respondbytes(fout, rsp.res)
531 elif isinstance(rsp, wireprototypes.ooberror):
530 elif isinstance(rsp, wireprototypes.ooberror):
532 _sshv1respondooberror(fout, ui.ferr, rsp.message)
531 _sshv1respondooberror(fout, ui.ferr, rsp.message)
533 else:
532 else:
534 raise error.ProgrammingError('unhandled response type from '
533 raise error.ProgrammingError('unhandled response type from '
535 'wire protocol command: %s' % rsp)
534 'wire protocol command: %s' % rsp)
536
535
537 # For now, protocol version 2 serving just goes back to version 1.
536 # For now, protocol version 2 serving just goes back to version 1.
538 elif state == 'protov2-serving':
537 elif state == 'protov2-serving':
539 state = 'protov1-serving'
538 state = 'protov1-serving'
540 continue
539 continue
541
540
542 elif state == 'upgrade-initial':
541 elif state == 'upgrade-initial':
543 # We should never transition into this state if we've switched
542 # We should never transition into this state if we've switched
544 # protocols.
543 # protocols.
545 assert not protoswitched
544 assert not protoswitched
546 assert proto.name == wireprototypes.SSHV1
545 assert proto.name == wireprototypes.SSHV1
547
546
548 # Expected: upgrade <token> <capabilities>
547 # Expected: upgrade <token> <capabilities>
549 # If we get something else, the request is malformed. It could be
548 # If we get something else, the request is malformed. It could be
550 # from a future client that has altered the upgrade line content.
549 # from a future client that has altered the upgrade line content.
551 # We treat this as an unknown command.
550 # We treat this as an unknown command.
552 try:
551 try:
553 token, caps = request.split(b' ')[1:]
552 token, caps = request.split(b' ')[1:]
554 except ValueError:
553 except ValueError:
555 _sshv1respondbytes(fout, b'')
554 _sshv1respondbytes(fout, b'')
556 state = 'protov1-serving'
555 state = 'protov1-serving'
557 continue
556 continue
558
557
559 # Send empty response if we don't support upgrading protocols.
558 # Send empty response if we don't support upgrading protocols.
560 if not ui.configbool('experimental', 'sshserver.support-v2'):
559 if not ui.configbool('experimental', 'sshserver.support-v2'):
561 _sshv1respondbytes(fout, b'')
560 _sshv1respondbytes(fout, b'')
562 state = 'protov1-serving'
561 state = 'protov1-serving'
563 continue
562 continue
564
563
565 try:
564 try:
566 caps = urlreq.parseqs(caps)
565 caps = urlreq.parseqs(caps)
567 except ValueError:
566 except ValueError:
568 _sshv1respondbytes(fout, b'')
567 _sshv1respondbytes(fout, b'')
569 state = 'protov1-serving'
568 state = 'protov1-serving'
570 continue
569 continue
571
570
572 # We don't see an upgrade request to protocol version 2. Ignore
571 # We don't see an upgrade request to protocol version 2. Ignore
573 # the upgrade request.
572 # the upgrade request.
574 wantedprotos = caps.get(b'proto', [b''])[0]
573 wantedprotos = caps.get(b'proto', [b''])[0]
575 if SSHV2 not in wantedprotos:
574 if SSHV2 not in wantedprotos:
576 _sshv1respondbytes(fout, b'')
575 _sshv1respondbytes(fout, b'')
577 state = 'protov1-serving'
576 state = 'protov1-serving'
578 continue
577 continue
579
578
580 # It looks like we can honor this upgrade request to protocol 2.
579 # It looks like we can honor this upgrade request to protocol 2.
581 # Filter the rest of the handshake protocol request lines.
580 # Filter the rest of the handshake protocol request lines.
582 state = 'upgrade-v2-filter-legacy-handshake'
581 state = 'upgrade-v2-filter-legacy-handshake'
583 continue
582 continue
584
583
585 elif state == 'upgrade-v2-filter-legacy-handshake':
584 elif state == 'upgrade-v2-filter-legacy-handshake':
586 # Client should have sent legacy handshake after an ``upgrade``
585 # Client should have sent legacy handshake after an ``upgrade``
587 # request. Expected lines:
586 # request. Expected lines:
588 #
587 #
589 # hello
588 # hello
590 # between
589 # between
591 # pairs 81
590 # pairs 81
592 # 0000...-0000...
591 # 0000...-0000...
593
592
594 ok = True
593 ok = True
595 for line in (b'hello', b'between', b'pairs 81'):
594 for line in (b'hello', b'between', b'pairs 81'):
596 request = fin.readline()[:-1]
595 request = fin.readline()[:-1]
597
596
598 if request != line:
597 if request != line:
599 _sshv1respondooberror(fout, ui.ferr,
598 _sshv1respondooberror(fout, ui.ferr,
600 b'malformed handshake protocol: '
599 b'malformed handshake protocol: '
601 b'missing %s' % line)
600 b'missing %s' % line)
602 ok = False
601 ok = False
603 state = 'shutdown'
602 state = 'shutdown'
604 break
603 break
605
604
606 if not ok:
605 if not ok:
607 continue
606 continue
608
607
609 request = fin.read(81)
608 request = fin.read(81)
610 if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
609 if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
611 _sshv1respondooberror(fout, ui.ferr,
610 _sshv1respondooberror(fout, ui.ferr,
612 b'malformed handshake protocol: '
611 b'malformed handshake protocol: '
613 b'missing between argument value')
612 b'missing between argument value')
614 state = 'shutdown'
613 state = 'shutdown'
615 continue
614 continue
616
615
617 state = 'upgrade-v2-finish'
616 state = 'upgrade-v2-finish'
618 continue
617 continue
619
618
620 elif state == 'upgrade-v2-finish':
619 elif state == 'upgrade-v2-finish':
621 # Send the upgrade response.
620 # Send the upgrade response.
622 fout.write(b'upgraded %s %s\n' % (token, SSHV2))
621 fout.write(b'upgraded %s %s\n' % (token, SSHV2))
623 servercaps = wireproto.capabilities(repo, proto)
622 servercaps = wireproto.capabilities(repo, proto)
624 rsp = b'capabilities: %s' % servercaps.data
623 rsp = b'capabilities: %s' % servercaps.data
625 fout.write(b'%d\n%s\n' % (len(rsp), rsp))
624 fout.write(b'%d\n%s\n' % (len(rsp), rsp))
626 fout.flush()
625 fout.flush()
627
626
628 proto = sshv2protocolhandler(ui, fin, fout)
627 proto = sshv2protocolhandler(ui, fin, fout)
629 protoswitched = True
628 protoswitched = True
630
629
631 state = 'protov2-serving'
630 state = 'protov2-serving'
632 continue
631 continue
633
632
634 elif state == 'shutdown':
633 elif state == 'shutdown':
635 break
634 break
636
635
637 else:
636 else:
638 raise error.ProgrammingError('unhandled ssh server state: %s' %
637 raise error.ProgrammingError('unhandled ssh server state: %s' %
639 state)
638 state)
640
639
641 class sshserver(object):
640 class sshserver(object):
642 def __init__(self, ui, repo, logfh=None):
641 def __init__(self, ui, repo, logfh=None):
643 self._ui = ui
642 self._ui = ui
644 self._repo = repo
643 self._repo = repo
645 self._fin = ui.fin
644 self._fin = ui.fin
646 self._fout = ui.fout
645 self._fout = ui.fout
647
646
648 # Log write I/O to stdout and stderr if configured.
647 # Log write I/O to stdout and stderr if configured.
649 if logfh:
648 if logfh:
650 self._fout = util.makeloggingfileobject(
649 self._fout = util.makeloggingfileobject(
651 logfh, self._fout, 'o', logdata=True)
650 logfh, self._fout, 'o', logdata=True)
652 ui.ferr = util.makeloggingfileobject(
651 ui.ferr = util.makeloggingfileobject(
653 logfh, ui.ferr, 'e', logdata=True)
652 logfh, ui.ferr, 'e', logdata=True)
654
653
655 hook.redirect(True)
654 hook.redirect(True)
656 ui.fout = repo.ui.fout = ui.ferr
655 ui.fout = repo.ui.fout = ui.ferr
657
656
658 # Prevent insertion/deletion of CRs
657 # Prevent insertion/deletion of CRs
659 util.setbinary(self._fin)
658 util.setbinary(self._fin)
660 util.setbinary(self._fout)
659 util.setbinary(self._fout)
661
660
662 def serve_forever(self):
661 def serve_forever(self):
663 self.serveuntil(threading.Event())
662 self.serveuntil(threading.Event())
664 sys.exit(0)
663 sys.exit(0)
665
664
666 def serveuntil(self, ev):
665 def serveuntil(self, ev):
667 """Serve until a threading.Event is set."""
666 """Serve until a threading.Event is set."""
668 _runsshserver(self._ui, self._repo, self._fin, self._fout, ev)
667 _runsshserver(self._ui, self._repo, self._fin, self._fout, ev)
General Comments 0
You need to be logged in to leave comments. Login now