##// END OF EJS Templates
hgweb: refactor the request draining code...
Gregory Szorc -
r36871:2cdf47e1 default
parent child Browse files
Show More
@@ -1,320 +1,361
1 1 # hgweb/request.py - An http request from either CGI or the standalone server.
2 2 #
3 3 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
4 4 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
5 5 #
6 6 # This software may be used and distributed according to the terms of the
7 7 # GNU General Public License version 2 or any later version.
8 8
9 9 from __future__ import absolute_import
10 10
11 11 import cgi
12 12 import errno
13 13 import socket
14 14 import wsgiref.headers as wsgiheaders
15 15 #import wsgiref.validate
16 16
17 17 from .common import (
18 18 ErrorResponse,
19 19 HTTP_NOT_MODIFIED,
20 20 statusmessage,
21 21 )
22 22
23 23 from ..thirdparty import (
24 24 attr,
25 25 )
26 26 from .. import (
27 27 pycompat,
28 28 util,
29 29 )
30 30
31 31 shortcuts = {
32 32 'cl': [('cmd', ['changelog']), ('rev', None)],
33 33 'sl': [('cmd', ['shortlog']), ('rev', None)],
34 34 'cs': [('cmd', ['changeset']), ('node', None)],
35 35 'f': [('cmd', ['file']), ('filenode', None)],
36 36 'fl': [('cmd', ['filelog']), ('filenode', None)],
37 37 'fd': [('cmd', ['filediff']), ('node', None)],
38 38 'fa': [('cmd', ['annotate']), ('filenode', None)],
39 39 'mf': [('cmd', ['manifest']), ('manifest', None)],
40 40 'ca': [('cmd', ['archive']), ('node', None)],
41 41 'tags': [('cmd', ['tags'])],
42 42 'tip': [('cmd', ['changeset']), ('node', ['tip'])],
43 43 'static': [('cmd', ['static']), ('file', None)]
44 44 }
45 45
46 46 def normalize(form):
47 47 # first expand the shortcuts
48 48 for k in shortcuts:
49 49 if k in form:
50 50 for name, value in shortcuts[k]:
51 51 if value is None:
52 52 value = form[k]
53 53 form[name] = value
54 54 del form[k]
55 55 # And strip the values
56 56 bytesform = {}
57 57 for k, v in form.iteritems():
58 58 bytesform[pycompat.bytesurl(k)] = [
59 59 pycompat.bytesurl(i.strip()) for i in v]
60 60 return bytesform
61 61
62 62 @attr.s(frozen=True)
63 63 class parsedrequest(object):
64 64 """Represents a parsed WSGI request / static HTTP request parameters."""
65 65
66 66 # Request method.
67 67 method = attr.ib()
68 68 # Full URL for this request.
69 69 url = attr.ib()
70 70 # URL without any path components. Just <proto>://<host><port>.
71 71 baseurl = attr.ib()
72 72 # Advertised URL. Like ``url`` and ``baseurl`` but uses SERVER_NAME instead
73 73 # of HTTP: Host header for hostname. This is likely what clients used.
74 74 advertisedurl = attr.ib()
75 75 advertisedbaseurl = attr.ib()
76 76 # WSGI application path.
77 77 apppath = attr.ib()
78 78 # List of path parts to be used for dispatch.
79 79 dispatchparts = attr.ib()
80 80 # URL path component (no query string) used for dispatch.
81 81 dispatchpath = attr.ib()
82 82 # Whether there is a path component to this request. This can be true
83 83 # when ``dispatchpath`` is empty due to REPO_NAME muckery.
84 84 havepathinfo = attr.ib()
85 85 # Raw query string (part after "?" in URL).
86 86 querystring = attr.ib()
87 87 # List of 2-tuples of query string arguments.
88 88 querystringlist = attr.ib()
89 89 # Dict of query string arguments. Values are lists with at least 1 item.
90 90 querystringdict = attr.ib()
91 91 # wsgiref.headers.Headers instance. Operates like a dict with case
92 92 # insensitive keys.
93 93 headers = attr.ib()
94 94
95 95 def parserequestfromenv(env):
96 96 """Parse URL components from environment variables.
97 97
98 98 WSGI defines request attributes via environment variables. This function
99 99 parses the environment variables into a data structure.
100 100 """
101 101 # PEP-0333 defines the WSGI spec and is a useful reference for this code.
102 102
103 103 # We first validate that the incoming object conforms with the WSGI spec.
104 104 # We only want to be dealing with spec-conforming WSGI implementations.
105 105 # TODO enable this once we fix internal violations.
106 106 #wsgiref.validate.check_environ(env)
107 107
108 108 # PEP-0333 states that environment keys and values are native strings
109 109 # (bytes on Python 2 and str on Python 3). The code points for the Unicode
110 110 # strings on Python 3 must be between \00000-\000FF. We deal with bytes
111 111 # in Mercurial, so mass convert string keys and values to bytes.
112 112 if pycompat.ispy3:
113 113 env = {k.encode('latin-1'): v for k, v in env.iteritems()}
114 114 env = {k: v.encode('latin-1') if isinstance(v, str) else v
115 115 for k, v in env.iteritems()}
116 116
117 117 # https://www.python.org/dev/peps/pep-0333/#environ-variables defines
118 118 # the environment variables.
119 119 # https://www.python.org/dev/peps/pep-0333/#url-reconstruction defines
120 120 # how URLs are reconstructed.
121 121 fullurl = env['wsgi.url_scheme'] + '://'
122 122 advertisedfullurl = fullurl
123 123
124 124 def addport(s):
125 125 if env['wsgi.url_scheme'] == 'https':
126 126 if env['SERVER_PORT'] != '443':
127 127 s += ':' + env['SERVER_PORT']
128 128 else:
129 129 if env['SERVER_PORT'] != '80':
130 130 s += ':' + env['SERVER_PORT']
131 131
132 132 return s
133 133
134 134 if env.get('HTTP_HOST'):
135 135 fullurl += env['HTTP_HOST']
136 136 else:
137 137 fullurl += env['SERVER_NAME']
138 138 fullurl = addport(fullurl)
139 139
140 140 advertisedfullurl += env['SERVER_NAME']
141 141 advertisedfullurl = addport(advertisedfullurl)
142 142
143 143 baseurl = fullurl
144 144 advertisedbaseurl = advertisedfullurl
145 145
146 146 fullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
147 147 advertisedfullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
148 148 fullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
149 149 advertisedfullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
150 150
151 151 if env.get('QUERY_STRING'):
152 152 fullurl += '?' + env['QUERY_STRING']
153 153 advertisedfullurl += '?' + env['QUERY_STRING']
154 154
155 155 # When dispatching requests, we look at the URL components (PATH_INFO
156 156 # and QUERY_STRING) after the application root (SCRIPT_NAME). But hgwebdir
157 157 # has the concept of "virtual" repositories. This is defined via REPO_NAME.
158 158 # If REPO_NAME is defined, we append it to SCRIPT_NAME to form a new app
159 159 # root. We also exclude its path components from PATH_INFO when resolving
160 160 # the dispatch path.
161 161
162 162 apppath = env['SCRIPT_NAME']
163 163
164 164 if env.get('REPO_NAME'):
165 165 if not apppath.endswith('/'):
166 166 apppath += '/'
167 167
168 168 apppath += env.get('REPO_NAME')
169 169
170 170 if 'PATH_INFO' in env:
171 171 dispatchparts = env['PATH_INFO'].strip('/').split('/')
172 172
173 173 # Strip out repo parts.
174 174 repoparts = env.get('REPO_NAME', '').split('/')
175 175 if dispatchparts[:len(repoparts)] == repoparts:
176 176 dispatchparts = dispatchparts[len(repoparts):]
177 177 else:
178 178 dispatchparts = []
179 179
180 180 dispatchpath = '/'.join(dispatchparts)
181 181
182 182 querystring = env.get('QUERY_STRING', '')
183 183
184 184 # We store as a list so we have ordering information. We also store as
185 185 # a dict to facilitate fast lookup.
186 186 querystringlist = util.urlreq.parseqsl(querystring, keep_blank_values=True)
187 187
188 188 querystringdict = {}
189 189 for k, v in querystringlist:
190 190 if k in querystringdict:
191 191 querystringdict[k].append(v)
192 192 else:
193 193 querystringdict[k] = [v]
194 194
195 195 # HTTP_* keys contain HTTP request headers. The Headers structure should
196 196 # perform case normalization for us. We just rewrite underscore to dash
197 197 # so keys match what likely went over the wire.
198 198 headers = []
199 199 for k, v in env.iteritems():
200 200 if k.startswith('HTTP_'):
201 201 headers.append((k[len('HTTP_'):].replace('_', '-'), v))
202 202
203 203 headers = wsgiheaders.Headers(headers)
204 204
205 205 # This is kind of a lie because the HTTP header wasn't explicitly
206 206 # sent. But for all intents and purposes it should be OK to lie about
207 207 # this, since a consumer will either either value to determine how many
208 208 # bytes are available to read.
209 209 if 'CONTENT_LENGTH' in env and 'HTTP_CONTENT_LENGTH' not in env:
210 210 headers['Content-Length'] = env['CONTENT_LENGTH']
211 211
212 212 return parsedrequest(method=env['REQUEST_METHOD'],
213 213 url=fullurl, baseurl=baseurl,
214 214 advertisedurl=advertisedfullurl,
215 215 advertisedbaseurl=advertisedbaseurl,
216 216 apppath=apppath,
217 217 dispatchparts=dispatchparts, dispatchpath=dispatchpath,
218 218 havepathinfo='PATH_INFO' in env,
219 219 querystring=querystring,
220 220 querystringlist=querystringlist,
221 221 querystringdict=querystringdict,
222 222 headers=headers)
223 223
224 224 class wsgirequest(object):
225 225 """Higher-level API for a WSGI request.
226 226
227 227 WSGI applications are invoked with 2 arguments. They are used to
228 228 instantiate instances of this class, which provides higher-level APIs
229 229 for obtaining request parameters, writing HTTP output, etc.
230 230 """
231 231 def __init__(self, wsgienv, start_response):
232 232 version = wsgienv[r'wsgi.version']
233 233 if (version < (1, 0)) or (version >= (2, 0)):
234 234 raise RuntimeError("Unknown and unsupported WSGI version %d.%d"
235 235 % version)
236 236 self.inp = wsgienv[r'wsgi.input']
237 237
238 238 if r'HTTP_CONTENT_LENGTH' in wsgienv:
239 239 self.inp = util.cappedreader(self.inp,
240 240 int(wsgienv[r'HTTP_CONTENT_LENGTH']))
241 241 elif r'CONTENT_LENGTH' in wsgienv:
242 242 self.inp = util.cappedreader(self.inp,
243 243 int(wsgienv[r'CONTENT_LENGTH']))
244 244
245 245 self.err = wsgienv[r'wsgi.errors']
246 246 self.threaded = wsgienv[r'wsgi.multithread']
247 247 self.multiprocess = wsgienv[r'wsgi.multiprocess']
248 248 self.run_once = wsgienv[r'wsgi.run_once']
249 249 self.env = wsgienv
250 250 self.form = normalize(cgi.parse(self.inp,
251 251 self.env,
252 252 keep_blank_values=1))
253 253 self._start_response = start_response
254 254 self.server_write = None
255 255 self.headers = []
256 256
257 def drain(self):
258 '''need to read all data from request, httplib is half-duplex'''
259 length = int(self.env.get('CONTENT_LENGTH') or 0)
260 for s in util.filechunkiter(self.inp, limit=length):
261 pass
262
263 257 def respond(self, status, type, filename=None, body=None):
264 258 if not isinstance(type, str):
265 259 type = pycompat.sysstr(type)
266 260 if self._start_response is not None:
267 261 self.headers.append((r'Content-Type', type))
268 262 if filename:
269 263 filename = (filename.rpartition('/')[-1]
270 264 .replace('\\', '\\\\').replace('"', '\\"'))
271 265 self.headers.append(('Content-Disposition',
272 266 'inline; filename="%s"' % filename))
273 267 if body is not None:
274 268 self.headers.append((r'Content-Length', str(len(body))))
275 269
276 270 for k, v in self.headers:
277 271 if not isinstance(v, str):
278 272 raise TypeError('header value must be string: %r' % (v,))
279 273
280 274 if isinstance(status, ErrorResponse):
281 275 self.headers.extend(status.headers)
282 276 if status.code == HTTP_NOT_MODIFIED:
283 277 # RFC 2616 Section 10.3.5: 304 Not Modified has cases where
284 278 # it MUST NOT include any headers other than these and no
285 279 # body
286 280 self.headers = [(k, v) for (k, v) in self.headers if
287 281 k in ('Date', 'ETag', 'Expires',
288 282 'Cache-Control', 'Vary')]
289 283 status = statusmessage(status.code, pycompat.bytestr(status))
290 284 elif status == 200:
291 285 status = '200 Script output follows'
292 286 elif isinstance(status, int):
293 287 status = statusmessage(status)
294 288
289 # Various HTTP clients (notably httplib) won't read the HTTP
290 # response until the HTTP request has been sent in full. If servers
291 # (us) send a response before the HTTP request has been fully sent,
292 # the connection may deadlock because neither end is reading.
293 #
294 # We work around this by "draining" the request data before
295 # sending any response in some conditions.
296 drain = False
297 close = False
298
299 # If the client sent Expect: 100-continue, we assume it is smart
300 # enough to deal with the server sending a response before reading
301 # the request. (httplib doesn't do this.)
302 if self.env.get(r'HTTP_EXPECT', r'').lower() == r'100-continue':
303 pass
304 # Only tend to request methods that have bodies. Strictly speaking,
305 # we should sniff for a body. But this is fine for our existing
306 # WSGI applications.
307 elif self.env[r'REQUEST_METHOD'] not in (r'POST', r'PUT'):
308 pass
309 else:
310 # If we don't know how much data to read, there's no guarantee
311 # that we can drain the request responsibly. The WSGI
312 # specification only says that servers *should* ensure the
313 # input stream doesn't overrun the actual request. So there's
314 # no guarantee that reading until EOF won't corrupt the stream
315 # state.
316 if not isinstance(self.inp, util.cappedreader):
317 close = True
318 else:
319 # We /could/ only drain certain HTTP response codes. But 200
320 # and non-200 wire protocol responses both require draining.
321 # Since we have a capped reader in place for all situations
322 # where we drain, it is safe to read from that stream. We'll
323 # either do a drain or no-op if we're already at EOF.
324 drain = True
325
326 if close:
327 self.headers.append((r'Connection', r'Close'))
328
329 if drain:
330 assert isinstance(self.inp, util.cappedreader)
331 while True:
332 chunk = self.inp.read(32768)
333 if not chunk:
334 break
335
295 336 self.server_write = self._start_response(
296 337 pycompat.sysstr(status), self.headers)
297 338 self._start_response = None
298 339 self.headers = []
299 340 if body is not None:
300 341 self.write(body)
301 342 self.server_write = None
302 343
303 344 def write(self, thing):
304 345 if thing:
305 346 try:
306 347 self.server_write(thing)
307 348 except socket.error as inst:
308 349 if inst[0] != errno.ECONNRESET:
309 350 raise
310 351
311 352 def flush(self):
312 353 return None
313 354
314 355 def wsgiapplication(app_maker):
315 356 '''For compatibility with old CGI scripts. A plain hgweb() or hgwebdir()
316 357 can and should now be used as a WSGI application.'''
317 358 application = app_maker()
318 359 def run_wsgi(env, respond):
319 360 return application(env, respond)
320 361 return run_wsgi
@@ -1,667 +1,649
1 1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 3 #
4 4 # This software may be used and distributed according to the terms of the
5 5 # GNU General Public License version 2 or any later version.
6 6
7 7 from __future__ import absolute_import
8 8
9 9 import contextlib
10 10 import struct
11 11 import sys
12 12 import threading
13 13
14 14 from .i18n import _
15 15 from . import (
16 16 encoding,
17 17 error,
18 18 hook,
19 19 pycompat,
20 20 util,
21 21 wireproto,
22 22 wireprototypes,
23 23 )
24 24
25 25 stringio = util.stringio
26 26
27 27 urlerr = util.urlerr
28 28 urlreq = util.urlreq
29 29
30 30 HTTP_OK = 200
31 31
32 32 HGTYPE = 'application/mercurial-0.1'
33 33 HGTYPE2 = 'application/mercurial-0.2'
34 34 HGERRTYPE = 'application/hg-error'
35 35
36 36 SSHV1 = wireprototypes.SSHV1
37 37 SSHV2 = wireprototypes.SSHV2
38 38
39 39 def decodevaluefromheaders(req, headerprefix):
40 40 """Decode a long value from multiple HTTP request headers.
41 41
42 42 Returns the value as a bytes, not a str.
43 43 """
44 44 chunks = []
45 45 i = 1
46 46 while True:
47 47 v = req.headers.get(b'%s-%d' % (headerprefix, i))
48 48 if v is None:
49 49 break
50 50 chunks.append(pycompat.bytesurl(v))
51 51 i += 1
52 52
53 53 return ''.join(chunks)
54 54
55 55 class httpv1protocolhandler(wireprototypes.baseprotocolhandler):
56 56 def __init__(self, wsgireq, req, ui, checkperm):
57 57 self._wsgireq = wsgireq
58 58 self._req = req
59 59 self._ui = ui
60 60 self._checkperm = checkperm
61 61
62 62 @property
63 63 def name(self):
64 64 return 'http-v1'
65 65
66 66 def getargs(self, args):
67 67 knownargs = self._args()
68 68 data = {}
69 69 keys = args.split()
70 70 for k in keys:
71 71 if k == '*':
72 72 star = {}
73 73 for key in knownargs.keys():
74 74 if key != 'cmd' and key not in keys:
75 75 star[key] = knownargs[key][0]
76 76 data['*'] = star
77 77 else:
78 78 data[k] = knownargs[k][0]
79 79 return [data[k] for k in keys]
80 80
81 81 def _args(self):
82 82 args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
83 83 postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
84 84 if postlen:
85 85 args.update(urlreq.parseqs(
86 86 self._wsgireq.inp.read(postlen), keep_blank_values=True))
87 87 return args
88 88
89 89 argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
90 90 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
91 91 return args
92 92
93 93 def forwardpayload(self, fp):
94 94 # Existing clients *always* send Content-Length.
95 95 length = int(self._req.headers[b'Content-Length'])
96 96
97 97 # If httppostargs is used, we need to read Content-Length
98 98 # minus the amount that was consumed by args.
99 99 length -= int(self._req.headers.get(b'X-HgArgs-Post', 0))
100 100 for s in util.filechunkiter(self._wsgireq.inp, limit=length):
101 101 fp.write(s)
102 102
103 103 @contextlib.contextmanager
104 104 def mayberedirectstdio(self):
105 105 oldout = self._ui.fout
106 106 olderr = self._ui.ferr
107 107
108 108 out = util.stringio()
109 109
110 110 try:
111 111 self._ui.fout = out
112 112 self._ui.ferr = out
113 113 yield out
114 114 finally:
115 115 self._ui.fout = oldout
116 116 self._ui.ferr = olderr
117 117
118 118 def client(self):
119 119 return 'remote:%s:%s:%s' % (
120 120 self._wsgireq.env.get('wsgi.url_scheme') or 'http',
121 121 urlreq.quote(self._wsgireq.env.get('REMOTE_HOST', '')),
122 122 urlreq.quote(self._wsgireq.env.get('REMOTE_USER', '')))
123 123
124 124 def addcapabilities(self, repo, caps):
125 125 caps.append('httpheader=%d' %
126 126 repo.ui.configint('server', 'maxhttpheaderlen'))
127 127 if repo.ui.configbool('experimental', 'httppostargs'):
128 128 caps.append('httppostargs')
129 129
130 130 # FUTURE advertise 0.2rx once support is implemented
131 131 # FUTURE advertise minrx and mintx after consulting config option
132 132 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
133 133
134 134 compengines = wireproto.supportedcompengines(repo.ui, util.SERVERROLE)
135 135 if compengines:
136 136 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
137 137 for e in compengines)
138 138 caps.append('compression=%s' % comptypes)
139 139
140 140 return caps
141 141
142 142 def checkperm(self, perm):
143 143 return self._checkperm(perm)
144 144
145 145 # This method exists mostly so that extensions like remotefilelog can
146 146 # disable a kludgey legacy method only over http. As of early 2018,
147 147 # there are no other known users, so with any luck we can discard this
148 148 # hook if remotefilelog becomes a first-party extension.
149 149 def iscmd(cmd):
150 150 return cmd in wireproto.commands
151 151
152 152 def handlewsgirequest(rctx, wsgireq, req, checkperm):
153 153 """Possibly process a wire protocol request.
154 154
155 155 If the current request is a wire protocol request, the request is
156 156 processed by this function.
157 157
158 158 ``wsgireq`` is a ``wsgirequest`` instance.
159 159 ``req`` is a ``parsedrequest`` instance.
160 160
161 161 Returns a 2-tuple of (bool, response) where the 1st element indicates
162 162 whether the request was handled and the 2nd element is a return
163 163 value for a WSGI application (often a generator of bytes).
164 164 """
165 165 # Avoid cycle involving hg module.
166 166 from .hgweb import common as hgwebcommon
167 167
168 168 repo = rctx.repo
169 169
170 170 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
171 171 # string parameter. If it isn't present, this isn't a wire protocol
172 172 # request.
173 173 if 'cmd' not in req.querystringdict:
174 174 return False, None
175 175
176 176 cmd = req.querystringdict['cmd'][0]
177 177
178 178 # The "cmd" request parameter is used by both the wire protocol and hgweb.
179 179 # While not all wire protocol commands are available for all transports,
180 180 # if we see a "cmd" value that resembles a known wire protocol command, we
181 181 # route it to a protocol handler. This is better than routing possible
182 182 # wire protocol requests to hgweb because it prevents hgweb from using
183 183 # known wire protocol commands and it is less confusing for machine
184 184 # clients.
185 185 if not iscmd(cmd):
186 186 return False, None
187 187
188 188 # The "cmd" query string argument is only valid on the root path of the
189 189 # repo. e.g. ``/?cmd=foo``, ``/repo?cmd=foo``. URL paths within the repo
190 190 # like ``/blah?cmd=foo`` are not allowed. So don't recognize the request
191 191 # in this case. We send an HTTP 404 for backwards compatibility reasons.
192 192 if req.dispatchpath:
193 193 res = _handlehttperror(
194 194 hgwebcommon.ErrorResponse(hgwebcommon.HTTP_NOT_FOUND), wsgireq,
195 195 req)
196 196
197 197 return True, res
198 198
199 199 proto = httpv1protocolhandler(wsgireq, req, repo.ui,
200 200 lambda perm: checkperm(rctx, wsgireq, perm))
201 201
202 202 # The permissions checker should be the only thing that can raise an
203 203 # ErrorResponse. It is kind of a layer violation to catch an hgweb
204 204 # exception here. So consider refactoring into a exception type that
205 205 # is associated with the wire protocol.
206 206 try:
207 207 res = _callhttp(repo, wsgireq, req, proto, cmd)
208 208 except hgwebcommon.ErrorResponse as e:
209 209 res = _handlehttperror(e, wsgireq, req)
210 210
211 211 return True, res
212 212
213 213 def _httpresponsetype(ui, req, prefer_uncompressed):
214 214 """Determine the appropriate response type and compression settings.
215 215
216 216 Returns a tuple of (mediatype, compengine, engineopts).
217 217 """
218 218 # Determine the response media type and compression engine based
219 219 # on the request parameters.
220 220 protocaps = decodevaluefromheaders(req, 'X-HgProto').split(' ')
221 221
222 222 if '0.2' in protocaps:
223 223 # All clients are expected to support uncompressed data.
224 224 if prefer_uncompressed:
225 225 return HGTYPE2, util._noopengine(), {}
226 226
227 227 # Default as defined by wire protocol spec.
228 228 compformats = ['zlib', 'none']
229 229 for cap in protocaps:
230 230 if cap.startswith('comp='):
231 231 compformats = cap[5:].split(',')
232 232 break
233 233
234 234 # Now find an agreed upon compression format.
235 235 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
236 236 if engine.wireprotosupport().name in compformats:
237 237 opts = {}
238 238 level = ui.configint('server', '%slevel' % engine.name())
239 239 if level is not None:
240 240 opts['level'] = level
241 241
242 242 return HGTYPE2, engine, opts
243 243
244 244 # No mutually supported compression format. Fall back to the
245 245 # legacy protocol.
246 246
247 247 # Don't allow untrusted settings because disabling compression or
248 248 # setting a very high compression level could lead to flooding
249 249 # the server's network or CPU.
250 250 opts = {'level': ui.configint('server', 'zliblevel')}
251 251 return HGTYPE, util.compengines['zlib'], opts
252 252
253 253 def _callhttp(repo, wsgireq, req, proto, cmd):
254 254 def genversion2(gen, engine, engineopts):
255 255 # application/mercurial-0.2 always sends a payload header
256 256 # identifying the compression engine.
257 257 name = engine.wireprotosupport().name
258 258 assert 0 < len(name) < 256
259 259 yield struct.pack('B', len(name))
260 260 yield name
261 261
262 262 for chunk in gen:
263 263 yield chunk
264 264
265 265 if not wireproto.commands.commandavailable(cmd, proto):
266 266 wsgireq.respond(HTTP_OK, HGERRTYPE,
267 267 body=_('requested wire protocol command is not '
268 268 'available over HTTP'))
269 269 return []
270 270
271 271 proto.checkperm(wireproto.commands[cmd].permission)
272 272
273 273 rsp = wireproto.dispatch(repo, proto, cmd)
274 274
275 275 if isinstance(rsp, bytes):
276 276 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
277 277 return []
278 278 elif isinstance(rsp, wireprototypes.bytesresponse):
279 279 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp.data)
280 280 return []
281 281 elif isinstance(rsp, wireprototypes.streamreslegacy):
282 282 gen = rsp.gen
283 283 wsgireq.respond(HTTP_OK, HGTYPE)
284 284 return gen
285 285 elif isinstance(rsp, wireprototypes.streamres):
286 286 gen = rsp.gen
287 287
288 288 # This code for compression should not be streamres specific. It
289 289 # is here because we only compress streamres at the moment.
290 290 mediatype, engine, engineopts = _httpresponsetype(
291 291 repo.ui, req, rsp.prefer_uncompressed)
292 292 gen = engine.compressstream(gen, engineopts)
293 293
294 294 if mediatype == HGTYPE2:
295 295 gen = genversion2(gen, engine, engineopts)
296 296
297 297 wsgireq.respond(HTTP_OK, mediatype)
298 298 return gen
299 299 elif isinstance(rsp, wireprototypes.pushres):
300 300 rsp = '%d\n%s' % (rsp.res, rsp.output)
301 301 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
302 302 return []
303 303 elif isinstance(rsp, wireprototypes.pusherr):
304 # This is the httplib workaround documented in _handlehttperror().
305 wsgireq.drain()
306
307 304 rsp = '0\n%s\n' % rsp.res
308 305 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
309 306 return []
310 307 elif isinstance(rsp, wireprototypes.ooberror):
311 308 rsp = rsp.message
312 309 wsgireq.respond(HTTP_OK, HGERRTYPE, body=rsp)
313 310 return []
314 311 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
315 312
316 313 def _handlehttperror(e, wsgireq, req):
317 314 """Called when an ErrorResponse is raised during HTTP request processing."""
318 315
319 # Clients using Python's httplib are stateful: the HTTP client
320 # won't process an HTTP response until all request data is
321 # sent to the server. The intent of this code is to ensure
322 # we always read HTTP request data from the client, thus
323 # ensuring httplib transitions to a state that allows it to read
324 # the HTTP response. In other words, it helps prevent deadlocks
325 # on clients using httplib.
326
327 if (req.method == 'POST' and
328 # But not if Expect: 100-continue is being used.
329 (req.headers.get('Expect', '').lower() != '100-continue')):
330 wsgireq.drain()
331 else:
332 wsgireq.headers.append((r'Connection', r'Close'))
333
334 316 # TODO This response body assumes the failed command was
335 317 # "unbundle." That assumption is not always valid.
336 318 wsgireq.respond(e, HGTYPE, body='0\n%s\n' % pycompat.bytestr(e))
337 319
338 320 return ''
339 321
340 322 def _sshv1respondbytes(fout, value):
341 323 """Send a bytes response for protocol version 1."""
342 324 fout.write('%d\n' % len(value))
343 325 fout.write(value)
344 326 fout.flush()
345 327
346 328 def _sshv1respondstream(fout, source):
347 329 write = fout.write
348 330 for chunk in source.gen:
349 331 write(chunk)
350 332 fout.flush()
351 333
352 334 def _sshv1respondooberror(fout, ferr, rsp):
353 335 ferr.write(b'%s\n-\n' % rsp)
354 336 ferr.flush()
355 337 fout.write(b'\n')
356 338 fout.flush()
357 339
358 340 class sshv1protocolhandler(wireprototypes.baseprotocolhandler):
359 341 """Handler for requests services via version 1 of SSH protocol."""
360 342 def __init__(self, ui, fin, fout):
361 343 self._ui = ui
362 344 self._fin = fin
363 345 self._fout = fout
364 346
365 347 @property
366 348 def name(self):
367 349 return wireprototypes.SSHV1
368 350
369 351 def getargs(self, args):
370 352 data = {}
371 353 keys = args.split()
372 354 for n in xrange(len(keys)):
373 355 argline = self._fin.readline()[:-1]
374 356 arg, l = argline.split()
375 357 if arg not in keys:
376 358 raise error.Abort(_("unexpected parameter %r") % arg)
377 359 if arg == '*':
378 360 star = {}
379 361 for k in xrange(int(l)):
380 362 argline = self._fin.readline()[:-1]
381 363 arg, l = argline.split()
382 364 val = self._fin.read(int(l))
383 365 star[arg] = val
384 366 data['*'] = star
385 367 else:
386 368 val = self._fin.read(int(l))
387 369 data[arg] = val
388 370 return [data[k] for k in keys]
389 371
390 372 def forwardpayload(self, fpout):
391 373 # We initially send an empty response. This tells the client it is
392 374 # OK to start sending data. If a client sees any other response, it
393 375 # interprets it as an error.
394 376 _sshv1respondbytes(self._fout, b'')
395 377
396 378 # The file is in the form:
397 379 #
398 380 # <chunk size>\n<chunk>
399 381 # ...
400 382 # 0\n
401 383 count = int(self._fin.readline())
402 384 while count:
403 385 fpout.write(self._fin.read(count))
404 386 count = int(self._fin.readline())
405 387
406 388 @contextlib.contextmanager
407 389 def mayberedirectstdio(self):
408 390 yield None
409 391
410 392 def client(self):
411 393 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
412 394 return 'remote:ssh:' + client
413 395
414 396 def addcapabilities(self, repo, caps):
415 397 return caps
416 398
417 399 def checkperm(self, perm):
418 400 pass
419 401
420 402 class sshv2protocolhandler(sshv1protocolhandler):
421 403 """Protocol handler for version 2 of the SSH protocol."""
422 404
423 405 @property
424 406 def name(self):
425 407 return wireprototypes.SSHV2
426 408
427 409 def _runsshserver(ui, repo, fin, fout, ev):
428 410 # This function operates like a state machine of sorts. The following
429 411 # states are defined:
430 412 #
431 413 # protov1-serving
432 414 # Server is in protocol version 1 serving mode. Commands arrive on
433 415 # new lines. These commands are processed in this state, one command
434 416 # after the other.
435 417 #
436 418 # protov2-serving
437 419 # Server is in protocol version 2 serving mode.
438 420 #
439 421 # upgrade-initial
440 422 # The server is going to process an upgrade request.
441 423 #
442 424 # upgrade-v2-filter-legacy-handshake
443 425 # The protocol is being upgraded to version 2. The server is expecting
444 426 # the legacy handshake from version 1.
445 427 #
446 428 # upgrade-v2-finish
447 429 # The upgrade to version 2 of the protocol is imminent.
448 430 #
449 431 # shutdown
450 432 # The server is shutting down, possibly in reaction to a client event.
451 433 #
452 434 # And here are their transitions:
453 435 #
454 436 # protov1-serving -> shutdown
455 437 # When server receives an empty request or encounters another
456 438 # error.
457 439 #
458 440 # protov1-serving -> upgrade-initial
459 441 # An upgrade request line was seen.
460 442 #
461 443 # upgrade-initial -> upgrade-v2-filter-legacy-handshake
462 444 # Upgrade to version 2 in progress. Server is expecting to
463 445 # process a legacy handshake.
464 446 #
465 447 # upgrade-v2-filter-legacy-handshake -> shutdown
466 448 # Client did not fulfill upgrade handshake requirements.
467 449 #
468 450 # upgrade-v2-filter-legacy-handshake -> upgrade-v2-finish
469 451 # Client fulfilled version 2 upgrade requirements. Finishing that
470 452 # upgrade.
471 453 #
472 454 # upgrade-v2-finish -> protov2-serving
473 455 # Protocol upgrade to version 2 complete. Server can now speak protocol
474 456 # version 2.
475 457 #
476 458 # protov2-serving -> protov1-serving
477 459 # Ths happens by default since protocol version 2 is the same as
478 460 # version 1 except for the handshake.
479 461
480 462 state = 'protov1-serving'
481 463 proto = sshv1protocolhandler(ui, fin, fout)
482 464 protoswitched = False
483 465
484 466 while not ev.is_set():
485 467 if state == 'protov1-serving':
486 468 # Commands are issued on new lines.
487 469 request = fin.readline()[:-1]
488 470
489 471 # Empty lines signal to terminate the connection.
490 472 if not request:
491 473 state = 'shutdown'
492 474 continue
493 475
494 476 # It looks like a protocol upgrade request. Transition state to
495 477 # handle it.
496 478 if request.startswith(b'upgrade '):
497 479 if protoswitched:
498 480 _sshv1respondooberror(fout, ui.ferr,
499 481 b'cannot upgrade protocols multiple '
500 482 b'times')
501 483 state = 'shutdown'
502 484 continue
503 485
504 486 state = 'upgrade-initial'
505 487 continue
506 488
507 489 available = wireproto.commands.commandavailable(request, proto)
508 490
509 491 # This command isn't available. Send an empty response and go
510 492 # back to waiting for a new command.
511 493 if not available:
512 494 _sshv1respondbytes(fout, b'')
513 495 continue
514 496
515 497 rsp = wireproto.dispatch(repo, proto, request)
516 498
517 499 if isinstance(rsp, bytes):
518 500 _sshv1respondbytes(fout, rsp)
519 501 elif isinstance(rsp, wireprototypes.bytesresponse):
520 502 _sshv1respondbytes(fout, rsp.data)
521 503 elif isinstance(rsp, wireprototypes.streamres):
522 504 _sshv1respondstream(fout, rsp)
523 505 elif isinstance(rsp, wireprototypes.streamreslegacy):
524 506 _sshv1respondstream(fout, rsp)
525 507 elif isinstance(rsp, wireprototypes.pushres):
526 508 _sshv1respondbytes(fout, b'')
527 509 _sshv1respondbytes(fout, b'%d' % rsp.res)
528 510 elif isinstance(rsp, wireprototypes.pusherr):
529 511 _sshv1respondbytes(fout, rsp.res)
530 512 elif isinstance(rsp, wireprototypes.ooberror):
531 513 _sshv1respondooberror(fout, ui.ferr, rsp.message)
532 514 else:
533 515 raise error.ProgrammingError('unhandled response type from '
534 516 'wire protocol command: %s' % rsp)
535 517
536 518 # For now, protocol version 2 serving just goes back to version 1.
537 519 elif state == 'protov2-serving':
538 520 state = 'protov1-serving'
539 521 continue
540 522
541 523 elif state == 'upgrade-initial':
542 524 # We should never transition into this state if we've switched
543 525 # protocols.
544 526 assert not protoswitched
545 527 assert proto.name == wireprototypes.SSHV1
546 528
547 529 # Expected: upgrade <token> <capabilities>
548 530 # If we get something else, the request is malformed. It could be
549 531 # from a future client that has altered the upgrade line content.
550 532 # We treat this as an unknown command.
551 533 try:
552 534 token, caps = request.split(b' ')[1:]
553 535 except ValueError:
554 536 _sshv1respondbytes(fout, b'')
555 537 state = 'protov1-serving'
556 538 continue
557 539
558 540 # Send empty response if we don't support upgrading protocols.
559 541 if not ui.configbool('experimental', 'sshserver.support-v2'):
560 542 _sshv1respondbytes(fout, b'')
561 543 state = 'protov1-serving'
562 544 continue
563 545
564 546 try:
565 547 caps = urlreq.parseqs(caps)
566 548 except ValueError:
567 549 _sshv1respondbytes(fout, b'')
568 550 state = 'protov1-serving'
569 551 continue
570 552
571 553 # We don't see an upgrade request to protocol version 2. Ignore
572 554 # the upgrade request.
573 555 wantedprotos = caps.get(b'proto', [b''])[0]
574 556 if SSHV2 not in wantedprotos:
575 557 _sshv1respondbytes(fout, b'')
576 558 state = 'protov1-serving'
577 559 continue
578 560
579 561 # It looks like we can honor this upgrade request to protocol 2.
580 562 # Filter the rest of the handshake protocol request lines.
581 563 state = 'upgrade-v2-filter-legacy-handshake'
582 564 continue
583 565
584 566 elif state == 'upgrade-v2-filter-legacy-handshake':
585 567 # Client should have sent legacy handshake after an ``upgrade``
586 568 # request. Expected lines:
587 569 #
588 570 # hello
589 571 # between
590 572 # pairs 81
591 573 # 0000...-0000...
592 574
593 575 ok = True
594 576 for line in (b'hello', b'between', b'pairs 81'):
595 577 request = fin.readline()[:-1]
596 578
597 579 if request != line:
598 580 _sshv1respondooberror(fout, ui.ferr,
599 581 b'malformed handshake protocol: '
600 582 b'missing %s' % line)
601 583 ok = False
602 584 state = 'shutdown'
603 585 break
604 586
605 587 if not ok:
606 588 continue
607 589
608 590 request = fin.read(81)
609 591 if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
610 592 _sshv1respondooberror(fout, ui.ferr,
611 593 b'malformed handshake protocol: '
612 594 b'missing between argument value')
613 595 state = 'shutdown'
614 596 continue
615 597
616 598 state = 'upgrade-v2-finish'
617 599 continue
618 600
619 601 elif state == 'upgrade-v2-finish':
620 602 # Send the upgrade response.
621 603 fout.write(b'upgraded %s %s\n' % (token, SSHV2))
622 604 servercaps = wireproto.capabilities(repo, proto)
623 605 rsp = b'capabilities: %s' % servercaps.data
624 606 fout.write(b'%d\n%s\n' % (len(rsp), rsp))
625 607 fout.flush()
626 608
627 609 proto = sshv2protocolhandler(ui, fin, fout)
628 610 protoswitched = True
629 611
630 612 state = 'protov2-serving'
631 613 continue
632 614
633 615 elif state == 'shutdown':
634 616 break
635 617
636 618 else:
637 619 raise error.ProgrammingError('unhandled ssh server state: %s' %
638 620 state)
639 621
640 622 class sshserver(object):
641 623 def __init__(self, ui, repo, logfh=None):
642 624 self._ui = ui
643 625 self._repo = repo
644 626 self._fin = ui.fin
645 627 self._fout = ui.fout
646 628
647 629 # Log write I/O to stdout and stderr if configured.
648 630 if logfh:
649 631 self._fout = util.makeloggingfileobject(
650 632 logfh, self._fout, 'o', logdata=True)
651 633 ui.ferr = util.makeloggingfileobject(
652 634 logfh, ui.ferr, 'e', logdata=True)
653 635
654 636 hook.redirect(True)
655 637 ui.fout = repo.ui.fout = ui.ferr
656 638
657 639 # Prevent insertion/deletion of CRs
658 640 util.setbinary(self._fin)
659 641 util.setbinary(self._fout)
660 642
661 643 def serve_forever(self):
662 644 self.serveuntil(threading.Event())
663 645 sys.exit(0)
664 646
665 647 def serveuntil(self, ev):
666 648 """Serve until a threading.Event is set."""
667 649 _runsshserver(self._ui, self._repo, self._fin, self._fout, ev)
General Comments 0
You need to be logged in to leave comments. Login now