##// END OF EJS Templates
tests: add test coverage for parsing WSGI requests...
Gregory Szorc -
r36912:b2a3308d default
parent child Browse files
Show More
@@ -0,0 +1,255 b''
1 from __future__ import absolute_import, print_function
2
3 import unittest
4
5 from mercurial.hgweb import (
6 request as requestmod,
7 )
8
9 DEFAULT_ENV = {
10 r'REQUEST_METHOD': r'GET',
11 r'SERVER_NAME': r'testserver',
12 r'SERVER_PORT': r'80',
13 r'SERVER_PROTOCOL': r'http',
14 r'wsgi.version': (1, 0),
15 r'wsgi.url_scheme': r'http',
16 r'wsgi.input': None,
17 r'wsgi.errors': None,
18 r'wsgi.multithread': False,
19 r'wsgi.multiprocess': True,
20 r'wsgi.run_once': False,
21 }
22
23 def parse(env, bodyfh=None, extra=None):
24 env = dict(env)
25 env.update(extra or {})
26
27 return requestmod.parserequestfromenv(env, bodyfh)
28
29 class ParseRequestTests(unittest.TestCase):
30 def testdefault(self):
31 r = parse(DEFAULT_ENV)
32 self.assertEqual(r.url, b'http://testserver')
33 self.assertEqual(r.baseurl, b'http://testserver')
34 self.assertEqual(r.advertisedurl, b'http://testserver')
35 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
36 self.assertEqual(r.urlscheme, b'http')
37 self.assertEqual(r.method, b'GET')
38 self.assertIsNone(r.remoteuser)
39 self.assertIsNone(r.remotehost)
40 self.assertEqual(r.apppath, b'')
41 self.assertEqual(r.dispatchparts, [])
42 self.assertEqual(r.dispatchpath, b'')
43 self.assertFalse(r.havepathinfo)
44 self.assertIsNone(r.reponame)
45 self.assertEqual(r.querystring, b'')
46 self.assertEqual(len(r.qsparams), 0)
47 self.assertEqual(len(r.headers), 0)
48
49 def testcustomport(self):
50 r = parse(DEFAULT_ENV, extra={
51 r'SERVER_PORT': r'8000',
52 })
53
54 self.assertEqual(r.url, b'http://testserver:8000')
55 self.assertEqual(r.baseurl, b'http://testserver:8000')
56 self.assertEqual(r.advertisedurl, b'http://testserver:8000')
57 self.assertEqual(r.advertisedbaseurl, b'http://testserver:8000')
58
59 r = parse(DEFAULT_ENV, extra={
60 r'SERVER_PORT': r'4000',
61 r'wsgi.url_scheme': r'https',
62 })
63
64 self.assertEqual(r.url, b'https://testserver:4000')
65 self.assertEqual(r.baseurl, b'https://testserver:4000')
66 self.assertEqual(r.advertisedurl, b'https://testserver:4000')
67 self.assertEqual(r.advertisedbaseurl, b'https://testserver:4000')
68
69 def testhttphost(self):
70 r = parse(DEFAULT_ENV, extra={
71 r'HTTP_HOST': r'altserver',
72 })
73
74 self.assertEqual(r.url, b'http://altserver')
75 self.assertEqual(r.baseurl, b'http://altserver')
76 self.assertEqual(r.advertisedurl, b'http://testserver')
77 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
78
79 def testscriptname(self):
80 r = parse(DEFAULT_ENV, extra={
81 r'SCRIPT_NAME': r'',
82 })
83
84 self.assertEqual(r.url, b'http://testserver')
85 self.assertEqual(r.baseurl, b'http://testserver')
86 self.assertEqual(r.advertisedurl, b'http://testserver')
87 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
88 self.assertEqual(r.apppath, b'')
89 self.assertEqual(r.dispatchparts, [])
90 self.assertEqual(r.dispatchpath, b'')
91 self.assertFalse(r.havepathinfo)
92
93 r = parse(DEFAULT_ENV, extra={
94 r'SCRIPT_NAME': r'/script',
95 })
96
97 self.assertEqual(r.url, b'http://testserver/script')
98 self.assertEqual(r.baseurl, b'http://testserver')
99 self.assertEqual(r.advertisedurl, b'http://testserver/script')
100 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
101 self.assertEqual(r.apppath, b'/script')
102 self.assertEqual(r.dispatchparts, [])
103 self.assertEqual(r.dispatchpath, b'')
104 self.assertFalse(r.havepathinfo)
105
106 r = parse(DEFAULT_ENV, extra={
107 r'SCRIPT_NAME': r'/multiple words',
108 })
109
110 self.assertEqual(r.url, b'http://testserver/multiple%20words')
111 self.assertEqual(r.baseurl, b'http://testserver')
112 self.assertEqual(r.advertisedurl, b'http://testserver/multiple%20words')
113 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
114 self.assertEqual(r.apppath, b'/multiple words')
115 self.assertEqual(r.dispatchparts, [])
116 self.assertEqual(r.dispatchpath, b'')
117 self.assertFalse(r.havepathinfo)
118
119 def testpathinfo(self):
120 r = parse(DEFAULT_ENV, extra={
121 r'PATH_INFO': r'',
122 })
123
124 self.assertEqual(r.url, b'http://testserver')
125 self.assertEqual(r.baseurl, b'http://testserver')
126 self.assertEqual(r.advertisedurl, b'http://testserver')
127 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
128 self.assertEqual(r.apppath, b'')
129 self.assertEqual(r.dispatchparts, [])
130 self.assertEqual(r.dispatchpath, b'')
131 self.assertTrue(r.havepathinfo)
132
133 r = parse(DEFAULT_ENV, extra={
134 r'PATH_INFO': r'/pathinfo',
135 })
136
137 self.assertEqual(r.url, b'http://testserver/pathinfo')
138 self.assertEqual(r.baseurl, b'http://testserver')
139 self.assertEqual(r.advertisedurl, b'http://testserver/pathinfo')
140 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
141 self.assertEqual(r.apppath, b'')
142 self.assertEqual(r.dispatchparts, [b'pathinfo'])
143 self.assertEqual(r.dispatchpath, b'pathinfo')
144 self.assertTrue(r.havepathinfo)
145
146 r = parse(DEFAULT_ENV, extra={
147 r'PATH_INFO': r'/one/two/',
148 })
149
150 self.assertEqual(r.url, b'http://testserver/one/two/')
151 self.assertEqual(r.baseurl, b'http://testserver')
152 self.assertEqual(r.advertisedurl, b'http://testserver/one/two/')
153 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
154 self.assertEqual(r.apppath, b'')
155 self.assertEqual(r.dispatchparts, [b'one', b'two'])
156 self.assertEqual(r.dispatchpath, b'one/two')
157 self.assertTrue(r.havepathinfo)
158
159 def testscriptandpathinfo(self):
160 r = parse(DEFAULT_ENV, extra={
161 r'SCRIPT_NAME': r'/script',
162 r'PATH_INFO': r'/pathinfo',
163 })
164
165 self.assertEqual(r.url, b'http://testserver/script/pathinfo')
166 self.assertEqual(r.baseurl, b'http://testserver')
167 self.assertEqual(r.advertisedurl, b'http://testserver/script/pathinfo')
168 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
169 self.assertEqual(r.apppath, b'/script')
170 self.assertEqual(r.dispatchparts, [b'pathinfo'])
171 self.assertEqual(r.dispatchpath, b'pathinfo')
172 self.assertTrue(r.havepathinfo)
173
174 r = parse(DEFAULT_ENV, extra={
175 r'SCRIPT_NAME': r'/script1/script2',
176 r'PATH_INFO': r'/path1/path2',
177 })
178
179 self.assertEqual(r.url,
180 b'http://testserver/script1/script2/path1/path2')
181 self.assertEqual(r.baseurl, b'http://testserver')
182 self.assertEqual(r.advertisedurl,
183 b'http://testserver/script1/script2/path1/path2')
184 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
185 self.assertEqual(r.apppath, b'/script1/script2')
186 self.assertEqual(r.dispatchparts, [b'path1', b'path2'])
187 self.assertEqual(r.dispatchpath, b'path1/path2')
188 self.assertTrue(r.havepathinfo)
189
190 r = parse(DEFAULT_ENV, extra={
191 r'HTTP_HOST': r'hostserver',
192 r'SCRIPT_NAME': r'/script',
193 r'PATH_INFO': r'/pathinfo',
194 })
195
196 self.assertEqual(r.url, b'http://hostserver/script/pathinfo')
197 self.assertEqual(r.baseurl, b'http://hostserver')
198 self.assertEqual(r.advertisedurl, b'http://testserver/script/pathinfo')
199 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
200 self.assertEqual(r.apppath, b'/script')
201 self.assertEqual(r.dispatchparts, [b'pathinfo'])
202 self.assertEqual(r.dispatchpath, b'pathinfo')
203 self.assertTrue(r.havepathinfo)
204
205 def testreponame(self):
206 """REPO_NAME path components get stripped from URL."""
207 r = parse(DEFAULT_ENV, extra={
208 r'REPO_NAME': r'repo',
209 r'PATH_INFO': r'/path1/path2'
210 })
211
212 self.assertEqual(r.url, b'http://testserver/path1/path2')
213 self.assertEqual(r.baseurl, b'http://testserver')
214 self.assertEqual(r.advertisedurl, b'http://testserver/path1/path2')
215 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
216 self.assertEqual(r.apppath, b'/repo')
217 self.assertEqual(r.dispatchparts, [b'path1', b'path2'])
218 self.assertEqual(r.dispatchpath, b'path1/path2')
219 self.assertTrue(r.havepathinfo)
220 self.assertEqual(r.reponame, b'repo')
221
222 r = parse(DEFAULT_ENV, extra={
223 r'REPO_NAME': r'repo',
224 r'PATH_INFO': r'/repo/path1/path2',
225 })
226
227 self.assertEqual(r.url, b'http://testserver/repo/path1/path2')
228 self.assertEqual(r.baseurl, b'http://testserver')
229 self.assertEqual(r.advertisedurl, b'http://testserver/repo/path1/path2')
230 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
231 self.assertEqual(r.apppath, b'/repo')
232 self.assertEqual(r.dispatchparts, [b'path1', b'path2'])
233 self.assertEqual(r.dispatchpath, b'path1/path2')
234 self.assertTrue(r.havepathinfo)
235 self.assertEqual(r.reponame, b'repo')
236
237 r = parse(DEFAULT_ENV, extra={
238 r'REPO_NAME': r'prefix/repo',
239 r'PATH_INFO': r'/prefix/repo/path1/path2',
240 })
241
242 self.assertEqual(r.url, b'http://testserver/prefix/repo/path1/path2')
243 self.assertEqual(r.baseurl, b'http://testserver')
244 self.assertEqual(r.advertisedurl,
245 b'http://testserver/prefix/repo/path1/path2')
246 self.assertEqual(r.advertisedbaseurl, b'http://testserver')
247 self.assertEqual(r.apppath, b'/prefix/repo')
248 self.assertEqual(r.dispatchparts, [b'path1', b'path2'])
249 self.assertEqual(r.dispatchpath, b'path1/path2')
250 self.assertTrue(r.havepathinfo)
251 self.assertEqual(r.reponame, b'prefix/repo')
252
253 if __name__ == '__main__':
254 import silenttestrunner
255 silenttestrunner.main(__name__)
@@ -1,651 +1,651 b''
1 1 # hgweb/request.py - An http request from either CGI or the standalone server.
2 2 #
3 3 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
4 4 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
5 5 #
6 6 # This software may be used and distributed according to the terms of the
7 7 # GNU General Public License version 2 or any later version.
8 8
9 9 from __future__ import absolute_import
10 10
11 11 import errno
12 12 import socket
13 13 import wsgiref.headers as wsgiheaders
14 14 #import wsgiref.validate
15 15
16 16 from .common import (
17 17 ErrorResponse,
18 18 statusmessage,
19 19 )
20 20
21 21 from ..thirdparty import (
22 22 attr,
23 23 )
24 24 from .. import (
25 25 error,
26 26 pycompat,
27 27 util,
28 28 )
29 29
30 30 class multidict(object):
31 31 """A dict like object that can store multiple values for a key.
32 32
33 33 Used to store parsed request parameters.
34 34
35 35 This is inspired by WebOb's class of the same name.
36 36 """
37 37 def __init__(self):
38 38 # Stores (key, value) 2-tuples. This isn't the most efficient. But we
39 39 # don't rely on parameters that much, so it shouldn't be a perf issue.
40 40 # we can always add dict for fast lookups.
41 41 self._items = []
42 42
43 43 def __getitem__(self, key):
44 44 """Returns the last set value for a key."""
45 45 for k, v in reversed(self._items):
46 46 if k == key:
47 47 return v
48 48
49 49 raise KeyError(key)
50 50
51 51 def __setitem__(self, key, value):
52 52 """Replace a values for a key with a new value."""
53 53 try:
54 54 del self[key]
55 55 except KeyError:
56 56 pass
57 57
58 58 self._items.append((key, value))
59 59
60 60 def __delitem__(self, key):
61 61 """Delete all values for a key."""
62 62 oldlen = len(self._items)
63 63
64 64 self._items[:] = [(k, v) for k, v in self._items if k != key]
65 65
66 66 if oldlen == len(self._items):
67 67 raise KeyError(key)
68 68
69 69 def __contains__(self, key):
70 70 return any(k == key for k, v in self._items)
71 71
72 72 def __len__(self):
73 73 return len(self._items)
74 74
75 75 def get(self, key, default=None):
76 76 try:
77 77 return self.__getitem__(key)
78 78 except KeyError:
79 79 return default
80 80
81 81 def add(self, key, value):
82 82 """Add a new value for a key. Does not replace existing values."""
83 83 self._items.append((key, value))
84 84
85 85 def getall(self, key):
86 86 """Obtains all values for a key."""
87 87 return [v for k, v in self._items if k == key]
88 88
89 89 def getone(self, key):
90 90 """Obtain a single value for a key.
91 91
92 92 Raises KeyError if key not defined or it has multiple values set.
93 93 """
94 94 vals = self.getall(key)
95 95
96 96 if not vals:
97 97 raise KeyError(key)
98 98
99 99 if len(vals) > 1:
100 100 raise KeyError('multiple values for %r' % key)
101 101
102 102 return vals[0]
103 103
104 104 def asdictoflists(self):
105 105 d = {}
106 106 for k, v in self._items:
107 107 if k in d:
108 108 d[k].append(v)
109 109 else:
110 110 d[k] = [v]
111 111
112 112 return d
113 113
114 114 @attr.s(frozen=True)
115 115 class parsedrequest(object):
116 116 """Represents a parsed WSGI request.
117 117
118 118 Contains both parsed parameters as well as a handle on the input stream.
119 119 """
120 120
121 121 # Request method.
122 122 method = attr.ib()
123 123 # Full URL for this request.
124 124 url = attr.ib()
125 125 # URL without any path components. Just <proto>://<host><port>.
126 126 baseurl = attr.ib()
127 127 # Advertised URL. Like ``url`` and ``baseurl`` but uses SERVER_NAME instead
128 128 # of HTTP: Host header for hostname. This is likely what clients used.
129 129 advertisedurl = attr.ib()
130 130 advertisedbaseurl = attr.ib()
131 131 # URL scheme (part before ``://``). e.g. ``http`` or ``https``.
132 132 urlscheme = attr.ib()
133 133 # Value of REMOTE_USER, if set, or None.
134 134 remoteuser = attr.ib()
135 135 # Value of REMOTE_HOST, if set, or None.
136 136 remotehost = attr.ib()
137 137 # WSGI application path.
138 138 apppath = attr.ib()
139 139 # List of path parts to be used for dispatch.
140 140 dispatchparts = attr.ib()
141 141 # URL path component (no query string) used for dispatch.
142 142 dispatchpath = attr.ib()
143 143 # Whether there is a path component to this request. This can be true
144 144 # when ``dispatchpath`` is empty due to REPO_NAME muckery.
145 145 havepathinfo = attr.ib()
146 146 # The name of the repository being accessed.
147 147 reponame = attr.ib()
148 148 # Raw query string (part after "?" in URL).
149 149 querystring = attr.ib()
150 150 # multidict of query string parameters.
151 151 qsparams = attr.ib()
152 152 # wsgiref.headers.Headers instance. Operates like a dict with case
153 153 # insensitive keys.
154 154 headers = attr.ib()
155 155 # Request body input stream.
156 156 bodyfh = attr.ib()
157 157
158 158 def parserequestfromenv(env, bodyfh):
159 159 """Parse URL components from environment variables.
160 160
161 161 WSGI defines request attributes via environment variables. This function
162 162 parses the environment variables into a data structure.
163 163 """
164 164 # PEP-0333 defines the WSGI spec and is a useful reference for this code.
165 165
166 166 # We first validate that the incoming object conforms with the WSGI spec.
167 167 # We only want to be dealing with spec-conforming WSGI implementations.
168 168 # TODO enable this once we fix internal violations.
169 169 #wsgiref.validate.check_environ(env)
170 170
171 171 # PEP-0333 states that environment keys and values are native strings
172 172 # (bytes on Python 2 and str on Python 3). The code points for the Unicode
173 173 # strings on Python 3 must be between \00000-\000FF. We deal with bytes
174 174 # in Mercurial, so mass convert string keys and values to bytes.
175 175 if pycompat.ispy3:
176 176 env = {k.encode('latin-1'): v for k, v in env.iteritems()}
177 177 env = {k: v.encode('latin-1') if isinstance(v, str) else v
178 178 for k, v in env.iteritems()}
179 179
180 180 # https://www.python.org/dev/peps/pep-0333/#environ-variables defines
181 181 # the environment variables.
182 182 # https://www.python.org/dev/peps/pep-0333/#url-reconstruction defines
183 183 # how URLs are reconstructed.
184 184 fullurl = env['wsgi.url_scheme'] + '://'
185 185 advertisedfullurl = fullurl
186 186
187 187 def addport(s):
188 188 if env['wsgi.url_scheme'] == 'https':
189 189 if env['SERVER_PORT'] != '443':
190 190 s += ':' + env['SERVER_PORT']
191 191 else:
192 192 if env['SERVER_PORT'] != '80':
193 193 s += ':' + env['SERVER_PORT']
194 194
195 195 return s
196 196
197 197 if env.get('HTTP_HOST'):
198 198 fullurl += env['HTTP_HOST']
199 199 else:
200 200 fullurl += env['SERVER_NAME']
201 201 fullurl = addport(fullurl)
202 202
203 203 advertisedfullurl += env['SERVER_NAME']
204 204 advertisedfullurl = addport(advertisedfullurl)
205 205
206 206 baseurl = fullurl
207 207 advertisedbaseurl = advertisedfullurl
208 208
209 209 fullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
210 210 advertisedfullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
211 211 fullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
212 212 advertisedfullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
213 213
214 214 if env.get('QUERY_STRING'):
215 215 fullurl += '?' + env['QUERY_STRING']
216 216 advertisedfullurl += '?' + env['QUERY_STRING']
217 217
218 218 # When dispatching requests, we look at the URL components (PATH_INFO
219 219 # and QUERY_STRING) after the application root (SCRIPT_NAME). But hgwebdir
220 220 # has the concept of "virtual" repositories. This is defined via REPO_NAME.
221 221 # If REPO_NAME is defined, we append it to SCRIPT_NAME to form a new app
222 222 # root. We also exclude its path components from PATH_INFO when resolving
223 223 # the dispatch path.
224 224
225 apppath = env['SCRIPT_NAME']
225 apppath = env.get('SCRIPT_NAME', '')
226 226
227 227 if env.get('REPO_NAME'):
228 228 if not apppath.endswith('/'):
229 229 apppath += '/'
230 230
231 231 apppath += env.get('REPO_NAME')
232 232
233 233 if 'PATH_INFO' in env:
234 234 dispatchparts = env['PATH_INFO'].strip('/').split('/')
235 235
236 236 # Strip out repo parts.
237 237 repoparts = env.get('REPO_NAME', '').split('/')
238 238 if dispatchparts[:len(repoparts)] == repoparts:
239 239 dispatchparts = dispatchparts[len(repoparts):]
240 240 else:
241 241 dispatchparts = []
242 242
243 243 dispatchpath = '/'.join(dispatchparts)
244 244
245 245 querystring = env.get('QUERY_STRING', '')
246 246
247 247 # We store as a list so we have ordering information. We also store as
248 248 # a dict to facilitate fast lookup.
249 249 qsparams = multidict()
250 250 for k, v in util.urlreq.parseqsl(querystring, keep_blank_values=True):
251 251 qsparams.add(k, v)
252 252
253 253 # HTTP_* keys contain HTTP request headers. The Headers structure should
254 254 # perform case normalization for us. We just rewrite underscore to dash
255 255 # so keys match what likely went over the wire.
256 256 headers = []
257 257 for k, v in env.iteritems():
258 258 if k.startswith('HTTP_'):
259 259 headers.append((k[len('HTTP_'):].replace('_', '-'), v))
260 260
261 261 headers = wsgiheaders.Headers(headers)
262 262
263 263 # This is kind of a lie because the HTTP header wasn't explicitly
264 264 # sent. But for all intents and purposes it should be OK to lie about
265 265 # this, since a consumer will either either value to determine how many
266 266 # bytes are available to read.
267 267 if 'CONTENT_LENGTH' in env and 'HTTP_CONTENT_LENGTH' not in env:
268 268 headers['Content-Length'] = env['CONTENT_LENGTH']
269 269
270 270 # TODO do this once we remove wsgirequest.inp, otherwise we could have
271 271 # multiple readers from the underlying input stream.
272 272 #bodyfh = env['wsgi.input']
273 273 #if 'Content-Length' in headers:
274 274 # bodyfh = util.cappedreader(bodyfh, int(headers['Content-Length']))
275 275
276 276 return parsedrequest(method=env['REQUEST_METHOD'],
277 277 url=fullurl, baseurl=baseurl,
278 278 advertisedurl=advertisedfullurl,
279 279 advertisedbaseurl=advertisedbaseurl,
280 280 urlscheme=env['wsgi.url_scheme'],
281 281 remoteuser=env.get('REMOTE_USER'),
282 282 remotehost=env.get('REMOTE_HOST'),
283 283 apppath=apppath,
284 284 dispatchparts=dispatchparts, dispatchpath=dispatchpath,
285 285 havepathinfo='PATH_INFO' in env,
286 286 reponame=env.get('REPO_NAME'),
287 287 querystring=querystring,
288 288 qsparams=qsparams,
289 289 headers=headers,
290 290 bodyfh=bodyfh)
291 291
292 292 class offsettrackingwriter(object):
293 293 """A file object like object that is append only and tracks write count.
294 294
295 295 Instances are bound to a callable. This callable is called with data
296 296 whenever a ``write()`` is attempted.
297 297
298 298 Instances track the amount of written data so they can answer ``tell()``
299 299 requests.
300 300
301 301 The intent of this class is to wrap the ``write()`` function returned by
302 302 a WSGI ``start_response()`` function. Since ``write()`` is a callable and
303 303 not a file object, it doesn't implement other file object methods.
304 304 """
305 305 def __init__(self, writefn):
306 306 self._write = writefn
307 307 self._offset = 0
308 308
309 309 def write(self, s):
310 310 res = self._write(s)
311 311 # Some Python objects don't report the number of bytes written.
312 312 if res is None:
313 313 self._offset += len(s)
314 314 else:
315 315 self._offset += res
316 316
317 317 def flush(self):
318 318 pass
319 319
320 320 def tell(self):
321 321 return self._offset
322 322
323 323 class wsgiresponse(object):
324 324 """Represents a response to a WSGI request.
325 325
326 326 A response consists of a status line, headers, and a body.
327 327
328 328 Consumers must populate the ``status`` and ``headers`` fields and
329 329 make a call to a ``setbody*()`` method before the response can be
330 330 issued.
331 331
332 332 When it is time to start sending the response over the wire,
333 333 ``sendresponse()`` is called. It handles emitting the header portion
334 334 of the response message. It then yields chunks of body data to be
335 335 written to the peer. Typically, the WSGI application itself calls
336 336 and returns the value from ``sendresponse()``.
337 337 """
338 338
339 339 def __init__(self, req, startresponse):
340 340 """Create an empty response tied to a specific request.
341 341
342 342 ``req`` is a ``parsedrequest``. ``startresponse`` is the
343 343 ``start_response`` function passed to the WSGI application.
344 344 """
345 345 self._req = req
346 346 self._startresponse = startresponse
347 347
348 348 self.status = None
349 349 self.headers = wsgiheaders.Headers([])
350 350
351 351 self._bodybytes = None
352 352 self._bodygen = None
353 353 self._bodywillwrite = False
354 354 self._started = False
355 355 self._bodywritefn = None
356 356
357 357 def _verifybody(self):
358 358 if (self._bodybytes is not None or self._bodygen is not None
359 359 or self._bodywillwrite):
360 360 raise error.ProgrammingError('cannot define body multiple times')
361 361
362 362 def setbodybytes(self, b):
363 363 """Define the response body as static bytes.
364 364
365 365 The empty string signals that there is no response body.
366 366 """
367 367 self._verifybody()
368 368 self._bodybytes = b
369 369 self.headers['Content-Length'] = '%d' % len(b)
370 370
371 371 def setbodygen(self, gen):
372 372 """Define the response body as a generator of bytes."""
373 373 self._verifybody()
374 374 self._bodygen = gen
375 375
376 376 def setbodywillwrite(self):
377 377 """Signal an intent to use write() to emit the response body.
378 378
379 379 **This is the least preferred way to send a body.**
380 380
381 381 It is preferred for WSGI applications to emit a generator of chunks
382 382 constituting the response body. However, some consumers can't emit
383 383 data this way. So, WSGI provides a way to obtain a ``write(data)``
384 384 function that can be used to synchronously perform an unbuffered
385 385 write.
386 386
387 387 Calling this function signals an intent to produce the body in this
388 388 manner.
389 389 """
390 390 self._verifybody()
391 391 self._bodywillwrite = True
392 392
393 393 def sendresponse(self):
394 394 """Send the generated response to the client.
395 395
396 396 Before this is called, ``status`` must be set and one of
397 397 ``setbodybytes()`` or ``setbodygen()`` must be called.
398 398
399 399 Calling this method multiple times is not allowed.
400 400 """
401 401 if self._started:
402 402 raise error.ProgrammingError('sendresponse() called multiple times')
403 403
404 404 self._started = True
405 405
406 406 if not self.status:
407 407 raise error.ProgrammingError('status line not defined')
408 408
409 409 if (self._bodybytes is None and self._bodygen is None
410 410 and not self._bodywillwrite):
411 411 raise error.ProgrammingError('response body not defined')
412 412
413 413 # RFC 7232 Section 4.1 states that a 304 MUST generate one of
414 414 # {Cache-Control, Content-Location, Date, ETag, Expires, Vary}
415 415 # and SHOULD NOT generate other headers unless they could be used
416 416 # to guide cache updates. Furthermore, RFC 7230 Section 3.3.2
417 417 # states that no response body can be issued. Content-Length can
418 418 # be sent. But if it is present, it should be the size of the response
419 419 # that wasn't transferred.
420 420 if self.status.startswith('304 '):
421 421 # setbodybytes('') will set C-L to 0. This doesn't conform with the
422 422 # spec. So remove it.
423 423 if self.headers.get('Content-Length') == '0':
424 424 del self.headers['Content-Length']
425 425
426 426 # Strictly speaking, this is too strict. But until it causes
427 427 # problems, let's be strict.
428 428 badheaders = {k for k in self.headers.keys()
429 429 if k.lower() not in ('date', 'etag', 'expires',
430 430 'cache-control',
431 431 'content-location',
432 432 'vary')}
433 433 if badheaders:
434 434 raise error.ProgrammingError(
435 435 'illegal header on 304 response: %s' %
436 436 ', '.join(sorted(badheaders)))
437 437
438 438 if self._bodygen is not None or self._bodywillwrite:
439 439 raise error.ProgrammingError("must use setbodybytes('') with "
440 440 "304 responses")
441 441
442 442 # Various HTTP clients (notably httplib) won't read the HTTP response
443 443 # until the HTTP request has been sent in full. If servers (us) send a
444 444 # response before the HTTP request has been fully sent, the connection
445 445 # may deadlock because neither end is reading.
446 446 #
447 447 # We work around this by "draining" the request data before
448 448 # sending any response in some conditions.
449 449 drain = False
450 450 close = False
451 451
452 452 # If the client sent Expect: 100-continue, we assume it is smart enough
453 453 # to deal with the server sending a response before reading the request.
454 454 # (httplib doesn't do this.)
455 455 if self._req.headers.get('Expect', '').lower() == '100-continue':
456 456 pass
457 457 # Only tend to request methods that have bodies. Strictly speaking,
458 458 # we should sniff for a body. But this is fine for our existing
459 459 # WSGI applications.
460 460 elif self._req.method not in ('POST', 'PUT'):
461 461 pass
462 462 else:
463 463 # If we don't know how much data to read, there's no guarantee
464 464 # that we can drain the request responsibly. The WSGI
465 465 # specification only says that servers *should* ensure the
466 466 # input stream doesn't overrun the actual request. So there's
467 467 # no guarantee that reading until EOF won't corrupt the stream
468 468 # state.
469 469 if not isinstance(self._req.bodyfh, util.cappedreader):
470 470 close = True
471 471 else:
472 472 # We /could/ only drain certain HTTP response codes. But 200 and
473 473 # non-200 wire protocol responses both require draining. Since
474 474 # we have a capped reader in place for all situations where we
475 475 # drain, it is safe to read from that stream. We'll either do
476 476 # a drain or no-op if we're already at EOF.
477 477 drain = True
478 478
479 479 if close:
480 480 self.headers['Connection'] = 'Close'
481 481
482 482 if drain:
483 483 assert isinstance(self._req.bodyfh, util.cappedreader)
484 484 while True:
485 485 chunk = self._req.bodyfh.read(32768)
486 486 if not chunk:
487 487 break
488 488
489 489 write = self._startresponse(pycompat.sysstr(self.status),
490 490 self.headers.items())
491 491
492 492 if self._bodybytes:
493 493 yield self._bodybytes
494 494 elif self._bodygen:
495 495 for chunk in self._bodygen:
496 496 yield chunk
497 497 elif self._bodywillwrite:
498 498 self._bodywritefn = write
499 499 else:
500 500 error.ProgrammingError('do not know how to send body')
501 501
502 502 def getbodyfile(self):
503 503 """Obtain a file object like object representing the response body.
504 504
505 505 For this to work, you must call ``setbodywillwrite()`` and then
506 506 ``sendresponse()`` first. ``sendresponse()`` is a generator and the
507 507 function won't run to completion unless the generator is advanced. The
508 508 generator yields not items. The easiest way to consume it is with
509 509 ``list(res.sendresponse())``, which should resolve to an empty list -
510 510 ``[]``.
511 511 """
512 512 if not self._bodywillwrite:
513 513 raise error.ProgrammingError('must call setbodywillwrite() first')
514 514
515 515 if not self._started:
516 516 raise error.ProgrammingError('must call sendresponse() first; did '
517 517 'you remember to consume it since it '
518 518 'is a generator?')
519 519
520 520 assert self._bodywritefn
521 521 return offsettrackingwriter(self._bodywritefn)
522 522
523 523 class wsgirequest(object):
524 524 """Higher-level API for a WSGI request.
525 525
526 526 WSGI applications are invoked with 2 arguments. They are used to
527 527 instantiate instances of this class, which provides higher-level APIs
528 528 for obtaining request parameters, writing HTTP output, etc.
529 529 """
530 530 def __init__(self, wsgienv, start_response):
531 531 version = wsgienv[r'wsgi.version']
532 532 if (version < (1, 0)) or (version >= (2, 0)):
533 533 raise RuntimeError("Unknown and unsupported WSGI version %d.%d"
534 534 % version)
535 535
536 536 inp = wsgienv[r'wsgi.input']
537 537
538 538 if r'HTTP_CONTENT_LENGTH' in wsgienv:
539 539 inp = util.cappedreader(inp, int(wsgienv[r'HTTP_CONTENT_LENGTH']))
540 540 elif r'CONTENT_LENGTH' in wsgienv:
541 541 inp = util.cappedreader(inp, int(wsgienv[r'CONTENT_LENGTH']))
542 542
543 543 self.err = wsgienv[r'wsgi.errors']
544 544 self.threaded = wsgienv[r'wsgi.multithread']
545 545 self.multiprocess = wsgienv[r'wsgi.multiprocess']
546 546 self.run_once = wsgienv[r'wsgi.run_once']
547 547 self.env = wsgienv
548 548 self.req = parserequestfromenv(wsgienv, inp)
549 549 self.res = wsgiresponse(self.req, start_response)
550 550 self._start_response = start_response
551 551 self.server_write = None
552 552 self.headers = []
553 553
554 554 def respond(self, status, type, filename=None, body=None):
555 555 if not isinstance(type, str):
556 556 type = pycompat.sysstr(type)
557 557 if self._start_response is not None:
558 558 self.headers.append((r'Content-Type', type))
559 559 if filename:
560 560 filename = (filename.rpartition('/')[-1]
561 561 .replace('\\', '\\\\').replace('"', '\\"'))
562 562 self.headers.append(('Content-Disposition',
563 563 'inline; filename="%s"' % filename))
564 564 if body is not None:
565 565 self.headers.append((r'Content-Length', str(len(body))))
566 566
567 567 for k, v in self.headers:
568 568 if not isinstance(v, str):
569 569 raise TypeError('header value must be string: %r' % (v,))
570 570
571 571 if isinstance(status, ErrorResponse):
572 572 self.headers.extend(status.headers)
573 573 status = statusmessage(status.code, pycompat.bytestr(status))
574 574 elif status == 200:
575 575 status = '200 Script output follows'
576 576 elif isinstance(status, int):
577 577 status = statusmessage(status)
578 578
579 579 # Various HTTP clients (notably httplib) won't read the HTTP
580 580 # response until the HTTP request has been sent in full. If servers
581 581 # (us) send a response before the HTTP request has been fully sent,
582 582 # the connection may deadlock because neither end is reading.
583 583 #
584 584 # We work around this by "draining" the request data before
585 585 # sending any response in some conditions.
586 586 drain = False
587 587 close = False
588 588
589 589 # If the client sent Expect: 100-continue, we assume it is smart
590 590 # enough to deal with the server sending a response before reading
591 591 # the request. (httplib doesn't do this.)
592 592 if self.env.get(r'HTTP_EXPECT', r'').lower() == r'100-continue':
593 593 pass
594 594 # Only tend to request methods that have bodies. Strictly speaking,
595 595 # we should sniff for a body. But this is fine for our existing
596 596 # WSGI applications.
597 597 elif self.env[r'REQUEST_METHOD'] not in (r'POST', r'PUT'):
598 598 pass
599 599 else:
600 600 # If we don't know how much data to read, there's no guarantee
601 601 # that we can drain the request responsibly. The WSGI
602 602 # specification only says that servers *should* ensure the
603 603 # input stream doesn't overrun the actual request. So there's
604 604 # no guarantee that reading until EOF won't corrupt the stream
605 605 # state.
606 606 if not isinstance(self.req.bodyfh, util.cappedreader):
607 607 close = True
608 608 else:
609 609 # We /could/ only drain certain HTTP response codes. But 200
610 610 # and non-200 wire protocol responses both require draining.
611 611 # Since we have a capped reader in place for all situations
612 612 # where we drain, it is safe to read from that stream. We'll
613 613 # either do a drain or no-op if we're already at EOF.
614 614 drain = True
615 615
616 616 if close:
617 617 self.headers.append((r'Connection', r'Close'))
618 618
619 619 if drain:
620 620 assert isinstance(self.req.bodyfh, util.cappedreader)
621 621 while True:
622 622 chunk = self.req.bodyfh.read(32768)
623 623 if not chunk:
624 624 break
625 625
626 626 self.server_write = self._start_response(
627 627 pycompat.sysstr(status), self.headers)
628 628 self._start_response = None
629 629 self.headers = []
630 630 if body is not None:
631 631 self.write(body)
632 632 self.server_write = None
633 633
634 634 def write(self, thing):
635 635 if thing:
636 636 try:
637 637 self.server_write(thing)
638 638 except socket.error as inst:
639 639 if inst[0] != errno.ECONNRESET:
640 640 raise
641 641
642 642 def flush(self):
643 643 return None
644 644
645 645 def wsgiapplication(app_maker):
646 646 '''For compatibility with old CGI scripts. A plain hgweb() or hgwebdir()
647 647 can and should now be used as a WSGI application.'''
648 648 application = app_maker()
649 649 def run_wsgi(env, respond):
650 650 return application(env, respond)
651 651 return run_wsgi
General Comments 0
You need to be logged in to leave comments. Login now