##// END OF EJS Templates
wireprotoserver: make response handling attributes private...
Gregory Szorc -
r35889:49ea1ba1 default
parent child Browse files
Show More
@@ -1,344 +1,344 b''
1 1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 3 #
4 4 # This software may be used and distributed according to the terms of the
5 5 # GNU General Public License version 2 or any later version.
6 6
7 7 from __future__ import absolute_import
8 8
9 9 import cgi
10 10 import struct
11 11 import sys
12 12
13 13 from .i18n import _
14 14 from . import (
15 15 encoding,
16 16 error,
17 17 hook,
18 18 pycompat,
19 19 util,
20 20 wireproto,
21 21 )
22 22
23 23 stringio = util.stringio
24 24
25 25 urlerr = util.urlerr
26 26 urlreq = util.urlreq
27 27
28 28 HTTP_OK = 200
29 29
30 30 HGTYPE = 'application/mercurial-0.1'
31 31 HGTYPE2 = 'application/mercurial-0.2'
32 32 HGERRTYPE = 'application/hg-error'
33 33
34 34 class abstractserverproto(object):
35 35 """abstract class that summarizes the protocol API
36 36
37 37 Used as reference and documentation.
38 38 """
39 39
40 40 def getargs(self, args):
41 41 """return the value for arguments in <args>
42 42
43 43 returns a list of values (same order as <args>)"""
44 44 raise NotImplementedError()
45 45
46 46 def getfile(self, fp):
47 47 """write the whole content of a file into a file like object
48 48
49 49 The file is in the form::
50 50
51 51 (<chunk-size>\n<chunk>)+0\n
52 52
53 53 chunk size is the ascii version of the int.
54 54 """
55 55 raise NotImplementedError()
56 56
57 57 def redirect(self):
58 58 """may setup interception for stdout and stderr
59 59
60 60 See also the `restore` method."""
61 61 raise NotImplementedError()
62 62
63 63 # If the `redirect` function does install interception, the `restore`
64 64 # function MUST be defined. If interception is not used, this function
65 65 # MUST NOT be defined.
66 66 #
67 67 # left commented here on purpose
68 68 #
69 69 #def restore(self):
70 70 # """reinstall previous stdout and stderr and return intercepted stdout
71 71 # """
72 72 # raise NotImplementedError()
73 73
74 74 def decodevaluefromheaders(req, headerprefix):
75 75 """Decode a long value from multiple HTTP request headers.
76 76
77 77 Returns the value as a bytes, not a str.
78 78 """
79 79 chunks = []
80 80 i = 1
81 81 prefix = headerprefix.upper().replace(r'-', r'_')
82 82 while True:
83 83 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
84 84 if v is None:
85 85 break
86 86 chunks.append(pycompat.bytesurl(v))
87 87 i += 1
88 88
89 89 return ''.join(chunks)
90 90
91 91 class webproto(abstractserverproto):
92 92 def __init__(self, req, ui):
93 93 self._req = req
94 94 self._ui = ui
95 95 self.name = 'http'
96 96
97 97 def getargs(self, args):
98 98 knownargs = self._args()
99 99 data = {}
100 100 keys = args.split()
101 101 for k in keys:
102 102 if k == '*':
103 103 star = {}
104 104 for key in knownargs.keys():
105 105 if key != 'cmd' and key not in keys:
106 106 star[key] = knownargs[key][0]
107 107 data['*'] = star
108 108 else:
109 109 data[k] = knownargs[k][0]
110 110 return [data[k] for k in keys]
111 111
112 112 def _args(self):
113 113 args = self._req.form.copy()
114 114 if pycompat.ispy3:
115 115 args = {k.encode('ascii'): [v.encode('ascii') for v in vs]
116 116 for k, vs in args.items()}
117 117 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
118 118 if postlen:
119 119 args.update(cgi.parse_qs(
120 120 self._req.read(postlen), keep_blank_values=True))
121 121 return args
122 122
123 123 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
124 124 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
125 125 return args
126 126
127 127 def getfile(self, fp):
128 128 length = int(self._req.env[r'CONTENT_LENGTH'])
129 129 # If httppostargs is used, we need to read Content-Length
130 130 # minus the amount that was consumed by args.
131 131 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
132 132 for s in util.filechunkiter(self._req, limit=length):
133 133 fp.write(s)
134 134
135 135 def redirect(self):
136 136 self._oldio = self._ui.fout, self._ui.ferr
137 137 self._ui.ferr = self._ui.fout = stringio()
138 138
139 139 def restore(self):
140 140 val = self._ui.fout.getvalue()
141 141 self._ui.ferr, self._ui.fout = self._oldio
142 142 return val
143 143
144 144 def _client(self):
145 145 return 'remote:%s:%s:%s' % (
146 146 self._req.env.get('wsgi.url_scheme') or 'http',
147 147 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
148 148 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
149 149
150 150 def responsetype(self, prefer_uncompressed):
151 151 """Determine the appropriate response type and compression settings.
152 152
153 153 Returns a tuple of (mediatype, compengine, engineopts).
154 154 """
155 155 # Determine the response media type and compression engine based
156 156 # on the request parameters.
157 157 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
158 158
159 159 if '0.2' in protocaps:
160 160 # All clients are expected to support uncompressed data.
161 161 if prefer_uncompressed:
162 162 return HGTYPE2, util._noopengine(), {}
163 163
164 164 # Default as defined by wire protocol spec.
165 165 compformats = ['zlib', 'none']
166 166 for cap in protocaps:
167 167 if cap.startswith('comp='):
168 168 compformats = cap[5:].split(',')
169 169 break
170 170
171 171 # Now find an agreed upon compression format.
172 172 for engine in wireproto.supportedcompengines(self._ui, self,
173 173 util.SERVERROLE):
174 174 if engine.wireprotosupport().name in compformats:
175 175 opts = {}
176 176 level = self._ui.configint('server',
177 177 '%slevel' % engine.name())
178 178 if level is not None:
179 179 opts['level'] = level
180 180
181 181 return HGTYPE2, engine, opts
182 182
183 183 # No mutually supported compression format. Fall back to the
184 184 # legacy protocol.
185 185
186 186 # Don't allow untrusted settings because disabling compression or
187 187 # setting a very high compression level could lead to flooding
188 188 # the server's network or CPU.
189 189 opts = {'level': self._ui.configint('server', 'zliblevel')}
190 190 return HGTYPE, util.compengines['zlib'], opts
191 191
192 192 def iscmd(cmd):
193 193 return cmd in wireproto.commands
194 194
195 195 def callhttp(repo, req, cmd):
196 196 proto = webproto(req, repo.ui)
197 197
198 198 def genversion2(gen, engine, engineopts):
199 199 # application/mercurial-0.2 always sends a payload header
200 200 # identifying the compression engine.
201 201 name = engine.wireprotosupport().name
202 202 assert 0 < len(name) < 256
203 203 yield struct.pack('B', len(name))
204 204 yield name
205 205
206 206 for chunk in gen:
207 207 yield chunk
208 208
209 209 rsp = wireproto.dispatch(repo, proto, cmd)
210 210 if isinstance(rsp, bytes):
211 211 req.respond(HTTP_OK, HGTYPE, body=rsp)
212 212 return []
213 213 elif isinstance(rsp, wireproto.streamres_legacy):
214 214 gen = rsp.gen
215 215 req.respond(HTTP_OK, HGTYPE)
216 216 return gen
217 217 elif isinstance(rsp, wireproto.streamres):
218 218 gen = rsp.gen
219 219
220 220 # This code for compression should not be streamres specific. It
221 221 # is here because we only compress streamres at the moment.
222 222 mediatype, engine, engineopts = proto.responsetype(
223 223 rsp.prefer_uncompressed)
224 224 gen = engine.compressstream(gen, engineopts)
225 225
226 226 if mediatype == HGTYPE2:
227 227 gen = genversion2(gen, engine, engineopts)
228 228
229 229 req.respond(HTTP_OK, mediatype)
230 230 return gen
231 231 elif isinstance(rsp, wireproto.pushres):
232 232 val = proto.restore()
233 233 rsp = '%d\n%s' % (rsp.res, val)
234 234 req.respond(HTTP_OK, HGTYPE, body=rsp)
235 235 return []
236 236 elif isinstance(rsp, wireproto.pusherr):
237 237 # drain the incoming bundle
238 238 req.drain()
239 239 proto.restore()
240 240 rsp = '0\n%s\n' % rsp.res
241 241 req.respond(HTTP_OK, HGTYPE, body=rsp)
242 242 return []
243 243 elif isinstance(rsp, wireproto.ooberror):
244 244 rsp = rsp.message
245 245 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
246 246 return []
247 247 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
248 248
249 249 class sshserver(abstractserverproto):
250 250 def __init__(self, ui, repo):
251 251 self._ui = ui
252 252 self._repo = repo
253 253 self._fin = ui.fin
254 254 self._fout = ui.fout
255 255 self.name = 'ssh'
256 256
257 257 hook.redirect(True)
258 258 ui.fout = repo.ui.fout = ui.ferr
259 259
260 260 # Prevent insertion/deletion of CRs
261 261 util.setbinary(self._fin)
262 262 util.setbinary(self._fout)
263 263
264 264 def getargs(self, args):
265 265 data = {}
266 266 keys = args.split()
267 267 for n in xrange(len(keys)):
268 268 argline = self._fin.readline()[:-1]
269 269 arg, l = argline.split()
270 270 if arg not in keys:
271 271 raise error.Abort(_("unexpected parameter %r") % arg)
272 272 if arg == '*':
273 273 star = {}
274 274 for k in xrange(int(l)):
275 275 argline = self._fin.readline()[:-1]
276 276 arg, l = argline.split()
277 277 val = self._fin.read(int(l))
278 278 star[arg] = val
279 279 data['*'] = star
280 280 else:
281 281 val = self._fin.read(int(l))
282 282 data[arg] = val
283 283 return [data[k] for k in keys]
284 284
285 285 def getfile(self, fpout):
286 self.sendresponse('')
286 self._sendresponse('')
287 287 count = int(self._fin.readline())
288 288 while count:
289 289 fpout.write(self._fin.read(count))
290 290 count = int(self._fin.readline())
291 291
292 292 def redirect(self):
293 293 pass
294 294
295 def sendresponse(self, v):
295 def _sendresponse(self, v):
296 296 self._fout.write("%d\n" % len(v))
297 297 self._fout.write(v)
298 298 self._fout.flush()
299 299
300 def sendstream(self, source):
300 def _sendstream(self, source):
301 301 write = self._fout.write
302 302 for chunk in source.gen:
303 303 write(chunk)
304 304 self._fout.flush()
305 305
306 def sendpushresponse(self, rsp):
307 self.sendresponse('')
308 self.sendresponse(str(rsp.res))
306 def _sendpushresponse(self, rsp):
307 self._sendresponse('')
308 self._sendresponse(str(rsp.res))
309 309
310 def sendpusherror(self, rsp):
311 self.sendresponse(rsp.res)
310 def _sendpusherror(self, rsp):
311 self._sendresponse(rsp.res)
312 312
313 def sendooberror(self, rsp):
313 def _sendooberror(self, rsp):
314 314 self._ui.ferr.write('%s\n-\n' % rsp.message)
315 315 self._ui.ferr.flush()
316 316 self._fout.write('\n')
317 317 self._fout.flush()
318 318
319 319 def serve_forever(self):
320 320 while self.serve_one():
321 321 pass
322 322 sys.exit(0)
323 323
324 handlers = {
325 str: sendresponse,
326 wireproto.streamres: sendstream,
327 wireproto.streamres_legacy: sendstream,
328 wireproto.pushres: sendpushresponse,
329 wireproto.pusherr: sendpusherror,
330 wireproto.ooberror: sendooberror,
324 _handlers = {
325 str: _sendresponse,
326 wireproto.streamres: _sendstream,
327 wireproto.streamres_legacy: _sendstream,
328 wireproto.pushres: _sendpushresponse,
329 wireproto.pusherr: _sendpusherror,
330 wireproto.ooberror: _sendooberror,
331 331 }
332 332
333 333 def serve_one(self):
334 334 cmd = self._fin.readline()[:-1]
335 335 if cmd and cmd in wireproto.commands:
336 336 rsp = wireproto.dispatch(self._repo, self, cmd)
337 self.handlers[rsp.__class__](self, rsp)
337 self._handlers[rsp.__class__](self, rsp)
338 338 elif cmd:
339 self.sendresponse("")
339 self._sendresponse("")
340 340 return cmd != ''
341 341
342 342 def _client(self):
343 343 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
344 344 return 'remote:ssh:' + client
General Comments 0
You need to be logged in to leave comments. Login now