diff --git a/mercurial/hgweb/hgwebdir_mod.py b/mercurial/hgweb/hgwebdir_mod.py --- a/mercurial/hgweb/hgwebdir_mod.py +++ b/mercurial/hgweb/hgwebdir_mod.py @@ -291,7 +291,8 @@ class hgwebdir(object): # variable. # TODO this is kind of hacky and we should have a better # way of doing this than with REPO_NAME side-effects. - wsgireq.req = requestmod.parserequestfromenv(wsgireq.env) + wsgireq.req = requestmod.parserequestfromenv( + wsgireq.env, wsgireq.req.bodyfh) try: # ensure caller gets private copy of ui repo = hg.repository(self.ui.copy(), real) diff --git a/mercurial/hgweb/request.py b/mercurial/hgweb/request.py --- a/mercurial/hgweb/request.py +++ b/mercurial/hgweb/request.py @@ -61,7 +61,10 @@ def normalize(form): @attr.s(frozen=True) class parsedrequest(object): - """Represents a parsed WSGI request / static HTTP request parameters.""" + """Represents a parsed WSGI request. + + Contains both parsed parameters as well as a handle on the input stream. + """ # Request method. method = attr.ib() @@ -91,8 +94,10 @@ class parsedrequest(object): # wsgiref.headers.Headers instance. Operates like a dict with case # insensitive keys. headers = attr.ib() + # Request body input stream. + bodyfh = attr.ib() -def parserequestfromenv(env): +def parserequestfromenv(env, bodyfh): """Parse URL components from environment variables. WSGI defines request attributes via environment variables. This function @@ -209,6 +214,12 @@ def parserequestfromenv(env): if 'CONTENT_LENGTH' in env and 'HTTP_CONTENT_LENGTH' not in env: headers['Content-Length'] = env['CONTENT_LENGTH'] + # TODO do this once we remove wsgirequest.inp, otherwise we could have + # multiple readers from the underlying input stream. + #bodyfh = env['wsgi.input'] + #if 'Content-Length' in headers: + # bodyfh = util.cappedreader(bodyfh, int(headers['Content-Length'])) + return parsedrequest(method=env['REQUEST_METHOD'], url=fullurl, baseurl=baseurl, advertisedurl=advertisedfullurl, @@ -219,7 +230,8 @@ def parserequestfromenv(env): querystring=querystring, querystringlist=querystringlist, querystringdict=querystringdict, - headers=headers) + headers=headers, + bodyfh=bodyfh) class wsgirequest(object): """Higher-level API for a WSGI request. @@ -233,28 +245,27 @@ class wsgirequest(object): if (version < (1, 0)) or (version >= (2, 0)): raise RuntimeError("Unknown and unsupported WSGI version %d.%d" % version) - self.inp = wsgienv[r'wsgi.input'] + + inp = wsgienv[r'wsgi.input'] if r'HTTP_CONTENT_LENGTH' in wsgienv: - self.inp = util.cappedreader(self.inp, - int(wsgienv[r'HTTP_CONTENT_LENGTH'])) + inp = util.cappedreader(inp, int(wsgienv[r'HTTP_CONTENT_LENGTH'])) elif r'CONTENT_LENGTH' in wsgienv: - self.inp = util.cappedreader(self.inp, - int(wsgienv[r'CONTENT_LENGTH'])) + inp = util.cappedreader(inp, int(wsgienv[r'CONTENT_LENGTH'])) self.err = wsgienv[r'wsgi.errors'] self.threaded = wsgienv[r'wsgi.multithread'] self.multiprocess = wsgienv[r'wsgi.multiprocess'] self.run_once = wsgienv[r'wsgi.run_once'] self.env = wsgienv - self.form = normalize(cgi.parse(self.inp, + self.form = normalize(cgi.parse(inp, self.env, keep_blank_values=1)) self._start_response = start_response self.server_write = None self.headers = [] - self.req = parserequestfromenv(wsgienv) + self.req = parserequestfromenv(wsgienv, inp) def respond(self, status, type, filename=None, body=None): if not isinstance(type, str): @@ -315,7 +326,7 @@ class wsgirequest(object): # input stream doesn't overrun the actual request. So there's # no guarantee that reading until EOF won't corrupt the stream # state. - if not isinstance(self.inp, util.cappedreader): + if not isinstance(self.req.bodyfh, util.cappedreader): close = True else: # We /could/ only drain certain HTTP response codes. But 200 @@ -329,9 +340,9 @@ class wsgirequest(object): self.headers.append((r'Connection', r'Close')) if drain: - assert isinstance(self.inp, util.cappedreader) + assert isinstance(self.req.bodyfh, util.cappedreader) while True: - chunk = self.inp.read(32768) + chunk = self.req.bodyfh.read(32768) if not chunk: break diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -83,7 +83,7 @@ class httpv1protocolhandler(wireprototyp postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0)) if postlen: args.update(urlreq.parseqs( - self._wsgireq.inp.read(postlen), keep_blank_values=True)) + self._req.bodyfh.read(postlen), keep_blank_values=True)) return args argvalue = decodevaluefromheaders(self._req, b'X-HgArg') @@ -97,7 +97,7 @@ class httpv1protocolhandler(wireprototyp # If httppostargs is used, we need to read Content-Length # minus the amount that was consumed by args. length -= int(self._req.headers.get(b'X-HgArgs-Post', 0)) - for s in util.filechunkiter(self._wsgireq.inp, limit=length): + for s in util.filechunkiter(self._req.bodyfh, limit=length): fp.write(s) @contextlib.contextmanager