##// END OF EJS Templates
wireprotoserver: access headers through parsed request...
Gregory Szorc -
r36862:14f70c44 default
parent child Browse files
Show More
@@ -1,669 +1,668 b''
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 #
3 #
4 # This software may be used and distributed according to the terms of the
4 # This software may be used and distributed according to the terms of the
5 # GNU General Public License version 2 or any later version.
5 # GNU General Public License version 2 or any later version.
6
6
7 from __future__ import absolute_import
7 from __future__ import absolute_import
8
8
9 import contextlib
9 import contextlib
10 import struct
10 import struct
11 import sys
11 import sys
12 import threading
12 import threading
13
13
14 from .i18n import _
14 from .i18n import _
15 from . import (
15 from . import (
16 encoding,
16 encoding,
17 error,
17 error,
18 hook,
18 hook,
19 pycompat,
19 pycompat,
20 util,
20 util,
21 wireproto,
21 wireproto,
22 wireprototypes,
22 wireprototypes,
23 )
23 )
24
24
25 stringio = util.stringio
25 stringio = util.stringio
26
26
27 urlerr = util.urlerr
27 urlerr = util.urlerr
28 urlreq = util.urlreq
28 urlreq = util.urlreq
29
29
30 HTTP_OK = 200
30 HTTP_OK = 200
31
31
32 HGTYPE = 'application/mercurial-0.1'
32 HGTYPE = 'application/mercurial-0.1'
33 HGTYPE2 = 'application/mercurial-0.2'
33 HGTYPE2 = 'application/mercurial-0.2'
34 HGERRTYPE = 'application/hg-error'
34 HGERRTYPE = 'application/hg-error'
35
35
36 SSHV1 = wireprototypes.SSHV1
36 SSHV1 = wireprototypes.SSHV1
37 SSHV2 = wireprototypes.SSHV2
37 SSHV2 = wireprototypes.SSHV2
38
38
39 def decodevaluefromheaders(wsgireq, headerprefix):
39 def decodevaluefromheaders(req, headerprefix):
40 """Decode a long value from multiple HTTP request headers.
40 """Decode a long value from multiple HTTP request headers.
41
41
42 Returns the value as a bytes, not a str.
42 Returns the value as a bytes, not a str.
43 """
43 """
44 chunks = []
44 chunks = []
45 i = 1
45 i = 1
46 prefix = headerprefix.upper().replace(r'-', r'_')
47 while True:
46 while True:
48 v = wsgireq.env.get(r'HTTP_%s_%d' % (prefix, i))
47 v = req.headers.get(b'%s-%d' % (headerprefix, i))
49 if v is None:
48 if v is None:
50 break
49 break
51 chunks.append(pycompat.bytesurl(v))
50 chunks.append(pycompat.bytesurl(v))
52 i += 1
51 i += 1
53
52
54 return ''.join(chunks)
53 return ''.join(chunks)
55
54
56 class httpv1protocolhandler(wireprototypes.baseprotocolhandler):
55 class httpv1protocolhandler(wireprototypes.baseprotocolhandler):
57 def __init__(self, wsgireq, ui, checkperm):
56 def __init__(self, wsgireq, req, ui, checkperm):
58 self._wsgireq = wsgireq
57 self._wsgireq = wsgireq
58 self._req = req
59 self._ui = ui
59 self._ui = ui
60 self._checkperm = checkperm
60 self._checkperm = checkperm
61
61
62 @property
62 @property
63 def name(self):
63 def name(self):
64 return 'http-v1'
64 return 'http-v1'
65
65
66 def getargs(self, args):
66 def getargs(self, args):
67 knownargs = self._args()
67 knownargs = self._args()
68 data = {}
68 data = {}
69 keys = args.split()
69 keys = args.split()
70 for k in keys:
70 for k in keys:
71 if k == '*':
71 if k == '*':
72 star = {}
72 star = {}
73 for key in knownargs.keys():
73 for key in knownargs.keys():
74 if key != 'cmd' and key not in keys:
74 if key != 'cmd' and key not in keys:
75 star[key] = knownargs[key][0]
75 star[key] = knownargs[key][0]
76 data['*'] = star
76 data['*'] = star
77 else:
77 else:
78 data[k] = knownargs[k][0]
78 data[k] = knownargs[k][0]
79 return [data[k] for k in keys]
79 return [data[k] for k in keys]
80
80
81 def _args(self):
81 def _args(self):
82 args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
82 args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
83 postlen = int(self._wsgireq.env.get(r'HTTP_X_HGARGS_POST', 0))
83 postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
84 if postlen:
84 if postlen:
85 args.update(urlreq.parseqs(
85 args.update(urlreq.parseqs(
86 self._wsgireq.read(postlen), keep_blank_values=True))
86 self._wsgireq.read(postlen), keep_blank_values=True))
87 return args
87 return args
88
88
89 argvalue = decodevaluefromheaders(self._wsgireq, r'X-HgArg')
89 argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
90 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
90 args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
91 return args
91 return args
92
92
93 def forwardpayload(self, fp):
93 def forwardpayload(self, fp):
94 if r'HTTP_CONTENT_LENGTH' in self._wsgireq.env:
94 if b'Content-Length' in self._req.headers:
95 length = int(self._wsgireq.env[r'HTTP_CONTENT_LENGTH'])
95 length = int(self._req.headers[b'Content-Length'])
96 else:
96 else:
97 length = int(self._wsgireq.env[r'CONTENT_LENGTH'])
97 length = int(self._wsgireq.env[r'CONTENT_LENGTH'])
98 # If httppostargs is used, we need to read Content-Length
98 # If httppostargs is used, we need to read Content-Length
99 # minus the amount that was consumed by args.
99 # minus the amount that was consumed by args.
100 length -= int(self._wsgireq.env.get(r'HTTP_X_HGARGS_POST', 0))
100 length -= int(self._req.headers.get(b'X-HgArgs-Post', 0))
101 for s in util.filechunkiter(self._wsgireq, limit=length):
101 for s in util.filechunkiter(self._wsgireq, limit=length):
102 fp.write(s)
102 fp.write(s)
103
103
104 @contextlib.contextmanager
104 @contextlib.contextmanager
105 def mayberedirectstdio(self):
105 def mayberedirectstdio(self):
106 oldout = self._ui.fout
106 oldout = self._ui.fout
107 olderr = self._ui.ferr
107 olderr = self._ui.ferr
108
108
109 out = util.stringio()
109 out = util.stringio()
110
110
111 try:
111 try:
112 self._ui.fout = out
112 self._ui.fout = out
113 self._ui.ferr = out
113 self._ui.ferr = out
114 yield out
114 yield out
115 finally:
115 finally:
116 self._ui.fout = oldout
116 self._ui.fout = oldout
117 self._ui.ferr = olderr
117 self._ui.ferr = olderr
118
118
119 def client(self):
119 def client(self):
120 return 'remote:%s:%s:%s' % (
120 return 'remote:%s:%s:%s' % (
121 self._wsgireq.env.get('wsgi.url_scheme') or 'http',
121 self._wsgireq.env.get('wsgi.url_scheme') or 'http',
122 urlreq.quote(self._wsgireq.env.get('REMOTE_HOST', '')),
122 urlreq.quote(self._wsgireq.env.get('REMOTE_HOST', '')),
123 urlreq.quote(self._wsgireq.env.get('REMOTE_USER', '')))
123 urlreq.quote(self._wsgireq.env.get('REMOTE_USER', '')))
124
124
125 def addcapabilities(self, repo, caps):
125 def addcapabilities(self, repo, caps):
126 caps.append('httpheader=%d' %
126 caps.append('httpheader=%d' %
127 repo.ui.configint('server', 'maxhttpheaderlen'))
127 repo.ui.configint('server', 'maxhttpheaderlen'))
128 if repo.ui.configbool('experimental', 'httppostargs'):
128 if repo.ui.configbool('experimental', 'httppostargs'):
129 caps.append('httppostargs')
129 caps.append('httppostargs')
130
130
131 # FUTURE advertise 0.2rx once support is implemented
131 # FUTURE advertise 0.2rx once support is implemented
132 # FUTURE advertise minrx and mintx after consulting config option
132 # FUTURE advertise minrx and mintx after consulting config option
133 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
133 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
134
134
135 compengines = wireproto.supportedcompengines(repo.ui, util.SERVERROLE)
135 compengines = wireproto.supportedcompengines(repo.ui, util.SERVERROLE)
136 if compengines:
136 if compengines:
137 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
137 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
138 for e in compengines)
138 for e in compengines)
139 caps.append('compression=%s' % comptypes)
139 caps.append('compression=%s' % comptypes)
140
140
141 return caps
141 return caps
142
142
143 def checkperm(self, perm):
143 def checkperm(self, perm):
144 return self._checkperm(perm)
144 return self._checkperm(perm)
145
145
146 # This method exists mostly so that extensions like remotefilelog can
146 # This method exists mostly so that extensions like remotefilelog can
147 # disable a kludgey legacy method only over http. As of early 2018,
147 # disable a kludgey legacy method only over http. As of early 2018,
148 # there are no other known users, so with any luck we can discard this
148 # there are no other known users, so with any luck we can discard this
149 # hook if remotefilelog becomes a first-party extension.
149 # hook if remotefilelog becomes a first-party extension.
150 def iscmd(cmd):
150 def iscmd(cmd):
151 return cmd in wireproto.commands
151 return cmd in wireproto.commands
152
152
153 def handlewsgirequest(rctx, wsgireq, req, checkperm):
153 def handlewsgirequest(rctx, wsgireq, req, checkperm):
154 """Possibly process a wire protocol request.
154 """Possibly process a wire protocol request.
155
155
156 If the current request is a wire protocol request, the request is
156 If the current request is a wire protocol request, the request is
157 processed by this function.
157 processed by this function.
158
158
159 ``wsgireq`` is a ``wsgirequest`` instance.
159 ``wsgireq`` is a ``wsgirequest`` instance.
160 ``req`` is a ``parsedrequest`` instance.
160 ``req`` is a ``parsedrequest`` instance.
161
161
162 Returns a 2-tuple of (bool, response) where the 1st element indicates
162 Returns a 2-tuple of (bool, response) where the 1st element indicates
163 whether the request was handled and the 2nd element is a return
163 whether the request was handled and the 2nd element is a return
164 value for a WSGI application (often a generator of bytes).
164 value for a WSGI application (often a generator of bytes).
165 """
165 """
166 # Avoid cycle involving hg module.
166 # Avoid cycle involving hg module.
167 from .hgweb import common as hgwebcommon
167 from .hgweb import common as hgwebcommon
168
168
169 repo = rctx.repo
169 repo = rctx.repo
170
170
171 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
171 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
172 # string parameter. If it isn't present, this isn't a wire protocol
172 # string parameter. If it isn't present, this isn't a wire protocol
173 # request.
173 # request.
174 if 'cmd' not in req.querystringdict:
174 if 'cmd' not in req.querystringdict:
175 return False, None
175 return False, None
176
176
177 cmd = req.querystringdict['cmd'][0]
177 cmd = req.querystringdict['cmd'][0]
178
178
179 # The "cmd" request parameter is used by both the wire protocol and hgweb.
179 # The "cmd" request parameter is used by both the wire protocol and hgweb.
180 # While not all wire protocol commands are available for all transports,
180 # While not all wire protocol commands are available for all transports,
181 # if we see a "cmd" value that resembles a known wire protocol command, we
181 # if we see a "cmd" value that resembles a known wire protocol command, we
182 # route it to a protocol handler. This is better than routing possible
182 # route it to a protocol handler. This is better than routing possible
183 # wire protocol requests to hgweb because it prevents hgweb from using
183 # wire protocol requests to hgweb because it prevents hgweb from using
184 # known wire protocol commands and it is less confusing for machine
184 # known wire protocol commands and it is less confusing for machine
185 # clients.
185 # clients.
186 if not iscmd(cmd):
186 if not iscmd(cmd):
187 return False, None
187 return False, None
188
188
189 # The "cmd" query string argument is only valid on the root path of the
189 # The "cmd" query string argument is only valid on the root path of the
190 # repo. e.g. ``/?cmd=foo``, ``/repo?cmd=foo``. URL paths within the repo
190 # repo. e.g. ``/?cmd=foo``, ``/repo?cmd=foo``. URL paths within the repo
191 # like ``/blah?cmd=foo`` are not allowed. So don't recognize the request
191 # like ``/blah?cmd=foo`` are not allowed. So don't recognize the request
192 # in this case. We send an HTTP 404 for backwards compatibility reasons.
192 # in this case. We send an HTTP 404 for backwards compatibility reasons.
193 if req.dispatchpath:
193 if req.dispatchpath:
194 res = _handlehttperror(
194 res = _handlehttperror(
195 hgwebcommon.ErrorResponse(hgwebcommon.HTTP_NOT_FOUND), wsgireq,
195 hgwebcommon.ErrorResponse(hgwebcommon.HTTP_NOT_FOUND), wsgireq,
196 cmd)
196 req, cmd)
197
197
198 return True, res
198 return True, res
199
199
200 proto = httpv1protocolhandler(wsgireq, repo.ui,
200 proto = httpv1protocolhandler(wsgireq, req, repo.ui,
201 lambda perm: checkperm(rctx, wsgireq, perm))
201 lambda perm: checkperm(rctx, wsgireq, perm))
202
202
203 # The permissions checker should be the only thing that can raise an
203 # The permissions checker should be the only thing that can raise an
204 # ErrorResponse. It is kind of a layer violation to catch an hgweb
204 # ErrorResponse. It is kind of a layer violation to catch an hgweb
205 # exception here. So consider refactoring into a exception type that
205 # exception here. So consider refactoring into a exception type that
206 # is associated with the wire protocol.
206 # is associated with the wire protocol.
207 try:
207 try:
208 res = _callhttp(repo, wsgireq, proto, cmd)
208 res = _callhttp(repo, wsgireq, req, proto, cmd)
209 except hgwebcommon.ErrorResponse as e:
209 except hgwebcommon.ErrorResponse as e:
210 res = _handlehttperror(e, wsgireq, cmd)
210 res = _handlehttperror(e, wsgireq, req, cmd)
211
211
212 return True, res
212 return True, res
213
213
214 def _httpresponsetype(ui, wsgireq, prefer_uncompressed):
214 def _httpresponsetype(ui, req, prefer_uncompressed):
215 """Determine the appropriate response type and compression settings.
215 """Determine the appropriate response type and compression settings.
216
216
217 Returns a tuple of (mediatype, compengine, engineopts).
217 Returns a tuple of (mediatype, compengine, engineopts).
218 """
218 """
219 # Determine the response media type and compression engine based
219 # Determine the response media type and compression engine based
220 # on the request parameters.
220 # on the request parameters.
221 protocaps = decodevaluefromheaders(wsgireq, r'X-HgProto').split(' ')
221 protocaps = decodevaluefromheaders(req, 'X-HgProto').split(' ')
222
222
223 if '0.2' in protocaps:
223 if '0.2' in protocaps:
224 # All clients are expected to support uncompressed data.
224 # All clients are expected to support uncompressed data.
225 if prefer_uncompressed:
225 if prefer_uncompressed:
226 return HGTYPE2, util._noopengine(), {}
226 return HGTYPE2, util._noopengine(), {}
227
227
228 # Default as defined by wire protocol spec.
228 # Default as defined by wire protocol spec.
229 compformats = ['zlib', 'none']
229 compformats = ['zlib', 'none']
230 for cap in protocaps:
230 for cap in protocaps:
231 if cap.startswith('comp='):
231 if cap.startswith('comp='):
232 compformats = cap[5:].split(',')
232 compformats = cap[5:].split(',')
233 break
233 break
234
234
235 # Now find an agreed upon compression format.
235 # Now find an agreed upon compression format.
236 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
236 for engine in wireproto.supportedcompengines(ui, util.SERVERROLE):
237 if engine.wireprotosupport().name in compformats:
237 if engine.wireprotosupport().name in compformats:
238 opts = {}
238 opts = {}
239 level = ui.configint('server', '%slevel' % engine.name())
239 level = ui.configint('server', '%slevel' % engine.name())
240 if level is not None:
240 if level is not None:
241 opts['level'] = level
241 opts['level'] = level
242
242
243 return HGTYPE2, engine, opts
243 return HGTYPE2, engine, opts
244
244
245 # No mutually supported compression format. Fall back to the
245 # No mutually supported compression format. Fall back to the
246 # legacy protocol.
246 # legacy protocol.
247
247
248 # Don't allow untrusted settings because disabling compression or
248 # Don't allow untrusted settings because disabling compression or
249 # setting a very high compression level could lead to flooding
249 # setting a very high compression level could lead to flooding
250 # the server's network or CPU.
250 # the server's network or CPU.
251 opts = {'level': ui.configint('server', 'zliblevel')}
251 opts = {'level': ui.configint('server', 'zliblevel')}
252 return HGTYPE, util.compengines['zlib'], opts
252 return HGTYPE, util.compengines['zlib'], opts
253
253
254 def _callhttp(repo, wsgireq, proto, cmd):
254 def _callhttp(repo, wsgireq, req, proto, cmd):
255 def genversion2(gen, engine, engineopts):
255 def genversion2(gen, engine, engineopts):
256 # application/mercurial-0.2 always sends a payload header
256 # application/mercurial-0.2 always sends a payload header
257 # identifying the compression engine.
257 # identifying the compression engine.
258 name = engine.wireprotosupport().name
258 name = engine.wireprotosupport().name
259 assert 0 < len(name) < 256
259 assert 0 < len(name) < 256
260 yield struct.pack('B', len(name))
260 yield struct.pack('B', len(name))
261 yield name
261 yield name
262
262
263 for chunk in gen:
263 for chunk in gen:
264 yield chunk
264 yield chunk
265
265
266 if not wireproto.commands.commandavailable(cmd, proto):
266 if not wireproto.commands.commandavailable(cmd, proto):
267 wsgireq.respond(HTTP_OK, HGERRTYPE,
267 wsgireq.respond(HTTP_OK, HGERRTYPE,
268 body=_('requested wire protocol command is not '
268 body=_('requested wire protocol command is not '
269 'available over HTTP'))
269 'available over HTTP'))
270 return []
270 return []
271
271
272 proto.checkperm(wireproto.commands[cmd].permission)
272 proto.checkperm(wireproto.commands[cmd].permission)
273
273
274 rsp = wireproto.dispatch(repo, proto, cmd)
274 rsp = wireproto.dispatch(repo, proto, cmd)
275
275
276 if isinstance(rsp, bytes):
276 if isinstance(rsp, bytes):
277 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
277 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
278 return []
278 return []
279 elif isinstance(rsp, wireprototypes.bytesresponse):
279 elif isinstance(rsp, wireprototypes.bytesresponse):
280 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp.data)
280 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp.data)
281 return []
281 return []
282 elif isinstance(rsp, wireprototypes.streamreslegacy):
282 elif isinstance(rsp, wireprototypes.streamreslegacy):
283 gen = rsp.gen
283 gen = rsp.gen
284 wsgireq.respond(HTTP_OK, HGTYPE)
284 wsgireq.respond(HTTP_OK, HGTYPE)
285 return gen
285 return gen
286 elif isinstance(rsp, wireprototypes.streamres):
286 elif isinstance(rsp, wireprototypes.streamres):
287 gen = rsp.gen
287 gen = rsp.gen
288
288
289 # This code for compression should not be streamres specific. It
289 # This code for compression should not be streamres specific. It
290 # is here because we only compress streamres at the moment.
290 # is here because we only compress streamres at the moment.
291 mediatype, engine, engineopts = _httpresponsetype(
291 mediatype, engine, engineopts = _httpresponsetype(
292 repo.ui, wsgireq, rsp.prefer_uncompressed)
292 repo.ui, req, rsp.prefer_uncompressed)
293 gen = engine.compressstream(gen, engineopts)
293 gen = engine.compressstream(gen, engineopts)
294
294
295 if mediatype == HGTYPE2:
295 if mediatype == HGTYPE2:
296 gen = genversion2(gen, engine, engineopts)
296 gen = genversion2(gen, engine, engineopts)
297
297
298 wsgireq.respond(HTTP_OK, mediatype)
298 wsgireq.respond(HTTP_OK, mediatype)
299 return gen
299 return gen
300 elif isinstance(rsp, wireprototypes.pushres):
300 elif isinstance(rsp, wireprototypes.pushres):
301 rsp = '%d\n%s' % (rsp.res, rsp.output)
301 rsp = '%d\n%s' % (rsp.res, rsp.output)
302 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
302 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
303 return []
303 return []
304 elif isinstance(rsp, wireprototypes.pusherr):
304 elif isinstance(rsp, wireprototypes.pusherr):
305 # This is the httplib workaround documented in _handlehttperror().
305 # This is the httplib workaround documented in _handlehttperror().
306 wsgireq.drain()
306 wsgireq.drain()
307
307
308 rsp = '0\n%s\n' % rsp.res
308 rsp = '0\n%s\n' % rsp.res
309 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
309 wsgireq.respond(HTTP_OK, HGTYPE, body=rsp)
310 return []
310 return []
311 elif isinstance(rsp, wireprototypes.ooberror):
311 elif isinstance(rsp, wireprototypes.ooberror):
312 rsp = rsp.message
312 rsp = rsp.message
313 wsgireq.respond(HTTP_OK, HGERRTYPE, body=rsp)
313 wsgireq.respond(HTTP_OK, HGERRTYPE, body=rsp)
314 return []
314 return []
315 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
315 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
316
316
317 def _handlehttperror(e, wsgireq, cmd):
317 def _handlehttperror(e, wsgireq, req, cmd):
318 """Called when an ErrorResponse is raised during HTTP request processing."""
318 """Called when an ErrorResponse is raised during HTTP request processing."""
319
319
320 # Clients using Python's httplib are stateful: the HTTP client
320 # Clients using Python's httplib are stateful: the HTTP client
321 # won't process an HTTP response until all request data is
321 # won't process an HTTP response until all request data is
322 # sent to the server. The intent of this code is to ensure
322 # sent to the server. The intent of this code is to ensure
323 # we always read HTTP request data from the client, thus
323 # we always read HTTP request data from the client, thus
324 # ensuring httplib transitions to a state that allows it to read
324 # ensuring httplib transitions to a state that allows it to read
325 # the HTTP response. In other words, it helps prevent deadlocks
325 # the HTTP response. In other words, it helps prevent deadlocks
326 # on clients using httplib.
326 # on clients using httplib.
327
327
328 if (wsgireq.env[r'REQUEST_METHOD'] == r'POST' and
328 if (wsgireq.env[r'REQUEST_METHOD'] == r'POST' and
329 # But not if Expect: 100-continue is being used.
329 # But not if Expect: 100-continue is being used.
330 (wsgireq.env.get('HTTP_EXPECT',
330 (req.headers.get('Expect', '').lower() != '100-continue')):
331 '').lower() != '100-continue')):
332 wsgireq.drain()
331 wsgireq.drain()
333 else:
332 else:
334 wsgireq.headers.append((r'Connection', r'Close'))
333 wsgireq.headers.append((r'Connection', r'Close'))
335
334
336 # TODO This response body assumes the failed command was
335 # TODO This response body assumes the failed command was
337 # "unbundle." That assumption is not always valid.
336 # "unbundle." That assumption is not always valid.
338 wsgireq.respond(e, HGTYPE, body='0\n%s\n' % pycompat.bytestr(e))
337 wsgireq.respond(e, HGTYPE, body='0\n%s\n' % pycompat.bytestr(e))
339
338
340 return ''
339 return ''
341
340
342 def _sshv1respondbytes(fout, value):
341 def _sshv1respondbytes(fout, value):
343 """Send a bytes response for protocol version 1."""
342 """Send a bytes response for protocol version 1."""
344 fout.write('%d\n' % len(value))
343 fout.write('%d\n' % len(value))
345 fout.write(value)
344 fout.write(value)
346 fout.flush()
345 fout.flush()
347
346
348 def _sshv1respondstream(fout, source):
347 def _sshv1respondstream(fout, source):
349 write = fout.write
348 write = fout.write
350 for chunk in source.gen:
349 for chunk in source.gen:
351 write(chunk)
350 write(chunk)
352 fout.flush()
351 fout.flush()
353
352
354 def _sshv1respondooberror(fout, ferr, rsp):
353 def _sshv1respondooberror(fout, ferr, rsp):
355 ferr.write(b'%s\n-\n' % rsp)
354 ferr.write(b'%s\n-\n' % rsp)
356 ferr.flush()
355 ferr.flush()
357 fout.write(b'\n')
356 fout.write(b'\n')
358 fout.flush()
357 fout.flush()
359
358
360 class sshv1protocolhandler(wireprototypes.baseprotocolhandler):
359 class sshv1protocolhandler(wireprototypes.baseprotocolhandler):
361 """Handler for requests services via version 1 of SSH protocol."""
360 """Handler for requests services via version 1 of SSH protocol."""
362 def __init__(self, ui, fin, fout):
361 def __init__(self, ui, fin, fout):
363 self._ui = ui
362 self._ui = ui
364 self._fin = fin
363 self._fin = fin
365 self._fout = fout
364 self._fout = fout
366
365
367 @property
366 @property
368 def name(self):
367 def name(self):
369 return wireprototypes.SSHV1
368 return wireprototypes.SSHV1
370
369
371 def getargs(self, args):
370 def getargs(self, args):
372 data = {}
371 data = {}
373 keys = args.split()
372 keys = args.split()
374 for n in xrange(len(keys)):
373 for n in xrange(len(keys)):
375 argline = self._fin.readline()[:-1]
374 argline = self._fin.readline()[:-1]
376 arg, l = argline.split()
375 arg, l = argline.split()
377 if arg not in keys:
376 if arg not in keys:
378 raise error.Abort(_("unexpected parameter %r") % arg)
377 raise error.Abort(_("unexpected parameter %r") % arg)
379 if arg == '*':
378 if arg == '*':
380 star = {}
379 star = {}
381 for k in xrange(int(l)):
380 for k in xrange(int(l)):
382 argline = self._fin.readline()[:-1]
381 argline = self._fin.readline()[:-1]
383 arg, l = argline.split()
382 arg, l = argline.split()
384 val = self._fin.read(int(l))
383 val = self._fin.read(int(l))
385 star[arg] = val
384 star[arg] = val
386 data['*'] = star
385 data['*'] = star
387 else:
386 else:
388 val = self._fin.read(int(l))
387 val = self._fin.read(int(l))
389 data[arg] = val
388 data[arg] = val
390 return [data[k] for k in keys]
389 return [data[k] for k in keys]
391
390
392 def forwardpayload(self, fpout):
391 def forwardpayload(self, fpout):
393 # We initially send an empty response. This tells the client it is
392 # We initially send an empty response. This tells the client it is
394 # OK to start sending data. If a client sees any other response, it
393 # OK to start sending data. If a client sees any other response, it
395 # interprets it as an error.
394 # interprets it as an error.
396 _sshv1respondbytes(self._fout, b'')
395 _sshv1respondbytes(self._fout, b'')
397
396
398 # The file is in the form:
397 # The file is in the form:
399 #
398 #
400 # <chunk size>\n<chunk>
399 # <chunk size>\n<chunk>
401 # ...
400 # ...
402 # 0\n
401 # 0\n
403 count = int(self._fin.readline())
402 count = int(self._fin.readline())
404 while count:
403 while count:
405 fpout.write(self._fin.read(count))
404 fpout.write(self._fin.read(count))
406 count = int(self._fin.readline())
405 count = int(self._fin.readline())
407
406
408 @contextlib.contextmanager
407 @contextlib.contextmanager
409 def mayberedirectstdio(self):
408 def mayberedirectstdio(self):
410 yield None
409 yield None
411
410
412 def client(self):
411 def client(self):
413 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
412 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
414 return 'remote:ssh:' + client
413 return 'remote:ssh:' + client
415
414
416 def addcapabilities(self, repo, caps):
415 def addcapabilities(self, repo, caps):
417 return caps
416 return caps
418
417
419 def checkperm(self, perm):
418 def checkperm(self, perm):
420 pass
419 pass
421
420
422 class sshv2protocolhandler(sshv1protocolhandler):
421 class sshv2protocolhandler(sshv1protocolhandler):
423 """Protocol handler for version 2 of the SSH protocol."""
422 """Protocol handler for version 2 of the SSH protocol."""
424
423
425 @property
424 @property
426 def name(self):
425 def name(self):
427 return wireprototypes.SSHV2
426 return wireprototypes.SSHV2
428
427
429 def _runsshserver(ui, repo, fin, fout, ev):
428 def _runsshserver(ui, repo, fin, fout, ev):
430 # This function operates like a state machine of sorts. The following
429 # This function operates like a state machine of sorts. The following
431 # states are defined:
430 # states are defined:
432 #
431 #
433 # protov1-serving
432 # protov1-serving
434 # Server is in protocol version 1 serving mode. Commands arrive on
433 # Server is in protocol version 1 serving mode. Commands arrive on
435 # new lines. These commands are processed in this state, one command
434 # new lines. These commands are processed in this state, one command
436 # after the other.
435 # after the other.
437 #
436 #
438 # protov2-serving
437 # protov2-serving
439 # Server is in protocol version 2 serving mode.
438 # Server is in protocol version 2 serving mode.
440 #
439 #
441 # upgrade-initial
440 # upgrade-initial
442 # The server is going to process an upgrade request.
441 # The server is going to process an upgrade request.
443 #
442 #
444 # upgrade-v2-filter-legacy-handshake
443 # upgrade-v2-filter-legacy-handshake
445 # The protocol is being upgraded to version 2. The server is expecting
444 # The protocol is being upgraded to version 2. The server is expecting
446 # the legacy handshake from version 1.
445 # the legacy handshake from version 1.
447 #
446 #
448 # upgrade-v2-finish
447 # upgrade-v2-finish
449 # The upgrade to version 2 of the protocol is imminent.
448 # The upgrade to version 2 of the protocol is imminent.
450 #
449 #
451 # shutdown
450 # shutdown
452 # The server is shutting down, possibly in reaction to a client event.
451 # The server is shutting down, possibly in reaction to a client event.
453 #
452 #
454 # And here are their transitions:
453 # And here are their transitions:
455 #
454 #
456 # protov1-serving -> shutdown
455 # protov1-serving -> shutdown
457 # When server receives an empty request or encounters another
456 # When server receives an empty request or encounters another
458 # error.
457 # error.
459 #
458 #
460 # protov1-serving -> upgrade-initial
459 # protov1-serving -> upgrade-initial
461 # An upgrade request line was seen.
460 # An upgrade request line was seen.
462 #
461 #
463 # upgrade-initial -> upgrade-v2-filter-legacy-handshake
462 # upgrade-initial -> upgrade-v2-filter-legacy-handshake
464 # Upgrade to version 2 in progress. Server is expecting to
463 # Upgrade to version 2 in progress. Server is expecting to
465 # process a legacy handshake.
464 # process a legacy handshake.
466 #
465 #
467 # upgrade-v2-filter-legacy-handshake -> shutdown
466 # upgrade-v2-filter-legacy-handshake -> shutdown
468 # Client did not fulfill upgrade handshake requirements.
467 # Client did not fulfill upgrade handshake requirements.
469 #
468 #
470 # upgrade-v2-filter-legacy-handshake -> upgrade-v2-finish
469 # upgrade-v2-filter-legacy-handshake -> upgrade-v2-finish
471 # Client fulfilled version 2 upgrade requirements. Finishing that
470 # Client fulfilled version 2 upgrade requirements. Finishing that
472 # upgrade.
471 # upgrade.
473 #
472 #
474 # upgrade-v2-finish -> protov2-serving
473 # upgrade-v2-finish -> protov2-serving
475 # Protocol upgrade to version 2 complete. Server can now speak protocol
474 # Protocol upgrade to version 2 complete. Server can now speak protocol
476 # version 2.
475 # version 2.
477 #
476 #
478 # protov2-serving -> protov1-serving
477 # protov2-serving -> protov1-serving
479 # Ths happens by default since protocol version 2 is the same as
478 # Ths happens by default since protocol version 2 is the same as
480 # version 1 except for the handshake.
479 # version 1 except for the handshake.
481
480
482 state = 'protov1-serving'
481 state = 'protov1-serving'
483 proto = sshv1protocolhandler(ui, fin, fout)
482 proto = sshv1protocolhandler(ui, fin, fout)
484 protoswitched = False
483 protoswitched = False
485
484
486 while not ev.is_set():
485 while not ev.is_set():
487 if state == 'protov1-serving':
486 if state == 'protov1-serving':
488 # Commands are issued on new lines.
487 # Commands are issued on new lines.
489 request = fin.readline()[:-1]
488 request = fin.readline()[:-1]
490
489
491 # Empty lines signal to terminate the connection.
490 # Empty lines signal to terminate the connection.
492 if not request:
491 if not request:
493 state = 'shutdown'
492 state = 'shutdown'
494 continue
493 continue
495
494
496 # It looks like a protocol upgrade request. Transition state to
495 # It looks like a protocol upgrade request. Transition state to
497 # handle it.
496 # handle it.
498 if request.startswith(b'upgrade '):
497 if request.startswith(b'upgrade '):
499 if protoswitched:
498 if protoswitched:
500 _sshv1respondooberror(fout, ui.ferr,
499 _sshv1respondooberror(fout, ui.ferr,
501 b'cannot upgrade protocols multiple '
500 b'cannot upgrade protocols multiple '
502 b'times')
501 b'times')
503 state = 'shutdown'
502 state = 'shutdown'
504 continue
503 continue
505
504
506 state = 'upgrade-initial'
505 state = 'upgrade-initial'
507 continue
506 continue
508
507
509 available = wireproto.commands.commandavailable(request, proto)
508 available = wireproto.commands.commandavailable(request, proto)
510
509
511 # This command isn't available. Send an empty response and go
510 # This command isn't available. Send an empty response and go
512 # back to waiting for a new command.
511 # back to waiting for a new command.
513 if not available:
512 if not available:
514 _sshv1respondbytes(fout, b'')
513 _sshv1respondbytes(fout, b'')
515 continue
514 continue
516
515
517 rsp = wireproto.dispatch(repo, proto, request)
516 rsp = wireproto.dispatch(repo, proto, request)
518
517
519 if isinstance(rsp, bytes):
518 if isinstance(rsp, bytes):
520 _sshv1respondbytes(fout, rsp)
519 _sshv1respondbytes(fout, rsp)
521 elif isinstance(rsp, wireprototypes.bytesresponse):
520 elif isinstance(rsp, wireprototypes.bytesresponse):
522 _sshv1respondbytes(fout, rsp.data)
521 _sshv1respondbytes(fout, rsp.data)
523 elif isinstance(rsp, wireprototypes.streamres):
522 elif isinstance(rsp, wireprototypes.streamres):
524 _sshv1respondstream(fout, rsp)
523 _sshv1respondstream(fout, rsp)
525 elif isinstance(rsp, wireprototypes.streamreslegacy):
524 elif isinstance(rsp, wireprototypes.streamreslegacy):
526 _sshv1respondstream(fout, rsp)
525 _sshv1respondstream(fout, rsp)
527 elif isinstance(rsp, wireprototypes.pushres):
526 elif isinstance(rsp, wireprototypes.pushres):
528 _sshv1respondbytes(fout, b'')
527 _sshv1respondbytes(fout, b'')
529 _sshv1respondbytes(fout, b'%d' % rsp.res)
528 _sshv1respondbytes(fout, b'%d' % rsp.res)
530 elif isinstance(rsp, wireprototypes.pusherr):
529 elif isinstance(rsp, wireprototypes.pusherr):
531 _sshv1respondbytes(fout, rsp.res)
530 _sshv1respondbytes(fout, rsp.res)
532 elif isinstance(rsp, wireprototypes.ooberror):
531 elif isinstance(rsp, wireprototypes.ooberror):
533 _sshv1respondooberror(fout, ui.ferr, rsp.message)
532 _sshv1respondooberror(fout, ui.ferr, rsp.message)
534 else:
533 else:
535 raise error.ProgrammingError('unhandled response type from '
534 raise error.ProgrammingError('unhandled response type from '
536 'wire protocol command: %s' % rsp)
535 'wire protocol command: %s' % rsp)
537
536
538 # For now, protocol version 2 serving just goes back to version 1.
537 # For now, protocol version 2 serving just goes back to version 1.
539 elif state == 'protov2-serving':
538 elif state == 'protov2-serving':
540 state = 'protov1-serving'
539 state = 'protov1-serving'
541 continue
540 continue
542
541
543 elif state == 'upgrade-initial':
542 elif state == 'upgrade-initial':
544 # We should never transition into this state if we've switched
543 # We should never transition into this state if we've switched
545 # protocols.
544 # protocols.
546 assert not protoswitched
545 assert not protoswitched
547 assert proto.name == wireprototypes.SSHV1
546 assert proto.name == wireprototypes.SSHV1
548
547
549 # Expected: upgrade <token> <capabilities>
548 # Expected: upgrade <token> <capabilities>
550 # If we get something else, the request is malformed. It could be
549 # If we get something else, the request is malformed. It could be
551 # from a future client that has altered the upgrade line content.
550 # from a future client that has altered the upgrade line content.
552 # We treat this as an unknown command.
551 # We treat this as an unknown command.
553 try:
552 try:
554 token, caps = request.split(b' ')[1:]
553 token, caps = request.split(b' ')[1:]
555 except ValueError:
554 except ValueError:
556 _sshv1respondbytes(fout, b'')
555 _sshv1respondbytes(fout, b'')
557 state = 'protov1-serving'
556 state = 'protov1-serving'
558 continue
557 continue
559
558
560 # Send empty response if we don't support upgrading protocols.
559 # Send empty response if we don't support upgrading protocols.
561 if not ui.configbool('experimental', 'sshserver.support-v2'):
560 if not ui.configbool('experimental', 'sshserver.support-v2'):
562 _sshv1respondbytes(fout, b'')
561 _sshv1respondbytes(fout, b'')
563 state = 'protov1-serving'
562 state = 'protov1-serving'
564 continue
563 continue
565
564
566 try:
565 try:
567 caps = urlreq.parseqs(caps)
566 caps = urlreq.parseqs(caps)
568 except ValueError:
567 except ValueError:
569 _sshv1respondbytes(fout, b'')
568 _sshv1respondbytes(fout, b'')
570 state = 'protov1-serving'
569 state = 'protov1-serving'
571 continue
570 continue
572
571
573 # We don't see an upgrade request to protocol version 2. Ignore
572 # We don't see an upgrade request to protocol version 2. Ignore
574 # the upgrade request.
573 # the upgrade request.
575 wantedprotos = caps.get(b'proto', [b''])[0]
574 wantedprotos = caps.get(b'proto', [b''])[0]
576 if SSHV2 not in wantedprotos:
575 if SSHV2 not in wantedprotos:
577 _sshv1respondbytes(fout, b'')
576 _sshv1respondbytes(fout, b'')
578 state = 'protov1-serving'
577 state = 'protov1-serving'
579 continue
578 continue
580
579
581 # It looks like we can honor this upgrade request to protocol 2.
580 # It looks like we can honor this upgrade request to protocol 2.
582 # Filter the rest of the handshake protocol request lines.
581 # Filter the rest of the handshake protocol request lines.
583 state = 'upgrade-v2-filter-legacy-handshake'
582 state = 'upgrade-v2-filter-legacy-handshake'
584 continue
583 continue
585
584
586 elif state == 'upgrade-v2-filter-legacy-handshake':
585 elif state == 'upgrade-v2-filter-legacy-handshake':
587 # Client should have sent legacy handshake after an ``upgrade``
586 # Client should have sent legacy handshake after an ``upgrade``
588 # request. Expected lines:
587 # request. Expected lines:
589 #
588 #
590 # hello
589 # hello
591 # between
590 # between
592 # pairs 81
591 # pairs 81
593 # 0000...-0000...
592 # 0000...-0000...
594
593
595 ok = True
594 ok = True
596 for line in (b'hello', b'between', b'pairs 81'):
595 for line in (b'hello', b'between', b'pairs 81'):
597 request = fin.readline()[:-1]
596 request = fin.readline()[:-1]
598
597
599 if request != line:
598 if request != line:
600 _sshv1respondooberror(fout, ui.ferr,
599 _sshv1respondooberror(fout, ui.ferr,
601 b'malformed handshake protocol: '
600 b'malformed handshake protocol: '
602 b'missing %s' % line)
601 b'missing %s' % line)
603 ok = False
602 ok = False
604 state = 'shutdown'
603 state = 'shutdown'
605 break
604 break
606
605
607 if not ok:
606 if not ok:
608 continue
607 continue
609
608
610 request = fin.read(81)
609 request = fin.read(81)
611 if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
610 if request != b'%s-%s' % (b'0' * 40, b'0' * 40):
612 _sshv1respondooberror(fout, ui.ferr,
611 _sshv1respondooberror(fout, ui.ferr,
613 b'malformed handshake protocol: '
612 b'malformed handshake protocol: '
614 b'missing between argument value')
613 b'missing between argument value')
615 state = 'shutdown'
614 state = 'shutdown'
616 continue
615 continue
617
616
618 state = 'upgrade-v2-finish'
617 state = 'upgrade-v2-finish'
619 continue
618 continue
620
619
621 elif state == 'upgrade-v2-finish':
620 elif state == 'upgrade-v2-finish':
622 # Send the upgrade response.
621 # Send the upgrade response.
623 fout.write(b'upgraded %s %s\n' % (token, SSHV2))
622 fout.write(b'upgraded %s %s\n' % (token, SSHV2))
624 servercaps = wireproto.capabilities(repo, proto)
623 servercaps = wireproto.capabilities(repo, proto)
625 rsp = b'capabilities: %s' % servercaps.data
624 rsp = b'capabilities: %s' % servercaps.data
626 fout.write(b'%d\n%s\n' % (len(rsp), rsp))
625 fout.write(b'%d\n%s\n' % (len(rsp), rsp))
627 fout.flush()
626 fout.flush()
628
627
629 proto = sshv2protocolhandler(ui, fin, fout)
628 proto = sshv2protocolhandler(ui, fin, fout)
630 protoswitched = True
629 protoswitched = True
631
630
632 state = 'protov2-serving'
631 state = 'protov2-serving'
633 continue
632 continue
634
633
635 elif state == 'shutdown':
634 elif state == 'shutdown':
636 break
635 break
637
636
638 else:
637 else:
639 raise error.ProgrammingError('unhandled ssh server state: %s' %
638 raise error.ProgrammingError('unhandled ssh server state: %s' %
640 state)
639 state)
641
640
642 class sshserver(object):
641 class sshserver(object):
643 def __init__(self, ui, repo, logfh=None):
642 def __init__(self, ui, repo, logfh=None):
644 self._ui = ui
643 self._ui = ui
645 self._repo = repo
644 self._repo = repo
646 self._fin = ui.fin
645 self._fin = ui.fin
647 self._fout = ui.fout
646 self._fout = ui.fout
648
647
649 # Log write I/O to stdout and stderr if configured.
648 # Log write I/O to stdout and stderr if configured.
650 if logfh:
649 if logfh:
651 self._fout = util.makeloggingfileobject(
650 self._fout = util.makeloggingfileobject(
652 logfh, self._fout, 'o', logdata=True)
651 logfh, self._fout, 'o', logdata=True)
653 ui.ferr = util.makeloggingfileobject(
652 ui.ferr = util.makeloggingfileobject(
654 logfh, ui.ferr, 'e', logdata=True)
653 logfh, ui.ferr, 'e', logdata=True)
655
654
656 hook.redirect(True)
655 hook.redirect(True)
657 ui.fout = repo.ui.fout = ui.ferr
656 ui.fout = repo.ui.fout = ui.ferr
658
657
659 # Prevent insertion/deletion of CRs
658 # Prevent insertion/deletion of CRs
660 util.setbinary(self._fin)
659 util.setbinary(self._fin)
661 util.setbinary(self._fout)
660 util.setbinary(self._fout)
662
661
663 def serve_forever(self):
662 def serve_forever(self):
664 self.serveuntil(threading.Event())
663 self.serveuntil(threading.Event())
665 sys.exit(0)
664 sys.exit(0)
666
665
667 def serveuntil(self, ev):
666 def serveuntil(self, ev):
668 """Serve until a threading.Event is set."""
667 """Serve until a threading.Event is set."""
669 _runsshserver(self._ui, self._repo, self._fin, self._fout, ev)
668 _runsshserver(self._ui, self._repo, self._fin, self._fout, ev)
General Comments 0
You need to be logged in to leave comments. Login now