##// END OF EJS Templates
hgweb: refactor the request draining code...
Gregory Szorc -
r36871:2cdf47e1 default
parent child Browse files
Show More
@@ -1,320 +1,361
1 # hgweb/request.py - An http request from either CGI or the standalone server.
1 # hgweb/request.py - An http request from either CGI or the standalone server.
2 #
2 #
3 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
3 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
4 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
4 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 from __future__ import absolute_import
9 from __future__ import absolute_import
10
10
11 import cgi
11 import cgi
12 import errno
12 import errno
13 import socket
13 import socket
14 import wsgiref.headers as wsgiheaders
14 import wsgiref.headers as wsgiheaders
15 #import wsgiref.validate
15 #import wsgiref.validate
16
16
17 from .common import (
17 from .common import (
18 ErrorResponse,
18 ErrorResponse,
19 HTTP_NOT_MODIFIED,
19 HTTP_NOT_MODIFIED,
20 statusmessage,
20 statusmessage,
21 )
21 )
22
22
23 from ..thirdparty import (
23 from ..thirdparty import (
24 attr,
24 attr,
25 )
25 )
26 from .. import (
26 from .. import (
27 pycompat,
27 pycompat,
28 util,
28 util,
29 )
29 )
30
30
31 shortcuts = {
31 shortcuts = {
32 'cl': [('cmd', ['changelog']), ('rev', None)],
32 'cl': [('cmd', ['changelog']), ('rev', None)],
33 'sl': [('cmd', ['shortlog']), ('rev', None)],
33 'sl': [('cmd', ['shortlog']), ('rev', None)],
34 'cs': [('cmd', ['changeset']), ('node', None)],
34 'cs': [('cmd', ['changeset']), ('node', None)],
35 'f': [('cmd', ['file']), ('filenode', None)],
35 'f': [('cmd', ['file']), ('filenode', None)],
36 'fl': [('cmd', ['filelog']), ('filenode', None)],
36 'fl': [('cmd', ['filelog']), ('filenode', None)],
37 'fd': [('cmd', ['filediff']), ('node', None)],
37 'fd': [('cmd', ['filediff']), ('node', None)],
38 'fa': [('cmd', ['annotate']), ('filenode', None)],
38 'fa': [('cmd', ['annotate']), ('filenode', None)],
39 'mf': [('cmd', ['manifest']), ('manifest', None)],
39 'mf': [('cmd', ['manifest']), ('manifest', None)],
40 'ca': [('cmd', ['archive']), ('node', None)],
40 'ca': [('cmd', ['archive']), ('node', None)],
41 'tags': [('cmd', ['tags'])],
41 'tags': [('cmd', ['tags'])],
42 'tip': [('cmd', ['changeset']), ('node', ['tip'])],
42 'tip': [('cmd', ['changeset']), ('node', ['tip'])],
43 'static': [('cmd', ['static']), ('file', None)]
43 'static': [('cmd', ['static']), ('file', None)]
44 }
44 }
45
45
46 def normalize(form):
46 def normalize(form):
47 # first expand the shortcuts
47 # first expand the shortcuts
48 for k in shortcuts:
48 for k in shortcuts:
49 if k in form:
49 if k in form:
50 for name, value in shortcuts[k]:
50 for name, value in shortcuts[k]:
51 if value is None:
51 if value is None:
52 value = form[k]
52 value = form[k]
53 form[name] = value
53 form[name] = value
54 del form[k]
54 del form[k]
55 # And strip the values
55 # And strip the values
56 bytesform = {}
56 bytesform = {}
57 for k, v in form.iteritems():
57 for k, v in form.iteritems():
58 bytesform[pycompat.bytesurl(k)] = [
58 bytesform[pycompat.bytesurl(k)] = [
59 pycompat.bytesurl(i.strip()) for i in v]
59 pycompat.bytesurl(i.strip()) for i in v]
60 return bytesform
60 return bytesform
61
61
62 @attr.s(frozen=True)
62 @attr.s(frozen=True)
63 class parsedrequest(object):
63 class parsedrequest(object):
64 """Represents a parsed WSGI request / static HTTP request parameters."""
64 """Represents a parsed WSGI request / static HTTP request parameters."""
65
65
66 # Request method.
66 # Request method.
67 method = attr.ib()
67 method = attr.ib()
68 # Full URL for this request.
68 # Full URL for this request.
69 url = attr.ib()
69 url = attr.ib()
70 # URL without any path components. Just <proto>://<host><port>.
70 # URL without any path components. Just <proto>://<host><port>.
71 baseurl = attr.ib()
71 baseurl = attr.ib()
72 # Advertised URL. Like ``url`` and ``baseurl`` but uses SERVER_NAME instead
72 # Advertised URL. Like ``url`` and ``baseurl`` but uses SERVER_NAME instead
73 # of HTTP: Host header for hostname. This is likely what clients used.
73 # of HTTP: Host header for hostname. This is likely what clients used.
74 advertisedurl = attr.ib()
74 advertisedurl = attr.ib()
75 advertisedbaseurl = attr.ib()
75 advertisedbaseurl = attr.ib()
76 # WSGI application path.
76 # WSGI application path.
77 apppath = attr.ib()
77 apppath = attr.ib()
78 # List of path parts to be used for dispatch.
78 # List of path parts to be used for dispatch.
79 dispatchparts = attr.ib()
79 dispatchparts = attr.ib()
80 # URL path component (no query string) used for dispatch.
80 # URL path component (no query string) used for dispatch.
81 dispatchpath = attr.ib()
81 dispatchpath = attr.ib()
82 # Whether there is a path component to this request. This can be true
82 # Whether there is a path component to this request. This can be true
83 # when ``dispatchpath`` is empty due to REPO_NAME muckery.
83 # when ``dispatchpath`` is empty due to REPO_NAME muckery.
84 havepathinfo = attr.ib()
84 havepathinfo = attr.ib()
85 # Raw query string (part after "?" in URL).
85 # Raw query string (part after "?" in URL).
86 querystring = attr.ib()
86 querystring = attr.ib()
87 # List of 2-tuples of query string arguments.
87 # List of 2-tuples of query string arguments.
88 querystringlist = attr.ib()
88 querystringlist = attr.ib()
89 # Dict of query string arguments. Values are lists with at least 1 item.
89 # Dict of query string arguments. Values are lists with at least 1 item.
90 querystringdict = attr.ib()
90 querystringdict = attr.ib()
91 # wsgiref.headers.Headers instance. Operates like a dict with case
91 # wsgiref.headers.Headers instance. Operates like a dict with case
92 # insensitive keys.
92 # insensitive keys.
93 headers = attr.ib()
93 headers = attr.ib()
94
94
95 def parserequestfromenv(env):
95 def parserequestfromenv(env):
96 """Parse URL components from environment variables.
96 """Parse URL components from environment variables.
97
97
98 WSGI defines request attributes via environment variables. This function
98 WSGI defines request attributes via environment variables. This function
99 parses the environment variables into a data structure.
99 parses the environment variables into a data structure.
100 """
100 """
101 # PEP-0333 defines the WSGI spec and is a useful reference for this code.
101 # PEP-0333 defines the WSGI spec and is a useful reference for this code.
102
102
103 # We first validate that the incoming object conforms with the WSGI spec.
103 # We first validate that the incoming object conforms with the WSGI spec.
104 # We only want to be dealing with spec-conforming WSGI implementations.
104 # We only want to be dealing with spec-conforming WSGI implementations.
105 # TODO enable this once we fix internal violations.
105 # TODO enable this once we fix internal violations.
106 #wsgiref.validate.check_environ(env)
106 #wsgiref.validate.check_environ(env)
107
107
108 # PEP-0333 states that environment keys and values are native strings
108 # PEP-0333 states that environment keys and values are native strings
109 # (bytes on Python 2 and str on Python 3). The code points for the Unicode
109 # (bytes on Python 2 and str on Python 3). The code points for the Unicode
110 # strings on Python 3 must be between \00000-\000FF. We deal with bytes
110 # strings on Python 3 must be between \00000-\000FF. We deal with bytes
111 # in Mercurial, so mass convert string keys and values to bytes.
111 # in Mercurial, so mass convert string keys and values to bytes.
112 if pycompat.ispy3:
112 if pycompat.ispy3:
113 env = {k.encode('latin-1'): v for k, v in env.iteritems()}
113 env = {k.encode('latin-1'): v for k, v in env.iteritems()}
114 env = {k: v.encode('latin-1') if isinstance(v, str) else v
114 env = {k: v.encode('latin-1') if isinstance(v, str) else v
115 for k, v in env.iteritems()}
115 for k, v in env.iteritems()}
116
116
117 # https://www.python.org/dev/peps/pep-0333/#environ-variables defines
117 # https://www.python.org/dev/peps/pep-0333/#environ-variables defines
118 # the environment variables.
118 # the environment variables.
119 # https://www.python.org/dev/peps/pep-0333/#url-reconstruction defines
119 # https://www.python.org/dev/peps/pep-0333/#url-reconstruction defines
120 # how URLs are reconstructed.
120 # how URLs are reconstructed.
121 fullurl = env['wsgi.url_scheme'] + '://'
121 fullurl = env['wsgi.url_scheme'] + '://'
122 advertisedfullurl = fullurl
122 advertisedfullurl = fullurl
123
123
124 def addport(s):
124 def addport(s):
125 if env['wsgi.url_scheme'] == 'https':
125 if env['wsgi.url_scheme'] == 'https':
126 if env['SERVER_PORT'] != '443':
126 if env['SERVER_PORT'] != '443':
127 s += ':' + env['SERVER_PORT']
127 s += ':' + env['SERVER_PORT']
128 else:
128 else:
129 if env['SERVER_PORT'] != '80':
129 if env['SERVER_PORT'] != '80':
130 s += ':' + env['SERVER_PORT']
130 s += ':' + env['SERVER_PORT']
131
131
132 return s
132 return s
133
133
134 if env.get('HTTP_HOST'):
134 if env.get('HTTP_HOST'):
135 fullurl += env['HTTP_HOST']
135 fullurl += env['HTTP_HOST']
136 else:
136 else:
137 fullurl += env['SERVER_NAME']
137 fullurl += env['SERVER_NAME']
138 fullurl = addport(fullurl)
138 fullurl = addport(fullurl)
139
139
140 advertisedfullurl += env['SERVER_NAME']
140 advertisedfullurl += env['SERVER_NAME']
141 advertisedfullurl = addport(advertisedfullurl)
141 advertisedfullurl = addport(advertisedfullurl)
142
142
143 baseurl = fullurl
143 baseurl = fullurl
144 advertisedbaseurl = advertisedfullurl
144 advertisedbaseurl = advertisedfullurl
145
145
146 fullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
146 fullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
147 advertisedfullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
147 advertisedfullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
148 fullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
148 fullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
149 advertisedfullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
149 advertisedfullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
150
150
151 if env.get('QUERY_STRING'):
151 if env.get('QUERY_STRING'):
152 fullurl += '?' + env['QUERY_STRING']
152 fullurl += '?' + env['QUERY_STRING']
153 advertisedfullurl += '?' + env['QUERY_STRING']
153 advertisedfullurl += '?' + env['QUERY_STRING']
154
154
155 # When dispatching requests, we look at the URL components (PATH_INFO
155 # When dispatching requests, we look at the URL components (PATH_INFO
156 # and QUERY_STRING) after the application root (SCRIPT_NAME). But hgwebdir
156 # and QUERY_STRING) after the application root (SCRIPT_NAME). But hgwebdir
157 # has the concept of "virtual" repositories. This is defined via REPO_NAME.
157 # has the concept of "virtual" repositories. This is defined via REPO_NAME.
158 # If REPO_NAME is defined, we append it to SCRIPT_NAME to form a new app
158 # If REPO_NAME is defined, we append it to SCRIPT_NAME to form a new app
159 # root. We also exclude its path components from PATH_INFO when resolving
159 # root. We also exclude its path components from PATH_INFO when resolving
160 # the dispatch path.
160 # the dispatch path.
161
161
162 apppath = env['SCRIPT_NAME']
162 apppath = env['SCRIPT_NAME']
163
163
164 if env.get('REPO_NAME'):
164 if env.get('REPO_NAME'):
165 if not apppath.endswith('/'):
165 if not apppath.endswith('/'):
166 apppath += '/'
166 apppath += '/'
167
167
168 apppath += env.get('REPO_NAME')
168 apppath += env.get('REPO_NAME')
169
169
170 if 'PATH_INFO' in env:
170 if 'PATH_INFO' in env:
171 dispatchparts = env['PATH_INFO'].strip('/').split('/')
171 dispatchparts = env['PATH_INFO'].strip('/').split('/')
172
172
173 # Strip out repo parts.
173 # Strip out repo parts.
174 repoparts = env.get('REPO_NAME', '').split('/')
174 repoparts = env.get('REPO_NAME', '').split('/')
175 if dispatchparts[:len(repoparts)] == repoparts:
175 if dispatchparts[:len(repoparts)] == repoparts:
176 dispatchparts = dispatchparts[len(repoparts):]
176 dispatchparts = dispatchparts[len(repoparts):]
177 else:
177 else:
178 dispatchparts = []
178 dispatchparts = []
179
179
180 dispatchpath = '/'.join(dispatchparts)
180 dispatchpath = '/'.join(dispatchparts)
181
181
182 querystring = env.get('QUERY_STRING', '')
182 querystring = env.get('QUERY_STRING', '')
183
183
184 # We store as a list so we have ordering information. We also store as
184 # We store as a list so we have ordering information. We also store as
185 # a dict to facilitate fast lookup.
185 # a dict to facilitate fast lookup.
186 querystringlist = util.urlreq.parseqsl(querystring, keep_blank_values=True)
186 querystringlist = util.urlreq.parseqsl(querystring, keep_blank_values=True)
187
187
188 querystringdict = {}
188 querystringdict = {}
189 for k, v in querystringlist:
189 for k, v in querystringlist:
190 if k in querystringdict:
190 if k in querystringdict:
191 querystringdict[k].append(v)
191 querystringdict[k].append(v)
192 else:
192 else:
193 querystringdict[k] = [v]
193 querystringdict[k] = [v]
194
194
195 # HTTP_* keys contain HTTP request headers. The Headers structure should
195 # HTTP_* keys contain HTTP request headers. The Headers structure should
196 # perform case normalization for us. We just rewrite underscore to dash
196 # perform case normalization for us. We just rewrite underscore to dash
197 # so keys match what likely went over the wire.
197 # so keys match what likely went over the wire.
198 headers = []
198 headers = []
199 for k, v in env.iteritems():
199 for k, v in env.iteritems():
200 if k.startswith('HTTP_'):
200 if k.startswith('HTTP_'):
201 headers.append((k[len('HTTP_'):].replace('_', '-'), v))
201 headers.append((k[len('HTTP_'):].replace('_', '-'), v))
202
202
203 headers = wsgiheaders.Headers(headers)
203 headers = wsgiheaders.Headers(headers)
204
204
205 # This is kind of a lie because the HTTP header wasn't explicitly
205 # This is kind of a lie because the HTTP header wasn't explicitly
206 # sent. But for all intents and purposes it should be OK to lie about
206 # sent. But for all intents and purposes it should be OK to lie about
207 # this, since a consumer will either either value to determine how many
207 # this, since a consumer will either either value to determine how many
208 # bytes are available to read.
208 # bytes are available to read.
209 if 'CONTENT_LENGTH' in env and 'HTTP_CONTENT_LENGTH' not in env:
209 if 'CONTENT_LENGTH' in env and 'HTTP_CONTENT_LENGTH' not in env:
210 headers['Content-Length'] = env['CONTENT_LENGTH']
210 headers['Content-Length'] = env['CONTENT_LENGTH']
211
211
212 return parsedrequest(method=env['REQUEST_METHOD'],
212 return parsedrequest(method=env['REQUEST_METHOD'],
213 url=fullurl, baseurl=baseurl,
213 url=fullurl, baseurl=baseurl,
214 advertisedurl=advertisedfullurl,
214 advertisedurl=advertisedfullurl,
215 advertisedbaseurl=advertisedbaseurl,
215 advertisedbaseurl=advertisedbaseurl,
216 apppath=apppath,
216 apppath=apppath,
217 dispatchparts=dispatchparts, dispatchpath=dispatchpath,
217 dispatchparts=dispatchparts, dispatchpath=dispatchpath,
218 havepathinfo='PATH_INFO' in env,
218 havepathinfo='PATH_INFO' in env,
219 querystring=querystring,
219 querystring=querystring,
220 querystringlist=querystringlist,
220 querystringlist=querystringlist,
221 querystringdict=querystringdict,
221 querystringdict=querystringdict,
222 headers=headers)
222 headers=headers)
223
223
224 class wsgirequest(object):
224 class wsgirequest(object):
225 """Higher-level API for a WSGI request.
225 """Higher-level API for a WSGI request.
226
226
227 WSGI applications are invoked with 2 arguments. They are used to
227 WSGI applications are invoked with 2 arguments. They are used to
228 instantiate instances of this class, which provides higher-level APIs
228 instantiate instances of this class, which provides higher-level APIs
229 for obtaining request parameters, writing HTTP output, etc.
229 for obtaining request parameters, writing HTTP output, etc.
230 """
230 """
231 def __init__(self, wsgienv, start_response):
231 def __init__(self, wsgienv, start_response):
232 version = wsgienv[r'wsgi.version']
232 version = wsgienv[r'wsgi.version']
233 if (version < (1, 0)) or (version >= (2, 0)):
233 if (version < (1, 0)) or (version >= (2, 0)):
234 raise RuntimeError("Unknown and unsupported WSGI version %d.%d"
234 raise RuntimeError("Unknown and unsupported WSGI version %d.%d"
235 % version)
235 % version)
236 self.inp = wsgienv[r'wsgi.input']
236 self.inp = wsgienv[r'wsgi.input']
237
237
238 if r'HTTP_CONTENT_LENGTH' in wsgienv:
238 if r'HTTP_CONTENT_LENGTH' in wsgienv:
239 self.inp = util.cappedreader(self.inp,
239 self.inp = util.cappedreader(self.inp,
240 int(wsgienv[r'HTTP_CONTENT_LENGTH']))
240 int(wsgienv[r'HTTP_CONTENT_LENGTH']))
241 elif r'CONTENT_LENGTH' in wsgienv:
241 elif r'CONTENT_LENGTH' in wsgienv:
242 self.inp = util.cappedreader(self.inp,
242 self.inp = util.cappedreader(self.inp,
243 int(wsgienv[r'CONTENT_LENGTH']))
243 int(wsgienv[r'CONTENT_LENGTH']))
244
244
245 self.err = wsgienv[r'wsgi.errors']
245 self.err = wsgienv[r'wsgi.errors']
246 self.threaded = wsgienv[r'wsgi.multithread']
246 self.threaded = wsgienv[r'wsgi.multithread']
247 self.multiprocess = wsgienv[r'wsgi.multiprocess']
247 self.multiprocess = wsgienv[r'wsgi.multiprocess']
248 self.run_once = wsgienv[r'wsgi.run_once']
248 self.run_once = wsgienv[r'wsgi.run_once']
249 self.env = wsgienv
249 self.env = wsgienv
250 self.form = normalize(cgi.parse(self.inp,
250 self.form = normalize(cgi.parse(self.inp,
251 self.env,
251 self.env,
252 keep_blank_values=1))
252 keep_blank_values=1))
253 self._start_response = start_response
253 self._start_response = start_response
254 self.server_write = None
254 self.server_write = None
255 self.headers = []
255 self.headers = []
256
256
257 def drain(self):
258 '''need to read all data from request, httplib is half-duplex'''
259 length = int(self.env.get('CONTENT_LENGTH') or 0)
260 for s in util.filechunkiter(self.inp, limit=length):
261 pass
262
263 def respond(self, status, type, filename=None, body=None):
257 def respond(self, status, type, filename=None, body=None):
264 if not isinstance(type, str):
258 if not isinstance(type, str):
265 type = pycompat.sysstr(type)
259 type = pycompat.sysstr(type)
266 if self._start_response is not None:
260 if self._start_response is not None:
267 self.headers.append((r'Content-Type', type))
261 self.headers.append((r'Content-Type', type))
268 if filename:
262 if filename:
269 filename = (filename.rpartition('/')[-1]
263 filename = (filename.rpartition('/')[-1]
270 .replace('\\', '\\\\').replace('"', '\\"'))
264 .replace('\\', '\\\\').replace('"', '\\"'))
271 self.headers.append(('Content-Disposition',
265 self.headers.append(('Content-Disposition',
272 'inline; filename="%s"' % filename))
266 'inline; filename="%s"' % filename))
273 if body is not None:
267 if body is not None:
274 self.headers.append((r'Content-Length', str(len(body))))
268 self.headers.append((r'Content-Length', str(len(body))))
275
269
276 for k, v in self.headers:
270 for k, v in self.headers:
277 if not isinstance(v, str):
271 if not isinstance(v, str):
278 raise TypeError('header value must be string: %r' % (v,))
272 raise TypeError('header value must be string: %r' % (v,))
279
273
280 if isinstance(status, ErrorResponse):
274 if isinstance(status, ErrorResponse):
281 self.headers.extend(status.headers)
275 self.headers.extend(status.headers)
282 if status.code == HTTP_NOT_MODIFIED:
276 if status.code == HTTP_NOT_MODIFIED:
283 # RFC 2616 Section 10.3.5: 304 Not Modified has cases where
277 # RFC 2616 Section 10.3.5: 304 Not Modified has cases where
284 # it MUST NOT include any headers other than these and no
278 # it MUST NOT include any headers other than these and no
285 # body
279 # body
286 self.headers = [(k, v) for (k, v) in self.headers if
280 self.headers = [(k, v) for (k, v) in self.headers if
287 k in ('Date', 'ETag', 'Expires',
281 k in ('Date', 'ETag', 'Expires',
288 'Cache-Control', 'Vary')]
282 'Cache-Control', 'Vary')]
289 status = statusmessage(status.code, pycompat.bytestr(status))
283 status = statusmessage(status.code, pycompat.bytestr(status))
290 elif status == 200:
284 elif status == 200:
291 status = '200 Script output follows'
285 status = '200 Script output follows'
292 elif isinstance(status, int):
286 elif isinstance(status, int):
293 status = statusmessage(status)
287 status = statusmessage(status)
294
288
289 # Various HTTP clients (notably httplib) won't read the HTTP
290 # response until the HTTP request has been sent in full. If servers
291 # (us) send a response before the HTTP request has been fully sent,
292 # the connection may deadlock because neither end is reading.
293 #
294 # We work around this by "draining" the request data before
295 # sending any response in some conditions.
296 drain = False
297 close = False
298
299 # If the client sent Expect: 100-continue, we assume it is smart
300 # enough to deal with the server sending a response before reading
301 # the request. (httplib doesn't do this.)
302 if self.env.get(r'HTTP_EXPECT', r'').lower() == r'100-continue':
303 pass
304 # Only tend to request methods that have bodies. Strictly speaking,
305 # we should sniff for a body. But this is fine for our existing
306 # WSGI applications.
307 elif self.env[r'REQUEST_METHOD'] not in (r'POST', r'PUT'):
308 pass
309 else:
310 # If we don't know how much data to read, there's no guarantee
311 # that we can drain the request responsibly. The WSGI
312 # specification only says that servers *should* ensure the
313 # input stream doesn't overrun the actual request. So there's
314 # no guarantee that reading until EOF won't corrupt the stream
315 # state.
316 if not isinstance(self.inp, util.cappedreader):
317 close = True
318 else:
319 # We /could/ only drain certain HTTP response codes. But 200
320 # and non-200 wire protocol responses both require draining.
321 # Since we have a capped reader in place for all situations
322 # where we drain, it is safe to read from that stream. We'll
323 # either do a drain or no-op if we're already at EOF.
324 drain = True
325
326 if close:
327 self.headers.append((r'Connection', r'Close'))
328
329 if drain:
330 assert isinstance(self.inp, util.cappedreader)
331 while True:
332 chunk = self.inp.read(32768)
333 if not chunk:
334 break
335
295 self.server_write = self._start_response(
336 self.server_write = self._start_response(
296 pycompat.sysstr(status), self.headers)
337 pycompat.sysstr(status), self.headers)
297 self._start_response = None
338 self._start_response = None
298 self.headers = []
339 self.headers = []
299 if body is not None:
340 if body is not None:
300 self.write(body)
341 self.write(body)
301 self.server_write = None
342 self.server_write = None
302
343
303 def write(self, thing):
344 def write(self, thing):
304 if thing:
345 if thing:
305 try:
346 try:
306 self.server_write(thing)
347 self.server_write(thing)
307 except socket.error as inst:
348 except socket.error as inst:
308 if inst[0] != errno.ECONNRESET:
349 if inst[0] != errno.ECONNRESET:
309 raise
350 raise
310
351
311 def flush(self):
352 def flush(self):
312 return None
353 return None
313
354
314 def wsgiapplication(app_maker):
355 def wsgiapplication(app_maker):
315 '''For compatibility with old CGI scripts. A plain hgweb() or hgwebdir()
356 '''For compatibility with old CGI scripts. A plain hgweb() or hgwebdir()
316 can and should now be used as a WSGI application.'''
357 can and should now be used as a WSGI application.'''
317 application = app_maker()
358 application = app_maker()
318 def run_wsgi(env, respond):
359 def run_wsgi(env, respond):
319 return application(env, respond)
360 return application(env, respond)
320 return run_wsgi
361 return run_wsgi
@@ -1,667 +1,649
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 #
3 #
4 # This software may be used and distributed according to the terms of the
4 # This software may be used and distributed according to the terms of the
5 # GNU General Public License version 2 or any later version.
5 # GNU General Public License version 2 or any later version.
6
6
7 from __future__ import absolute_import
7 from __future__ import absolute_import
8
8
9 import contextlib
9 import contextlib
10 import struct
10 import struct
11 import sys
11 import sys
12 import threading
12 import threading
13
13
14 from .i18n import _
14 from .i18n import _
15 from . import (
15 from . import (
16 encoding,
16 encoding,
17 error,
17 error,
18 hook,
18 hook,
19 pycompat,
19 pycompat,
20 util,
20 util,
21 wireproto,
21 wireproto,
22 wireprototypes,
22 wireprototypes,
23 )
23 )
24
24
25 stringio = util.stringio
25 stringio = util.stringio
26
26
27 urlerr = util.urlerr
27 urlerr = util.urlerr
28 urlreq = util.urlreq
28 urlreq = util.urlreq
29
29
30 HTTP_OK = 200
30 HTTP_OK = 200
31
31
32 HGTYPE = 'application/mercurial-0.1'
32 HGTYPE = 'application/mercurial-0.1'
33 HGTYPE2 = 'application/mercurial-0.2'
33 HGTYPE2 = 'application/mercurial-0.2'
34 HGERRTYPE = 'application/hg-error'
34 HGERRTYPE = 'application/hg-error'
35
35
36 SSHV1 = wireprototypes.SSHV1
36 SSHV1 = wireprototypes.SSHV1
37 SSHV2 = wireprototypes.SSHV2
37 SSHV2 = wireprototypes.SSHV2
38
38
39 def decodevaluefromheaders(req, headerprefix):
39 def decodevaluefromheaders(req, headerprefix):
40 """Decode a long value from multiple HTTP request headers.
40 """Decode a long value from multiple HTTP request headers.
41
41
42 Returns the value as a bytes, not a str.
42 Returns the value as a bytes, not a str.
43 """
43 """
44 chunks = []
44 chunks = []
45 i = 1
45 i = 1
46 while True:
46 while True:
47 v = req.headers.get(b'%s-%d' % (headerprefix, i))
47 v = req.headers.get(b'%s-%d' % (headerprefix, i))
48 if v is None:
48 if v is None:
49 break
49 break
50 chunks.append(pycompat.bytesurl(v))
50 chunks.append(pycompat.bytesurl(v))
51 i += 1
51 i += 1
52
52
53 return ''.join(chunks)
53 return ''.join(chunks)
54
54
55 class httpv1protocolhandler(wireprototypes.baseprotocolhandler):
55 class httpv1protocolhandler(wireprototypes.baseprotocolhandler):
56 def __init__(self, wsgireq, req, ui, checkperm):
56 def __init__(self, wsgireq, req, ui, checkperm):
57 self._wsgireq = wsgireq
57 self._wsgireq = wsgireq
58 self._req = req
58 self._req = req
59 self._ui = ui
59 self._ui = ui
60 self._checkperm = checkperm
60 self._checkperm = checkperm
61
61
62 @property
62 @property
63 def name(self):
63 def name(self):
64 return 'http-v1'
64 return 'http-v1'
65
65
66 def getargs(self, args):
66 def getargs(self, args):
67 knownargs = self._args()
67 knownargs = self._args()
68 data = {}
68 data = {}
69 keys = args.split()
69 keys = args.split()
70 for k in keys:
70 for k in keys:
71 if k == '*':
71 if k == '*':
72 star = {}
72 star = {}
73 for key in knownargs.keys():
73 for key in knownargs.keys():
74 if key != 'cmd' and key not in keys:
74 if key != 'cmd' and key not in keys:
75 star[key] = knownargs[key][0]
75 star[key] = knownargs[key][0]
76 data['*'] = star
76 data['*'] = star
77 else:
77 else:
78 data[k] = knownargs[k][0]
78 data[k] = knownargs[k][0]
79 return [data[k] for k in keys]
79 return [data[k] for k in keys]
80
80
81 def _args(self):
81 def _args(self):
82 args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
82 args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
83 postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
83 postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
84 if postlen:
84 if postlen:
85 args.update(urlreq.parseqs(
85 args.update(urlreq.parseqs(
86 self._wsgireq.inp.read(postlen), keep_blank_values=True))
86 self._wsgireq.inp.read(postlen), keep_blank_values=True))
87 return args
87 return args
88
88
89 argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
89 argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
90 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
90 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
91 return args
91 return args
92
92
93 def forwardpayload(self, fp):
93 def forwardpayload(self, fp):
94 # Existing clients *always* send Content-Length.
94 # Existing clients *always* send Content-Length.
95 length = int(self._req.headers[b'Content-Length'])
95 length = int(self._req.headers[b'Content-Length'])
96
96
97 # If httppostargs is used, we need to read Content-Length
97 # If httppostargs is used, we need to read Content-Length
98 # minus the amount that was consumed by args.
98 # minus the amount that was consumed by args.
99 length -= int(self._req.headers.get(b'X-HgArgs-Post', 0))
99 length -= int(self._req.headers.get(b'X-HgArgs-Post', 0))
100 for s in util.filechunkiter(self._wsgireq.inp, limit=length):
100 for s in util.filechunkiter(self._wsgireq.inp, limit=length):
101 fp.write(s)
101 fp.write(s)
102
102
103 @contextlib.contextmanager
103 @contextlib.contextmanager
104 def mayberedirectstdio(self):
104 def mayberedirectstdio(self):
105 oldout = self._ui.fout
105 oldout = self._ui.fout
106 olderr = self._ui.ferr
106 olderr = self._ui.ferr
107
107
108 out = util.stringio()
108 out = util.stringio()
109
109
110 try:
110 try:
111 self._ui.fout = out
111 self._ui.fout = out
112 self._ui.ferr = out
112 self._ui.ferr = out
113 yield out
113 yield out
114 finally:
114 finally:
115 self._ui.fout = oldout
115 self._ui.fout = oldout
116 self._ui.ferr = olderr
116 self._ui.ferr = olderr
117
117
118 def client(self):
118 def client(self):
119 return 'remote:%s:%s:%s' % (
119 return 'remote:%s:%s:%s' % (
120 self._wsgireq.env.get('wsgi.url_scheme') or 'http',
120 self._wsgireq.env.get('wsgi.url_scheme') or 'http',
121 urlreq.quote(self._wsgireq.env.get('REMOTE_HOST', '')),
121 urlreq.quote(self._wsgireq.env.get('REMOTE_HOST', '')),
122 urlreq.quote(self._wsgireq.env.get('REMOTE_USER', '')))
122 urlreq.quote(self._wsgireq.env.get('REMOTE_USER', '')))
123
123
124 def addcapabilities(self, repo, caps):
124 def addcapabilities(self, repo, caps):
125 caps.append('httpheader=%d' %
125 caps.append('httpheader=%d' %
126 repo.ui.configint('server', 'maxhttpheaderlen'))
126 repo.ui.configint('server', 'maxhttpheaderlen'))
127 if repo.ui.configbool('experimental', 'httppostargs'):
127 if repo.ui.configbool('experimental', 'httppostargs'):
128 caps.append('httppostargs')
128 caps.append('httppostargs')
129
129
130 # FUTURE advertise 0.2rx once support is implemented
130 # FUTURE advertise 0.2rx once support is implemented
131 # FUTURE advertise minrx and mintx after consulting config option
131 # FUTURE advertise minrx and mintx after consulting config option
132 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
132 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
133
133
134 compengines = wireproto.supportedcompengines(repo.ui, util.SERVERROLE)
134 compengines = wireproto.supportedcompengines(repo.ui, util.SERVERROLE)
135 if compengines:
135 if compengines:
136 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
136 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
137 for e in compengines)
137 for e in compengines)
138 caps.append('compression=%s' % comptypes)
138 caps.append('compression=%s' % comptypes)
139
139
140 return caps
140 return caps
141
141
142 def checkperm(self, perm):
142 def checkperm(self, perm):
143 return self._checkperm(perm)
143 return self._checkperm(perm)
144
144
145 # This method exists mostly so that extensions like remotefilelog can
145 # This method exists mostly so that extensions like remotefilelog can
146 # disable a kludgey legacy method only over http. As of early 2018,
146 # disable a kludgey legacy method only over http. As of early 2018,
147 # there are no other known users, so with any luck we can discard this
147 # there are no other known users, so with any luck we can discard this
148 # hook if remotefilelog becomes a first-party extension.
148 # hook if remotefilelog becomes a first-party extension.
149 def iscmd(cmd):
149 def iscmd(cmd):
150 return cmd in wireproto.commands
150 return cmd in wireproto.commands
151
151
152 def handlewsgirequest(rctx, wsgireq, req, checkperm):
152 def handlewsgirequest(rctx, wsgireq, req, checkperm):
153 """Possibly process a wire protocol request.
153 """Possibly process a wire protocol request.
154
154
155 If the current request is a wire protocol request, the request is
155 If the current request is a wire protocol request, the request is
156 processed by this function.
156 processed by this function.
157
157
158 ``wsgireq`` is a ``wsgirequest`` instance.
158 ``wsgireq`` is a ``wsgirequest`` instance.
159 ``req`` is a ``parsedrequest`` instance.
159 ``req`` is a ``parsedrequest`` instance.
160
160
161 Returns a 2-tuple of (bool, response) where the 1st element indicates
161 Returns a 2-tuple of (bool, response) where the 1st element indicates
162 whether the request was handled and the 2nd element is a return
162 whether the request was handled and the 2nd element is a return
163 value for a WSGI application (often a generator of bytes).
163 value for a WSGI application (often a generator of bytes).
164 """
164 """
165 # Avoid cycle involving hg module.
165 # Avoid cycle involving hg module.
166 from .hgweb import common as hgwebcommon
166 from .hgweb import common as hgwebcommon
167
167
168 repo = rctx.repo
168 repo = rctx.repo
169
169
170 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
170 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
171 # string parameter. If it isn't present, this isn't a wire protocol
171 # string parameter. If it isn't present, this isn't a wire protocol
172 # request.
172 # request.
173 if 'cmd' not in req.querystringdict:
173 if 'cmd' not in req.querystringdict:
174 return False, None
174 return False, None
175
175
176 cmd = req.querystringdict['cmd'][0]
176 cmd = req.querystringdict['cmd'][0]
177
177
178 # The "cmd" request parameter is used by both the wire protocol and hgweb.
178 # The "cmd" request parameter is used by both the wire protocol and hgweb.
179 # While not all wire protocol commands are available for all transports,
179 # While not all wire protocol commands are available for all transports,
180 # if we see a "cmd" value that resembles a known wire protocol command, we
180 # if we see a "cmd" value that resembles a known wire protocol command, we
181 # route it to a protocol handler. This is better than routing possible
181 # route it to a protocol handler. This is better than routing possible
182 # wire protocol requests to hgweb because it prevents hgweb from using
182 # wire protocol requests to hgweb because it prevents hgweb from using
183 # known wire protocol commands and it is less confusing for machine
183 # known wire protocol commands and it is less confusing for machine
184 # clients.
184 # clients.
185 if not iscmd(cmd):
185 if not iscmd(cmd):
186 return False, None
186 return False, None
187
187
188 # The "cmd" query string argument is only valid on the root path of the
188 # The "cmd" query string argument is only valid on the root path of the
189 # repo. e.g. ``/?cmd=foo``, ``/repo?cmd=foo``. URL paths within the repo
189 # repo. e.g. ``/?cmd=foo``, ``/repo?cmd=foo``. URL paths within the repo
190 # like ``/blah?cmd=foo`` are not allowed. So don't recognize the request
190 # like ``/blah?cmd=foo`` are not allowed. So don't recognize the request
191 # in this case. We send an HTTP 404 for backwards compatibility reasons.
191 # in this case. We send an HTTP 404 for backwards compatibility reasons.
192 if req.dispatchpath:
192 if req.dispatchpath:
193 res = _handlehttperror(
193 res = _handlehttperror(
194 hgwebcommon.ErrorResponse(hgwebcommon.HTTP_NOT_FOUND), wsgireq,
194 hgwebcommon.ErrorResponse(hgwebcommon.HTTP_NOT_FOUND), wsgireq,
195 req)
195 req)
196
196
197 return True, res
197 return True, res
198
198
199 proto = httpv1protocolhandler(wsgireq, req, repo.ui,
199 proto = httpv1protocolhandler(wsgireq, req, repo.ui,
200 lambda perm: checkperm(rctx, wsgireq, perm))
200 lambda perm: checkperm(rctx, wsgireq, perm))
201
201
202 # The permissions checker should be the only thing that can raise an
202 # The permissions checker should be the only thing that can raise an
203 # ErrorResponse. It is kind of a layer violation to catch an hgweb
203 # ErrorResponse. It is kind of a layer violation to catch an hgweb
204 # exception here. So consider refactoring into a exception type that
204 # exception here. So consider refactoring into a exception type that
205 # is associated with the wire protocol.
205 # is associated with the wire protocol.
206 try:
206 try:
207 res = _callhttp(repo, wsgireq, req, proto, cmd)
207 res = _callhttp(repo, wsgireq, req, proto, cmd)
208 except hgwebcommon.ErrorResponse as e:
208 except hgwebcommon.ErrorResponse as e:
209 res = _handlehttperror(e, wsgireq, req)
209 res = _handlehttperror(e, wsgireq, req)
210
210
211 return True, res
211 return True, res
212
212
213 def _httpresponsetype(ui, req, prefer_uncompressed):
213 def _httpresponsetype(ui, req, prefer_uncompressed):
214 """Determine the appropriate response type and compression settings.
214 """Determine the appropriate response type and compression settings.
215
215
216 Returns a tuple of (mediatype, compengine, engineopts).
216 Returns a tuple of (mediatype, compengine, engineopts).
217 """
217 """
218 # Determine the response media type and compression engine based
218 # Determine the response media type and compression engine based
219 # on the request parameters.
219 # on the request parameters.
220 protocaps = decodevaluefromheaders(req, 'X-HgProto').split(' ')
220 protocaps = decodevaluefromheaders(req, 'X-HgProto').split(' ')
221
221
222 if '0.2' in protocaps:
222 if '0.2' in protocaps:
223 # All clients are expected to support uncompressed data.
223 # All clients are expected to support uncompressed data.
224 if prefer_uncompressed:
224 if prefer_uncompressed:
225 return HGTYPE2, util._noopengine(), {}
225 return HGTYPE2, util._noopengine(), {}
226
226
227 # Default as defined by wire protocol spec.
227 # Default as defined by wire protocol spec.
228 compformats = ['zlib', 'none']
228 compformats = ['zlib', 'none']
229 for cap in protocaps:
229 for cap in protocaps:
230 if cap.startswith('comp='):
230 if cap.startswith('comp='):
231 compformats = cap[5:].split(',')
231 compformats = cap[5:].split(',')
232 break
232 break
233
233
234 # Now find an agreed upon compression format.
234 # Now find an agreed upon compression format.
235 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
235 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
236 if engine.wireprotosupport().name in compformats:
236 if engine.wireprotosupport().name in compformats:
237 opts = {}
237 opts = {}
238 level = ui.configint('server', '%slevel' % engine.name())
238 level = ui.configint('server', '%slevel' % engine.name())
239 if level is not None:
239 if level is not None:
240 opts['level'] = level
240 opts['level'] = level
241
241
242 return HGTYPE2, engine, opts
242 return HGTYPE2, engine, opts
243
243
244 # No mutually supported compression format. Fall back to the
244 # No mutually supported compression format. Fall back to the
245 # legacy protocol.
245 # legacy protocol.
246
246
247 # Don't allow untrusted settings because disabling compression or
247 # Don't allow untrusted settings because disabling compression or
248 # setting a very high compression level could lead to flooding
248 # setting a very high compression level could lead to flooding
249 # the server's network or CPU.
249 # the server's network or CPU.
250 opts = {'level': ui.configint('server', 'zliblevel')}
250 opts = {'level': ui.configint('server', 'zliblevel')}
251 return HGTYPE, util.compengines['zlib'], opts
251 return HGTYPE, util.compengines['zlib'], opts
252
252
253 def _callhttp(repo, wsgireq, req, proto, cmd):
253 def _callhttp(repo, wsgireq, req, proto, cmd):
254 def genversion2(gen, engine, engineopts):
254 def genversion2(gen, engine, engineopts):
255 # application/mercurial-0.2 always sends a payload header
255 # application/mercurial-0.2 always sends a payload header
256 # identifying the compression engine.
256 # identifying the compression engine.
257 name = engine.wireprotosupport().name
257 name = engine.wireprotosupport().name
258 assert 0 < len(name) < 256
258 assert 0 < len(name) < 256
259 yield struct.pack('B', len(name))
259 yield struct.pack('B', len(name))
260 yield name
260 yield name
261
261
262 for chunk in gen:
262 for chunk in gen:
263 yield chunk
263 yield chunk
264
264
265 if not wireproto.commands.commandavailable(cmd, proto):
265 if not wireproto.commands.commandavailable(cmd, proto):
266 wsgireq.respond(HTTP_OK, HGERRTYPE,
266 wsgireq.respond(HTTP_OK, HGERRTYPE,
267 body=_('requested wire protocol command is not '
267 body=_('requested wire protocol command is not '
268 'available over HTTP'))
268 'available over HTTP'))
269 return []
269 return []
270
270
271 proto.checkperm(wireproto.commands[cmd].permission)
271 proto.checkperm(wireproto.commands[cmd].permission)
272
272
273 rsp = wireproto.dispatch(repo, proto, cmd)
273 rsp = wireproto.dispatch(repo, proto, cmd)
274
274
275 if isinstance(rsp, bytes):
275 if isinstance(rsp, bytes):
276 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
276 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
277 return []
277 return []
278 elif isinstance(rsp, wireprototypes.bytesresponse):
278 elif isinstance(rsp, wireprototypes.bytesresponse):
279 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp.data)
279 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp.data)
280 return []
280 return []
281 elif isinstance(rsp, wireprototypes.streamreslegacy):
281 elif isinstance(rsp, wireprototypes.streamreslegacy):
282 gen = rsp.gen
282 gen = rsp.gen
283 wsgireq.respond(HTTP_OK, HGTYPE)
283 wsgireq.respond(HTTP_OK, HGTYPE)
284 return gen
284 return gen
285 elif isinstance(rsp, wireprototypes.streamres):
285 elif isinstance(rsp, wireprototypes.streamres):
286 gen = rsp.gen
286 gen = rsp.gen
287
287
288 # This code for compression should not be streamres specific. It
288 # This code for compression should not be streamres specific. It
289 # is here because we only compress streamres at the moment.
289 # is here because we only compress streamres at the moment.
290 mediatype, engine, engineopts = _httpresponsetype(
290 mediatype, engine, engineopts = _httpresponsetype(
291 repo.ui, req, rsp.prefer_uncompressed)
291 repo.ui, req, rsp.prefer_uncompressed)
292 gen = engine.compressstream(gen, engineopts)
292 gen = engine.compressstream(gen, engineopts)
293
293
294 if mediatype == HGTYPE2:
294 if mediatype == HGTYPE2:
295 gen = genversion2(gen, engine, engineopts)
295 gen = genversion2(gen, engine, engineopts)
296
296
297 wsgireq.respond(HTTP_OK, mediatype)
297 wsgireq.respond(HTTP_OK, mediatype)
298 return gen
298 return gen
299 elif isinstance(rsp, wireprototypes.pushres):
299 elif isinstance(rsp, wireprototypes.pushres):
300 rsp = '%d\n%s' % (rsp.res, rsp.output)
300 rsp = '%d\n%s' % (rsp.res, rsp.output)
301 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
301 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
302 return []
302 return []
303 elif isinstance(rsp, wireprototypes.pusherr):
303 elif isinstance(rsp, wireprototypes.pusherr):
304 # This is the httplib workaround documented in _handlehttperror().
305 wsgireq.drain()
306
307 rsp = '0\n%s\n' % rsp.res
304 rsp = '0\n%s\n' % rsp.res
308 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
305 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
309 return []
306 return []
310 elif isinstance(rsp, wireprototypes.ooberror):
307 elif isinstance(rsp, wireprototypes.ooberror):
311 rsp = rsp.message
308 rsp = rsp.message
312 wsgireq.respond(HTTP_OK, HGERRTYPE, body=rsp)
309 wsgireq.respond(HTTP_OK, HGERRTYPE, body=rsp)
313 return []
310 return []
314 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
311 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
315
312
316 def _handlehttperror(e, wsgireq, req):
313 def _handlehttperror(e, wsgireq, req):
317 """Called when an ErrorResponse is raised during HTTP request processing."""
314 """Called when an ErrorResponse is raised during HTTP request processing."""
318
315
319 # Clients using Python's httplib are stateful: the HTTP client
320 # won't process an HTTP response until all request data is
321 # sent to the server. The intent of this code is to ensure
322 # we always read HTTP request data from the client, thus
323 # ensuring httplib transitions to a state that allows it to read
324 # the HTTP response. In other words, it helps prevent deadlocks
325 # on clients using httplib.
326
327 if (req.method == 'POST' and
328 # But not if Expect: 100-continue is being used.
329 (req.headers.get('Expect', '').lower() != '100-continue')):
330 wsgireq.drain()
331 else:
332 wsgireq.headers.append((r'Connection', r'Close'))
333
334 # TODO This response body assumes the failed command was
316 # TODO This response body assumes the failed command was
335 # "unbundle." That assumption is not always valid.
317 # "unbundle." That assumption is not always valid.
336 wsgireq.respond(e, HGTYPE, body='0\n%s\n' % pycompat.bytestr(e))
318 wsgireq.respond(e, HGTYPE, body='0\n%s\n' % pycompat.bytestr(e))
337
319
338 return ''
320 return ''
339
321
340 def _sshv1respondbytes(fout, value):
322 def _sshv1respondbytes(fout, value):
341 """Send a bytes response for protocol version 1."""
323 """Send a bytes response for protocol version 1."""
342 fout.write('%d\n' % len(value))
324 fout.write('%d\n' % len(value))
343 fout.write(value)
325 fout.write(value)
344 fout.flush()
326 fout.flush()
345
327
346 def _sshv1respondstream(fout, source):
328 def _sshv1respondstream(fout, source):
347 write = fout.write
329 write = fout.write
348 for chunk in source.gen:
330 for chunk in source.gen:
349 write(chunk)
331 write(chunk)
350 fout.flush()
332 fout.flush()
351
333
352 def _sshv1respondooberror(fout, ferr, rsp):
334 def _sshv1respondooberror(fout, ferr, rsp):
353 ferr.write(b'%s\n-\n' % rsp)
335 ferr.write(b'%s\n-\n' % rsp)
354 ferr.flush()
336 ferr.flush()
355 fout.write(b'\n')
337 fout.write(b'\n')
356 fout.flush()
338 fout.flush()
357
339
358 class sshv1protocolhandler(wireprototypes.baseprotocolhandler):
340 class sshv1protocolhandler(wireprototypes.baseprotocolhandler):
359 """Handler for requests services via version 1 of SSH protocol."""
341 """Handler for requests services via version 1 of SSH protocol."""
360 def __init__(self, ui, fin, fout):
342 def __init__(self, ui, fin, fout):
361 self._ui = ui
343 self._ui = ui
362 self._fin = fin
344 self._fin = fin
363 self._fout = fout
345 self._fout = fout
364
346
365 @property
347 @property
366 def name(self):
348 def name(self):
367 return wireprototypes.SSHV1
349 return wireprototypes.SSHV1
368
350
369 def getargs(self, args):
351 def getargs(self, args):
370 data = {}
352 data = {}
371 keys = args.split()
353 keys = args.split()
372 for n in xrange(len(keys)):
354 for n in xrange(len(keys)):
373 argline = self._fin.readline()[:-1]
355 argline = self._fin.readline()[:-1]
374 arg, l = argline.split()
356 arg, l = argline.split()
375 if arg not in keys:
357 if arg not in keys:
376 raise error.Abort(_("unexpected parameter %r") % arg)
358 raise error.Abort(_("unexpected parameter %r") % arg)
377 if arg == '*':
359 if arg == '*':
378 star = {}
360 star = {}
379 for k in xrange(int(l)):
361 for k in xrange(int(l)):
380 argline = self._fin.readline()[:-1]
362 argline = self._fin.readline()[:-1]
381 arg, l = argline.split()
363 arg, l = argline.split()
382 val = self._fin.read(int(l))
364 val = self._fin.read(int(l))
383 star[arg] = val
365 star[arg] = val
384 data['*'] = star
366 data['*'] = star
385 else:
367 else:
386 val = self._fin.read(int(l))
368 val = self._fin.read(int(l))
387 data[arg] = val
369 data[arg] = val
388 return [data[k] for k in keys]
370 return [data[k] for k in keys]
389
371
390 def forwardpayload(self, fpout):
372 def forwardpayload(self, fpout):
391 # We initially send an empty response. This tells the client it is
373 # We initially send an empty response. This tells the client it is
392 # OK to start sending data. If a client sees any other response, it
374 # OK to start sending data. If a client sees any other response, it
393 # interprets it as an error.
375 # interprets it as an error.
394 _sshv1respondbytes(self._fout, b'')
376 _sshv1respondbytes(self._fout, b'')
395
377
396 # The file is in the form:
378 # The file is in the form:
397 #
379 #
398 # <chunk size>\n<chunk>
380 # <chunk size>\n<chunk>
399 # ...
381 # ...
400 # 0\n
382 # 0\n
401 count = int(self._fin.readline())
383 count = int(self._fin.readline())
402 while count:
384 while count:
403 fpout.write(self._fin.read(count))
385 fpout.write(self._fin.read(count))
404 count = int(self._fin.readline())
386 count = int(self._fin.readline())
405
387
406 @contextlib.contextmanager
388 @contextlib.contextmanager
407 def mayberedirectstdio(self):
389 def mayberedirectstdio(self):
408 yield None
390 yield None
409
391
410 def client(self):
392 def client(self):
411 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
393 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
412 return 'remote:ssh:' + client
394 return 'remote:ssh:' + client
413
395
414 def addcapabilities(self, repo, caps):
396 def addcapabilities(self, repo, caps):
415 return caps
397 return caps
416
398
417 def checkperm(self, perm):
399 def checkperm(self, perm):
418 pass
400 pass
419
401
420 class sshv2protocolhandler(sshv1protocolhandler):
402 class sshv2protocolhandler(sshv1protocolhandler):
421 """Protocol handler for version 2 of the SSH protocol."""
403 """Protocol handler for version 2 of the SSH protocol."""
422
404
423 @property
405 @property
424 def name(self):
406 def name(self):
425 return wireprototypes.SSHV2
407 return wireprototypes.SSHV2
426
408
427 def _runsshserver(ui, repo, fin, fout, ev):
409 def _runsshserver(ui, repo, fin, fout, ev):
428 # This function operates like a state machine of sorts. The following
410 # This function operates like a state machine of sorts. The following
429 # states are defined:
411 # states are defined:
430 #
412 #
431 # protov1-serving
413 # protov1-serving
432 # Server is in protocol version 1 serving mode. Commands arrive on
414 # Server is in protocol version 1 serving mode. Commands arrive on
433 # new lines. These commands are processed in this state, one command
415 # new lines. These commands are processed in this state, one command
434 # after the other.
416 # after the other.
435 #
417 #
436 # protov2-serving
418 # protov2-serving
437 # Server is in protocol version 2 serving mode.
419 # Server is in protocol version 2 serving mode.
438 #
420 #
439 # upgrade-initial
421 # upgrade-initial
440 # The server is going to process an upgrade request.
422 # The server is going to process an upgrade request.
441 #
423 #
442 # upgrade-v2-filter-legacy-handshake
424 # upgrade-v2-filter-legacy-handshake
443 # The protocol is being upgraded to version 2. The server is expecting
425 # The protocol is being upgraded to version 2. The server is expecting
444 # the legacy handshake from version 1.
426 # the legacy handshake from version 1.
445 #
427 #
446 # upgrade-v2-finish
428 # upgrade-v2-finish
447 # The upgrade to version 2 of the protocol is imminent.
429 # The upgrade to version 2 of the protocol is imminent.
448 #
430 #
449 # shutdown
431 # shutdown
450 # The server is shutting down, possibly in reaction to a client event.
432 # The server is shutting down, possibly in reaction to a client event.
451 #
433 #
452 # And here are their transitions:
434 # And here are their transitions:
453 #
435 #
454 # protov1-serving -> shutdown
436 # protov1-serving -> shutdown
455 # When server receives an empty request or encounters another
437 # When server receives an empty request or encounters another
456 # error.
438 # error.
457 #
439 #
458 # protov1-serving -> upgrade-initial
440 # protov1-serving -> upgrade-initial
459 # An upgrade request line was seen.
441 # An upgrade request line was seen.
460 #
442 #
461 # upgrade-initial -> upgrade-v2-filter-legacy-handshake
443 # upgrade-initial -> upgrade-v2-filter-legacy-handshake
462 # Upgrade to version 2 in progress. Server is expecting to
444 # Upgrade to version 2 in progress. Server is expecting to
463 # process a legacy handshake.
445 # process a legacy handshake.
464 #
446 #
465 # upgrade-v2-filter-legacy-handshake -> shutdown
447 # upgrade-v2-filter-legacy-handshake -> shutdown
466 # Client did not fulfill upgrade handshake requirements.
448 # Client did not fulfill upgrade handshake requirements.
467 #
449 #
468 # upgrade-v2-filter-legacy-handshake -> upgrade-v2-finish
450 # upgrade-v2-filter-legacy-handshake -> upgrade-v2-finish
469 # Client fulfilled version 2 upgrade requirements. Finishing that
451 # Client fulfilled version 2 upgrade requirements. Finishing that
470 # upgrade.
452 # upgrade.
471 #
453 #
472 # upgrade-v2-finish -> protov2-serving
454 # upgrade-v2-finish -> protov2-serving
473 # Protocol upgrade to version 2 complete. Server can now speak protocol
455 # Protocol upgrade to version 2 complete. Server can now speak protocol
474 # version 2.
456 # version 2.
475 #
457 #
476 # protov2-serving -> protov1-serving
458 # protov2-serving -> protov1-serving
477 # Ths happens by default since protocol version 2 is the same as
459 # Ths happens by default since protocol version 2 is the same as
478 # version 1 except for the handshake.
460 # version 1 except for the handshake.
479
461
480 state = 'protov1-serving'
462 state = 'protov1-serving'
481 proto = sshv1protocolhandler(ui, fin, fout)
463 proto = sshv1protocolhandler(ui, fin, fout)
482 protoswitched = False
464 protoswitched = False
483
465
484 while not ev.is_set():
466 while not ev.is_set():
485 if state == 'protov1-serving':
467 if state == 'protov1-serving':
486 # Commands are issued on new lines.
468 # Commands are issued on new lines.
487 request = fin.readline()[:-1]
469 request = fin.readline()[:-1]
488
470
489 # Empty lines signal to terminate the connection.
471 # Empty lines signal to terminate the connection.
490 if not request:
472 if not request:
491 state = 'shutdown'
473 state = 'shutdown'
492 continue
474 continue
493
475
494 # It looks like a protocol upgrade request. Transition state to
476 # It looks like a protocol upgrade request. Transition state to
495 # handle it.
477 # handle it.
496 if request.startswith(b'upgrade '):
478 if request.startswith(b'upgrade '):
497 if protoswitched:
479 if protoswitched:
498 _sshv1respondooberror(fout, ui.ferr,
480 _sshv1respondooberror(fout, ui.ferr,
499 b'cannot upgrade protocols multiple '
481 b'cannot upgrade protocols multiple '
500 b'times')
482 b'times')
501 state = 'shutdown'
483 state = 'shutdown'
502 continue
484 continue
503
485
504 state = 'upgrade-initial'
486 state = 'upgrade-initial'
505 continue
487 continue
506
488
507 available = wireproto.commands.commandavailable(request, proto)
489 available = wireproto.commands.commandavailable(request, proto)
508
490
509 # This command isn't available. Send an empty response and go
491 # This command isn't available. Send an empty response and go
510 # back to waiting for a new command.
492 # back to waiting for a new command.
511 if not available:
493 if not available:
512 _sshv1respondbytes(fout, b'')
494 _sshv1respondbytes(fout, b'')
513 continue
495 continue
514
496
515 rsp = wireproto.dispatch(repo, proto, request)
497 rsp = wireproto.dispatch(repo, proto, request)
516
498
517 if isinstance(rsp, bytes):
499 if isinstance(rsp, bytes):
518 _sshv1respondbytes(fout, rsp)
500 _sshv1respondbytes(fout, rsp)
519 elif isinstance(rsp, wireprototypes.bytesresponse):
501 elif isinstance(rsp, wireprototypes.bytesresponse):
520 _sshv1respondbytes(fout, rsp.data)
502 _sshv1respondbytes(fout, rsp.data)
521 elif isinstance(rsp, wireprototypes.streamres):
503 elif isinstance(rsp, wireprototypes.streamres):
522 _sshv1respondstream(fout, rsp)
504 _sshv1respondstream(fout, rsp)
523 elif isinstance(rsp, wireprototypes.streamreslegacy):
505 elif isinstance(rsp, wireprototypes.streamreslegacy):
524 _sshv1respondstream(fout, rsp)
506 _sshv1respondstream(fout, rsp)
525 elif isinstance(rsp, wireprototypes.pushres):
507 elif isinstance(rsp, wireprototypes.pushres):
526 _sshv1respondbytes(fout, b'')
508 _sshv1respondbytes(fout, b'')
527 _sshv1respondbytes(fout, b'%d' % rsp.res)
509 _sshv1respondbytes(fout, b'%d' % rsp.res)
528 elif isinstance(rsp, wireprototypes.pusherr):
510 elif isinstance(rsp, wireprototypes.pusherr):
529 _sshv1respondbytes(fout, rsp.res)
511 _sshv1respondbytes(fout, rsp.res)
530 elif isinstance(rsp, wireprototypes.ooberror):
512 elif isinstance(rsp, wireprototypes.ooberror):
531 _sshv1respondooberror(fout, ui.ferr, rsp.message)
513 _sshv1respondooberror(fout, ui.ferr, rsp.message)
532 else:
514 else:
533 raise error.ProgrammingError('unhandled response type from '
515 raise error.ProgrammingError('unhandled response type from '
534 'wire protocol command: %s' % rsp)
516 'wire protocol command: %s' % rsp)
535
517
536 # For now, protocol version 2 serving just goes back to version 1.
518 # For now, protocol version 2 serving just goes back to version 1.
537 elif state == 'protov2-serving':
519 elif state == 'protov2-serving':
538 state = 'protov1-serving'
520 state = 'protov1-serving'
539 continue
521 continue
540
522
541 elif state == 'upgrade-initial':
523 elif state == 'upgrade-initial':
542 # We should never transition into this state if we've switched
524 # We should never transition into this state if we've switched
543 # protocols.
525 # protocols.
544 assert not protoswitched
526 assert not protoswitched
545 assert proto.name == wireprototypes.SSHV1
527 assert proto.name == wireprototypes.SSHV1
546
528
547 # Expected: upgrade <token> <capabilities>
529 # Expected: upgrade <token> <capabilities>
548 # If we get something else, the request is malformed. It could be
530 # If we get something else, the request is malformed. It could be
549 # from a future client that has altered the upgrade line content.
531 # from a future client that has altered the upgrade line content.
550 # We treat this as an unknown command.
532 # We treat this as an unknown command.
551 try:
533 try:
552 token, caps = request.split(b' ')[1:]
534 token, caps = request.split(b' ')[1:]
553 except ValueError:
535 except ValueError:
554 _sshv1respondbytes(fout, b'')
536 _sshv1respondbytes(fout, b'')
555 state = 'protov1-serving'
537 state = 'protov1-serving'
556 continue
538 continue
557
539
558 # Send empty response if we don't support upgrading protocols.
540 # Send empty response if we don't support upgrading protocols.
559 if not ui.configbool('experimental', 'sshserver.support-v2'):
541 if not ui.configbool('experimental', 'sshserver.support-v2'):
560 _sshv1respondbytes(fout, b'')
542 _sshv1respondbytes(fout, b'')
561 state = 'protov1-serving'
543 state = 'protov1-serving'
562 continue
544 continue
563
545
564 try:
546 try:
565 caps = urlreq.parseqs(caps)
547 caps = urlreq.parseqs(caps)
566 except ValueError:
548 except ValueError:
567 _sshv1respondbytes(fout, b'')
549 _sshv1respondbytes(fout, b'')
568 state = 'protov1-serving'
550 state = 'protov1-serving'
569 continue
551 continue
570
552
571 # We don't see an upgrade request to protocol version 2. Ignore
553 # We don't see an upgrade request to protocol version 2. Ignore
572 # the upgrade request.
554 # the upgrade request.
573 wantedprotos = caps.get(b'proto', [b''])[0]
555 wantedprotos = caps.get(b'proto', [b''])[0]
574 if SSHV2 not in wantedprotos:
556 if SSHV2 not in wantedprotos:
575 _sshv1respondbytes(fout, b'')
557 _sshv1respondbytes(fout, b'')
576 state = 'protov1-serving'
558 state = 'protov1-serving'
577 continue
559 continue
578
560
579 # It looks like we can honor this upgrade request to protocol 2.
561 # It looks like we can honor this upgrade request to protocol 2.
580 # Filter the rest of the handshake protocol request lines.
562 # Filter the rest of the handshake protocol request lines.
581 state = 'upgrade-v2-filter-legacy-handshake'
563 state = 'upgrade-v2-filter-legacy-handshake'
582 continue
564 continue
583
565
584 elif state == 'upgrade-v2-filter-legacy-handshake':
566 elif state == 'upgrade-v2-filter-legacy-handshake':
585 # Client should have sent legacy handshake after an ``upgrade``
567 # Client should have sent legacy handshake after an ``upgrade``
586 # request. Expected lines:
568 # request. Expected lines:
587 #
569 #
588 # hello
570 # hello
589 # between
571 # between
590 # pairs 81
572 # pairs 81
591 # 0000...-0000...
573 # 0000...-0000...
592
574
593 ok = True
575 ok = True
594 for line in (b'hello', b'between', b'pairs 81'):
576 for line in (b'hello', b'between', b'pairs 81'):
595 request = fin.readline()[:-1]
577 request = fin.readline()[:-1]
596
578
597 if request != line:
579 if request != line:
598 _sshv1respondooberror(fout, ui.ferr,
580 _sshv1respondooberror(fout, ui.ferr,
599 b'malformed handshake protocol: '
581 b'malformed handshake protocol: '
600 b'missing %s' % line)
582 b'missing %s' % line)
601 ok = False
583 ok = False
602 state = 'shutdown'
584 state = 'shutdown'
603 break
585 break
604
586
605 if not ok:
587 if not ok:
606 continue
588 continue
607
589
608 request = fin.read(81)
590 request = fin.read(81)
609 if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
591 if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
610 _sshv1respondooberror(fout, ui.ferr,
592 _sshv1respondooberror(fout, ui.ferr,
611 b'malformed handshake protocol: '
593 b'malformed handshake protocol: '
612 b'missing between argument value')
594 b'missing between argument value')
613 state = 'shutdown'
595 state = 'shutdown'
614 continue
596 continue
615
597
616 state = 'upgrade-v2-finish'
598 state = 'upgrade-v2-finish'
617 continue
599 continue
618
600
619 elif state == 'upgrade-v2-finish':
601 elif state == 'upgrade-v2-finish':
620 # Send the upgrade response.
602 # Send the upgrade response.
621 fout.write(b'upgraded %s %s\n' % (token, SSHV2))
603 fout.write(b'upgraded %s %s\n' % (token, SSHV2))
622 servercaps = wireproto.capabilities(repo, proto)
604 servercaps = wireproto.capabilities(repo, proto)
623 rsp = b'capabilities: %s' % servercaps.data
605 rsp = b'capabilities: %s' % servercaps.data
624 fout.write(b'%d\n%s\n' % (len(rsp), rsp))
606 fout.write(b'%d\n%s\n' % (len(rsp), rsp))
625 fout.flush()
607 fout.flush()
626
608
627 proto = sshv2protocolhandler(ui, fin, fout)
609 proto = sshv2protocolhandler(ui, fin, fout)
628 protoswitched = True
610 protoswitched = True
629
611
630 state = 'protov2-serving'
612 state = 'protov2-serving'
631 continue
613 continue
632
614
633 elif state == 'shutdown':
615 elif state == 'shutdown':
634 break
616 break
635
617
636 else:
618 else:
637 raise error.ProgrammingError('unhandled ssh server state: %s' %
619 raise error.ProgrammingError('unhandled ssh server state: %s' %
638 state)
620 state)
639
621
640 class sshserver(object):
622 class sshserver(object):
641 def __init__(self, ui, repo, logfh=None):
623 def __init__(self, ui, repo, logfh=None):
642 self._ui = ui
624 self._ui = ui
643 self._repo = repo
625 self._repo = repo
644 self._fin = ui.fin
626 self._fin = ui.fin
645 self._fout = ui.fout
627 self._fout = ui.fout
646
628
647 # Log write I/O to stdout and stderr if configured.
629 # Log write I/O to stdout and stderr if configured.
648 if logfh:
630 if logfh:
649 self._fout = util.makeloggingfileobject(
631 self._fout = util.makeloggingfileobject(
650 logfh, self._fout, 'o', logdata=True)
632 logfh, self._fout, 'o', logdata=True)
651 ui.ferr = util.makeloggingfileobject(
633 ui.ferr = util.makeloggingfileobject(
652 logfh, ui.ferr, 'e', logdata=True)
634 logfh, ui.ferr, 'e', logdata=True)
653
635
654 hook.redirect(True)
636 hook.redirect(True)
655 ui.fout = repo.ui.fout = ui.ferr
637 ui.fout = repo.ui.fout = ui.ferr
656
638
657 # Prevent insertion/deletion of CRs
639 # Prevent insertion/deletion of CRs
658 util.setbinary(self._fin)
640 util.setbinary(self._fin)
659 util.setbinary(self._fout)
641 util.setbinary(self._fout)
660
642
661 def serve_forever(self):
643 def serve_forever(self):
662 self.serveuntil(threading.Event())
644 self.serveuntil(threading.Event())
663 sys.exit(0)
645 sys.exit(0)
664
646
665 def serveuntil(self, ev):
647 def serveuntil(self, ev):
666 """Serve until a threading.Event is set."""
648 """Serve until a threading.Event is set."""
667 _runsshserver(self._ui, self._repo, self._fin, self._fout, ev)
649 _runsshserver(self._ui, self._repo, self._fin, self._fout, ev)
General Comments 0
You need to be logged in to leave comments. Login now