##// END OF EJS Templates
wireprotoserver: make abstractserverproto a proper abstract base class...
Gregory Szorc -
r35890:68dc621f default
parent child Browse files
Show More
@@ -1,344 +1,347
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 #
3 #
4 # This software may be used and distributed according to the terms of the
4 # This software may be used and distributed according to the terms of the
5 # GNU General Public License version 2 or any later version.
5 # GNU General Public License version 2 or any later version.
6
6
7 from __future__ import absolute_import
7 from __future__ import absolute_import
8
8
9 import abc
9 import cgi
10 import cgi
10 import struct
11 import struct
11 import sys
12 import sys
12
13
13 from .i18n import _
14 from .i18n import _
14 from . import (
15 from . import (
15 encoding,
16 encoding,
16 error,
17 error,
17 hook,
18 hook,
18 pycompat,
19 pycompat,
19 util,
20 util,
20 wireproto,
21 wireproto,
21 )
22 )
22
23
23 stringio = util.stringio
24 stringio = util.stringio
24
25
25 urlerr = util.urlerr
26 urlerr = util.urlerr
26 urlreq = util.urlreq
27 urlreq = util.urlreq
27
28
28 HTTP_OK = 200
29 HTTP_OK = 200
29
30
30 HGTYPE = 'application/mercurial-0.1'
31 HGTYPE = 'application/mercurial-0.1'
31 HGTYPE2 = 'application/mercurial-0.2'
32 HGTYPE2 = 'application/mercurial-0.2'
32 HGERRTYPE = 'application/hg-error'
33 HGERRTYPE = 'application/hg-error'
33
34
34 class abstractserverproto(object):
35 class abstractserverproto(object):
35 """abstract class that summarizes the protocol API
36 """abstract class that summarizes the protocol API
36
37
37 Used as reference and documentation.
38 Used as reference and documentation.
38 """
39 """
39
40
41 __metaclass__ = abc.ABCMeta
42
43 @abc.abstractmethod
40 def getargs(self, args):
44 def getargs(self, args):
41 """return the value for arguments in <args>
45 """return the value for arguments in <args>
42
46
43 returns a list of values (same order as <args>)"""
47 returns a list of values (same order as <args>)"""
44 raise NotImplementedError()
45
48
49 @abc.abstractmethod
46 def getfile(self, fp):
50 def getfile(self, fp):
47 """write the whole content of a file into a file like object
51 """write the whole content of a file into a file like object
48
52
49 The file is in the form::
53 The file is in the form::
50
54
51 (<chunk-size>\n<chunk>)+0\n
55 (<chunk-size>\n<chunk>)+0\n
52
56
53 chunk size is the ascii version of the int.
57 chunk size is the ascii version of the int.
54 """
58 """
55 raise NotImplementedError()
56
59
60 @abc.abstractmethod
57 def redirect(self):
61 def redirect(self):
58 """may setup interception for stdout and stderr
62 """may setup interception for stdout and stderr
59
63
60 See also the `restore` method."""
64 See also the `restore` method."""
61 raise NotImplementedError()
62
65
63 # If the `redirect` function does install interception, the `restore`
66 # If the `redirect` function does install interception, the `restore`
64 # function MUST be defined. If interception is not used, this function
67 # function MUST be defined. If interception is not used, this function
65 # MUST NOT be defined.
68 # MUST NOT be defined.
66 #
69 #
67 # left commented here on purpose
70 # left commented here on purpose
68 #
71 #
69 #def restore(self):
72 #def restore(self):
70 # """reinstall previous stdout and stderr and return intercepted stdout
73 # """reinstall previous stdout and stderr and return intercepted stdout
71 # """
74 # """
72 # raise NotImplementedError()
75 # raise NotImplementedError()
73
76
74 def decodevaluefromheaders(req, headerprefix):
77 def decodevaluefromheaders(req, headerprefix):
75 """Decode a long value from multiple HTTP request headers.
78 """Decode a long value from multiple HTTP request headers.
76
79
77 Returns the value as a bytes, not a str.
80 Returns the value as a bytes, not a str.
78 """
81 """
79 chunks = []
82 chunks = []
80 i = 1
83 i = 1
81 prefix = headerprefix.upper().replace(r'-', r'_')
84 prefix = headerprefix.upper().replace(r'-', r'_')
82 while True:
85 while True:
83 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
86 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
84 if v is None:
87 if v is None:
85 break
88 break
86 chunks.append(pycompat.bytesurl(v))
89 chunks.append(pycompat.bytesurl(v))
87 i += 1
90 i += 1
88
91
89 return ''.join(chunks)
92 return ''.join(chunks)
90
93
91 class webproto(abstractserverproto):
94 class webproto(abstractserverproto):
92 def __init__(self, req, ui):
95 def __init__(self, req, ui):
93 self._req = req
96 self._req = req
94 self._ui = ui
97 self._ui = ui
95 self.name = 'http'
98 self.name = 'http'
96
99
97 def getargs(self, args):
100 def getargs(self, args):
98 knownargs = self._args()
101 knownargs = self._args()
99 data = {}
102 data = {}
100 keys = args.split()
103 keys = args.split()
101 for k in keys:
104 for k in keys:
102 if k == '*':
105 if k == '*':
103 star = {}
106 star = {}
104 for key in knownargs.keys():
107 for key in knownargs.keys():
105 if key != 'cmd' and key not in keys:
108 if key != 'cmd' and key not in keys:
106 star[key] = knownargs[key][0]
109 star[key] = knownargs[key][0]
107 data['*'] = star
110 data['*'] = star
108 else:
111 else:
109 data[k] = knownargs[k][0]
112 data[k] = knownargs[k][0]
110 return [data[k] for k in keys]
113 return [data[k] for k in keys]
111
114
112 def _args(self):
115 def _args(self):
113 args = self._req.form.copy()
116 args = self._req.form.copy()
114 if pycompat.ispy3:
117 if pycompat.ispy3:
115 args = {k.encode('ascii'): [v.encode('ascii') for v in vs]
118 args = {k.encode('ascii'): [v.encode('ascii') for v in vs]
116 for k, vs in args.items()}
119 for k, vs in args.items()}
117 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
120 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
118 if postlen:
121 if postlen:
119 args.update(cgi.parse_qs(
122 args.update(cgi.parse_qs(
120 self._req.read(postlen), keep_blank_values=True))
123 self._req.read(postlen), keep_blank_values=True))
121 return args
124 return args
122
125
123 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
126 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
124 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
127 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
125 return args
128 return args
126
129
127 def getfile(self, fp):
130 def getfile(self, fp):
128 length = int(self._req.env[r'CONTENT_LENGTH'])
131 length = int(self._req.env[r'CONTENT_LENGTH'])
129 # If httppostargs is used, we need to read Content-Length
132 # If httppostargs is used, we need to read Content-Length
130 # minus the amount that was consumed by args.
133 # minus the amount that was consumed by args.
131 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
134 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
132 for s in util.filechunkiter(self._req, limit=length):
135 for s in util.filechunkiter(self._req, limit=length):
133 fp.write(s)
136 fp.write(s)
134
137
135 def redirect(self):
138 def redirect(self):
136 self._oldio = self._ui.fout, self._ui.ferr
139 self._oldio = self._ui.fout, self._ui.ferr
137 self._ui.ferr = self._ui.fout = stringio()
140 self._ui.ferr = self._ui.fout = stringio()
138
141
139 def restore(self):
142 def restore(self):
140 val = self._ui.fout.getvalue()
143 val = self._ui.fout.getvalue()
141 self._ui.ferr, self._ui.fout = self._oldio
144 self._ui.ferr, self._ui.fout = self._oldio
142 return val
145 return val
143
146
144 def _client(self):
147 def _client(self):
145 return 'remote:%s:%s:%s' % (
148 return 'remote:%s:%s:%s' % (
146 self._req.env.get('wsgi.url_scheme') or 'http',
149 self._req.env.get('wsgi.url_scheme') or 'http',
147 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
150 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
148 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
151 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
149
152
150 def responsetype(self, prefer_uncompressed):
153 def responsetype(self, prefer_uncompressed):
151 """Determine the appropriate response type and compression settings.
154 """Determine the appropriate response type and compression settings.
152
155
153 Returns a tuple of (mediatype, compengine, engineopts).
156 Returns a tuple of (mediatype, compengine, engineopts).
154 """
157 """
155 # Determine the response media type and compression engine based
158 # Determine the response media type and compression engine based
156 # on the request parameters.
159 # on the request parameters.
157 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
160 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
158
161
159 if '0.2' in protocaps:
162 if '0.2' in protocaps:
160 # All clients are expected to support uncompressed data.
163 # All clients are expected to support uncompressed data.
161 if prefer_uncompressed:
164 if prefer_uncompressed:
162 return HGTYPE2, util._noopengine(), {}
165 return HGTYPE2, util._noopengine(), {}
163
166
164 # Default as defined by wire protocol spec.
167 # Default as defined by wire protocol spec.
165 compformats = ['zlib', 'none']
168 compformats = ['zlib', 'none']
166 for cap in protocaps:
169 for cap in protocaps:
167 if cap.startswith('comp='):
170 if cap.startswith('comp='):
168 compformats = cap[5:].split(',')
171 compformats = cap[5:].split(',')
169 break
172 break
170
173
171 # Now find an agreed upon compression format.
174 # Now find an agreed upon compression format.
172 for engine in wireproto.supportedcompengines(self._ui, self,
175 for engine in wireproto.supportedcompengines(self._ui, self,
173 util.SERVERROLE):
176 util.SERVERROLE):
174 if engine.wireprotosupport().name in compformats:
177 if engine.wireprotosupport().name in compformats:
175 opts = {}
178 opts = {}
176 level = self._ui.configint('server',
179 level = self._ui.configint('server',
177 '%slevel' % engine.name())
180 '%slevel' % engine.name())
178 if level is not None:
181 if level is not None:
179 opts['level'] = level
182 opts['level'] = level
180
183
181 return HGTYPE2, engine, opts
184 return HGTYPE2, engine, opts
182
185
183 # No mutually supported compression format. Fall back to the
186 # No mutually supported compression format. Fall back to the
184 # legacy protocol.
187 # legacy protocol.
185
188
186 # Don't allow untrusted settings because disabling compression or
189 # Don't allow untrusted settings because disabling compression or
187 # setting a very high compression level could lead to flooding
190 # setting a very high compression level could lead to flooding
188 # the server's network or CPU.
191 # the server's network or CPU.
189 opts = {'level': self._ui.configint('server', 'zliblevel')}
192 opts = {'level': self._ui.configint('server', 'zliblevel')}
190 return HGTYPE, util.compengines['zlib'], opts
193 return HGTYPE, util.compengines['zlib'], opts
191
194
192 def iscmd(cmd):
195 def iscmd(cmd):
193 return cmd in wireproto.commands
196 return cmd in wireproto.commands
194
197
195 def callhttp(repo, req, cmd):
198 def callhttp(repo, req, cmd):
196 proto = webproto(req, repo.ui)
199 proto = webproto(req, repo.ui)
197
200
198 def genversion2(gen, engine, engineopts):
201 def genversion2(gen, engine, engineopts):
199 # application/mercurial-0.2 always sends a payload header
202 # application/mercurial-0.2 always sends a payload header
200 # identifying the compression engine.
203 # identifying the compression engine.
201 name = engine.wireprotosupport().name
204 name = engine.wireprotosupport().name
202 assert 0 < len(name) < 256
205 assert 0 < len(name) < 256
203 yield struct.pack('B', len(name))
206 yield struct.pack('B', len(name))
204 yield name
207 yield name
205
208
206 for chunk in gen:
209 for chunk in gen:
207 yield chunk
210 yield chunk
208
211
209 rsp = wireproto.dispatch(repo, proto, cmd)
212 rsp = wireproto.dispatch(repo, proto, cmd)
210 if isinstance(rsp, bytes):
213 if isinstance(rsp, bytes):
211 req.respond(HTTP_OK, HGTYPE, body=rsp)
214 req.respond(HTTP_OK, HGTYPE, body=rsp)
212 return []
215 return []
213 elif isinstance(rsp, wireproto.streamres_legacy):
216 elif isinstance(rsp, wireproto.streamres_legacy):
214 gen = rsp.gen
217 gen = rsp.gen
215 req.respond(HTTP_OK, HGTYPE)
218 req.respond(HTTP_OK, HGTYPE)
216 return gen
219 return gen
217 elif isinstance(rsp, wireproto.streamres):
220 elif isinstance(rsp, wireproto.streamres):
218 gen = rsp.gen
221 gen = rsp.gen
219
222
220 # This code for compression should not be streamres specific. It
223 # This code for compression should not be streamres specific. It
221 # is here because we only compress streamres at the moment.
224 # is here because we only compress streamres at the moment.
222 mediatype, engine, engineopts = proto.responsetype(
225 mediatype, engine, engineopts = proto.responsetype(
223 rsp.prefer_uncompressed)
226 rsp.prefer_uncompressed)
224 gen = engine.compressstream(gen, engineopts)
227 gen = engine.compressstream(gen, engineopts)
225
228
226 if mediatype == HGTYPE2:
229 if mediatype == HGTYPE2:
227 gen = genversion2(gen, engine, engineopts)
230 gen = genversion2(gen, engine, engineopts)
228
231
229 req.respond(HTTP_OK, mediatype)
232 req.respond(HTTP_OK, mediatype)
230 return gen
233 return gen
231 elif isinstance(rsp, wireproto.pushres):
234 elif isinstance(rsp, wireproto.pushres):
232 val = proto.restore()
235 val = proto.restore()
233 rsp = '%d\n%s' % (rsp.res, val)
236 rsp = '%d\n%s' % (rsp.res, val)
234 req.respond(HTTP_OK, HGTYPE, body=rsp)
237 req.respond(HTTP_OK, HGTYPE, body=rsp)
235 return []
238 return []
236 elif isinstance(rsp, wireproto.pusherr):
239 elif isinstance(rsp, wireproto.pusherr):
237 # drain the incoming bundle
240 # drain the incoming bundle
238 req.drain()
241 req.drain()
239 proto.restore()
242 proto.restore()
240 rsp = '0\n%s\n' % rsp.res
243 rsp = '0\n%s\n' % rsp.res
241 req.respond(HTTP_OK, HGTYPE, body=rsp)
244 req.respond(HTTP_OK, HGTYPE, body=rsp)
242 return []
245 return []
243 elif isinstance(rsp, wireproto.ooberror):
246 elif isinstance(rsp, wireproto.ooberror):
244 rsp = rsp.message
247 rsp = rsp.message
245 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
248 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
246 return []
249 return []
247 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
250 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
248
251
249 class sshserver(abstractserverproto):
252 class sshserver(abstractserverproto):
250 def __init__(self, ui, repo):
253 def __init__(self, ui, repo):
251 self._ui = ui
254 self._ui = ui
252 self._repo = repo
255 self._repo = repo
253 self._fin = ui.fin
256 self._fin = ui.fin
254 self._fout = ui.fout
257 self._fout = ui.fout
255 self.name = 'ssh'
258 self.name = 'ssh'
256
259
257 hook.redirect(True)
260 hook.redirect(True)
258 ui.fout = repo.ui.fout = ui.ferr
261 ui.fout = repo.ui.fout = ui.ferr
259
262
260 # Prevent insertion/deletion of CRs
263 # Prevent insertion/deletion of CRs
261 util.setbinary(self._fin)
264 util.setbinary(self._fin)
262 util.setbinary(self._fout)
265 util.setbinary(self._fout)
263
266
264 def getargs(self, args):
267 def getargs(self, args):
265 data = {}
268 data = {}
266 keys = args.split()
269 keys = args.split()
267 for n in xrange(len(keys)):
270 for n in xrange(len(keys)):
268 argline = self._fin.readline()[:-1]
271 argline = self._fin.readline()[:-1]
269 arg, l = argline.split()
272 arg, l = argline.split()
270 if arg not in keys:
273 if arg not in keys:
271 raise error.Abort(_("unexpected parameter %r") % arg)
274 raise error.Abort(_("unexpected parameter %r") % arg)
272 if arg == '*':
275 if arg == '*':
273 star = {}
276 star = {}
274 for k in xrange(int(l)):
277 for k in xrange(int(l)):
275 argline = self._fin.readline()[:-1]
278 argline = self._fin.readline()[:-1]
276 arg, l = argline.split()
279 arg, l = argline.split()
277 val = self._fin.read(int(l))
280 val = self._fin.read(int(l))
278 star[arg] = val
281 star[arg] = val
279 data['*'] = star
282 data['*'] = star
280 else:
283 else:
281 val = self._fin.read(int(l))
284 val = self._fin.read(int(l))
282 data[arg] = val
285 data[arg] = val
283 return [data[k] for k in keys]
286 return [data[k] for k in keys]
284
287
285 def getfile(self, fpout):
288 def getfile(self, fpout):
286 self._sendresponse('')
289 self._sendresponse('')
287 count = int(self._fin.readline())
290 count = int(self._fin.readline())
288 while count:
291 while count:
289 fpout.write(self._fin.read(count))
292 fpout.write(self._fin.read(count))
290 count = int(self._fin.readline())
293 count = int(self._fin.readline())
291
294
292 def redirect(self):
295 def redirect(self):
293 pass
296 pass
294
297
295 def _sendresponse(self, v):
298 def _sendresponse(self, v):
296 self._fout.write("%d\n" % len(v))
299 self._fout.write("%d\n" % len(v))
297 self._fout.write(v)
300 self._fout.write(v)
298 self._fout.flush()
301 self._fout.flush()
299
302
300 def _sendstream(self, source):
303 def _sendstream(self, source):
301 write = self._fout.write
304 write = self._fout.write
302 for chunk in source.gen:
305 for chunk in source.gen:
303 write(chunk)
306 write(chunk)
304 self._fout.flush()
307 self._fout.flush()
305
308
306 def _sendpushresponse(self, rsp):
309 def _sendpushresponse(self, rsp):
307 self._sendresponse('')
310 self._sendresponse('')
308 self._sendresponse(str(rsp.res))
311 self._sendresponse(str(rsp.res))
309
312
310 def _sendpusherror(self, rsp):
313 def _sendpusherror(self, rsp):
311 self._sendresponse(rsp.res)
314 self._sendresponse(rsp.res)
312
315
313 def _sendooberror(self, rsp):
316 def _sendooberror(self, rsp):
314 self._ui.ferr.write('%s\n-\n' % rsp.message)
317 self._ui.ferr.write('%s\n-\n' % rsp.message)
315 self._ui.ferr.flush()
318 self._ui.ferr.flush()
316 self._fout.write('\n')
319 self._fout.write('\n')
317 self._fout.flush()
320 self._fout.flush()
318
321
319 def serve_forever(self):
322 def serve_forever(self):
320 while self.serve_one():
323 while self.serve_one():
321 pass
324 pass
322 sys.exit(0)
325 sys.exit(0)
323
326
324 _handlers = {
327 _handlers = {
325 str: _sendresponse,
328 str: _sendresponse,
326 wireproto.streamres: _sendstream,
329 wireproto.streamres: _sendstream,
327 wireproto.streamres_legacy: _sendstream,
330 wireproto.streamres_legacy: _sendstream,
328 wireproto.pushres: _sendpushresponse,
331 wireproto.pushres: _sendpushresponse,
329 wireproto.pusherr: _sendpusherror,
332 wireproto.pusherr: _sendpusherror,
330 wireproto.ooberror: _sendooberror,
333 wireproto.ooberror: _sendooberror,
331 }
334 }
332
335
333 def serve_one(self):
336 def serve_one(self):
334 cmd = self._fin.readline()[:-1]
337 cmd = self._fin.readline()[:-1]
335 if cmd and cmd in wireproto.commands:
338 if cmd and cmd in wireproto.commands:
336 rsp = wireproto.dispatch(self._repo, self, cmd)
339 rsp = wireproto.dispatch(self._repo, self, cmd)
337 self._handlers[rsp.__class__](self, rsp)
340 self._handlers[rsp.__class__](self, rsp)
338 elif cmd:
341 elif cmd:
339 self._sendresponse("")
342 self._sendresponse("")
340 return cmd != ''
343 return cmd != ''
341
344
342 def _client(self):
345 def _client(self):
343 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
346 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
344 return 'remote:ssh:' + client
347 return 'remote:ssh:' + client
General Comments 0
You need to be logged in to leave comments. Login now