##// END OF EJS Templates
wireprotoserver: make name part of protocol interface...
Gregory Szorc -
r35891:29759c46 default
parent child Browse files
Show More
@@ -1,347 +1,360
1 1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 3 #
4 4 # This software may be used and distributed according to the terms of the
5 5 # GNU General Public License version 2 or any later version.
6 6
7 7 from __future__ import absolute_import
8 8
9 9 import abc
10 10 import cgi
11 11 import struct
12 12 import sys
13 13
14 14 from .i18n import _
15 15 from . import (
16 16 encoding,
17 17 error,
18 18 hook,
19 19 pycompat,
20 20 util,
21 21 wireproto,
22 22 )
23 23
24 24 stringio = util.stringio
25 25
26 26 urlerr = util.urlerr
27 27 urlreq = util.urlreq
28 28
29 29 HTTP_OK = 200
30 30
31 31 HGTYPE = 'application/mercurial-0.1'
32 32 HGTYPE2 = 'application/mercurial-0.2'
33 33 HGERRTYPE = 'application/hg-error'
34 34
35 35 class abstractserverproto(object):
36 36 """abstract class that summarizes the protocol API
37 37
38 38 Used as reference and documentation.
39 39 """
40 40
41 41 __metaclass__ = abc.ABCMeta
42 42
43 @abc.abstractproperty
44 def name(self):
45 """The name of the protocol implementation.
46
47 Used for uniquely identifying the transport type.
48 """
49
43 50 @abc.abstractmethod
44 51 def getargs(self, args):
45 52 """return the value for arguments in <args>
46 53
47 54 returns a list of values (same order as <args>)"""
48 55
49 56 @abc.abstractmethod
50 57 def getfile(self, fp):
51 58 """write the whole content of a file into a file like object
52 59
53 60 The file is in the form::
54 61
55 62 (<chunk-size>\n<chunk>)+0\n
56 63
57 64 chunk size is the ascii version of the int.
58 65 """
59 66
60 67 @abc.abstractmethod
61 68 def redirect(self):
62 69 """may setup interception for stdout and stderr
63 70
64 71 See also the `restore` method."""
65 72
66 73 # If the `redirect` function does install interception, the `restore`
67 74 # function MUST be defined. If interception is not used, this function
68 75 # MUST NOT be defined.
69 76 #
70 77 # left commented here on purpose
71 78 #
72 79 #def restore(self):
73 80 # """reinstall previous stdout and stderr and return intercepted stdout
74 81 # """
75 82 # raise NotImplementedError()
76 83
77 84 def decodevaluefromheaders(req, headerprefix):
78 85 """Decode a long value from multiple HTTP request headers.
79 86
80 87 Returns the value as a bytes, not a str.
81 88 """
82 89 chunks = []
83 90 i = 1
84 91 prefix = headerprefix.upper().replace(r'-', r'_')
85 92 while True:
86 93 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
87 94 if v is None:
88 95 break
89 96 chunks.append(pycompat.bytesurl(v))
90 97 i += 1
91 98
92 99 return ''.join(chunks)
93 100
94 101 class webproto(abstractserverproto):
95 102 def __init__(self, req, ui):
96 103 self._req = req
97 104 self._ui = ui
98 self.name = 'http'
105
106 @property
107 def name(self):
108 return 'http'
99 109
100 110 def getargs(self, args):
101 111 knownargs = self._args()
102 112 data = {}
103 113 keys = args.split()
104 114 for k in keys:
105 115 if k == '*':
106 116 star = {}
107 117 for key in knownargs.keys():
108 118 if key != 'cmd' and key not in keys:
109 119 star[key] = knownargs[key][0]
110 120 data['*'] = star
111 121 else:
112 122 data[k] = knownargs[k][0]
113 123 return [data[k] for k in keys]
114 124
115 125 def _args(self):
116 126 args = self._req.form.copy()
117 127 if pycompat.ispy3:
118 128 args = {k.encode('ascii'): [v.encode('ascii') for v in vs]
119 129 for k, vs in args.items()}
120 130 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
121 131 if postlen:
122 132 args.update(cgi.parse_qs(
123 133 self._req.read(postlen), keep_blank_values=True))
124 134 return args
125 135
126 136 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
127 137 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
128 138 return args
129 139
130 140 def getfile(self, fp):
131 141 length = int(self._req.env[r'CONTENT_LENGTH'])
132 142 # If httppostargs is used, we need to read Content-Length
133 143 # minus the amount that was consumed by args.
134 144 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
135 145 for s in util.filechunkiter(self._req, limit=length):
136 146 fp.write(s)
137 147
138 148 def redirect(self):
139 149 self._oldio = self._ui.fout, self._ui.ferr
140 150 self._ui.ferr = self._ui.fout = stringio()
141 151
142 152 def restore(self):
143 153 val = self._ui.fout.getvalue()
144 154 self._ui.ferr, self._ui.fout = self._oldio
145 155 return val
146 156
147 157 def _client(self):
148 158 return 'remote:%s:%s:%s' % (
149 159 self._req.env.get('wsgi.url_scheme') or 'http',
150 160 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
151 161 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
152 162
153 163 def responsetype(self, prefer_uncompressed):
154 164 """Determine the appropriate response type and compression settings.
155 165
156 166 Returns a tuple of (mediatype, compengine, engineopts).
157 167 """
158 168 # Determine the response media type and compression engine based
159 169 # on the request parameters.
160 170 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
161 171
162 172 if '0.2' in protocaps:
163 173 # All clients are expected to support uncompressed data.
164 174 if prefer_uncompressed:
165 175 return HGTYPE2, util._noopengine(), {}
166 176
167 177 # Default as defined by wire protocol spec.
168 178 compformats = ['zlib', 'none']
169 179 for cap in protocaps:
170 180 if cap.startswith('comp='):
171 181 compformats = cap[5:].split(',')
172 182 break
173 183
174 184 # Now find an agreed upon compression format.
175 185 for engine in wireproto.supportedcompengines(self._ui, self,
176 186 util.SERVERROLE):
177 187 if engine.wireprotosupport().name in compformats:
178 188 opts = {}
179 189 level = self._ui.configint('server',
180 190 '%slevel' % engine.name())
181 191 if level is not None:
182 192 opts['level'] = level
183 193
184 194 return HGTYPE2, engine, opts
185 195
186 196 # No mutually supported compression format. Fall back to the
187 197 # legacy protocol.
188 198
189 199 # Don't allow untrusted settings because disabling compression or
190 200 # setting a very high compression level could lead to flooding
191 201 # the server's network or CPU.
192 202 opts = {'level': self._ui.configint('server', 'zliblevel')}
193 203 return HGTYPE, util.compengines['zlib'], opts
194 204
195 205 def iscmd(cmd):
196 206 return cmd in wireproto.commands
197 207
198 208 def callhttp(repo, req, cmd):
199 209 proto = webproto(req, repo.ui)
200 210
201 211 def genversion2(gen, engine, engineopts):
202 212 # application/mercurial-0.2 always sends a payload header
203 213 # identifying the compression engine.
204 214 name = engine.wireprotosupport().name
205 215 assert 0 < len(name) < 256
206 216 yield struct.pack('B', len(name))
207 217 yield name
208 218
209 219 for chunk in gen:
210 220 yield chunk
211 221
212 222 rsp = wireproto.dispatch(repo, proto, cmd)
213 223 if isinstance(rsp, bytes):
214 224 req.respond(HTTP_OK, HGTYPE, body=rsp)
215 225 return []
216 226 elif isinstance(rsp, wireproto.streamres_legacy):
217 227 gen = rsp.gen
218 228 req.respond(HTTP_OK, HGTYPE)
219 229 return gen
220 230 elif isinstance(rsp, wireproto.streamres):
221 231 gen = rsp.gen
222 232
223 233 # This code for compression should not be streamres specific. It
224 234 # is here because we only compress streamres at the moment.
225 235 mediatype, engine, engineopts = proto.responsetype(
226 236 rsp.prefer_uncompressed)
227 237 gen = engine.compressstream(gen, engineopts)
228 238
229 239 if mediatype == HGTYPE2:
230 240 gen = genversion2(gen, engine, engineopts)
231 241
232 242 req.respond(HTTP_OK, mediatype)
233 243 return gen
234 244 elif isinstance(rsp, wireproto.pushres):
235 245 val = proto.restore()
236 246 rsp = '%d\n%s' % (rsp.res, val)
237 247 req.respond(HTTP_OK, HGTYPE, body=rsp)
238 248 return []
239 249 elif isinstance(rsp, wireproto.pusherr):
240 250 # drain the incoming bundle
241 251 req.drain()
242 252 proto.restore()
243 253 rsp = '0\n%s\n' % rsp.res
244 254 req.respond(HTTP_OK, HGTYPE, body=rsp)
245 255 return []
246 256 elif isinstance(rsp, wireproto.ooberror):
247 257 rsp = rsp.message
248 258 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
249 259 return []
250 260 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
251 261
252 262 class sshserver(abstractserverproto):
253 263 def __init__(self, ui, repo):
254 264 self._ui = ui
255 265 self._repo = repo
256 266 self._fin = ui.fin
257 267 self._fout = ui.fout
258 self.name = 'ssh'
259 268
260 269 hook.redirect(True)
261 270 ui.fout = repo.ui.fout = ui.ferr
262 271
263 272 # Prevent insertion/deletion of CRs
264 273 util.setbinary(self._fin)
265 274 util.setbinary(self._fout)
266 275
276 @property
277 def name(self):
278 return 'ssh'
279
267 280 def getargs(self, args):
268 281 data = {}
269 282 keys = args.split()
270 283 for n in xrange(len(keys)):
271 284 argline = self._fin.readline()[:-1]
272 285 arg, l = argline.split()
273 286 if arg not in keys:
274 287 raise error.Abort(_("unexpected parameter %r") % arg)
275 288 if arg == '*':
276 289 star = {}
277 290 for k in xrange(int(l)):
278 291 argline = self._fin.readline()[:-1]
279 292 arg, l = argline.split()
280 293 val = self._fin.read(int(l))
281 294 star[arg] = val
282 295 data['*'] = star
283 296 else:
284 297 val = self._fin.read(int(l))
285 298 data[arg] = val
286 299 return [data[k] for k in keys]
287 300
288 301 def getfile(self, fpout):
289 302 self._sendresponse('')
290 303 count = int(self._fin.readline())
291 304 while count:
292 305 fpout.write(self._fin.read(count))
293 306 count = int(self._fin.readline())
294 307
295 308 def redirect(self):
296 309 pass
297 310
298 311 def _sendresponse(self, v):
299 312 self._fout.write("%d\n" % len(v))
300 313 self._fout.write(v)
301 314 self._fout.flush()
302 315
303 316 def _sendstream(self, source):
304 317 write = self._fout.write
305 318 for chunk in source.gen:
306 319 write(chunk)
307 320 self._fout.flush()
308 321
309 322 def _sendpushresponse(self, rsp):
310 323 self._sendresponse('')
311 324 self._sendresponse(str(rsp.res))
312 325
313 326 def _sendpusherror(self, rsp):
314 327 self._sendresponse(rsp.res)
315 328
316 329 def _sendooberror(self, rsp):
317 330 self._ui.ferr.write('%s\n-\n' % rsp.message)
318 331 self._ui.ferr.flush()
319 332 self._fout.write('\n')
320 333 self._fout.flush()
321 334
322 335 def serve_forever(self):
323 336 while self.serve_one():
324 337 pass
325 338 sys.exit(0)
326 339
327 340 _handlers = {
328 341 str: _sendresponse,
329 342 wireproto.streamres: _sendstream,
330 343 wireproto.streamres_legacy: _sendstream,
331 344 wireproto.pushres: _sendpushresponse,
332 345 wireproto.pusherr: _sendpusherror,
333 346 wireproto.ooberror: _sendooberror,
334 347 }
335 348
336 349 def serve_one(self):
337 350 cmd = self._fin.readline()[:-1]
338 351 if cmd and cmd in wireproto.commands:
339 352 rsp = wireproto.dispatch(self._repo, self, cmd)
340 353 self._handlers[rsp.__class__](self, rsp)
341 354 elif cmd:
342 355 self._sendresponse("")
343 356 return cmd != ''
344 357
345 358 def _client(self):
346 359 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
347 360 return 'remote:ssh:' + client
General Comments 0
You need to be logged in to leave comments. Login now