##// END OF EJS Templates
wireprotoserver: make abstractserverproto a proper abstract base class...
Gregory Szorc -
r35890:68dc621f default
parent child Browse files
Show More
@@ -1,344 +1,347
1 1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 3 #
4 4 # This software may be used and distributed according to the terms of the
5 5 # GNU General Public License version 2 or any later version.
6 6
7 7 from __future__ import absolute_import
8 8
9 import abc
9 10 import cgi
10 11 import struct
11 12 import sys
12 13
13 14 from .i18n import _
14 15 from . import (
15 16 encoding,
16 17 error,
17 18 hook,
18 19 pycompat,
19 20 util,
20 21 wireproto,
21 22 )
22 23
23 24 stringio = util.stringio
24 25
25 26 urlerr = util.urlerr
26 27 urlreq = util.urlreq
27 28
28 29 HTTP_OK = 200
29 30
30 31 HGTYPE = 'application/mercurial-0.1'
31 32 HGTYPE2 = 'application/mercurial-0.2'
32 33 HGERRTYPE = 'application/hg-error'
33 34
34 35 class abstractserverproto(object):
35 36 """abstract class that summarizes the protocol API
36 37
37 38 Used as reference and documentation.
38 39 """
39 40
41 __metaclass__ = abc.ABCMeta
42
43 @abc.abstractmethod
40 44 def getargs(self, args):
41 45 """return the value for arguments in <args>
42 46
43 47 returns a list of values (same order as <args>)"""
44 raise NotImplementedError()
45 48
49 @abc.abstractmethod
46 50 def getfile(self, fp):
47 51 """write the whole content of a file into a file like object
48 52
49 53 The file is in the form::
50 54
51 55 (<chunk-size>\n<chunk>)+0\n
52 56
53 57 chunk size is the ascii version of the int.
54 58 """
55 raise NotImplementedError()
56 59
60 @abc.abstractmethod
57 61 def redirect(self):
58 62 """may setup interception for stdout and stderr
59 63
60 64 See also the `restore` method."""
61 raise NotImplementedError()
62 65
63 66 # If the `redirect` function does install interception, the `restore`
64 67 # function MUST be defined. If interception is not used, this function
65 68 # MUST NOT be defined.
66 69 #
67 70 # left commented here on purpose
68 71 #
69 72 #def restore(self):
70 73 # """reinstall previous stdout and stderr and return intercepted stdout
71 74 # """
72 75 # raise NotImplementedError()
73 76
74 77 def decodevaluefromheaders(req, headerprefix):
75 78 """Decode a long value from multiple HTTP request headers.
76 79
77 80 Returns the value as a bytes, not a str.
78 81 """
79 82 chunks = []
80 83 i = 1
81 84 prefix = headerprefix.upper().replace(r'-', r'_')
82 85 while True:
83 86 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
84 87 if v is None:
85 88 break
86 89 chunks.append(pycompat.bytesurl(v))
87 90 i += 1
88 91
89 92 return ''.join(chunks)
90 93
91 94 class webproto(abstractserverproto):
92 95 def __init__(self, req, ui):
93 96 self._req = req
94 97 self._ui = ui
95 98 self.name = 'http'
96 99
97 100 def getargs(self, args):
98 101 knownargs = self._args()
99 102 data = {}
100 103 keys = args.split()
101 104 for k in keys:
102 105 if k == '*':
103 106 star = {}
104 107 for key in knownargs.keys():
105 108 if key != 'cmd' and key not in keys:
106 109 star[key] = knownargs[key][0]
107 110 data['*'] = star
108 111 else:
109 112 data[k] = knownargs[k][0]
110 113 return [data[k] for k in keys]
111 114
112 115 def _args(self):
113 116 args = self._req.form.copy()
114 117 if pycompat.ispy3:
115 118 args = {k.encode('ascii'): [v.encode('ascii') for v in vs]
116 119 for k, vs in args.items()}
117 120 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
118 121 if postlen:
119 122 args.update(cgi.parse_qs(
120 123 self._req.read(postlen), keep_blank_values=True))
121 124 return args
122 125
123 126 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
124 127 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
125 128 return args
126 129
127 130 def getfile(self, fp):
128 131 length = int(self._req.env[r'CONTENT_LENGTH'])
129 132 # If httppostargs is used, we need to read Content-Length
130 133 # minus the amount that was consumed by args.
131 134 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
132 135 for s in util.filechunkiter(self._req, limit=length):
133 136 fp.write(s)
134 137
135 138 def redirect(self):
136 139 self._oldio = self._ui.fout, self._ui.ferr
137 140 self._ui.ferr = self._ui.fout = stringio()
138 141
139 142 def restore(self):
140 143 val = self._ui.fout.getvalue()
141 144 self._ui.ferr, self._ui.fout = self._oldio
142 145 return val
143 146
144 147 def _client(self):
145 148 return 'remote:%s:%s:%s' % (
146 149 self._req.env.get('wsgi.url_scheme') or 'http',
147 150 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
148 151 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
149 152
150 153 def responsetype(self, prefer_uncompressed):
151 154 """Determine the appropriate response type and compression settings.
152 155
153 156 Returns a tuple of (mediatype, compengine, engineopts).
154 157 """
155 158 # Determine the response media type and compression engine based
156 159 # on the request parameters.
157 160 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
158 161
159 162 if '0.2' in protocaps:
160 163 # All clients are expected to support uncompressed data.
161 164 if prefer_uncompressed:
162 165 return HGTYPE2, util._noopengine(), {}
163 166
164 167 # Default as defined by wire protocol spec.
165 168 compformats = ['zlib', 'none']
166 169 for cap in protocaps:
167 170 if cap.startswith('comp='):
168 171 compformats = cap[5:].split(',')
169 172 break
170 173
171 174 # Now find an agreed upon compression format.
172 175 for engine in wireproto.supportedcompengines(self._ui, self,
173 176 util.SERVERROLE):
174 177 if engine.wireprotosupport().name in compformats:
175 178 opts = {}
176 179 level = self._ui.configint('server',
177 180 '%slevel' % engine.name())
178 181 if level is not None:
179 182 opts['level'] = level
180 183
181 184 return HGTYPE2, engine, opts
182 185
183 186 # No mutually supported compression format. Fall back to the
184 187 # legacy protocol.
185 188
186 189 # Don't allow untrusted settings because disabling compression or
187 190 # setting a very high compression level could lead to flooding
188 191 # the server's network or CPU.
189 192 opts = {'level': self._ui.configint('server', 'zliblevel')}
190 193 return HGTYPE, util.compengines['zlib'], opts
191 194
192 195 def iscmd(cmd):
193 196 return cmd in wireproto.commands
194 197
195 198 def callhttp(repo, req, cmd):
196 199 proto = webproto(req, repo.ui)
197 200
198 201 def genversion2(gen, engine, engineopts):
199 202 # application/mercurial-0.2 always sends a payload header
200 203 # identifying the compression engine.
201 204 name = engine.wireprotosupport().name
202 205 assert 0 < len(name) < 256
203 206 yield struct.pack('B', len(name))
204 207 yield name
205 208
206 209 for chunk in gen:
207 210 yield chunk
208 211
209 212 rsp = wireproto.dispatch(repo, proto, cmd)
210 213 if isinstance(rsp, bytes):
211 214 req.respond(HTTP_OK, HGTYPE, body=rsp)
212 215 return []
213 216 elif isinstance(rsp, wireproto.streamres_legacy):
214 217 gen = rsp.gen
215 218 req.respond(HTTP_OK, HGTYPE)
216 219 return gen
217 220 elif isinstance(rsp, wireproto.streamres):
218 221 gen = rsp.gen
219 222
220 223 # This code for compression should not be streamres specific. It
221 224 # is here because we only compress streamres at the moment.
222 225 mediatype, engine, engineopts = proto.responsetype(
223 226 rsp.prefer_uncompressed)
224 227 gen = engine.compressstream(gen, engineopts)
225 228
226 229 if mediatype == HGTYPE2:
227 230 gen = genversion2(gen, engine, engineopts)
228 231
229 232 req.respond(HTTP_OK, mediatype)
230 233 return gen
231 234 elif isinstance(rsp, wireproto.pushres):
232 235 val = proto.restore()
233 236 rsp = '%d\n%s' % (rsp.res, val)
234 237 req.respond(HTTP_OK, HGTYPE, body=rsp)
235 238 return []
236 239 elif isinstance(rsp, wireproto.pusherr):
237 240 # drain the incoming bundle
238 241 req.drain()
239 242 proto.restore()
240 243 rsp = '0\n%s\n' % rsp.res
241 244 req.respond(HTTP_OK, HGTYPE, body=rsp)
242 245 return []
243 246 elif isinstance(rsp, wireproto.ooberror):
244 247 rsp = rsp.message
245 248 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
246 249 return []
247 250 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
248 251
249 252 class sshserver(abstractserverproto):
250 253 def __init__(self, ui, repo):
251 254 self._ui = ui
252 255 self._repo = repo
253 256 self._fin = ui.fin
254 257 self._fout = ui.fout
255 258 self.name = 'ssh'
256 259
257 260 hook.redirect(True)
258 261 ui.fout = repo.ui.fout = ui.ferr
259 262
260 263 # Prevent insertion/deletion of CRs
261 264 util.setbinary(self._fin)
262 265 util.setbinary(self._fout)
263 266
264 267 def getargs(self, args):
265 268 data = {}
266 269 keys = args.split()
267 270 for n in xrange(len(keys)):
268 271 argline = self._fin.readline()[:-1]
269 272 arg, l = argline.split()
270 273 if arg not in keys:
271 274 raise error.Abort(_("unexpected parameter %r") % arg)
272 275 if arg == '*':
273 276 star = {}
274 277 for k in xrange(int(l)):
275 278 argline = self._fin.readline()[:-1]
276 279 arg, l = argline.split()
277 280 val = self._fin.read(int(l))
278 281 star[arg] = val
279 282 data['*'] = star
280 283 else:
281 284 val = self._fin.read(int(l))
282 285 data[arg] = val
283 286 return [data[k] for k in keys]
284 287
285 288 def getfile(self, fpout):
286 289 self._sendresponse('')
287 290 count = int(self._fin.readline())
288 291 while count:
289 292 fpout.write(self._fin.read(count))
290 293 count = int(self._fin.readline())
291 294
292 295 def redirect(self):
293 296 pass
294 297
295 298 def _sendresponse(self, v):
296 299 self._fout.write("%d\n" % len(v))
297 300 self._fout.write(v)
298 301 self._fout.flush()
299 302
300 303 def _sendstream(self, source):
301 304 write = self._fout.write
302 305 for chunk in source.gen:
303 306 write(chunk)
304 307 self._fout.flush()
305 308
306 309 def _sendpushresponse(self, rsp):
307 310 self._sendresponse('')
308 311 self._sendresponse(str(rsp.res))
309 312
310 313 def _sendpusherror(self, rsp):
311 314 self._sendresponse(rsp.res)
312 315
313 316 def _sendooberror(self, rsp):
314 317 self._ui.ferr.write('%s\n-\n' % rsp.message)
315 318 self._ui.ferr.flush()
316 319 self._fout.write('\n')
317 320 self._fout.flush()
318 321
319 322 def serve_forever(self):
320 323 while self.serve_one():
321 324 pass
322 325 sys.exit(0)
323 326
324 327 _handlers = {
325 328 str: _sendresponse,
326 329 wireproto.streamres: _sendstream,
327 330 wireproto.streamres_legacy: _sendstream,
328 331 wireproto.pushres: _sendpushresponse,
329 332 wireproto.pusherr: _sendpusherror,
330 333 wireproto.ooberror: _sendooberror,
331 334 }
332 335
333 336 def serve_one(self):
334 337 cmd = self._fin.readline()[:-1]
335 338 if cmd and cmd in wireproto.commands:
336 339 rsp = wireproto.dispatch(self._repo, self, cmd)
337 340 self._handlers[rsp.__class__](self, rsp)
338 341 elif cmd:
339 342 self._sendresponse("")
340 343 return cmd != ''
341 344
342 345 def _client(self):
343 346 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
344 347 return 'remote:ssh:' + client
General Comments 0
You need to be logged in to leave comments. Login now