diff --git a/mercurial/hgweb/request.py b/mercurial/hgweb/request.py --- a/mercurial/hgweb/request.py +++ b/mercurial/hgweb/request.py @@ -28,6 +28,90 @@ from .. import ( util, ) +class multidict(object): + """A dict like object that can store multiple values for a key. + + Used to store parsed request parameters. + + This is inspired by WebOb's class of the same name. + """ + def __init__(self): + # Stores (key, value) 2-tuples. This isn't the most efficient. But we + # don't rely on parameters that much, so it shouldn't be a perf issue. + # we can always add dict for fast lookups. + self._items = [] + + def __getitem__(self, key): + """Returns the last set value for a key.""" + for k, v in reversed(self._items): + if k == key: + return v + + raise KeyError(key) + + def __setitem__(self, key, value): + """Replace a values for a key with a new value.""" + try: + del self[key] + except KeyError: + pass + + self._items.append((key, value)) + + def __delitem__(self, key): + """Delete all values for a key.""" + oldlen = len(self._items) + + self._items[:] = [(k, v) for k, v in self._items if k != key] + + if oldlen == len(self._items): + raise KeyError(key) + + def __contains__(self, key): + return any(k == key for k, v in self._items) + + def __len__(self): + return len(self._items) + + def get(self, key, default=None): + try: + return self.__getitem__(key) + except KeyError: + return default + + def add(self, key, value): + """Add a new value for a key. Does not replace existing values.""" + self._items.append((key, value)) + + def getall(self, key): + """Obtains all values for a key.""" + return [v for k, v in self._items if k == key] + + def getone(self, key): + """Obtain a single value for a key. + + Raises KeyError if key not defined or it has multiple values set. + """ + vals = self.getall(key) + + if not vals: + raise KeyError(key) + + if len(vals) > 1: + raise KeyError('multiple values for %r' % key) + + return vals[0] + + def asdictoflists(self): + d = {} + for k, v in self._items: + if k in d: + d[k].append(v) + else: + d[k] = [v] + + return d + @attr.s(frozen=True) class parsedrequest(object): """Represents a parsed WSGI request. @@ -56,10 +140,8 @@ class parsedrequest(object): havepathinfo = attr.ib() # Raw query string (part after "?" in URL). querystring = attr.ib() - # List of 2-tuples of query string arguments. - querystringlist = attr.ib() - # Dict of query string arguments. Values are lists with at least 1 item. - querystringdict = attr.ib() + # multidict of query string parameters. + qsparams = attr.ib() # wsgiref.headers.Headers instance. Operates like a dict with case # insensitive keys. headers = attr.ib() @@ -157,14 +239,9 @@ def parserequestfromenv(env, bodyfh): # We store as a list so we have ordering information. We also store as # a dict to facilitate fast lookup. - querystringlist = util.urlreq.parseqsl(querystring, keep_blank_values=True) - - querystringdict = {} - for k, v in querystringlist: - if k in querystringdict: - querystringdict[k].append(v) - else: - querystringdict[k] = [v] + qsparams = multidict() + for k, v in util.urlreq.parseqsl(querystring, keep_blank_values=True): + qsparams.add(k, v) # HTTP_* keys contain HTTP request headers. The Headers structure should # perform case normalization for us. We just rewrite underscore to dash @@ -197,8 +274,7 @@ def parserequestfromenv(env, bodyfh): dispatchparts=dispatchparts, dispatchpath=dispatchpath, havepathinfo='PATH_INFO' in env, querystring=querystring, - querystringlist=querystringlist, - querystringdict=querystringdict, + qsparams=qsparams, headers=headers, bodyfh=bodyfh) @@ -350,7 +426,7 @@ class wsgirequest(object): self.run_once = wsgienv[r'wsgi.run_once'] self.env = wsgienv self.req = parserequestfromenv(wsgienv, inp) - self.form = self.req.querystringdict + self.form = self.req.qsparams.asdictoflists() self.res = wsgiresponse(self.req, start_response) self._start_response = start_response self.server_write = None diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -79,7 +79,7 @@ class httpv1protocolhandler(wireprototyp return [data[k] for k in keys] def _args(self): - args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy()) + args = self._req.qsparams.asdictoflists() postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0)) if postlen: args.update(urlreq.parseqs( @@ -170,10 +170,10 @@ def handlewsgirequest(rctx, wsgireq, req # HTTP version 1 wire protocol requests are denoted by a "cmd" query # string parameter. If it isn't present, this isn't a wire protocol # request. - if 'cmd' not in req.querystringdict: + if 'cmd' not in req.qsparams: return False - cmd = req.querystringdict['cmd'][0] + cmd = req.qsparams['cmd'] # The "cmd" request parameter is used by both the wire protocol and hgweb. # While not all wire protocol commands are available for all transports,