##// END OF EJS Templates
wireproto: compress data from a generator...
Gregory Szorc -
r30206:d1051954 default
parent child Browse files
Show More
@@ -1,124 +1,133 b''
1 1 #
2 2 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
3 3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import cgi
11 11 import zlib
12 12
13 13 from .common import (
14 14 HTTP_OK,
15 15 )
16 16
17 17 from .. import (
18 18 util,
19 19 wireproto,
20 20 )
21 21 stringio = util.stringio
22 22
23 23 urlerr = util.urlerr
24 24 urlreq = util.urlreq
25 25
26 26 HGTYPE = 'application/mercurial-0.1'
27 27 HGERRTYPE = 'application/hg-error'
28 28
29 29 class webproto(wireproto.abstractserverproto):
30 30 def __init__(self, req, ui):
31 31 self.req = req
32 32 self.response = ''
33 33 self.ui = ui
34 34 def getargs(self, args):
35 35 knownargs = self._args()
36 36 data = {}
37 37 keys = args.split()
38 38 for k in keys:
39 39 if k == '*':
40 40 star = {}
41 41 for key in knownargs.keys():
42 42 if key != 'cmd' and key not in keys:
43 43 star[key] = knownargs[key][0]
44 44 data['*'] = star
45 45 else:
46 46 data[k] = knownargs[k][0]
47 47 return [data[k] for k in keys]
48 48 def _args(self):
49 49 args = self.req.form.copy()
50 50 postlen = int(self.req.env.get('HTTP_X_HGARGS_POST', 0))
51 51 if postlen:
52 52 args.update(cgi.parse_qs(
53 53 self.req.read(postlen), keep_blank_values=True))
54 54 return args
55 55 chunks = []
56 56 i = 1
57 57 while True:
58 58 h = self.req.env.get('HTTP_X_HGARG_' + str(i))
59 59 if h is None:
60 60 break
61 61 chunks += [h]
62 62 i += 1
63 63 args.update(cgi.parse_qs(''.join(chunks), keep_blank_values=True))
64 64 return args
65 65 def getfile(self, fp):
66 66 length = int(self.req.env['CONTENT_LENGTH'])
67 67 for s in util.filechunkiter(self.req, limit=length):
68 68 fp.write(s)
69 69 def redirect(self):
70 70 self.oldio = self.ui.fout, self.ui.ferr
71 71 self.ui.ferr = self.ui.fout = stringio()
72 72 def restore(self):
73 73 val = self.ui.fout.getvalue()
74 74 self.ui.ferr, self.ui.fout = self.oldio
75 75 return val
76
76 77 def groupchunks(self, fh):
78 def getchunks():
79 while True:
80 chunk = fh.read(32768)
81 if not chunk:
82 break
83 yield chunk
84
85 return self.compresschunks(getchunks())
86
87 def compresschunks(self, chunks):
77 88 # Don't allow untrusted settings because disabling compression or
78 89 # setting a very high compression level could lead to flooding
79 90 # the server's network or CPU.
80 91 z = zlib.compressobj(self.ui.configint('server', 'zliblevel', -1))
81 while True:
82 chunk = fh.read(32768)
83 if not chunk:
84 break
92 for chunk in chunks:
85 93 data = z.compress(chunk)
86 94 # Not all calls to compress() emit data. It is cheaper to inspect
87 95 # that here than to send it via the generator.
88 96 if data:
89 97 yield data
90 98 yield z.flush()
99
91 100 def _client(self):
92 101 return 'remote:%s:%s:%s' % (
93 102 self.req.env.get('wsgi.url_scheme') or 'http',
94 103 urlreq.quote(self.req.env.get('REMOTE_HOST', '')),
95 104 urlreq.quote(self.req.env.get('REMOTE_USER', '')))
96 105
97 106 def iscmd(cmd):
98 107 return cmd in wireproto.commands
99 108
100 109 def call(repo, req, cmd):
101 110 p = webproto(req, repo.ui)
102 111 rsp = wireproto.dispatch(repo, p, cmd)
103 112 if isinstance(rsp, str):
104 113 req.respond(HTTP_OK, HGTYPE, body=rsp)
105 114 return []
106 115 elif isinstance(rsp, wireproto.streamres):
107 116 req.respond(HTTP_OK, HGTYPE)
108 117 return rsp.gen
109 118 elif isinstance(rsp, wireproto.pushres):
110 119 val = p.restore()
111 120 rsp = '%d\n%s' % (rsp.res, val)
112 121 req.respond(HTTP_OK, HGTYPE, body=rsp)
113 122 return []
114 123 elif isinstance(rsp, wireproto.pusherr):
115 124 # drain the incoming bundle
116 125 req.drain()
117 126 p.restore()
118 127 rsp = '0\n%s\n' % rsp.res
119 128 req.respond(HTTP_OK, HGTYPE, body=rsp)
120 129 return []
121 130 elif isinstance(rsp, wireproto.ooberror):
122 131 rsp = rsp.message
123 132 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
124 133 return []
@@ -1,131 +1,135 b''
1 1 # sshserver.py - ssh protocol server support for mercurial
2 2 #
3 3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
4 4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
5 5 #
6 6 # This software may be used and distributed according to the terms of the
7 7 # GNU General Public License version 2 or any later version.
8 8
9 9 from __future__ import absolute_import
10 10
11 11 import os
12 12 import sys
13 13
14 14 from .i18n import _
15 15 from . import (
16 16 error,
17 17 hook,
18 18 util,
19 19 wireproto,
20 20 )
21 21
22 22 class sshserver(wireproto.abstractserverproto):
23 23 def __init__(self, ui, repo):
24 24 self.ui = ui
25 25 self.repo = repo
26 26 self.lock = None
27 27 self.fin = ui.fin
28 28 self.fout = ui.fout
29 29
30 30 hook.redirect(True)
31 31 ui.fout = repo.ui.fout = ui.ferr
32 32
33 33 # Prevent insertion/deletion of CRs
34 34 util.setbinary(self.fin)
35 35 util.setbinary(self.fout)
36 36
37 37 def getargs(self, args):
38 38 data = {}
39 39 keys = args.split()
40 40 for n in xrange(len(keys)):
41 41 argline = self.fin.readline()[:-1]
42 42 arg, l = argline.split()
43 43 if arg not in keys:
44 44 raise error.Abort(_("unexpected parameter %r") % arg)
45 45 if arg == '*':
46 46 star = {}
47 47 for k in xrange(int(l)):
48 48 argline = self.fin.readline()[:-1]
49 49 arg, l = argline.split()
50 50 val = self.fin.read(int(l))
51 51 star[arg] = val
52 52 data['*'] = star
53 53 else:
54 54 val = self.fin.read(int(l))
55 55 data[arg] = val
56 56 return [data[k] for k in keys]
57 57
58 58 def getarg(self, name):
59 59 return self.getargs(name)[0]
60 60
61 61 def getfile(self, fpout):
62 62 self.sendresponse('')
63 63 count = int(self.fin.readline())
64 64 while count:
65 65 fpout.write(self.fin.read(count))
66 66 count = int(self.fin.readline())
67 67
68 68 def redirect(self):
69 69 pass
70 70
71 71 def groupchunks(self, fh):
72 72 return iter(lambda: fh.read(4096), '')
73 73
74 def compresschunks(self, chunks):
75 for chunk in chunks:
76 yield chunk
77
74 78 def sendresponse(self, v):
75 79 self.fout.write("%d\n" % len(v))
76 80 self.fout.write(v)
77 81 self.fout.flush()
78 82
79 83 def sendstream(self, source):
80 84 write = self.fout.write
81 85 for chunk in source.gen:
82 86 write(chunk)
83 87 self.fout.flush()
84 88
85 89 def sendpushresponse(self, rsp):
86 90 self.sendresponse('')
87 91 self.sendresponse(str(rsp.res))
88 92
89 93 def sendpusherror(self, rsp):
90 94 self.sendresponse(rsp.res)
91 95
92 96 def sendooberror(self, rsp):
93 97 self.ui.ferr.write('%s\n-\n' % rsp.message)
94 98 self.ui.ferr.flush()
95 99 self.fout.write('\n')
96 100 self.fout.flush()
97 101
98 102 def serve_forever(self):
99 103 try:
100 104 while self.serve_one():
101 105 pass
102 106 finally:
103 107 if self.lock is not None:
104 108 self.lock.release()
105 109 sys.exit(0)
106 110
107 111 handlers = {
108 112 str: sendresponse,
109 113 wireproto.streamres: sendstream,
110 114 wireproto.pushres: sendpushresponse,
111 115 wireproto.pusherr: sendpusherror,
112 116 wireproto.ooberror: sendooberror,
113 117 }
114 118
115 119 def serve_one(self):
116 120 cmd = self.fin.readline()[:-1]
117 121 if cmd and cmd in wireproto.commands:
118 122 rsp = wireproto.dispatch(self.repo, self, cmd)
119 123 self.handlers[rsp.__class__](self, rsp)
120 124 elif cmd:
121 125 impl = getattr(self, 'do_' + cmd, None)
122 126 if impl:
123 127 r = impl()
124 128 if r is not None:
125 129 self.sendresponse(r)
126 130 else: self.sendresponse("")
127 131 return cmd != ''
128 132
129 133 def _client(self):
130 134 client = os.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
131 135 return 'remote:ssh:' + client
@@ -1,959 +1,965 b''
1 1 # wireproto.py - generic wire protocol support functions
2 2 #
3 3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import hashlib
11 11 import itertools
12 12 import os
13 13 import sys
14 14 import tempfile
15 15
16 16 from .i18n import _
17 17 from .node import (
18 18 bin,
19 19 hex,
20 20 )
21 21
22 22 from . import (
23 23 bundle2,
24 24 changegroup as changegroupmod,
25 25 encoding,
26 26 error,
27 27 exchange,
28 28 peer,
29 29 pushkey as pushkeymod,
30 30 streamclone,
31 31 util,
32 32 )
33 33
34 34 urlerr = util.urlerr
35 35 urlreq = util.urlreq
36 36
37 37 bundle2required = _(
38 38 'incompatible Mercurial client; bundle2 required\n'
39 39 '(see https://www.mercurial-scm.org/wiki/IncompatibleClient)\n')
40 40
41 41 class abstractserverproto(object):
42 42 """abstract class that summarizes the protocol API
43 43
44 44 Used as reference and documentation.
45 45 """
46 46
47 47 def getargs(self, args):
48 48 """return the value for arguments in <args>
49 49
50 50 returns a list of values (same order as <args>)"""
51 51 raise NotImplementedError()
52 52
53 53 def getfile(self, fp):
54 54 """write the whole content of a file into a file like object
55 55
56 56 The file is in the form::
57 57
58 58 (<chunk-size>\n<chunk>)+0\n
59 59
60 60 chunk size is the ascii version of the int.
61 61 """
62 62 raise NotImplementedError()
63 63
64 64 def redirect(self):
65 65 """may setup interception for stdout and stderr
66 66
67 67 See also the `restore` method."""
68 68 raise NotImplementedError()
69 69
70 70 # If the `redirect` function does install interception, the `restore`
71 71 # function MUST be defined. If interception is not used, this function
72 72 # MUST NOT be defined.
73 73 #
74 74 # left commented here on purpose
75 75 #
76 76 #def restore(self):
77 77 # """reinstall previous stdout and stderr and return intercepted stdout
78 78 # """
79 79 # raise NotImplementedError()
80 80
81 81 def groupchunks(self, fh):
82 82 """Generator of chunks to send to the client.
83 83
84 84 Some protocols may have compressed the contents.
85 85 """
86 86 raise NotImplementedError()
87 87
88 def compresschunks(self, chunks):
89 """Generator of possible compressed chunks to send to the client.
90
91 This is like ``groupchunks()`` except it accepts a generator as
92 its argument.
93 """
94 raise NotImplementedError()
95
88 96 class remotebatch(peer.batcher):
89 97 '''batches the queued calls; uses as few roundtrips as possible'''
90 98 def __init__(self, remote):
91 99 '''remote must support _submitbatch(encbatch) and
92 100 _submitone(op, encargs)'''
93 101 peer.batcher.__init__(self)
94 102 self.remote = remote
95 103 def submit(self):
96 104 req, rsp = [], []
97 105 for name, args, opts, resref in self.calls:
98 106 mtd = getattr(self.remote, name)
99 107 batchablefn = getattr(mtd, 'batchable', None)
100 108 if batchablefn is not None:
101 109 batchable = batchablefn(mtd.im_self, *args, **opts)
102 110 encargsorres, encresref = next(batchable)
103 111 if encresref:
104 112 req.append((name, encargsorres,))
105 113 rsp.append((batchable, encresref, resref,))
106 114 else:
107 115 resref.set(encargsorres)
108 116 else:
109 117 if req:
110 118 self._submitreq(req, rsp)
111 119 req, rsp = [], []
112 120 resref.set(mtd(*args, **opts))
113 121 if req:
114 122 self._submitreq(req, rsp)
115 123 def _submitreq(self, req, rsp):
116 124 encresults = self.remote._submitbatch(req)
117 125 for encres, r in zip(encresults, rsp):
118 126 batchable, encresref, resref = r
119 127 encresref.set(encres)
120 128 resref.set(next(batchable))
121 129
122 130 class remoteiterbatcher(peer.iterbatcher):
123 131 def __init__(self, remote):
124 132 super(remoteiterbatcher, self).__init__()
125 133 self._remote = remote
126 134
127 135 def __getattr__(self, name):
128 136 if not getattr(self._remote, name, False):
129 137 raise AttributeError(
130 138 'Attempted to iterbatch non-batchable call to %r' % name)
131 139 return super(remoteiterbatcher, self).__getattr__(name)
132 140
133 141 def submit(self):
134 142 """Break the batch request into many patch calls and pipeline them.
135 143
136 144 This is mostly valuable over http where request sizes can be
137 145 limited, but can be used in other places as well.
138 146 """
139 147 req, rsp = [], []
140 148 for name, args, opts, resref in self.calls:
141 149 mtd = getattr(self._remote, name)
142 150 batchable = mtd.batchable(mtd.im_self, *args, **opts)
143 151 encargsorres, encresref = next(batchable)
144 152 assert encresref
145 153 req.append((name, encargsorres))
146 154 rsp.append((batchable, encresref))
147 155 if req:
148 156 self._resultiter = self._remote._submitbatch(req)
149 157 self._rsp = rsp
150 158
151 159 def results(self):
152 160 for (batchable, encresref), encres in itertools.izip(
153 161 self._rsp, self._resultiter):
154 162 encresref.set(encres)
155 163 yield next(batchable)
156 164
157 165 # Forward a couple of names from peer to make wireproto interactions
158 166 # slightly more sensible.
159 167 batchable = peer.batchable
160 168 future = peer.future
161 169
162 170 # list of nodes encoding / decoding
163 171
164 172 def decodelist(l, sep=' '):
165 173 if l:
166 174 return map(bin, l.split(sep))
167 175 return []
168 176
169 177 def encodelist(l, sep=' '):
170 178 try:
171 179 return sep.join(map(hex, l))
172 180 except TypeError:
173 181 raise
174 182
175 183 # batched call argument encoding
176 184
177 185 def escapearg(plain):
178 186 return (plain
179 187 .replace(':', ':c')
180 188 .replace(',', ':o')
181 189 .replace(';', ':s')
182 190 .replace('=', ':e'))
183 191
184 192 def unescapearg(escaped):
185 193 return (escaped
186 194 .replace(':e', '=')
187 195 .replace(':s', ';')
188 196 .replace(':o', ',')
189 197 .replace(':c', ':'))
190 198
191 199 def encodebatchcmds(req):
192 200 """Return a ``cmds`` argument value for the ``batch`` command."""
193 201 cmds = []
194 202 for op, argsdict in req:
195 203 # Old servers didn't properly unescape argument names. So prevent
196 204 # the sending of argument names that may not be decoded properly by
197 205 # servers.
198 206 assert all(escapearg(k) == k for k in argsdict)
199 207
200 208 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
201 209 for k, v in argsdict.iteritems())
202 210 cmds.append('%s %s' % (op, args))
203 211
204 212 return ';'.join(cmds)
205 213
206 214 # mapping of options accepted by getbundle and their types
207 215 #
208 216 # Meant to be extended by extensions. It is extensions responsibility to ensure
209 217 # such options are properly processed in exchange.getbundle.
210 218 #
211 219 # supported types are:
212 220 #
213 221 # :nodes: list of binary nodes
214 222 # :csv: list of comma-separated values
215 223 # :scsv: list of comma-separated values return as set
216 224 # :plain: string with no transformation needed.
217 225 gboptsmap = {'heads': 'nodes',
218 226 'common': 'nodes',
219 227 'obsmarkers': 'boolean',
220 228 'bundlecaps': 'scsv',
221 229 'listkeys': 'csv',
222 230 'cg': 'boolean',
223 231 'cbattempted': 'boolean'}
224 232
225 233 # client side
226 234
227 235 class wirepeer(peer.peerrepository):
228 236 """Client-side interface for communicating with a peer repository.
229 237
230 238 Methods commonly call wire protocol commands of the same name.
231 239
232 240 See also httppeer.py and sshpeer.py for protocol-specific
233 241 implementations of this interface.
234 242 """
235 243 def batch(self):
236 244 if self.capable('batch'):
237 245 return remotebatch(self)
238 246 else:
239 247 return peer.localbatch(self)
240 248 def _submitbatch(self, req):
241 249 """run batch request <req> on the server
242 250
243 251 Returns an iterator of the raw responses from the server.
244 252 """
245 253 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
246 254 chunk = rsp.read(1024)
247 255 work = [chunk]
248 256 while chunk:
249 257 while ';' not in chunk and chunk:
250 258 chunk = rsp.read(1024)
251 259 work.append(chunk)
252 260 merged = ''.join(work)
253 261 while ';' in merged:
254 262 one, merged = merged.split(';', 1)
255 263 yield unescapearg(one)
256 264 chunk = rsp.read(1024)
257 265 work = [merged, chunk]
258 266 yield unescapearg(''.join(work))
259 267
260 268 def _submitone(self, op, args):
261 269 return self._call(op, **args)
262 270
263 271 def iterbatch(self):
264 272 return remoteiterbatcher(self)
265 273
266 274 @batchable
267 275 def lookup(self, key):
268 276 self.requirecap('lookup', _('look up remote revision'))
269 277 f = future()
270 278 yield {'key': encoding.fromlocal(key)}, f
271 279 d = f.value
272 280 success, data = d[:-1].split(" ", 1)
273 281 if int(success):
274 282 yield bin(data)
275 283 self._abort(error.RepoError(data))
276 284
277 285 @batchable
278 286 def heads(self):
279 287 f = future()
280 288 yield {}, f
281 289 d = f.value
282 290 try:
283 291 yield decodelist(d[:-1])
284 292 except ValueError:
285 293 self._abort(error.ResponseError(_("unexpected response:"), d))
286 294
287 295 @batchable
288 296 def known(self, nodes):
289 297 f = future()
290 298 yield {'nodes': encodelist(nodes)}, f
291 299 d = f.value
292 300 try:
293 301 yield [bool(int(b)) for b in d]
294 302 except ValueError:
295 303 self._abort(error.ResponseError(_("unexpected response:"), d))
296 304
297 305 @batchable
298 306 def branchmap(self):
299 307 f = future()
300 308 yield {}, f
301 309 d = f.value
302 310 try:
303 311 branchmap = {}
304 312 for branchpart in d.splitlines():
305 313 branchname, branchheads = branchpart.split(' ', 1)
306 314 branchname = encoding.tolocal(urlreq.unquote(branchname))
307 315 branchheads = decodelist(branchheads)
308 316 branchmap[branchname] = branchheads
309 317 yield branchmap
310 318 except TypeError:
311 319 self._abort(error.ResponseError(_("unexpected response:"), d))
312 320
313 321 def branches(self, nodes):
314 322 n = encodelist(nodes)
315 323 d = self._call("branches", nodes=n)
316 324 try:
317 325 br = [tuple(decodelist(b)) for b in d.splitlines()]
318 326 return br
319 327 except ValueError:
320 328 self._abort(error.ResponseError(_("unexpected response:"), d))
321 329
322 330 def between(self, pairs):
323 331 batch = 8 # avoid giant requests
324 332 r = []
325 333 for i in xrange(0, len(pairs), batch):
326 334 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
327 335 d = self._call("between", pairs=n)
328 336 try:
329 337 r.extend(l and decodelist(l) or [] for l in d.splitlines())
330 338 except ValueError:
331 339 self._abort(error.ResponseError(_("unexpected response:"), d))
332 340 return r
333 341
334 342 @batchable
335 343 def pushkey(self, namespace, key, old, new):
336 344 if not self.capable('pushkey'):
337 345 yield False, None
338 346 f = future()
339 347 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
340 348 yield {'namespace': encoding.fromlocal(namespace),
341 349 'key': encoding.fromlocal(key),
342 350 'old': encoding.fromlocal(old),
343 351 'new': encoding.fromlocal(new)}, f
344 352 d = f.value
345 353 d, output = d.split('\n', 1)
346 354 try:
347 355 d = bool(int(d))
348 356 except ValueError:
349 357 raise error.ResponseError(
350 358 _('push failed (unexpected response):'), d)
351 359 for l in output.splitlines(True):
352 360 self.ui.status(_('remote: '), l)
353 361 yield d
354 362
355 363 @batchable
356 364 def listkeys(self, namespace):
357 365 if not self.capable('pushkey'):
358 366 yield {}, None
359 367 f = future()
360 368 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
361 369 yield {'namespace': encoding.fromlocal(namespace)}, f
362 370 d = f.value
363 371 self.ui.debug('received listkey for "%s": %i bytes\n'
364 372 % (namespace, len(d)))
365 373 yield pushkeymod.decodekeys(d)
366 374
367 375 def stream_out(self):
368 376 return self._callstream('stream_out')
369 377
370 378 def changegroup(self, nodes, kind):
371 379 n = encodelist(nodes)
372 380 f = self._callcompressable("changegroup", roots=n)
373 381 return changegroupmod.cg1unpacker(f, 'UN')
374 382
375 383 def changegroupsubset(self, bases, heads, kind):
376 384 self.requirecap('changegroupsubset', _('look up remote changes'))
377 385 bases = encodelist(bases)
378 386 heads = encodelist(heads)
379 387 f = self._callcompressable("changegroupsubset",
380 388 bases=bases, heads=heads)
381 389 return changegroupmod.cg1unpacker(f, 'UN')
382 390
383 391 def getbundle(self, source, **kwargs):
384 392 self.requirecap('getbundle', _('look up remote changes'))
385 393 opts = {}
386 394 bundlecaps = kwargs.get('bundlecaps')
387 395 if bundlecaps is not None:
388 396 kwargs['bundlecaps'] = sorted(bundlecaps)
389 397 else:
390 398 bundlecaps = () # kwargs could have it to None
391 399 for key, value in kwargs.iteritems():
392 400 if value is None:
393 401 continue
394 402 keytype = gboptsmap.get(key)
395 403 if keytype is None:
396 404 assert False, 'unexpected'
397 405 elif keytype == 'nodes':
398 406 value = encodelist(value)
399 407 elif keytype in ('csv', 'scsv'):
400 408 value = ','.join(value)
401 409 elif keytype == 'boolean':
402 410 value = '%i' % bool(value)
403 411 elif keytype != 'plain':
404 412 raise KeyError('unknown getbundle option type %s'
405 413 % keytype)
406 414 opts[key] = value
407 415 f = self._callcompressable("getbundle", **opts)
408 416 if any((cap.startswith('HG2') for cap in bundlecaps)):
409 417 return bundle2.getunbundler(self.ui, f)
410 418 else:
411 419 return changegroupmod.cg1unpacker(f, 'UN')
412 420
413 421 def unbundle(self, cg, heads, url):
414 422 '''Send cg (a readable file-like object representing the
415 423 changegroup to push, typically a chunkbuffer object) to the
416 424 remote server as a bundle.
417 425
418 426 When pushing a bundle10 stream, return an integer indicating the
419 427 result of the push (see localrepository.addchangegroup()).
420 428
421 429 When pushing a bundle20 stream, return a bundle20 stream.
422 430
423 431 `url` is the url the client thinks it's pushing to, which is
424 432 visible to hooks.
425 433 '''
426 434
427 435 if heads != ['force'] and self.capable('unbundlehash'):
428 436 heads = encodelist(['hashed',
429 437 hashlib.sha1(''.join(sorted(heads))).digest()])
430 438 else:
431 439 heads = encodelist(heads)
432 440
433 441 if util.safehasattr(cg, 'deltaheader'):
434 442 # this a bundle10, do the old style call sequence
435 443 ret, output = self._callpush("unbundle", cg, heads=heads)
436 444 if ret == "":
437 445 raise error.ResponseError(
438 446 _('push failed:'), output)
439 447 try:
440 448 ret = int(ret)
441 449 except ValueError:
442 450 raise error.ResponseError(
443 451 _('push failed (unexpected response):'), ret)
444 452
445 453 for l in output.splitlines(True):
446 454 self.ui.status(_('remote: '), l)
447 455 else:
448 456 # bundle2 push. Send a stream, fetch a stream.
449 457 stream = self._calltwowaystream('unbundle', cg, heads=heads)
450 458 ret = bundle2.getunbundler(self.ui, stream)
451 459 return ret
452 460
453 461 def debugwireargs(self, one, two, three=None, four=None, five=None):
454 462 # don't pass optional arguments left at their default value
455 463 opts = {}
456 464 if three is not None:
457 465 opts['three'] = three
458 466 if four is not None:
459 467 opts['four'] = four
460 468 return self._call('debugwireargs', one=one, two=two, **opts)
461 469
462 470 def _call(self, cmd, **args):
463 471 """execute <cmd> on the server
464 472
465 473 The command is expected to return a simple string.
466 474
467 475 returns the server reply as a string."""
468 476 raise NotImplementedError()
469 477
470 478 def _callstream(self, cmd, **args):
471 479 """execute <cmd> on the server
472 480
473 481 The command is expected to return a stream. Note that if the
474 482 command doesn't return a stream, _callstream behaves
475 483 differently for ssh and http peers.
476 484
477 485 returns the server reply as a file like object.
478 486 """
479 487 raise NotImplementedError()
480 488
481 489 def _callcompressable(self, cmd, **args):
482 490 """execute <cmd> on the server
483 491
484 492 The command is expected to return a stream.
485 493
486 494 The stream may have been compressed in some implementations. This
487 495 function takes care of the decompression. This is the only difference
488 496 with _callstream.
489 497
490 498 returns the server reply as a file like object.
491 499 """
492 500 raise NotImplementedError()
493 501
494 502 def _callpush(self, cmd, fp, **args):
495 503 """execute a <cmd> on server
496 504
497 505 The command is expected to be related to a push. Push has a special
498 506 return method.
499 507
500 508 returns the server reply as a (ret, output) tuple. ret is either
501 509 empty (error) or a stringified int.
502 510 """
503 511 raise NotImplementedError()
504 512
505 513 def _calltwowaystream(self, cmd, fp, **args):
506 514 """execute <cmd> on server
507 515
508 516 The command will send a stream to the server and get a stream in reply.
509 517 """
510 518 raise NotImplementedError()
511 519
512 520 def _abort(self, exception):
513 521 """clearly abort the wire protocol connection and raise the exception
514 522 """
515 523 raise NotImplementedError()
516 524
517 525 # server side
518 526
519 527 # wire protocol command can either return a string or one of these classes.
520 528 class streamres(object):
521 529 """wireproto reply: binary stream
522 530
523 531 The call was successful and the result is a stream.
524 532 Iterate on the `self.gen` attribute to retrieve chunks.
525 533 """
526 534 def __init__(self, gen):
527 535 self.gen = gen
528 536
529 537 class pushres(object):
530 538 """wireproto reply: success with simple integer return
531 539
532 540 The call was successful and returned an integer contained in `self.res`.
533 541 """
534 542 def __init__(self, res):
535 543 self.res = res
536 544
537 545 class pusherr(object):
538 546 """wireproto reply: failure
539 547
540 548 The call failed. The `self.res` attribute contains the error message.
541 549 """
542 550 def __init__(self, res):
543 551 self.res = res
544 552
545 553 class ooberror(object):
546 554 """wireproto reply: failure of a batch of operation
547 555
548 556 Something failed during a batch call. The error message is stored in
549 557 `self.message`.
550 558 """
551 559 def __init__(self, message):
552 560 self.message = message
553 561
554 562 def getdispatchrepo(repo, proto, command):
555 563 """Obtain the repo used for processing wire protocol commands.
556 564
557 565 The intent of this function is to serve as a monkeypatch point for
558 566 extensions that need commands to operate on different repo views under
559 567 specialized circumstances.
560 568 """
561 569 return repo.filtered('served')
562 570
563 571 def dispatch(repo, proto, command):
564 572 repo = getdispatchrepo(repo, proto, command)
565 573 func, spec = commands[command]
566 574 args = proto.getargs(spec)
567 575 return func(repo, proto, *args)
568 576
569 577 def options(cmd, keys, others):
570 578 opts = {}
571 579 for k in keys:
572 580 if k in others:
573 581 opts[k] = others[k]
574 582 del others[k]
575 583 if others:
576 584 sys.stderr.write("warning: %s ignored unexpected arguments %s\n"
577 585 % (cmd, ",".join(others)))
578 586 return opts
579 587
580 588 def bundle1allowed(repo, action):
581 589 """Whether a bundle1 operation is allowed from the server.
582 590
583 591 Priority is:
584 592
585 593 1. server.bundle1gd.<action> (if generaldelta active)
586 594 2. server.bundle1.<action>
587 595 3. server.bundle1gd (if generaldelta active)
588 596 4. server.bundle1
589 597 """
590 598 ui = repo.ui
591 599 gd = 'generaldelta' in repo.requirements
592 600
593 601 if gd:
594 602 v = ui.configbool('server', 'bundle1gd.%s' % action, None)
595 603 if v is not None:
596 604 return v
597 605
598 606 v = ui.configbool('server', 'bundle1.%s' % action, None)
599 607 if v is not None:
600 608 return v
601 609
602 610 if gd:
603 611 v = ui.configbool('server', 'bundle1gd', None)
604 612 if v is not None:
605 613 return v
606 614
607 615 return ui.configbool('server', 'bundle1', True)
608 616
609 617 # list of commands
610 618 commands = {}
611 619
612 620 def wireprotocommand(name, args=''):
613 621 """decorator for wire protocol command"""
614 622 def register(func):
615 623 commands[name] = (func, args)
616 624 return func
617 625 return register
618 626
619 627 @wireprotocommand('batch', 'cmds *')
620 628 def batch(repo, proto, cmds, others):
621 629 repo = repo.filtered("served")
622 630 res = []
623 631 for pair in cmds.split(';'):
624 632 op, args = pair.split(' ', 1)
625 633 vals = {}
626 634 for a in args.split(','):
627 635 if a:
628 636 n, v = a.split('=')
629 637 vals[unescapearg(n)] = unescapearg(v)
630 638 func, spec = commands[op]
631 639 if spec:
632 640 keys = spec.split()
633 641 data = {}
634 642 for k in keys:
635 643 if k == '*':
636 644 star = {}
637 645 for key in vals.keys():
638 646 if key not in keys:
639 647 star[key] = vals[key]
640 648 data['*'] = star
641 649 else:
642 650 data[k] = vals[k]
643 651 result = func(repo, proto, *[data[k] for k in keys])
644 652 else:
645 653 result = func(repo, proto)
646 654 if isinstance(result, ooberror):
647 655 return result
648 656 res.append(escapearg(result))
649 657 return ';'.join(res)
650 658
651 659 @wireprotocommand('between', 'pairs')
652 660 def between(repo, proto, pairs):
653 661 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
654 662 r = []
655 663 for b in repo.between(pairs):
656 664 r.append(encodelist(b) + "\n")
657 665 return "".join(r)
658 666
659 667 @wireprotocommand('branchmap')
660 668 def branchmap(repo, proto):
661 669 branchmap = repo.branchmap()
662 670 heads = []
663 671 for branch, nodes in branchmap.iteritems():
664 672 branchname = urlreq.quote(encoding.fromlocal(branch))
665 673 branchnodes = encodelist(nodes)
666 674 heads.append('%s %s' % (branchname, branchnodes))
667 675 return '\n'.join(heads)
668 676
669 677 @wireprotocommand('branches', 'nodes')
670 678 def branches(repo, proto, nodes):
671 679 nodes = decodelist(nodes)
672 680 r = []
673 681 for b in repo.branches(nodes):
674 682 r.append(encodelist(b) + "\n")
675 683 return "".join(r)
676 684
677 685 @wireprotocommand('clonebundles', '')
678 686 def clonebundles(repo, proto):
679 687 """Server command for returning info for available bundles to seed clones.
680 688
681 689 Clients will parse this response and determine what bundle to fetch.
682 690
683 691 Extensions may wrap this command to filter or dynamically emit data
684 692 depending on the request. e.g. you could advertise URLs for the closest
685 693 data center given the client's IP address.
686 694 """
687 695 return repo.opener.tryread('clonebundles.manifest')
688 696
689 697 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
690 698 'known', 'getbundle', 'unbundlehash', 'batch']
691 699
692 700 def _capabilities(repo, proto):
693 701 """return a list of capabilities for a repo
694 702
695 703 This function exists to allow extensions to easily wrap capabilities
696 704 computation
697 705
698 706 - returns a lists: easy to alter
699 707 - change done here will be propagated to both `capabilities` and `hello`
700 708 command without any other action needed.
701 709 """
702 710 # copy to prevent modification of the global list
703 711 caps = list(wireprotocaps)
704 712 if streamclone.allowservergeneration(repo.ui):
705 713 if repo.ui.configbool('server', 'preferuncompressed', False):
706 714 caps.append('stream-preferred')
707 715 requiredformats = repo.requirements & repo.supportedformats
708 716 # if our local revlogs are just revlogv1, add 'stream' cap
709 717 if not requiredformats - set(('revlogv1',)):
710 718 caps.append('stream')
711 719 # otherwise, add 'streamreqs' detailing our local revlog format
712 720 else:
713 721 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
714 722 if repo.ui.configbool('experimental', 'bundle2-advertise', True):
715 723 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo))
716 724 caps.append('bundle2=' + urlreq.quote(capsblob))
717 725 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
718 726 caps.append(
719 727 'httpheader=%d' % repo.ui.configint('server', 'maxhttpheaderlen', 1024))
720 728 if repo.ui.configbool('experimental', 'httppostargs', False):
721 729 caps.append('httppostargs')
722 730 return caps
723 731
724 732 # If you are writing an extension and consider wrapping this function. Wrap
725 733 # `_capabilities` instead.
726 734 @wireprotocommand('capabilities')
727 735 def capabilities(repo, proto):
728 736 return ' '.join(_capabilities(repo, proto))
729 737
730 738 @wireprotocommand('changegroup', 'roots')
731 739 def changegroup(repo, proto, roots):
732 740 nodes = decodelist(roots)
733 741 cg = changegroupmod.changegroup(repo, nodes, 'serve')
734 742 return streamres(proto.groupchunks(cg))
735 743
736 744 @wireprotocommand('changegroupsubset', 'bases heads')
737 745 def changegroupsubset(repo, proto, bases, heads):
738 746 bases = decodelist(bases)
739 747 heads = decodelist(heads)
740 748 cg = changegroupmod.changegroupsubset(repo, bases, heads, 'serve')
741 749 return streamres(proto.groupchunks(cg))
742 750
743 751 @wireprotocommand('debugwireargs', 'one two *')
744 752 def debugwireargs(repo, proto, one, two, others):
745 753 # only accept optional args from the known set
746 754 opts = options('debugwireargs', ['three', 'four'], others)
747 755 return repo.debugwireargs(one, two, **opts)
748 756
749 757 @wireprotocommand('getbundle', '*')
750 758 def getbundle(repo, proto, others):
751 759 opts = options('getbundle', gboptsmap.keys(), others)
752 760 for k, v in opts.iteritems():
753 761 keytype = gboptsmap[k]
754 762 if keytype == 'nodes':
755 763 opts[k] = decodelist(v)
756 764 elif keytype == 'csv':
757 765 opts[k] = list(v.split(','))
758 766 elif keytype == 'scsv':
759 767 opts[k] = set(v.split(','))
760 768 elif keytype == 'boolean':
761 769 # Client should serialize False as '0', which is a non-empty string
762 770 # so it evaluates as a True bool.
763 771 if v == '0':
764 772 opts[k] = False
765 773 else:
766 774 opts[k] = bool(v)
767 775 elif keytype != 'plain':
768 776 raise KeyError('unknown getbundle option type %s'
769 777 % keytype)
770 778
771 779 if not bundle1allowed(repo, 'pull'):
772 780 if not exchange.bundle2requested(opts.get('bundlecaps')):
773 781 return ooberror(bundle2required)
774 782
775 783 chunks = exchange.getbundlechunks(repo, 'serve', **opts)
776 # TODO avoid util.chunkbuffer() here since it is adding overhead to
777 # what is fundamentally a generator proxying operation.
778 return streamres(proto.groupchunks(util.chunkbuffer(chunks)))
784 return streamres(proto.compresschunks(chunks))
779 785
780 786 @wireprotocommand('heads')
781 787 def heads(repo, proto):
782 788 h = repo.heads()
783 789 return encodelist(h) + "\n"
784 790
785 791 @wireprotocommand('hello')
786 792 def hello(repo, proto):
787 793 '''the hello command returns a set of lines describing various
788 794 interesting things about the server, in an RFC822-like format.
789 795 Currently the only one defined is "capabilities", which
790 796 consists of a line in the form:
791 797
792 798 capabilities: space separated list of tokens
793 799 '''
794 800 return "capabilities: %s\n" % (capabilities(repo, proto))
795 801
796 802 @wireprotocommand('listkeys', 'namespace')
797 803 def listkeys(repo, proto, namespace):
798 804 d = repo.listkeys(encoding.tolocal(namespace)).items()
799 805 return pushkeymod.encodekeys(d)
800 806
801 807 @wireprotocommand('lookup', 'key')
802 808 def lookup(repo, proto, key):
803 809 try:
804 810 k = encoding.tolocal(key)
805 811 c = repo[k]
806 812 r = c.hex()
807 813 success = 1
808 814 except Exception as inst:
809 815 r = str(inst)
810 816 success = 0
811 817 return "%s %s\n" % (success, r)
812 818
813 819 @wireprotocommand('known', 'nodes *')
814 820 def known(repo, proto, nodes, others):
815 821 return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
816 822
817 823 @wireprotocommand('pushkey', 'namespace key old new')
818 824 def pushkey(repo, proto, namespace, key, old, new):
819 825 # compatibility with pre-1.8 clients which were accidentally
820 826 # sending raw binary nodes rather than utf-8-encoded hex
821 827 if len(new) == 20 and new.encode('string-escape') != new:
822 828 # looks like it could be a binary node
823 829 try:
824 830 new.decode('utf-8')
825 831 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
826 832 except UnicodeDecodeError:
827 833 pass # binary, leave unmodified
828 834 else:
829 835 new = encoding.tolocal(new) # normal path
830 836
831 837 if util.safehasattr(proto, 'restore'):
832 838
833 839 proto.redirect()
834 840
835 841 try:
836 842 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
837 843 encoding.tolocal(old), new) or False
838 844 except error.Abort:
839 845 r = False
840 846
841 847 output = proto.restore()
842 848
843 849 return '%s\n%s' % (int(r), output)
844 850
845 851 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
846 852 encoding.tolocal(old), new)
847 853 return '%s\n' % int(r)
848 854
849 855 @wireprotocommand('stream_out')
850 856 def stream(repo, proto):
851 857 '''If the server supports streaming clone, it advertises the "stream"
852 858 capability with a value representing the version and flags of the repo
853 859 it is serving. Client checks to see if it understands the format.
854 860 '''
855 861 if not streamclone.allowservergeneration(repo.ui):
856 862 return '1\n'
857 863
858 864 def getstream(it):
859 865 yield '0\n'
860 866 for chunk in it:
861 867 yield chunk
862 868
863 869 try:
864 870 # LockError may be raised before the first result is yielded. Don't
865 871 # emit output until we're sure we got the lock successfully.
866 872 it = streamclone.generatev1wireproto(repo)
867 873 return streamres(getstream(it))
868 874 except error.LockError:
869 875 return '2\n'
870 876
871 877 @wireprotocommand('unbundle', 'heads')
872 878 def unbundle(repo, proto, heads):
873 879 their_heads = decodelist(heads)
874 880
875 881 try:
876 882 proto.redirect()
877 883
878 884 exchange.check_heads(repo, their_heads, 'preparing changes')
879 885
880 886 # write bundle data to temporary file because it can be big
881 887 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
882 888 fp = os.fdopen(fd, 'wb+')
883 889 r = 0
884 890 try:
885 891 proto.getfile(fp)
886 892 fp.seek(0)
887 893 gen = exchange.readbundle(repo.ui, fp, None)
888 894 if (isinstance(gen, changegroupmod.cg1unpacker)
889 895 and not bundle1allowed(repo, 'push')):
890 896 return ooberror(bundle2required)
891 897
892 898 r = exchange.unbundle(repo, gen, their_heads, 'serve',
893 899 proto._client())
894 900 if util.safehasattr(r, 'addpart'):
895 901 # The return looks streamable, we are in the bundle2 case and
896 902 # should return a stream.
897 903 return streamres(r.getchunks())
898 904 return pushres(r)
899 905
900 906 finally:
901 907 fp.close()
902 908 os.unlink(tempname)
903 909
904 910 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
905 911 # handle non-bundle2 case first
906 912 if not getattr(exc, 'duringunbundle2', False):
907 913 try:
908 914 raise
909 915 except error.Abort:
910 916 # The old code we moved used sys.stderr directly.
911 917 # We did not change it to minimise code change.
912 918 # This need to be moved to something proper.
913 919 # Feel free to do it.
914 920 sys.stderr.write("abort: %s\n" % exc)
915 921 return pushres(0)
916 922 except error.PushRaced:
917 923 return pusherr(str(exc))
918 924
919 925 bundler = bundle2.bundle20(repo.ui)
920 926 for out in getattr(exc, '_bundle2salvagedoutput', ()):
921 927 bundler.addpart(out)
922 928 try:
923 929 try:
924 930 raise
925 931 except error.PushkeyFailed as exc:
926 932 # check client caps
927 933 remotecaps = getattr(exc, '_replycaps', None)
928 934 if (remotecaps is not None
929 935 and 'pushkey' not in remotecaps.get('error', ())):
930 936 # no support remote side, fallback to Abort handler.
931 937 raise
932 938 part = bundler.newpart('error:pushkey')
933 939 part.addparam('in-reply-to', exc.partid)
934 940 if exc.namespace is not None:
935 941 part.addparam('namespace', exc.namespace, mandatory=False)
936 942 if exc.key is not None:
937 943 part.addparam('key', exc.key, mandatory=False)
938 944 if exc.new is not None:
939 945 part.addparam('new', exc.new, mandatory=False)
940 946 if exc.old is not None:
941 947 part.addparam('old', exc.old, mandatory=False)
942 948 if exc.ret is not None:
943 949 part.addparam('ret', exc.ret, mandatory=False)
944 950 except error.BundleValueError as exc:
945 951 errpart = bundler.newpart('error:unsupportedcontent')
946 952 if exc.parttype is not None:
947 953 errpart.addparam('parttype', exc.parttype)
948 954 if exc.params:
949 955 errpart.addparam('params', '\0'.join(exc.params))
950 956 except error.Abort as exc:
951 957 manargs = [('message', str(exc))]
952 958 advargs = []
953 959 if exc.hint is not None:
954 960 advargs.append(('hint', exc.hint))
955 961 bundler.addpart(bundle2.bundlepart('error:abort',
956 962 manargs, advargs))
957 963 except error.PushRaced as exc:
958 964 bundler.newpart('error:pushraced', [('message', str(exc))])
959 965 return streamres(bundler.getchunks())
General Comments 0
You need to be logged in to leave comments. Login now