##// END OF EJS Templates
hgweb: store and use request method on parsed request...
Gregory Szorc -
r36864:16292bbd default
parent child Browse files
Show More
@@ -1,322 +1,325
1 # hgweb/request.py - An http request from either CGI or the standalone server.
1 # hgweb/request.py - An http request from either CGI or the standalone server.
2 #
2 #
3 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
3 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
4 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
4 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 from __future__ import absolute_import
9 from __future__ import absolute_import
10
10
11 import cgi
11 import cgi
12 import errno
12 import errno
13 import socket
13 import socket
14 import wsgiref.headers as wsgiheaders
14 import wsgiref.headers as wsgiheaders
15 #import wsgiref.validate
15 #import wsgiref.validate
16
16
17 from .common import (
17 from .common import (
18 ErrorResponse,
18 ErrorResponse,
19 HTTP_NOT_MODIFIED,
19 HTTP_NOT_MODIFIED,
20 statusmessage,
20 statusmessage,
21 )
21 )
22
22
23 from ..thirdparty import (
23 from ..thirdparty import (
24 attr,
24 attr,
25 )
25 )
26 from .. import (
26 from .. import (
27 pycompat,
27 pycompat,
28 util,
28 util,
29 )
29 )
30
30
31 shortcuts = {
31 shortcuts = {
32 'cl': [('cmd', ['changelog']), ('rev', None)],
32 'cl': [('cmd', ['changelog']), ('rev', None)],
33 'sl': [('cmd', ['shortlog']), ('rev', None)],
33 'sl': [('cmd', ['shortlog']), ('rev', None)],
34 'cs': [('cmd', ['changeset']), ('node', None)],
34 'cs': [('cmd', ['changeset']), ('node', None)],
35 'f': [('cmd', ['file']), ('filenode', None)],
35 'f': [('cmd', ['file']), ('filenode', None)],
36 'fl': [('cmd', ['filelog']), ('filenode', None)],
36 'fl': [('cmd', ['filelog']), ('filenode', None)],
37 'fd': [('cmd', ['filediff']), ('node', None)],
37 'fd': [('cmd', ['filediff']), ('node', None)],
38 'fa': [('cmd', ['annotate']), ('filenode', None)],
38 'fa': [('cmd', ['annotate']), ('filenode', None)],
39 'mf': [('cmd', ['manifest']), ('manifest', None)],
39 'mf': [('cmd', ['manifest']), ('manifest', None)],
40 'ca': [('cmd', ['archive']), ('node', None)],
40 'ca': [('cmd', ['archive']), ('node', None)],
41 'tags': [('cmd', ['tags'])],
41 'tags': [('cmd', ['tags'])],
42 'tip': [('cmd', ['changeset']), ('node', ['tip'])],
42 'tip': [('cmd', ['changeset']), ('node', ['tip'])],
43 'static': [('cmd', ['static']), ('file', None)]
43 'static': [('cmd', ['static']), ('file', None)]
44 }
44 }
45
45
46 def normalize(form):
46 def normalize(form):
47 # first expand the shortcuts
47 # first expand the shortcuts
48 for k in shortcuts:
48 for k in shortcuts:
49 if k in form:
49 if k in form:
50 for name, value in shortcuts[k]:
50 for name, value in shortcuts[k]:
51 if value is None:
51 if value is None:
52 value = form[k]
52 value = form[k]
53 form[name] = value
53 form[name] = value
54 del form[k]
54 del form[k]
55 # And strip the values
55 # And strip the values
56 bytesform = {}
56 bytesform = {}
57 for k, v in form.iteritems():
57 for k, v in form.iteritems():
58 bytesform[pycompat.bytesurl(k)] = [
58 bytesform[pycompat.bytesurl(k)] = [
59 pycompat.bytesurl(i.strip()) for i in v]
59 pycompat.bytesurl(i.strip()) for i in v]
60 return bytesform
60 return bytesform
61
61
62 @attr.s(frozen=True)
62 @attr.s(frozen=True)
63 class parsedrequest(object):
63 class parsedrequest(object):
64 """Represents a parsed WSGI request / static HTTP request parameters."""
64 """Represents a parsed WSGI request / static HTTP request parameters."""
65
65
66 # Request method.
67 method = attr.ib()
66 # Full URL for this request.
68 # Full URL for this request.
67 url = attr.ib()
69 url = attr.ib()
68 # URL without any path components. Just <proto>://<host><port>.
70 # URL without any path components. Just <proto>://<host><port>.
69 baseurl = attr.ib()
71 baseurl = attr.ib()
70 # Advertised URL. Like ``url`` and ``baseurl`` but uses SERVER_NAME instead
72 # Advertised URL. Like ``url`` and ``baseurl`` but uses SERVER_NAME instead
71 # of HTTP: Host header for hostname. This is likely what clients used.
73 # of HTTP: Host header for hostname. This is likely what clients used.
72 advertisedurl = attr.ib()
74 advertisedurl = attr.ib()
73 advertisedbaseurl = attr.ib()
75 advertisedbaseurl = attr.ib()
74 # WSGI application path.
76 # WSGI application path.
75 apppath = attr.ib()
77 apppath = attr.ib()
76 # List of path parts to be used for dispatch.
78 # List of path parts to be used for dispatch.
77 dispatchparts = attr.ib()
79 dispatchparts = attr.ib()
78 # URL path component (no query string) used for dispatch.
80 # URL path component (no query string) used for dispatch.
79 dispatchpath = attr.ib()
81 dispatchpath = attr.ib()
80 # Whether there is a path component to this request. This can be true
82 # Whether there is a path component to this request. This can be true
81 # when ``dispatchpath`` is empty due to REPO_NAME muckery.
83 # when ``dispatchpath`` is empty due to REPO_NAME muckery.
82 havepathinfo = attr.ib()
84 havepathinfo = attr.ib()
83 # Raw query string (part after "?" in URL).
85 # Raw query string (part after "?" in URL).
84 querystring = attr.ib()
86 querystring = attr.ib()
85 # List of 2-tuples of query string arguments.
87 # List of 2-tuples of query string arguments.
86 querystringlist = attr.ib()
88 querystringlist = attr.ib()
87 # Dict of query string arguments. Values are lists with at least 1 item.
89 # Dict of query string arguments. Values are lists with at least 1 item.
88 querystringdict = attr.ib()
90 querystringdict = attr.ib()
89 # wsgiref.headers.Headers instance. Operates like a dict with case
91 # wsgiref.headers.Headers instance. Operates like a dict with case
90 # insensitive keys.
92 # insensitive keys.
91 headers = attr.ib()
93 headers = attr.ib()
92
94
93 def parserequestfromenv(env):
95 def parserequestfromenv(env):
94 """Parse URL components from environment variables.
96 """Parse URL components from environment variables.
95
97
96 WSGI defines request attributes via environment variables. This function
98 WSGI defines request attributes via environment variables. This function
97 parses the environment variables into a data structure.
99 parses the environment variables into a data structure.
98 """
100 """
99 # PEP-0333 defines the WSGI spec and is a useful reference for this code.
101 # PEP-0333 defines the WSGI spec and is a useful reference for this code.
100
102
101 # We first validate that the incoming object conforms with the WSGI spec.
103 # We first validate that the incoming object conforms with the WSGI spec.
102 # We only want to be dealing with spec-conforming WSGI implementations.
104 # We only want to be dealing with spec-conforming WSGI implementations.
103 # TODO enable this once we fix internal violations.
105 # TODO enable this once we fix internal violations.
104 #wsgiref.validate.check_environ(env)
106 #wsgiref.validate.check_environ(env)
105
107
106 # PEP-0333 states that environment keys and values are native strings
108 # PEP-0333 states that environment keys and values are native strings
107 # (bytes on Python 2 and str on Python 3). The code points for the Unicode
109 # (bytes on Python 2 and str on Python 3). The code points for the Unicode
108 # strings on Python 3 must be between \00000-\000FF. We deal with bytes
110 # strings on Python 3 must be between \00000-\000FF. We deal with bytes
109 # in Mercurial, so mass convert string keys and values to bytes.
111 # in Mercurial, so mass convert string keys and values to bytes.
110 if pycompat.ispy3:
112 if pycompat.ispy3:
111 env = {k.encode('latin-1'): v for k, v in env.iteritems()}
113 env = {k.encode('latin-1'): v for k, v in env.iteritems()}
112 env = {k: v.encode('latin-1') if isinstance(v, str) else v
114 env = {k: v.encode('latin-1') if isinstance(v, str) else v
113 for k, v in env.iteritems()}
115 for k, v in env.iteritems()}
114
116
115 # https://www.python.org/dev/peps/pep-0333/#environ-variables defines
117 # https://www.python.org/dev/peps/pep-0333/#environ-variables defines
116 # the environment variables.
118 # the environment variables.
117 # https://www.python.org/dev/peps/pep-0333/#url-reconstruction defines
119 # https://www.python.org/dev/peps/pep-0333/#url-reconstruction defines
118 # how URLs are reconstructed.
120 # how URLs are reconstructed.
119 fullurl = env['wsgi.url_scheme'] + '://'
121 fullurl = env['wsgi.url_scheme'] + '://'
120 advertisedfullurl = fullurl
122 advertisedfullurl = fullurl
121
123
122 def addport(s):
124 def addport(s):
123 if env['wsgi.url_scheme'] == 'https':
125 if env['wsgi.url_scheme'] == 'https':
124 if env['SERVER_PORT'] != '443':
126 if env['SERVER_PORT'] != '443':
125 s += ':' + env['SERVER_PORT']
127 s += ':' + env['SERVER_PORT']
126 else:
128 else:
127 if env['SERVER_PORT'] != '80':
129 if env['SERVER_PORT'] != '80':
128 s += ':' + env['SERVER_PORT']
130 s += ':' + env['SERVER_PORT']
129
131
130 return s
132 return s
131
133
132 if env.get('HTTP_HOST'):
134 if env.get('HTTP_HOST'):
133 fullurl += env['HTTP_HOST']
135 fullurl += env['HTTP_HOST']
134 else:
136 else:
135 fullurl += env['SERVER_NAME']
137 fullurl += env['SERVER_NAME']
136 fullurl = addport(fullurl)
138 fullurl = addport(fullurl)
137
139
138 advertisedfullurl += env['SERVER_NAME']
140 advertisedfullurl += env['SERVER_NAME']
139 advertisedfullurl = addport(advertisedfullurl)
141 advertisedfullurl = addport(advertisedfullurl)
140
142
141 baseurl = fullurl
143 baseurl = fullurl
142 advertisedbaseurl = advertisedfullurl
144 advertisedbaseurl = advertisedfullurl
143
145
144 fullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
146 fullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
145 advertisedfullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
147 advertisedfullurl += util.urlreq.quote(env.get('SCRIPT_NAME', ''))
146 fullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
148 fullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
147 advertisedfullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
149 advertisedfullurl += util.urlreq.quote(env.get('PATH_INFO', ''))
148
150
149 if env.get('QUERY_STRING'):
151 if env.get('QUERY_STRING'):
150 fullurl += '?' + env['QUERY_STRING']
152 fullurl += '?' + env['QUERY_STRING']
151 advertisedfullurl += '?' + env['QUERY_STRING']
153 advertisedfullurl += '?' + env['QUERY_STRING']
152
154
153 # When dispatching requests, we look at the URL components (PATH_INFO
155 # When dispatching requests, we look at the URL components (PATH_INFO
154 # and QUERY_STRING) after the application root (SCRIPT_NAME). But hgwebdir
156 # and QUERY_STRING) after the application root (SCRIPT_NAME). But hgwebdir
155 # has the concept of "virtual" repositories. This is defined via REPO_NAME.
157 # has the concept of "virtual" repositories. This is defined via REPO_NAME.
156 # If REPO_NAME is defined, we append it to SCRIPT_NAME to form a new app
158 # If REPO_NAME is defined, we append it to SCRIPT_NAME to form a new app
157 # root. We also exclude its path components from PATH_INFO when resolving
159 # root. We also exclude its path components from PATH_INFO when resolving
158 # the dispatch path.
160 # the dispatch path.
159
161
160 apppath = env['SCRIPT_NAME']
162 apppath = env['SCRIPT_NAME']
161
163
162 if env.get('REPO_NAME'):
164 if env.get('REPO_NAME'):
163 if not apppath.endswith('/'):
165 if not apppath.endswith('/'):
164 apppath += '/'
166 apppath += '/'
165
167
166 apppath += env.get('REPO_NAME')
168 apppath += env.get('REPO_NAME')
167
169
168 if 'PATH_INFO' in env:
170 if 'PATH_INFO' in env:
169 dispatchparts = env['PATH_INFO'].strip('/').split('/')
171 dispatchparts = env['PATH_INFO'].strip('/').split('/')
170
172
171 # Strip out repo parts.
173 # Strip out repo parts.
172 repoparts = env.get('REPO_NAME', '').split('/')
174 repoparts = env.get('REPO_NAME', '').split('/')
173 if dispatchparts[:len(repoparts)] == repoparts:
175 if dispatchparts[:len(repoparts)] == repoparts:
174 dispatchparts = dispatchparts[len(repoparts):]
176 dispatchparts = dispatchparts[len(repoparts):]
175 else:
177 else:
176 dispatchparts = []
178 dispatchparts = []
177
179
178 dispatchpath = '/'.join(dispatchparts)
180 dispatchpath = '/'.join(dispatchparts)
179
181
180 querystring = env.get('QUERY_STRING', '')
182 querystring = env.get('QUERY_STRING', '')
181
183
182 # We store as a list so we have ordering information. We also store as
184 # We store as a list so we have ordering information. We also store as
183 # a dict to facilitate fast lookup.
185 # a dict to facilitate fast lookup.
184 querystringlist = util.urlreq.parseqsl(querystring, keep_blank_values=True)
186 querystringlist = util.urlreq.parseqsl(querystring, keep_blank_values=True)
185
187
186 querystringdict = {}
188 querystringdict = {}
187 for k, v in querystringlist:
189 for k, v in querystringlist:
188 if k in querystringdict:
190 if k in querystringdict:
189 querystringdict[k].append(v)
191 querystringdict[k].append(v)
190 else:
192 else:
191 querystringdict[k] = [v]
193 querystringdict[k] = [v]
192
194
193 # HTTP_* keys contain HTTP request headers. The Headers structure should
195 # HTTP_* keys contain HTTP request headers. The Headers structure should
194 # perform case normalization for us. We just rewrite underscore to dash
196 # perform case normalization for us. We just rewrite underscore to dash
195 # so keys match what likely went over the wire.
197 # so keys match what likely went over the wire.
196 headers = []
198 headers = []
197 for k, v in env.iteritems():
199 for k, v in env.iteritems():
198 if k.startswith('HTTP_'):
200 if k.startswith('HTTP_'):
199 headers.append((k[len('HTTP_'):].replace('_', '-'), v))
201 headers.append((k[len('HTTP_'):].replace('_', '-'), v))
200
202
201 headers = wsgiheaders.Headers(headers)
203 headers = wsgiheaders.Headers(headers)
202
204
203 # This is kind of a lie because the HTTP header wasn't explicitly
205 # This is kind of a lie because the HTTP header wasn't explicitly
204 # sent. But for all intents and purposes it should be OK to lie about
206 # sent. But for all intents and purposes it should be OK to lie about
205 # this, since a consumer will either either value to determine how many
207 # this, since a consumer will either either value to determine how many
206 # bytes are available to read.
208 # bytes are available to read.
207 if 'CONTENT_LENGTH' in env and 'HTTP_CONTENT_LENGTH' not in env:
209 if 'CONTENT_LENGTH' in env and 'HTTP_CONTENT_LENGTH' not in env:
208 headers['Content-Length'] = env['CONTENT_LENGTH']
210 headers['Content-Length'] = env['CONTENT_LENGTH']
209
211
210 return parsedrequest(url=fullurl, baseurl=baseurl,
212 return parsedrequest(method=env['REQUEST_METHOD'],
213 url=fullurl, baseurl=baseurl,
211 advertisedurl=advertisedfullurl,
214 advertisedurl=advertisedfullurl,
212 advertisedbaseurl=advertisedbaseurl,
215 advertisedbaseurl=advertisedbaseurl,
213 apppath=apppath,
216 apppath=apppath,
214 dispatchparts=dispatchparts, dispatchpath=dispatchpath,
217 dispatchparts=dispatchparts, dispatchpath=dispatchpath,
215 havepathinfo='PATH_INFO' in env,
218 havepathinfo='PATH_INFO' in env,
216 querystring=querystring,
219 querystring=querystring,
217 querystringlist=querystringlist,
220 querystringlist=querystringlist,
218 querystringdict=querystringdict,
221 querystringdict=querystringdict,
219 headers=headers)
222 headers=headers)
220
223
221 class wsgirequest(object):
224 class wsgirequest(object):
222 """Higher-level API for a WSGI request.
225 """Higher-level API for a WSGI request.
223
226
224 WSGI applications are invoked with 2 arguments. They are used to
227 WSGI applications are invoked with 2 arguments. They are used to
225 instantiate instances of this class, which provides higher-level APIs
228 instantiate instances of this class, which provides higher-level APIs
226 for obtaining request parameters, writing HTTP output, etc.
229 for obtaining request parameters, writing HTTP output, etc.
227 """
230 """
228 def __init__(self, wsgienv, start_response):
231 def __init__(self, wsgienv, start_response):
229 version = wsgienv[r'wsgi.version']
232 version = wsgienv[r'wsgi.version']
230 if (version < (1, 0)) or (version >= (2, 0)):
233 if (version < (1, 0)) or (version >= (2, 0)):
231 raise RuntimeError("Unknown and unsupported WSGI version %d.%d"
234 raise RuntimeError("Unknown and unsupported WSGI version %d.%d"
232 % version)
235 % version)
233 self.inp = wsgienv[r'wsgi.input']
236 self.inp = wsgienv[r'wsgi.input']
234 self.err = wsgienv[r'wsgi.errors']
237 self.err = wsgienv[r'wsgi.errors']
235 self.threaded = wsgienv[r'wsgi.multithread']
238 self.threaded = wsgienv[r'wsgi.multithread']
236 self.multiprocess = wsgienv[r'wsgi.multiprocess']
239 self.multiprocess = wsgienv[r'wsgi.multiprocess']
237 self.run_once = wsgienv[r'wsgi.run_once']
240 self.run_once = wsgienv[r'wsgi.run_once']
238 self.env = wsgienv
241 self.env = wsgienv
239 self.form = normalize(cgi.parse(self.inp,
242 self.form = normalize(cgi.parse(self.inp,
240 self.env,
243 self.env,
241 keep_blank_values=1))
244 keep_blank_values=1))
242 self._start_response = start_response
245 self._start_response = start_response
243 self.server_write = None
246 self.server_write = None
244 self.headers = []
247 self.headers = []
245
248
246 def __iter__(self):
249 def __iter__(self):
247 return iter([])
250 return iter([])
248
251
249 def read(self, count=-1):
252 def read(self, count=-1):
250 return self.inp.read(count)
253 return self.inp.read(count)
251
254
252 def drain(self):
255 def drain(self):
253 '''need to read all data from request, httplib is half-duplex'''
256 '''need to read all data from request, httplib is half-duplex'''
254 length = int(self.env.get('CONTENT_LENGTH') or 0)
257 length = int(self.env.get('CONTENT_LENGTH') or 0)
255 for s in util.filechunkiter(self.inp, limit=length):
258 for s in util.filechunkiter(self.inp, limit=length):
256 pass
259 pass
257
260
258 def respond(self, status, type, filename=None, body=None):
261 def respond(self, status, type, filename=None, body=None):
259 if not isinstance(type, str):
262 if not isinstance(type, str):
260 type = pycompat.sysstr(type)
263 type = pycompat.sysstr(type)
261 if self._start_response is not None:
264 if self._start_response is not None:
262 self.headers.append((r'Content-Type', type))
265 self.headers.append((r'Content-Type', type))
263 if filename:
266 if filename:
264 filename = (filename.rpartition('/')[-1]
267 filename = (filename.rpartition('/')[-1]
265 .replace('\\', '\\\\').replace('"', '\\"'))
268 .replace('\\', '\\\\').replace('"', '\\"'))
266 self.headers.append(('Content-Disposition',
269 self.headers.append(('Content-Disposition',
267 'inline; filename="%s"' % filename))
270 'inline; filename="%s"' % filename))
268 if body is not None:
271 if body is not None:
269 self.headers.append((r'Content-Length', str(len(body))))
272 self.headers.append((r'Content-Length', str(len(body))))
270
273
271 for k, v in self.headers:
274 for k, v in self.headers:
272 if not isinstance(v, str):
275 if not isinstance(v, str):
273 raise TypeError('header value must be string: %r' % (v,))
276 raise TypeError('header value must be string: %r' % (v,))
274
277
275 if isinstance(status, ErrorResponse):
278 if isinstance(status, ErrorResponse):
276 self.headers.extend(status.headers)
279 self.headers.extend(status.headers)
277 if status.code == HTTP_NOT_MODIFIED:
280 if status.code == HTTP_NOT_MODIFIED:
278 # RFC 2616 Section 10.3.5: 304 Not Modified has cases where
281 # RFC 2616 Section 10.3.5: 304 Not Modified has cases where
279 # it MUST NOT include any headers other than these and no
282 # it MUST NOT include any headers other than these and no
280 # body
283 # body
281 self.headers = [(k, v) for (k, v) in self.headers if
284 self.headers = [(k, v) for (k, v) in self.headers if
282 k in ('Date', 'ETag', 'Expires',
285 k in ('Date', 'ETag', 'Expires',
283 'Cache-Control', 'Vary')]
286 'Cache-Control', 'Vary')]
284 status = statusmessage(status.code, pycompat.bytestr(status))
287 status = statusmessage(status.code, pycompat.bytestr(status))
285 elif status == 200:
288 elif status == 200:
286 status = '200 Script output follows'
289 status = '200 Script output follows'
287 elif isinstance(status, int):
290 elif isinstance(status, int):
288 status = statusmessage(status)
291 status = statusmessage(status)
289
292
290 self.server_write = self._start_response(
293 self.server_write = self._start_response(
291 pycompat.sysstr(status), self.headers)
294 pycompat.sysstr(status), self.headers)
292 self._start_response = None
295 self._start_response = None
293 self.headers = []
296 self.headers = []
294 if body is not None:
297 if body is not None:
295 self.write(body)
298 self.write(body)
296 self.server_write = None
299 self.server_write = None
297
300
298 def write(self, thing):
301 def write(self, thing):
299 if thing:
302 if thing:
300 try:
303 try:
301 self.server_write(thing)
304 self.server_write(thing)
302 except socket.error as inst:
305 except socket.error as inst:
303 if inst[0] != errno.ECONNRESET:
306 if inst[0] != errno.ECONNRESET:
304 raise
307 raise
305
308
306 def writelines(self, lines):
309 def writelines(self, lines):
307 for line in lines:
310 for line in lines:
308 self.write(line)
311 self.write(line)
309
312
310 def flush(self):
313 def flush(self):
311 return None
314 return None
312
315
313 def close(self):
316 def close(self):
314 return None
317 return None
315
318
316 def wsgiapplication(app_maker):
319 def wsgiapplication(app_maker):
317 '''For compatibility with old CGI scripts. A plain hgweb() or hgwebdir()
320 '''For compatibility with old CGI scripts. A plain hgweb() or hgwebdir()
318 can and should now be used as a WSGI application.'''
321 can and should now be used as a WSGI application.'''
319 application = app_maker()
322 application = app_maker()
320 def run_wsgi(env, respond):
323 def run_wsgi(env, respond):
321 return application(env, respond)
324 return application(env, respond)
322 return run_wsgi
325 return run_wsgi
@@ -1,667 +1,667
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 #
3 #
4 # This software may be used and distributed according to the terms of the
4 # This software may be used and distributed according to the terms of the
5 # GNU General Public License version 2 or any later version.
5 # GNU General Public License version 2 or any later version.
6
6
7 from __future__ import absolute_import
7 from __future__ import absolute_import
8
8
9 import contextlib
9 import contextlib
10 import struct
10 import struct
11 import sys
11 import sys
12 import threading
12 import threading
13
13
14 from .i18n import _
14 from .i18n import _
15 from . import (
15 from . import (
16 encoding,
16 encoding,
17 error,
17 error,
18 hook,
18 hook,
19 pycompat,
19 pycompat,
20 util,
20 util,
21 wireproto,
21 wireproto,
22 wireprototypes,
22 wireprototypes,
23 )
23 )
24
24
25 stringio = util.stringio
25 stringio = util.stringio
26
26
27 urlerr = util.urlerr
27 urlerr = util.urlerr
28 urlreq = util.urlreq
28 urlreq = util.urlreq
29
29
30 HTTP_OK = 200
30 HTTP_OK = 200
31
31
32 HGTYPE = 'application/mercurial-0.1'
32 HGTYPE = 'application/mercurial-0.1'
33 HGTYPE2 = 'application/mercurial-0.2'
33 HGTYPE2 = 'application/mercurial-0.2'
34 HGERRTYPE = 'application/hg-error'
34 HGERRTYPE = 'application/hg-error'
35
35
36 SSHV1 = wireprototypes.SSHV1
36 SSHV1 = wireprototypes.SSHV1
37 SSHV2 = wireprototypes.SSHV2
37 SSHV2 = wireprototypes.SSHV2
38
38
39 def decodevaluefromheaders(req, headerprefix):
39 def decodevaluefromheaders(req, headerprefix):
40 """Decode a long value from multiple HTTP request headers.
40 """Decode a long value from multiple HTTP request headers.
41
41
42 Returns the value as a bytes, not a str.
42 Returns the value as a bytes, not a str.
43 """
43 """
44 chunks = []
44 chunks = []
45 i = 1
45 i = 1
46 while True:
46 while True:
47 v = req.headers.get(b'%s-%d' % (headerprefix, i))
47 v = req.headers.get(b'%s-%d' % (headerprefix, i))
48 if v is None:
48 if v is None:
49 break
49 break
50 chunks.append(pycompat.bytesurl(v))
50 chunks.append(pycompat.bytesurl(v))
51 i += 1
51 i += 1
52
52
53 return ''.join(chunks)
53 return ''.join(chunks)
54
54
55 class httpv1protocolhandler(wireprototypes.baseprotocolhandler):
55 class httpv1protocolhandler(wireprototypes.baseprotocolhandler):
56 def __init__(self, wsgireq, req, ui, checkperm):
56 def __init__(self, wsgireq, req, ui, checkperm):
57 self._wsgireq = wsgireq
57 self._wsgireq = wsgireq
58 self._req = req
58 self._req = req
59 self._ui = ui
59 self._ui = ui
60 self._checkperm = checkperm
60 self._checkperm = checkperm
61
61
62 @property
62 @property
63 def name(self):
63 def name(self):
64 return 'http-v1'
64 return 'http-v1'
65
65
66 def getargs(self, args):
66 def getargs(self, args):
67 knownargs = self._args()
67 knownargs = self._args()
68 data = {}
68 data = {}
69 keys = args.split()
69 keys = args.split()
70 for k in keys:
70 for k in keys:
71 if k == '*':
71 if k == '*':
72 star = {}
72 star = {}
73 for key in knownargs.keys():
73 for key in knownargs.keys():
74 if key != 'cmd' and key not in keys:
74 if key != 'cmd' and key not in keys:
75 star[key] = knownargs[key][0]
75 star[key] = knownargs[key][0]
76 data['*'] = star
76 data['*'] = star
77 else:
77 else:
78 data[k] = knownargs[k][0]
78 data[k] = knownargs[k][0]
79 return [data[k] for k in keys]
79 return [data[k] for k in keys]
80
80
81 def _args(self):
81 def _args(self):
82 args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
82 args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
83 postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
83 postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
84 if postlen:
84 if postlen:
85 args.update(urlreq.parseqs(
85 args.update(urlreq.parseqs(
86 self._wsgireq.read(postlen), keep_blank_values=True))
86 self._wsgireq.read(postlen), keep_blank_values=True))
87 return args
87 return args
88
88
89 argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
89 argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
90 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
90 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
91 return args
91 return args
92
92
93 def forwardpayload(self, fp):
93 def forwardpayload(self, fp):
94 # Existing clients *always* send Content-Length.
94 # Existing clients *always* send Content-Length.
95 length = int(self._req.headers[b'Content-Length'])
95 length = int(self._req.headers[b'Content-Length'])
96
96
97 # If httppostargs is used, we need to read Content-Length
97 # If httppostargs is used, we need to read Content-Length
98 # minus the amount that was consumed by args.
98 # minus the amount that was consumed by args.
99 length -= int(self._req.headers.get(b'X-HgArgs-Post', 0))
99 length -= int(self._req.headers.get(b'X-HgArgs-Post', 0))
100 for s in util.filechunkiter(self._wsgireq, limit=length):
100 for s in util.filechunkiter(self._wsgireq, limit=length):
101 fp.write(s)
101 fp.write(s)
102
102
103 @contextlib.contextmanager
103 @contextlib.contextmanager
104 def mayberedirectstdio(self):
104 def mayberedirectstdio(self):
105 oldout = self._ui.fout
105 oldout = self._ui.fout
106 olderr = self._ui.ferr
106 olderr = self._ui.ferr
107
107
108 out = util.stringio()
108 out = util.stringio()
109
109
110 try:
110 try:
111 self._ui.fout = out
111 self._ui.fout = out
112 self._ui.ferr = out
112 self._ui.ferr = out
113 yield out
113 yield out
114 finally:
114 finally:
115 self._ui.fout = oldout
115 self._ui.fout = oldout
116 self._ui.ferr = olderr
116 self._ui.ferr = olderr
117
117
118 def client(self):
118 def client(self):
119 return 'remote:%s:%s:%s' % (
119 return 'remote:%s:%s:%s' % (
120 self._wsgireq.env.get('wsgi.url_scheme') or 'http',
120 self._wsgireq.env.get('wsgi.url_scheme') or 'http',
121 urlreq.quote(self._wsgireq.env.get('REMOTE_HOST', '')),
121 urlreq.quote(self._wsgireq.env.get('REMOTE_HOST', '')),
122 urlreq.quote(self._wsgireq.env.get('REMOTE_USER', '')))
122 urlreq.quote(self._wsgireq.env.get('REMOTE_USER', '')))
123
123
124 def addcapabilities(self, repo, caps):
124 def addcapabilities(self, repo, caps):
125 caps.append('httpheader=%d' %
125 caps.append('httpheader=%d' %
126 repo.ui.configint('server', 'maxhttpheaderlen'))
126 repo.ui.configint('server', 'maxhttpheaderlen'))
127 if repo.ui.configbool('experimental', 'httppostargs'):
127 if repo.ui.configbool('experimental', 'httppostargs'):
128 caps.append('httppostargs')
128 caps.append('httppostargs')
129
129
130 # FUTURE advertise 0.2rx once support is implemented
130 # FUTURE advertise 0.2rx once support is implemented
131 # FUTURE advertise minrx and mintx after consulting config option
131 # FUTURE advertise minrx and mintx after consulting config option
132 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
132 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
133
133
134 compengines = wireproto.supportedcompengines(repo.ui, util.SERVERROLE)
134 compengines = wireproto.supportedcompengines(repo.ui, util.SERVERROLE)
135 if compengines:
135 if compengines:
136 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
136 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
137 for e in compengines)
137 for e in compengines)
138 caps.append('compression=%s' % comptypes)
138 caps.append('compression=%s' % comptypes)
139
139
140 return caps
140 return caps
141
141
142 def checkperm(self, perm):
142 def checkperm(self, perm):
143 return self._checkperm(perm)
143 return self._checkperm(perm)
144
144
145 # This method exists mostly so that extensions like remotefilelog can
145 # This method exists mostly so that extensions like remotefilelog can
146 # disable a kludgey legacy method only over http. As of early 2018,
146 # disable a kludgey legacy method only over http. As of early 2018,
147 # there are no other known users, so with any luck we can discard this
147 # there are no other known users, so with any luck we can discard this
148 # hook if remotefilelog becomes a first-party extension.
148 # hook if remotefilelog becomes a first-party extension.
149 def iscmd(cmd):
149 def iscmd(cmd):
150 return cmd in wireproto.commands
150 return cmd in wireproto.commands
151
151
152 def handlewsgirequest(rctx, wsgireq, req, checkperm):
152 def handlewsgirequest(rctx, wsgireq, req, checkperm):
153 """Possibly process a wire protocol request.
153 """Possibly process a wire protocol request.
154
154
155 If the current request is a wire protocol request, the request is
155 If the current request is a wire protocol request, the request is
156 processed by this function.
156 processed by this function.
157
157
158 ``wsgireq`` is a ``wsgirequest`` instance.
158 ``wsgireq`` is a ``wsgirequest`` instance.
159 ``req`` is a ``parsedrequest`` instance.
159 ``req`` is a ``parsedrequest`` instance.
160
160
161 Returns a 2-tuple of (bool, response) where the 1st element indicates
161 Returns a 2-tuple of (bool, response) where the 1st element indicates
162 whether the request was handled and the 2nd element is a return
162 whether the request was handled and the 2nd element is a return
163 value for a WSGI application (often a generator of bytes).
163 value for a WSGI application (often a generator of bytes).
164 """
164 """
165 # Avoid cycle involving hg module.
165 # Avoid cycle involving hg module.
166 from .hgweb import common as hgwebcommon
166 from .hgweb import common as hgwebcommon
167
167
168 repo = rctx.repo
168 repo = rctx.repo
169
169
170 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
170 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
171 # string parameter. If it isn't present, this isn't a wire protocol
171 # string parameter. If it isn't present, this isn't a wire protocol
172 # request.
172 # request.
173 if 'cmd' not in req.querystringdict:
173 if 'cmd' not in req.querystringdict:
174 return False, None
174 return False, None
175
175
176 cmd = req.querystringdict['cmd'][0]
176 cmd = req.querystringdict['cmd'][0]
177
177
178 # The "cmd" request parameter is used by both the wire protocol and hgweb.
178 # The "cmd" request parameter is used by both the wire protocol and hgweb.
179 # While not all wire protocol commands are available for all transports,
179 # While not all wire protocol commands are available for all transports,
180 # if we see a "cmd" value that resembles a known wire protocol command, we
180 # if we see a "cmd" value that resembles a known wire protocol command, we
181 # route it to a protocol handler. This is better than routing possible
181 # route it to a protocol handler. This is better than routing possible
182 # wire protocol requests to hgweb because it prevents hgweb from using
182 # wire protocol requests to hgweb because it prevents hgweb from using
183 # known wire protocol commands and it is less confusing for machine
183 # known wire protocol commands and it is less confusing for machine
184 # clients.
184 # clients.
185 if not iscmd(cmd):
185 if not iscmd(cmd):
186 return False, None
186 return False, None
187
187
188 # The "cmd" query string argument is only valid on the root path of the
188 # The "cmd" query string argument is only valid on the root path of the
189 # repo. e.g. ``/?cmd=foo``, ``/repo?cmd=foo``. URL paths within the repo
189 # repo. e.g. ``/?cmd=foo``, ``/repo?cmd=foo``. URL paths within the repo
190 # like ``/blah?cmd=foo`` are not allowed. So don't recognize the request
190 # like ``/blah?cmd=foo`` are not allowed. So don't recognize the request
191 # in this case. We send an HTTP 404 for backwards compatibility reasons.
191 # in this case. We send an HTTP 404 for backwards compatibility reasons.
192 if req.dispatchpath:
192 if req.dispatchpath:
193 res = _handlehttperror(
193 res = _handlehttperror(
194 hgwebcommon.ErrorResponse(hgwebcommon.HTTP_NOT_FOUND), wsgireq,
194 hgwebcommon.ErrorResponse(hgwebcommon.HTTP_NOT_FOUND), wsgireq,
195 req, cmd)
195 req, cmd)
196
196
197 return True, res
197 return True, res
198
198
199 proto = httpv1protocolhandler(wsgireq, req, repo.ui,
199 proto = httpv1protocolhandler(wsgireq, req, repo.ui,
200 lambda perm: checkperm(rctx, wsgireq, perm))
200 lambda perm: checkperm(rctx, wsgireq, perm))
201
201
202 # The permissions checker should be the only thing that can raise an
202 # The permissions checker should be the only thing that can raise an
203 # ErrorResponse. It is kind of a layer violation to catch an hgweb
203 # ErrorResponse. It is kind of a layer violation to catch an hgweb
204 # exception here. So consider refactoring into a exception type that
204 # exception here. So consider refactoring into a exception type that
205 # is associated with the wire protocol.
205 # is associated with the wire protocol.
206 try:
206 try:
207 res = _callhttp(repo, wsgireq, req, proto, cmd)
207 res = _callhttp(repo, wsgireq, req, proto, cmd)
208 except hgwebcommon.ErrorResponse as e:
208 except hgwebcommon.ErrorResponse as e:
209 res = _handlehttperror(e, wsgireq, req, cmd)
209 res = _handlehttperror(e, wsgireq, req, cmd)
210
210
211 return True, res
211 return True, res
212
212
213 def _httpresponsetype(ui, req, prefer_uncompressed):
213 def _httpresponsetype(ui, req, prefer_uncompressed):
214 """Determine the appropriate response type and compression settings.
214 """Determine the appropriate response type and compression settings.
215
215
216 Returns a tuple of (mediatype, compengine, engineopts).
216 Returns a tuple of (mediatype, compengine, engineopts).
217 """
217 """
218 # Determine the response media type and compression engine based
218 # Determine the response media type and compression engine based
219 # on the request parameters.
219 # on the request parameters.
220 protocaps = decodevaluefromheaders(req, 'X-HgProto').split(' ')
220 protocaps = decodevaluefromheaders(req, 'X-HgProto').split(' ')
221
221
222 if '0.2' in protocaps:
222 if '0.2' in protocaps:
223 # All clients are expected to support uncompressed data.
223 # All clients are expected to support uncompressed data.
224 if prefer_uncompressed:
224 if prefer_uncompressed:
225 return HGTYPE2, util._noopengine(), {}
225 return HGTYPE2, util._noopengine(), {}
226
226
227 # Default as defined by wire protocol spec.
227 # Default as defined by wire protocol spec.
228 compformats = ['zlib', 'none']
228 compformats = ['zlib', 'none']
229 for cap in protocaps:
229 for cap in protocaps:
230 if cap.startswith('comp='):
230 if cap.startswith('comp='):
231 compformats = cap[5:].split(',')
231 compformats = cap[5:].split(',')
232 break
232 break
233
233
234 # Now find an agreed upon compression format.
234 # Now find an agreed upon compression format.
235 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
235 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
236 if engine.wireprotosupport().name in compformats:
236 if engine.wireprotosupport().name in compformats:
237 opts = {}
237 opts = {}
238 level = ui.configint('server', '%slevel' % engine.name())
238 level = ui.configint('server', '%slevel' % engine.name())
239 if level is not None:
239 if level is not None:
240 opts['level'] = level
240 opts['level'] = level
241
241
242 return HGTYPE2, engine, opts
242 return HGTYPE2, engine, opts
243
243
244 # No mutually supported compression format. Fall back to the
244 # No mutually supported compression format. Fall back to the
245 # legacy protocol.
245 # legacy protocol.
246
246
247 # Don't allow untrusted settings because disabling compression or
247 # Don't allow untrusted settings because disabling compression or
248 # setting a very high compression level could lead to flooding
248 # setting a very high compression level could lead to flooding
249 # the server's network or CPU.
249 # the server's network or CPU.
250 opts = {'level': ui.configint('server', 'zliblevel')}
250 opts = {'level': ui.configint('server', 'zliblevel')}
251 return HGTYPE, util.compengines['zlib'], opts
251 return HGTYPE, util.compengines['zlib'], opts
252
252
253 def _callhttp(repo, wsgireq, req, proto, cmd):
253 def _callhttp(repo, wsgireq, req, proto, cmd):
254 def genversion2(gen, engine, engineopts):
254 def genversion2(gen, engine, engineopts):
255 # application/mercurial-0.2 always sends a payload header
255 # application/mercurial-0.2 always sends a payload header
256 # identifying the compression engine.
256 # identifying the compression engine.
257 name = engine.wireprotosupport().name
257 name = engine.wireprotosupport().name
258 assert 0 < len(name) < 256
258 assert 0 < len(name) < 256
259 yield struct.pack('B', len(name))
259 yield struct.pack('B', len(name))
260 yield name
260 yield name
261
261
262 for chunk in gen:
262 for chunk in gen:
263 yield chunk
263 yield chunk
264
264
265 if not wireproto.commands.commandavailable(cmd, proto):
265 if not wireproto.commands.commandavailable(cmd, proto):
266 wsgireq.respond(HTTP_OK, HGERRTYPE,
266 wsgireq.respond(HTTP_OK, HGERRTYPE,
267 body=_('requested wire protocol command is not '
267 body=_('requested wire protocol command is not '
268 'available over HTTP'))
268 'available over HTTP'))
269 return []
269 return []
270
270
271 proto.checkperm(wireproto.commands[cmd].permission)
271 proto.checkperm(wireproto.commands[cmd].permission)
272
272
273 rsp = wireproto.dispatch(repo, proto, cmd)
273 rsp = wireproto.dispatch(repo, proto, cmd)
274
274
275 if isinstance(rsp, bytes):
275 if isinstance(rsp, bytes):
276 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
276 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
277 return []
277 return []
278 elif isinstance(rsp, wireprototypes.bytesresponse):
278 elif isinstance(rsp, wireprototypes.bytesresponse):
279 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp.data)
279 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp.data)
280 return []
280 return []
281 elif isinstance(rsp, wireprototypes.streamreslegacy):
281 elif isinstance(rsp, wireprototypes.streamreslegacy):
282 gen = rsp.gen
282 gen = rsp.gen
283 wsgireq.respond(HTTP_OK, HGTYPE)
283 wsgireq.respond(HTTP_OK, HGTYPE)
284 return gen
284 return gen
285 elif isinstance(rsp, wireprototypes.streamres):
285 elif isinstance(rsp, wireprototypes.streamres):
286 gen = rsp.gen
286 gen = rsp.gen
287
287
288 # This code for compression should not be streamres specific. It
288 # This code for compression should not be streamres specific. It
289 # is here because we only compress streamres at the moment.
289 # is here because we only compress streamres at the moment.
290 mediatype, engine, engineopts = _httpresponsetype(
290 mediatype, engine, engineopts = _httpresponsetype(
291 repo.ui, req, rsp.prefer_uncompressed)
291 repo.ui, req, rsp.prefer_uncompressed)
292 gen = engine.compressstream(gen, engineopts)
292 gen = engine.compressstream(gen, engineopts)
293
293
294 if mediatype == HGTYPE2:
294 if mediatype == HGTYPE2:
295 gen = genversion2(gen, engine, engineopts)
295 gen = genversion2(gen, engine, engineopts)
296
296
297 wsgireq.respond(HTTP_OK, mediatype)
297 wsgireq.respond(HTTP_OK, mediatype)
298 return gen
298 return gen
299 elif isinstance(rsp, wireprototypes.pushres):
299 elif isinstance(rsp, wireprototypes.pushres):
300 rsp = '%d\n%s' % (rsp.res, rsp.output)
300 rsp = '%d\n%s' % (rsp.res, rsp.output)
301 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
301 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
302 return []
302 return []
303 elif isinstance(rsp, wireprototypes.pusherr):
303 elif isinstance(rsp, wireprototypes.pusherr):
304 # This is the httplib workaround documented in _handlehttperror().
304 # This is the httplib workaround documented in _handlehttperror().
305 wsgireq.drain()
305 wsgireq.drain()
306
306
307 rsp = '0\n%s\n' % rsp.res
307 rsp = '0\n%s\n' % rsp.res
308 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
308 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
309 return []
309 return []
310 elif isinstance(rsp, wireprototypes.ooberror):
310 elif isinstance(rsp, wireprototypes.ooberror):
311 rsp = rsp.message
311 rsp = rsp.message
312 wsgireq.respond(HTTP_OK, HGERRTYPE, body=rsp)
312 wsgireq.respond(HTTP_OK, HGERRTYPE, body=rsp)
313 return []
313 return []
314 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
314 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
315
315
316 def _handlehttperror(e, wsgireq, req, cmd):
316 def _handlehttperror(e, wsgireq, req, cmd):
317 """Called when an ErrorResponse is raised during HTTP request processing."""
317 """Called when an ErrorResponse is raised during HTTP request processing."""
318
318
319 # Clients using Python's httplib are stateful: the HTTP client
319 # Clients using Python's httplib are stateful: the HTTP client
320 # won't process an HTTP response until all request data is
320 # won't process an HTTP response until all request data is
321 # sent to the server. The intent of this code is to ensure
321 # sent to the server. The intent of this code is to ensure
322 # we always read HTTP request data from the client, thus
322 # we always read HTTP request data from the client, thus
323 # ensuring httplib transitions to a state that allows it to read
323 # ensuring httplib transitions to a state that allows it to read
324 # the HTTP response. In other words, it helps prevent deadlocks
324 # the HTTP response. In other words, it helps prevent deadlocks
325 # on clients using httplib.
325 # on clients using httplib.
326
326
327 if (wsgireq.env[r'REQUEST_METHOD'] == r'POST' and
327 if (req.method == 'POST' and
328 # But not if Expect: 100-continue is being used.
328 # But not if Expect: 100-continue is being used.
329 (req.headers.get('Expect', '').lower() != '100-continue')):
329 (req.headers.get('Expect', '').lower() != '100-continue')):
330 wsgireq.drain()
330 wsgireq.drain()
331 else:
331 else:
332 wsgireq.headers.append((r'Connection', r'Close'))
332 wsgireq.headers.append((r'Connection', r'Close'))
333
333
334 # TODO This response body assumes the failed command was
334 # TODO This response body assumes the failed command was
335 # "unbundle." That assumption is not always valid.
335 # "unbundle." That assumption is not always valid.
336 wsgireq.respond(e, HGTYPE, body='0\n%s\n' % pycompat.bytestr(e))
336 wsgireq.respond(e, HGTYPE, body='0\n%s\n' % pycompat.bytestr(e))
337
337
338 return ''
338 return ''
339
339
340 def _sshv1respondbytes(fout, value):
340 def _sshv1respondbytes(fout, value):
341 """Send a bytes response for protocol version 1."""
341 """Send a bytes response for protocol version 1."""
342 fout.write('%d\n' % len(value))
342 fout.write('%d\n' % len(value))
343 fout.write(value)
343 fout.write(value)
344 fout.flush()
344 fout.flush()
345
345
346 def _sshv1respondstream(fout, source):
346 def _sshv1respondstream(fout, source):
347 write = fout.write
347 write = fout.write
348 for chunk in source.gen:
348 for chunk in source.gen:
349 write(chunk)
349 write(chunk)
350 fout.flush()
350 fout.flush()
351
351
352 def _sshv1respondooberror(fout, ferr, rsp):
352 def _sshv1respondooberror(fout, ferr, rsp):
353 ferr.write(b'%s\n-\n' % rsp)
353 ferr.write(b'%s\n-\n' % rsp)
354 ferr.flush()
354 ferr.flush()
355 fout.write(b'\n')
355 fout.write(b'\n')
356 fout.flush()
356 fout.flush()
357
357
358 class sshv1protocolhandler(wireprototypes.baseprotocolhandler):
358 class sshv1protocolhandler(wireprototypes.baseprotocolhandler):
359 """Handler for requests services via version 1 of SSH protocol."""
359 """Handler for requests services via version 1 of SSH protocol."""
360 def __init__(self, ui, fin, fout):
360 def __init__(self, ui, fin, fout):
361 self._ui = ui
361 self._ui = ui
362 self._fin = fin
362 self._fin = fin
363 self._fout = fout
363 self._fout = fout
364
364
365 @property
365 @property
366 def name(self):
366 def name(self):
367 return wireprototypes.SSHV1
367 return wireprototypes.SSHV1
368
368
369 def getargs(self, args):
369 def getargs(self, args):
370 data = {}
370 data = {}
371 keys = args.split()
371 keys = args.split()
372 for n in xrange(len(keys)):
372 for n in xrange(len(keys)):
373 argline = self._fin.readline()[:-1]
373 argline = self._fin.readline()[:-1]
374 arg, l = argline.split()
374 arg, l = argline.split()
375 if arg not in keys:
375 if arg not in keys:
376 raise error.Abort(_("unexpected parameter %r") % arg)
376 raise error.Abort(_("unexpected parameter %r") % arg)
377 if arg == '*':
377 if arg == '*':
378 star = {}
378 star = {}
379 for k in xrange(int(l)):
379 for k in xrange(int(l)):
380 argline = self._fin.readline()[:-1]
380 argline = self._fin.readline()[:-1]
381 arg, l = argline.split()
381 arg, l = argline.split()
382 val = self._fin.read(int(l))
382 val = self._fin.read(int(l))
383 star[arg] = val
383 star[arg] = val
384 data['*'] = star
384 data['*'] = star
385 else:
385 else:
386 val = self._fin.read(int(l))
386 val = self._fin.read(int(l))
387 data[arg] = val
387 data[arg] = val
388 return [data[k] for k in keys]
388 return [data[k] for k in keys]
389
389
390 def forwardpayload(self, fpout):
390 def forwardpayload(self, fpout):
391 # We initially send an empty response. This tells the client it is
391 # We initially send an empty response. This tells the client it is
392 # OK to start sending data. If a client sees any other response, it
392 # OK to start sending data. If a client sees any other response, it
393 # interprets it as an error.
393 # interprets it as an error.
394 _sshv1respondbytes(self._fout, b'')
394 _sshv1respondbytes(self._fout, b'')
395
395
396 # The file is in the form:
396 # The file is in the form:
397 #
397 #
398 # <chunk size>\n<chunk>
398 # <chunk size>\n<chunk>
399 # ...
399 # ...
400 # 0\n
400 # 0\n
401 count = int(self._fin.readline())
401 count = int(self._fin.readline())
402 while count:
402 while count:
403 fpout.write(self._fin.read(count))
403 fpout.write(self._fin.read(count))
404 count = int(self._fin.readline())
404 count = int(self._fin.readline())
405
405
406 @contextlib.contextmanager
406 @contextlib.contextmanager
407 def mayberedirectstdio(self):
407 def mayberedirectstdio(self):
408 yield None
408 yield None
409
409
410 def client(self):
410 def client(self):
411 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
411 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
412 return 'remote:ssh:' + client
412 return 'remote:ssh:' + client
413
413
414 def addcapabilities(self, repo, caps):
414 def addcapabilities(self, repo, caps):
415 return caps
415 return caps
416
416
417 def checkperm(self, perm):
417 def checkperm(self, perm):
418 pass
418 pass
419
419
420 class sshv2protocolhandler(sshv1protocolhandler):
420 class sshv2protocolhandler(sshv1protocolhandler):
421 """Protocol handler for version 2 of the SSH protocol."""
421 """Protocol handler for version 2 of the SSH protocol."""
422
422
423 @property
423 @property
424 def name(self):
424 def name(self):
425 return wireprototypes.SSHV2
425 return wireprototypes.SSHV2
426
426
427 def _runsshserver(ui, repo, fin, fout, ev):
427 def _runsshserver(ui, repo, fin, fout, ev):
428 # This function operates like a state machine of sorts. The following
428 # This function operates like a state machine of sorts. The following
429 # states are defined:
429 # states are defined:
430 #
430 #
431 # protov1-serving
431 # protov1-serving
432 # Server is in protocol version 1 serving mode. Commands arrive on
432 # Server is in protocol version 1 serving mode. Commands arrive on
433 # new lines. These commands are processed in this state, one command
433 # new lines. These commands are processed in this state, one command
434 # after the other.
434 # after the other.
435 #
435 #
436 # protov2-serving
436 # protov2-serving
437 # Server is in protocol version 2 serving mode.
437 # Server is in protocol version 2 serving mode.
438 #
438 #
439 # upgrade-initial
439 # upgrade-initial
440 # The server is going to process an upgrade request.
440 # The server is going to process an upgrade request.
441 #
441 #
442 # upgrade-v2-filter-legacy-handshake
442 # upgrade-v2-filter-legacy-handshake
443 # The protocol is being upgraded to version 2. The server is expecting
443 # The protocol is being upgraded to version 2. The server is expecting
444 # the legacy handshake from version 1.
444 # the legacy handshake from version 1.
445 #
445 #
446 # upgrade-v2-finish
446 # upgrade-v2-finish
447 # The upgrade to version 2 of the protocol is imminent.
447 # The upgrade to version 2 of the protocol is imminent.
448 #
448 #
449 # shutdown
449 # shutdown
450 # The server is shutting down, possibly in reaction to a client event.
450 # The server is shutting down, possibly in reaction to a client event.
451 #
451 #
452 # And here are their transitions:
452 # And here are their transitions:
453 #
453 #
454 # protov1-serving -> shutdown
454 # protov1-serving -> shutdown
455 # When server receives an empty request or encounters another
455 # When server receives an empty request or encounters another
456 # error.
456 # error.
457 #
457 #
458 # protov1-serving -> upgrade-initial
458 # protov1-serving -> upgrade-initial
459 # An upgrade request line was seen.
459 # An upgrade request line was seen.
460 #
460 #
461 # upgrade-initial -> upgrade-v2-filter-legacy-handshake
461 # upgrade-initial -> upgrade-v2-filter-legacy-handshake
462 # Upgrade to version 2 in progress. Server is expecting to
462 # Upgrade to version 2 in progress. Server is expecting to
463 # process a legacy handshake.
463 # process a legacy handshake.
464 #
464 #
465 # upgrade-v2-filter-legacy-handshake -> shutdown
465 # upgrade-v2-filter-legacy-handshake -> shutdown
466 # Client did not fulfill upgrade handshake requirements.
466 # Client did not fulfill upgrade handshake requirements.
467 #
467 #
468 # upgrade-v2-filter-legacy-handshake -> upgrade-v2-finish
468 # upgrade-v2-filter-legacy-handshake -> upgrade-v2-finish
469 # Client fulfilled version 2 upgrade requirements. Finishing that
469 # Client fulfilled version 2 upgrade requirements. Finishing that
470 # upgrade.
470 # upgrade.
471 #
471 #
472 # upgrade-v2-finish -> protov2-serving
472 # upgrade-v2-finish -> protov2-serving
473 # Protocol upgrade to version 2 complete. Server can now speak protocol
473 # Protocol upgrade to version 2 complete. Server can now speak protocol
474 # version 2.
474 # version 2.
475 #
475 #
476 # protov2-serving -> protov1-serving
476 # protov2-serving -> protov1-serving
477 # Ths happens by default since protocol version 2 is the same as
477 # Ths happens by default since protocol version 2 is the same as
478 # version 1 except for the handshake.
478 # version 1 except for the handshake.
479
479
480 state = 'protov1-serving'
480 state = 'protov1-serving'
481 proto = sshv1protocolhandler(ui, fin, fout)
481 proto = sshv1protocolhandler(ui, fin, fout)
482 protoswitched = False
482 protoswitched = False
483
483
484 while not ev.is_set():
484 while not ev.is_set():
485 if state == 'protov1-serving':
485 if state == 'protov1-serving':
486 # Commands are issued on new lines.
486 # Commands are issued on new lines.
487 request = fin.readline()[:-1]
487 request = fin.readline()[:-1]
488
488
489 # Empty lines signal to terminate the connection.
489 # Empty lines signal to terminate the connection.
490 if not request:
490 if not request:
491 state = 'shutdown'
491 state = 'shutdown'
492 continue
492 continue
493
493
494 # It looks like a protocol upgrade request. Transition state to
494 # It looks like a protocol upgrade request. Transition state to
495 # handle it.
495 # handle it.
496 if request.startswith(b'upgrade '):
496 if request.startswith(b'upgrade '):
497 if protoswitched:
497 if protoswitched:
498 _sshv1respondooberror(fout, ui.ferr,
498 _sshv1respondooberror(fout, ui.ferr,
499 b'cannot upgrade protocols multiple '
499 b'cannot upgrade protocols multiple '
500 b'times')
500 b'times')
501 state = 'shutdown'
501 state = 'shutdown'
502 continue
502 continue
503
503
504 state = 'upgrade-initial'
504 state = 'upgrade-initial'
505 continue
505 continue
506
506
507 available = wireproto.commands.commandavailable(request, proto)
507 available = wireproto.commands.commandavailable(request, proto)
508
508
509 # This command isn't available. Send an empty response and go
509 # This command isn't available. Send an empty response and go
510 # back to waiting for a new command.
510 # back to waiting for a new command.
511 if not available:
511 if not available:
512 _sshv1respondbytes(fout, b'')
512 _sshv1respondbytes(fout, b'')
513 continue
513 continue
514
514
515 rsp = wireproto.dispatch(repo, proto, request)
515 rsp = wireproto.dispatch(repo, proto, request)
516
516
517 if isinstance(rsp, bytes):
517 if isinstance(rsp, bytes):
518 _sshv1respondbytes(fout, rsp)
518 _sshv1respondbytes(fout, rsp)
519 elif isinstance(rsp, wireprototypes.bytesresponse):
519 elif isinstance(rsp, wireprototypes.bytesresponse):
520 _sshv1respondbytes(fout, rsp.data)
520 _sshv1respondbytes(fout, rsp.data)
521 elif isinstance(rsp, wireprototypes.streamres):
521 elif isinstance(rsp, wireprototypes.streamres):
522 _sshv1respondstream(fout, rsp)
522 _sshv1respondstream(fout, rsp)
523 elif isinstance(rsp, wireprototypes.streamreslegacy):
523 elif isinstance(rsp, wireprototypes.streamreslegacy):
524 _sshv1respondstream(fout, rsp)
524 _sshv1respondstream(fout, rsp)
525 elif isinstance(rsp, wireprototypes.pushres):
525 elif isinstance(rsp, wireprototypes.pushres):
526 _sshv1respondbytes(fout, b'')
526 _sshv1respondbytes(fout, b'')
527 _sshv1respondbytes(fout, b'%d' % rsp.res)
527 _sshv1respondbytes(fout, b'%d' % rsp.res)
528 elif isinstance(rsp, wireprototypes.pusherr):
528 elif isinstance(rsp, wireprototypes.pusherr):
529 _sshv1respondbytes(fout, rsp.res)
529 _sshv1respondbytes(fout, rsp.res)
530 elif isinstance(rsp, wireprototypes.ooberror):
530 elif isinstance(rsp, wireprototypes.ooberror):
531 _sshv1respondooberror(fout, ui.ferr, rsp.message)
531 _sshv1respondooberror(fout, ui.ferr, rsp.message)
532 else:
532 else:
533 raise error.ProgrammingError('unhandled response type from '
533 raise error.ProgrammingError('unhandled response type from '
534 'wire protocol command: %s' % rsp)
534 'wire protocol command: %s' % rsp)
535
535
536 # For now, protocol version 2 serving just goes back to version 1.
536 # For now, protocol version 2 serving just goes back to version 1.
537 elif state == 'protov2-serving':
537 elif state == 'protov2-serving':
538 state = 'protov1-serving'
538 state = 'protov1-serving'
539 continue
539 continue
540
540
541 elif state == 'upgrade-initial':
541 elif state == 'upgrade-initial':
542 # We should never transition into this state if we've switched
542 # We should never transition into this state if we've switched
543 # protocols.
543 # protocols.
544 assert not protoswitched
544 assert not protoswitched
545 assert proto.name == wireprototypes.SSHV1
545 assert proto.name == wireprototypes.SSHV1
546
546
547 # Expected: upgrade <token> <capabilities>
547 # Expected: upgrade <token> <capabilities>
548 # If we get something else, the request is malformed. It could be
548 # If we get something else, the request is malformed. It could be
549 # from a future client that has altered the upgrade line content.
549 # from a future client that has altered the upgrade line content.
550 # We treat this as an unknown command.
550 # We treat this as an unknown command.
551 try:
551 try:
552 token, caps = request.split(b' ')[1:]
552 token, caps = request.split(b' ')[1:]
553 except ValueError:
553 except ValueError:
554 _sshv1respondbytes(fout, b'')
554 _sshv1respondbytes(fout, b'')
555 state = 'protov1-serving'
555 state = 'protov1-serving'
556 continue
556 continue
557
557
558 # Send empty response if we don't support upgrading protocols.
558 # Send empty response if we don't support upgrading protocols.
559 if not ui.configbool('experimental', 'sshserver.support-v2'):
559 if not ui.configbool('experimental', 'sshserver.support-v2'):
560 _sshv1respondbytes(fout, b'')
560 _sshv1respondbytes(fout, b'')
561 state = 'protov1-serving'
561 state = 'protov1-serving'
562 continue
562 continue
563
563
564 try:
564 try:
565 caps = urlreq.parseqs(caps)
565 caps = urlreq.parseqs(caps)
566 except ValueError:
566 except ValueError:
567 _sshv1respondbytes(fout, b'')
567 _sshv1respondbytes(fout, b'')
568 state = 'protov1-serving'
568 state = 'protov1-serving'
569 continue
569 continue
570
570
571 # We don't see an upgrade request to protocol version 2. Ignore
571 # We don't see an upgrade request to protocol version 2. Ignore
572 # the upgrade request.
572 # the upgrade request.
573 wantedprotos = caps.get(b'proto', [b''])[0]
573 wantedprotos = caps.get(b'proto', [b''])[0]
574 if SSHV2 not in wantedprotos:
574 if SSHV2 not in wantedprotos:
575 _sshv1respondbytes(fout, b'')
575 _sshv1respondbytes(fout, b'')
576 state = 'protov1-serving'
576 state = 'protov1-serving'
577 continue
577 continue
578
578
579 # It looks like we can honor this upgrade request to protocol 2.
579 # It looks like we can honor this upgrade request to protocol 2.
580 # Filter the rest of the handshake protocol request lines.
580 # Filter the rest of the handshake protocol request lines.
581 state = 'upgrade-v2-filter-legacy-handshake'
581 state = 'upgrade-v2-filter-legacy-handshake'
582 continue
582 continue
583
583
584 elif state == 'upgrade-v2-filter-legacy-handshake':
584 elif state == 'upgrade-v2-filter-legacy-handshake':
585 # Client should have sent legacy handshake after an ``upgrade``
585 # Client should have sent legacy handshake after an ``upgrade``
586 # request. Expected lines:
586 # request. Expected lines:
587 #
587 #
588 # hello
588 # hello
589 # between
589 # between
590 # pairs 81
590 # pairs 81
591 # 0000...-0000...
591 # 0000...-0000...
592
592
593 ok = True
593 ok = True
594 for line in (b'hello', b'between', b'pairs 81'):
594 for line in (b'hello', b'between', b'pairs 81'):
595 request = fin.readline()[:-1]
595 request = fin.readline()[:-1]
596
596
597 if request != line:
597 if request != line:
598 _sshv1respondooberror(fout, ui.ferr,
598 _sshv1respondooberror(fout, ui.ferr,
599 b'malformed handshake protocol: '
599 b'malformed handshake protocol: '
600 b'missing %s' % line)
600 b'missing %s' % line)
601 ok = False
601 ok = False
602 state = 'shutdown'
602 state = 'shutdown'
603 break
603 break
604
604
605 if not ok:
605 if not ok:
606 continue
606 continue
607
607
608 request = fin.read(81)
608 request = fin.read(81)
609 if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
609 if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
610 _sshv1respondooberror(fout, ui.ferr,
610 _sshv1respondooberror(fout, ui.ferr,
611 b'malformed handshake protocol: '
611 b'malformed handshake protocol: '
612 b'missing between argument value')
612 b'missing between argument value')
613 state = 'shutdown'
613 state = 'shutdown'
614 continue
614 continue
615
615
616 state = 'upgrade-v2-finish'
616 state = 'upgrade-v2-finish'
617 continue
617 continue
618
618
619 elif state == 'upgrade-v2-finish':
619 elif state == 'upgrade-v2-finish':
620 # Send the upgrade response.
620 # Send the upgrade response.
621 fout.write(b'upgraded %s %s\n' % (token, SSHV2))
621 fout.write(b'upgraded %s %s\n' % (token, SSHV2))
622 servercaps = wireproto.capabilities(repo, proto)
622 servercaps = wireproto.capabilities(repo, proto)
623 rsp = b'capabilities: %s' % servercaps.data
623 rsp = b'capabilities: %s' % servercaps.data
624 fout.write(b'%d\n%s\n' % (len(rsp), rsp))
624 fout.write(b'%d\n%s\n' % (len(rsp), rsp))
625 fout.flush()
625 fout.flush()
626
626
627 proto = sshv2protocolhandler(ui, fin, fout)
627 proto = sshv2protocolhandler(ui, fin, fout)
628 protoswitched = True
628 protoswitched = True
629
629
630 state = 'protov2-serving'
630 state = 'protov2-serving'
631 continue
631 continue
632
632
633 elif state == 'shutdown':
633 elif state == 'shutdown':
634 break
634 break
635
635
636 else:
636 else:
637 raise error.ProgrammingError('unhandled ssh server state: %s' %
637 raise error.ProgrammingError('unhandled ssh server state: %s' %
638 state)
638 state)
639
639
640 class sshserver(object):
640 class sshserver(object):
641 def __init__(self, ui, repo, logfh=None):
641 def __init__(self, ui, repo, logfh=None):
642 self._ui = ui
642 self._ui = ui
643 self._repo = repo
643 self._repo = repo
644 self._fin = ui.fin
644 self._fin = ui.fin
645 self._fout = ui.fout
645 self._fout = ui.fout
646
646
647 # Log write I/O to stdout and stderr if configured.
647 # Log write I/O to stdout and stderr if configured.
648 if logfh:
648 if logfh:
649 self._fout = util.makeloggingfileobject(
649 self._fout = util.makeloggingfileobject(
650 logfh, self._fout, 'o', logdata=True)
650 logfh, self._fout, 'o', logdata=True)
651 ui.ferr = util.makeloggingfileobject(
651 ui.ferr = util.makeloggingfileobject(
652 logfh, ui.ferr, 'e', logdata=True)
652 logfh, ui.ferr, 'e', logdata=True)
653
653
654 hook.redirect(True)
654 hook.redirect(True)
655 ui.fout = repo.ui.fout = ui.ferr
655 ui.fout = repo.ui.fout = ui.ferr
656
656
657 # Prevent insertion/deletion of CRs
657 # Prevent insertion/deletion of CRs
658 util.setbinary(self._fin)
658 util.setbinary(self._fin)
659 util.setbinary(self._fout)
659 util.setbinary(self._fout)
660
660
661 def serve_forever(self):
661 def serve_forever(self):
662 self.serveuntil(threading.Event())
662 self.serveuntil(threading.Event())
663 sys.exit(0)
663 sys.exit(0)
664
664
665 def serveuntil(self, ev):
665 def serveuntil(self, ev):
666 """Serve until a threading.Event is set."""
666 """Serve until a threading.Event is set."""
667 _runsshserver(self._ui, self._repo, self._fin, self._fout, ev)
667 _runsshserver(self._ui, self._repo, self._fin, self._fout, ev)
General Comments 0
You need to be logged in to leave comments. Login now