##// END OF EJS Templates
peer: remove non iterating batcher (API)...
Gregory Szorc -
r33762:b47fe973 default
parent child Browse files
Show More
@@ -1,153 +1,140 b''
1 1 # peer.py - repository base classes for mercurial
2 2 #
3 3 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
4 4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
5 5 #
6 6 # This software may be used and distributed according to the terms of the
7 7 # GNU General Public License version 2 or any later version.
8 8
9 9 from __future__ import absolute_import
10 10
11 11 from .i18n import _
12 12 from . import (
13 13 error,
14 14 util,
15 15 )
16 16
17 17 # abstract batching support
18 18
19 19 class future(object):
20 20 '''placeholder for a value to be set later'''
21 21 def set(self, value):
22 22 if util.safehasattr(self, 'value'):
23 23 raise error.RepoError("future is already set")
24 24 self.value = value
25 25
26 26 class batcher(object):
27 27 '''base class for batches of commands submittable in a single request
28 28
29 29 All methods invoked on instances of this class are simply queued and
30 30 return a a future for the result. Once you call submit(), all the queued
31 31 calls are performed and the results set in their respective futures.
32 32 '''
33 33 def __init__(self):
34 34 self.calls = []
35 35 def __getattr__(self, name):
36 36 def call(*args, **opts):
37 37 resref = future()
38 38 self.calls.append((name, args, opts, resref,))
39 39 return resref
40 40 return call
41 41 def submit(self):
42 42 raise NotImplementedError()
43 43
44 44 class iterbatcher(batcher):
45 45
46 46 def submit(self):
47 47 raise NotImplementedError()
48 48
49 49 def results(self):
50 50 raise NotImplementedError()
51 51
52 class localbatch(batcher):
53 '''performs the queued calls directly'''
54 def __init__(self, local):
55 batcher.__init__(self)
56 self.local = local
57 def submit(self):
58 for name, args, opts, resref in self.calls:
59 resref.set(getattr(self.local, name)(*args, **opts))
60
61 52 class localiterbatcher(iterbatcher):
62 53 def __init__(self, local):
63 54 super(iterbatcher, self).__init__()
64 55 self.local = local
65 56
66 57 def submit(self):
67 58 # submit for a local iter batcher is a noop
68 59 pass
69 60
70 61 def results(self):
71 62 for name, args, opts, resref in self.calls:
72 63 resref.set(getattr(self.local, name)(*args, **opts))
73 64 yield resref.value
74 65
75 66 def batchable(f):
76 67 '''annotation for batchable methods
77 68
78 69 Such methods must implement a coroutine as follows:
79 70
80 71 @batchable
81 72 def sample(self, one, two=None):
82 73 # Build list of encoded arguments suitable for your wire protocol:
83 74 encargs = [('one', encode(one),), ('two', encode(two),)]
84 75 # Create future for injection of encoded result:
85 76 encresref = future()
86 77 # Return encoded arguments and future:
87 78 yield encargs, encresref
88 79 # Assuming the future to be filled with the result from the batched
89 80 # request now. Decode it:
90 81 yield decode(encresref.value)
91 82
92 83 The decorator returns a function which wraps this coroutine as a plain
93 84 method, but adds the original method as an attribute called "batchable",
94 85 which is used by remotebatch to split the call into separate encoding and
95 86 decoding phases.
96 87 '''
97 88 def plain(*args, **opts):
98 89 batchable = f(*args, **opts)
99 90 encargsorres, encresref = next(batchable)
100 91 if not encresref:
101 92 return encargsorres # a local result in this case
102 93 self = args[0]
103 94 encresref.set(self._submitone(f.func_name, encargsorres))
104 95 return next(batchable)
105 96 setattr(plain, 'batchable', f)
106 97 return plain
107 98
108 99 class peerrepository(object):
109
110 def batch(self):
111 return localbatch(self)
112
113 100 def iterbatch(self):
114 101 """Batch requests but allow iterating over the results.
115 102
116 103 This is to allow interleaving responses with things like
117 104 progress updates for clients.
118 105 """
119 106 return localiterbatcher(self)
120 107
121 108 def capable(self, name):
122 109 '''tell whether repo supports named capability.
123 110 return False if not supported.
124 111 if boolean capability, return True.
125 112 if string capability, return string.'''
126 113 caps = self._capabilities()
127 114 if name in caps:
128 115 return True
129 116 name_eq = name + '='
130 117 for cap in caps:
131 118 if cap.startswith(name_eq):
132 119 return cap[len(name_eq):]
133 120 return False
134 121
135 122 def requirecap(self, name, purpose):
136 123 '''raise an exception if the given capability is not present'''
137 124 if not self.capable(name):
138 125 raise error.CapabilityError(
139 126 _('cannot %s; remote repository does not '
140 127 'support the %r capability') % (purpose, name))
141 128
142 129 def local(self):
143 130 '''return peer as a localrepo, or None'''
144 131 return None
145 132
146 133 def peer(self):
147 134 return self
148 135
149 136 def canpush(self):
150 137 return True
151 138
152 139 def close(self):
153 140 pass
@@ -1,1088 +1,1050 b''
1 1 # wireproto.py - generic wire protocol support functions
2 2 #
3 3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import hashlib
11 import itertools
12 11 import os
13 12 import tempfile
14 13
15 14 from .i18n import _
16 15 from .node import (
17 16 bin,
18 17 hex,
19 18 nullid,
20 19 )
21 20
22 21 from . import (
23 22 bundle2,
24 23 changegroup as changegroupmod,
25 24 encoding,
26 25 error,
27 26 exchange,
28 27 peer,
29 28 pushkey as pushkeymod,
30 29 pycompat,
31 30 streamclone,
32 31 util,
33 32 )
34 33
35 34 urlerr = util.urlerr
36 35 urlreq = util.urlreq
37 36
38 37 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
39 38 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
40 39 'IncompatibleClient')
41 40 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
42 41
43 42 class abstractserverproto(object):
44 43 """abstract class that summarizes the protocol API
45 44
46 45 Used as reference and documentation.
47 46 """
48 47
49 48 def getargs(self, args):
50 49 """return the value for arguments in <args>
51 50
52 51 returns a list of values (same order as <args>)"""
53 52 raise NotImplementedError()
54 53
55 54 def getfile(self, fp):
56 55 """write the whole content of a file into a file like object
57 56
58 57 The file is in the form::
59 58
60 59 (<chunk-size>\n<chunk>)+0\n
61 60
62 61 chunk size is the ascii version of the int.
63 62 """
64 63 raise NotImplementedError()
65 64
66 65 def redirect(self):
67 66 """may setup interception for stdout and stderr
68 67
69 68 See also the `restore` method."""
70 69 raise NotImplementedError()
71 70
72 71 # If the `redirect` function does install interception, the `restore`
73 72 # function MUST be defined. If interception is not used, this function
74 73 # MUST NOT be defined.
75 74 #
76 75 # left commented here on purpose
77 76 #
78 77 #def restore(self):
79 78 # """reinstall previous stdout and stderr and return intercepted stdout
80 79 # """
81 80 # raise NotImplementedError()
82 81
83 class remotebatch(peer.batcher):
84 '''batches the queued calls; uses as few roundtrips as possible'''
85 def __init__(self, remote):
86 '''remote must support _submitbatch(encbatch) and
87 _submitone(op, encargs)'''
88 peer.batcher.__init__(self)
89 self.remote = remote
90 def submit(self):
91 req, rsp = [], []
92 for name, args, opts, resref in self.calls:
93 mtd = getattr(self.remote, name)
94 batchablefn = getattr(mtd, 'batchable', None)
95 if batchablefn is not None:
96 batchable = batchablefn(mtd.im_self, *args, **opts)
97 encargsorres, encresref = next(batchable)
98 assert encresref
99 req.append((name, encargsorres,))
100 rsp.append((batchable, encresref, resref,))
101 else:
102 if req:
103 self._submitreq(req, rsp)
104 req, rsp = [], []
105 resref.set(mtd(*args, **opts))
106 if req:
107 self._submitreq(req, rsp)
108 def _submitreq(self, req, rsp):
109 encresults = self.remote._submitbatch(req)
110 for encres, r in zip(encresults, rsp):
111 batchable, encresref, resref = r
112 encresref.set(encres)
113 resref.set(next(batchable))
114
115 82 class remoteiterbatcher(peer.iterbatcher):
116 83 def __init__(self, remote):
117 84 super(remoteiterbatcher, self).__init__()
118 85 self._remote = remote
119 86
120 87 def __getattr__(self, name):
121 88 # Validate this method is batchable, since submit() only supports
122 89 # batchable methods.
123 90 fn = getattr(self._remote, name)
124 91 if not getattr(fn, 'batchable', None):
125 92 raise error.ProgrammingError('Attempted to batch a non-batchable '
126 93 'call to %r' % name)
127 94
128 95 return super(remoteiterbatcher, self).__getattr__(name)
129 96
130 97 def submit(self):
131 98 """Break the batch request into many patch calls and pipeline them.
132 99
133 100 This is mostly valuable over http where request sizes can be
134 101 limited, but can be used in other places as well.
135 102 """
136 103 # 2-tuple of (command, arguments) that represents what will be
137 104 # sent over the wire.
138 105 requests = []
139 106
140 107 # 4-tuple of (command, final future, @batchable generator, remote
141 108 # future).
142 109 results = []
143 110
144 111 for command, args, opts, finalfuture in self.calls:
145 112 mtd = getattr(self._remote, command)
146 113 batchable = mtd.batchable(mtd.im_self, *args, **opts)
147 114
148 115 commandargs, fremote = next(batchable)
149 116 assert fremote
150 117 requests.append((command, commandargs))
151 118 results.append((command, finalfuture, batchable, fremote))
152 119
153 120 if requests:
154 121 self._resultiter = self._remote._submitbatch(requests)
155 122
156 123 self._results = results
157 124
158 125 def results(self):
159 126 for command, finalfuture, batchable, remotefuture in self._results:
160 127 # Get the raw result, set it in the remote future, feed it
161 128 # back into the @batchable generator so it can be decoded, and
162 129 # set the result on the final future to this value.
163 130 remoteresult = next(self._resultiter)
164 131 remotefuture.set(remoteresult)
165 132 finalfuture.set(next(batchable))
166 133
167 134 # Verify our @batchable generators only emit 2 values.
168 135 try:
169 136 next(batchable)
170 137 except StopIteration:
171 138 pass
172 139 else:
173 140 raise error.ProgrammingError('%s @batchable generator emitted '
174 141 'unexpected value count' % command)
175 142
176 143 yield finalfuture.value
177 144
178 145 # Forward a couple of names from peer to make wireproto interactions
179 146 # slightly more sensible.
180 147 batchable = peer.batchable
181 148 future = peer.future
182 149
183 150 # list of nodes encoding / decoding
184 151
185 152 def decodelist(l, sep=' '):
186 153 if l:
187 154 return map(bin, l.split(sep))
188 155 return []
189 156
190 157 def encodelist(l, sep=' '):
191 158 try:
192 159 return sep.join(map(hex, l))
193 160 except TypeError:
194 161 raise
195 162
196 163 # batched call argument encoding
197 164
198 165 def escapearg(plain):
199 166 return (plain
200 167 .replace(':', ':c')
201 168 .replace(',', ':o')
202 169 .replace(';', ':s')
203 170 .replace('=', ':e'))
204 171
205 172 def unescapearg(escaped):
206 173 return (escaped
207 174 .replace(':e', '=')
208 175 .replace(':s', ';')
209 176 .replace(':o', ',')
210 177 .replace(':c', ':'))
211 178
212 179 def encodebatchcmds(req):
213 180 """Return a ``cmds`` argument value for the ``batch`` command."""
214 181 cmds = []
215 182 for op, argsdict in req:
216 183 # Old servers didn't properly unescape argument names. So prevent
217 184 # the sending of argument names that may not be decoded properly by
218 185 # servers.
219 186 assert all(escapearg(k) == k for k in argsdict)
220 187
221 188 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
222 189 for k, v in argsdict.iteritems())
223 190 cmds.append('%s %s' % (op, args))
224 191
225 192 return ';'.join(cmds)
226 193
227 194 # mapping of options accepted by getbundle and their types
228 195 #
229 196 # Meant to be extended by extensions. It is extensions responsibility to ensure
230 197 # such options are properly processed in exchange.getbundle.
231 198 #
232 199 # supported types are:
233 200 #
234 201 # :nodes: list of binary nodes
235 202 # :csv: list of comma-separated values
236 203 # :scsv: list of comma-separated values return as set
237 204 # :plain: string with no transformation needed.
238 205 gboptsmap = {'heads': 'nodes',
239 206 'common': 'nodes',
240 207 'obsmarkers': 'boolean',
241 208 'bundlecaps': 'scsv',
242 209 'listkeys': 'csv',
243 210 'cg': 'boolean',
244 211 'cbattempted': 'boolean'}
245 212
246 213 # client side
247 214
248 215 class wirepeer(peer.peerrepository):
249 216 """Client-side interface for communicating with a peer repository.
250 217
251 218 Methods commonly call wire protocol commands of the same name.
252 219
253 220 See also httppeer.py and sshpeer.py for protocol-specific
254 221 implementations of this interface.
255 222 """
256 def batch(self):
257 if self.capable('batch'):
258 return remotebatch(self)
259 else:
260 return peer.localbatch(self)
261 223 def _submitbatch(self, req):
262 224 """run batch request <req> on the server
263 225
264 226 Returns an iterator of the raw responses from the server.
265 227 """
266 228 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
267 229 chunk = rsp.read(1024)
268 230 work = [chunk]
269 231 while chunk:
270 232 while ';' not in chunk and chunk:
271 233 chunk = rsp.read(1024)
272 234 work.append(chunk)
273 235 merged = ''.join(work)
274 236 while ';' in merged:
275 237 one, merged = merged.split(';', 1)
276 238 yield unescapearg(one)
277 239 chunk = rsp.read(1024)
278 240 work = [merged, chunk]
279 241 yield unescapearg(''.join(work))
280 242
281 243 def _submitone(self, op, args):
282 244 return self._call(op, **args)
283 245
284 246 def iterbatch(self):
285 247 return remoteiterbatcher(self)
286 248
287 249 @batchable
288 250 def lookup(self, key):
289 251 self.requirecap('lookup', _('look up remote revision'))
290 252 f = future()
291 253 yield {'key': encoding.fromlocal(key)}, f
292 254 d = f.value
293 255 success, data = d[:-1].split(" ", 1)
294 256 if int(success):
295 257 yield bin(data)
296 258 self._abort(error.RepoError(data))
297 259
298 260 @batchable
299 261 def heads(self):
300 262 f = future()
301 263 yield {}, f
302 264 d = f.value
303 265 try:
304 266 yield decodelist(d[:-1])
305 267 except ValueError:
306 268 self._abort(error.ResponseError(_("unexpected response:"), d))
307 269
308 270 @batchable
309 271 def known(self, nodes):
310 272 f = future()
311 273 yield {'nodes': encodelist(nodes)}, f
312 274 d = f.value
313 275 try:
314 276 yield [bool(int(b)) for b in d]
315 277 except ValueError:
316 278 self._abort(error.ResponseError(_("unexpected response:"), d))
317 279
318 280 @batchable
319 281 def branchmap(self):
320 282 f = future()
321 283 yield {}, f
322 284 d = f.value
323 285 try:
324 286 branchmap = {}
325 287 for branchpart in d.splitlines():
326 288 branchname, branchheads = branchpart.split(' ', 1)
327 289 branchname = encoding.tolocal(urlreq.unquote(branchname))
328 290 branchheads = decodelist(branchheads)
329 291 branchmap[branchname] = branchheads
330 292 yield branchmap
331 293 except TypeError:
332 294 self._abort(error.ResponseError(_("unexpected response:"), d))
333 295
334 296 def branches(self, nodes):
335 297 n = encodelist(nodes)
336 298 d = self._call("branches", nodes=n)
337 299 try:
338 300 br = [tuple(decodelist(b)) for b in d.splitlines()]
339 301 return br
340 302 except ValueError:
341 303 self._abort(error.ResponseError(_("unexpected response:"), d))
342 304
343 305 def between(self, pairs):
344 306 batch = 8 # avoid giant requests
345 307 r = []
346 308 for i in xrange(0, len(pairs), batch):
347 309 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
348 310 d = self._call("between", pairs=n)
349 311 try:
350 312 r.extend(l and decodelist(l) or [] for l in d.splitlines())
351 313 except ValueError:
352 314 self._abort(error.ResponseError(_("unexpected response:"), d))
353 315 return r
354 316
355 317 @batchable
356 318 def pushkey(self, namespace, key, old, new):
357 319 if not self.capable('pushkey'):
358 320 yield False, None
359 321 f = future()
360 322 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
361 323 yield {'namespace': encoding.fromlocal(namespace),
362 324 'key': encoding.fromlocal(key),
363 325 'old': encoding.fromlocal(old),
364 326 'new': encoding.fromlocal(new)}, f
365 327 d = f.value
366 328 d, output = d.split('\n', 1)
367 329 try:
368 330 d = bool(int(d))
369 331 except ValueError:
370 332 raise error.ResponseError(
371 333 _('push failed (unexpected response):'), d)
372 334 for l in output.splitlines(True):
373 335 self.ui.status(_('remote: '), l)
374 336 yield d
375 337
376 338 @batchable
377 339 def listkeys(self, namespace):
378 340 if not self.capable('pushkey'):
379 341 yield {}, None
380 342 f = future()
381 343 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
382 344 yield {'namespace': encoding.fromlocal(namespace)}, f
383 345 d = f.value
384 346 self.ui.debug('received listkey for "%s": %i bytes\n'
385 347 % (namespace, len(d)))
386 348 yield pushkeymod.decodekeys(d)
387 349
388 350 def stream_out(self):
389 351 return self._callstream('stream_out')
390 352
391 353 def changegroup(self, nodes, kind):
392 354 n = encodelist(nodes)
393 355 f = self._callcompressable("changegroup", roots=n)
394 356 return changegroupmod.cg1unpacker(f, 'UN')
395 357
396 358 def changegroupsubset(self, bases, heads, kind):
397 359 self.requirecap('changegroupsubset', _('look up remote changes'))
398 360 bases = encodelist(bases)
399 361 heads = encodelist(heads)
400 362 f = self._callcompressable("changegroupsubset",
401 363 bases=bases, heads=heads)
402 364 return changegroupmod.cg1unpacker(f, 'UN')
403 365
404 366 def getbundle(self, source, **kwargs):
405 367 self.requirecap('getbundle', _('look up remote changes'))
406 368 opts = {}
407 369 bundlecaps = kwargs.get('bundlecaps')
408 370 if bundlecaps is not None:
409 371 kwargs['bundlecaps'] = sorted(bundlecaps)
410 372 else:
411 373 bundlecaps = () # kwargs could have it to None
412 374 for key, value in kwargs.iteritems():
413 375 if value is None:
414 376 continue
415 377 keytype = gboptsmap.get(key)
416 378 if keytype is None:
417 379 assert False, 'unexpected'
418 380 elif keytype == 'nodes':
419 381 value = encodelist(value)
420 382 elif keytype in ('csv', 'scsv'):
421 383 value = ','.join(value)
422 384 elif keytype == 'boolean':
423 385 value = '%i' % bool(value)
424 386 elif keytype != 'plain':
425 387 raise KeyError('unknown getbundle option type %s'
426 388 % keytype)
427 389 opts[key] = value
428 390 f = self._callcompressable("getbundle", **opts)
429 391 if any((cap.startswith('HG2') for cap in bundlecaps)):
430 392 return bundle2.getunbundler(self.ui, f)
431 393 else:
432 394 return changegroupmod.cg1unpacker(f, 'UN')
433 395
434 396 def unbundle(self, cg, heads, url):
435 397 '''Send cg (a readable file-like object representing the
436 398 changegroup to push, typically a chunkbuffer object) to the
437 399 remote server as a bundle.
438 400
439 401 When pushing a bundle10 stream, return an integer indicating the
440 402 result of the push (see changegroup.apply()).
441 403
442 404 When pushing a bundle20 stream, return a bundle20 stream.
443 405
444 406 `url` is the url the client thinks it's pushing to, which is
445 407 visible to hooks.
446 408 '''
447 409
448 410 if heads != ['force'] and self.capable('unbundlehash'):
449 411 heads = encodelist(['hashed',
450 412 hashlib.sha1(''.join(sorted(heads))).digest()])
451 413 else:
452 414 heads = encodelist(heads)
453 415
454 416 if util.safehasattr(cg, 'deltaheader'):
455 417 # this a bundle10, do the old style call sequence
456 418 ret, output = self._callpush("unbundle", cg, heads=heads)
457 419 if ret == "":
458 420 raise error.ResponseError(
459 421 _('push failed:'), output)
460 422 try:
461 423 ret = int(ret)
462 424 except ValueError:
463 425 raise error.ResponseError(
464 426 _('push failed (unexpected response):'), ret)
465 427
466 428 for l in output.splitlines(True):
467 429 self.ui.status(_('remote: '), l)
468 430 else:
469 431 # bundle2 push. Send a stream, fetch a stream.
470 432 stream = self._calltwowaystream('unbundle', cg, heads=heads)
471 433 ret = bundle2.getunbundler(self.ui, stream)
472 434 return ret
473 435
474 436 def debugwireargs(self, one, two, three=None, four=None, five=None):
475 437 # don't pass optional arguments left at their default value
476 438 opts = {}
477 439 if three is not None:
478 440 opts['three'] = three
479 441 if four is not None:
480 442 opts['four'] = four
481 443 return self._call('debugwireargs', one=one, two=two, **opts)
482 444
483 445 def _call(self, cmd, **args):
484 446 """execute <cmd> on the server
485 447
486 448 The command is expected to return a simple string.
487 449
488 450 returns the server reply as a string."""
489 451 raise NotImplementedError()
490 452
491 453 def _callstream(self, cmd, **args):
492 454 """execute <cmd> on the server
493 455
494 456 The command is expected to return a stream. Note that if the
495 457 command doesn't return a stream, _callstream behaves
496 458 differently for ssh and http peers.
497 459
498 460 returns the server reply as a file like object.
499 461 """
500 462 raise NotImplementedError()
501 463
502 464 def _callcompressable(self, cmd, **args):
503 465 """execute <cmd> on the server
504 466
505 467 The command is expected to return a stream.
506 468
507 469 The stream may have been compressed in some implementations. This
508 470 function takes care of the decompression. This is the only difference
509 471 with _callstream.
510 472
511 473 returns the server reply as a file like object.
512 474 """
513 475 raise NotImplementedError()
514 476
515 477 def _callpush(self, cmd, fp, **args):
516 478 """execute a <cmd> on server
517 479
518 480 The command is expected to be related to a push. Push has a special
519 481 return method.
520 482
521 483 returns the server reply as a (ret, output) tuple. ret is either
522 484 empty (error) or a stringified int.
523 485 """
524 486 raise NotImplementedError()
525 487
526 488 def _calltwowaystream(self, cmd, fp, **args):
527 489 """execute <cmd> on server
528 490
529 491 The command will send a stream to the server and get a stream in reply.
530 492 """
531 493 raise NotImplementedError()
532 494
533 495 def _abort(self, exception):
534 496 """clearly abort the wire protocol connection and raise the exception
535 497 """
536 498 raise NotImplementedError()
537 499
538 500 # server side
539 501
540 502 # wire protocol command can either return a string or one of these classes.
541 503 class streamres(object):
542 504 """wireproto reply: binary stream
543 505
544 506 The call was successful and the result is a stream.
545 507
546 508 Accepts either a generator or an object with a ``read(size)`` method.
547 509
548 510 ``v1compressible`` indicates whether this data can be compressed to
549 511 "version 1" clients (technically: HTTP peers using
550 512 application/mercurial-0.1 media type). This flag should NOT be used on
551 513 new commands because new clients should support a more modern compression
552 514 mechanism.
553 515 """
554 516 def __init__(self, gen=None, reader=None, v1compressible=False):
555 517 self.gen = gen
556 518 self.reader = reader
557 519 self.v1compressible = v1compressible
558 520
559 521 class pushres(object):
560 522 """wireproto reply: success with simple integer return
561 523
562 524 The call was successful and returned an integer contained in `self.res`.
563 525 """
564 526 def __init__(self, res):
565 527 self.res = res
566 528
567 529 class pusherr(object):
568 530 """wireproto reply: failure
569 531
570 532 The call failed. The `self.res` attribute contains the error message.
571 533 """
572 534 def __init__(self, res):
573 535 self.res = res
574 536
575 537 class ooberror(object):
576 538 """wireproto reply: failure of a batch of operation
577 539
578 540 Something failed during a batch call. The error message is stored in
579 541 `self.message`.
580 542 """
581 543 def __init__(self, message):
582 544 self.message = message
583 545
584 546 def getdispatchrepo(repo, proto, command):
585 547 """Obtain the repo used for processing wire protocol commands.
586 548
587 549 The intent of this function is to serve as a monkeypatch point for
588 550 extensions that need commands to operate on different repo views under
589 551 specialized circumstances.
590 552 """
591 553 return repo.filtered('served')
592 554
593 555 def dispatch(repo, proto, command):
594 556 repo = getdispatchrepo(repo, proto, command)
595 557 func, spec = commands[command]
596 558 args = proto.getargs(spec)
597 559 return func(repo, proto, *args)
598 560
599 561 def options(cmd, keys, others):
600 562 opts = {}
601 563 for k in keys:
602 564 if k in others:
603 565 opts[k] = others[k]
604 566 del others[k]
605 567 if others:
606 568 util.stderr.write("warning: %s ignored unexpected arguments %s\n"
607 569 % (cmd, ",".join(others)))
608 570 return opts
609 571
610 572 def bundle1allowed(repo, action):
611 573 """Whether a bundle1 operation is allowed from the server.
612 574
613 575 Priority is:
614 576
615 577 1. server.bundle1gd.<action> (if generaldelta active)
616 578 2. server.bundle1.<action>
617 579 3. server.bundle1gd (if generaldelta active)
618 580 4. server.bundle1
619 581 """
620 582 ui = repo.ui
621 583 gd = 'generaldelta' in repo.requirements
622 584
623 585 if gd:
624 586 v = ui.configbool('server', 'bundle1gd.%s' % action, None)
625 587 if v is not None:
626 588 return v
627 589
628 590 v = ui.configbool('server', 'bundle1.%s' % action, None)
629 591 if v is not None:
630 592 return v
631 593
632 594 if gd:
633 595 v = ui.configbool('server', 'bundle1gd')
634 596 if v is not None:
635 597 return v
636 598
637 599 return ui.configbool('server', 'bundle1')
638 600
639 601 def supportedcompengines(ui, proto, role):
640 602 """Obtain the list of supported compression engines for a request."""
641 603 assert role in (util.CLIENTROLE, util.SERVERROLE)
642 604
643 605 compengines = util.compengines.supportedwireengines(role)
644 606
645 607 # Allow config to override default list and ordering.
646 608 if role == util.SERVERROLE:
647 609 configengines = ui.configlist('server', 'compressionengines')
648 610 config = 'server.compressionengines'
649 611 else:
650 612 # This is currently implemented mainly to facilitate testing. In most
651 613 # cases, the server should be in charge of choosing a compression engine
652 614 # because a server has the most to lose from a sub-optimal choice. (e.g.
653 615 # CPU DoS due to an expensive engine or a network DoS due to poor
654 616 # compression ratio).
655 617 configengines = ui.configlist('experimental',
656 618 'clientcompressionengines')
657 619 config = 'experimental.clientcompressionengines'
658 620
659 621 # No explicit config. Filter out the ones that aren't supposed to be
660 622 # advertised and return default ordering.
661 623 if not configengines:
662 624 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
663 625 return [e for e in compengines
664 626 if getattr(e.wireprotosupport(), attr) > 0]
665 627
666 628 # If compression engines are listed in the config, assume there is a good
667 629 # reason for it (like server operators wanting to achieve specific
668 630 # performance characteristics). So fail fast if the config references
669 631 # unusable compression engines.
670 632 validnames = set(e.name() for e in compengines)
671 633 invalidnames = set(e for e in configengines if e not in validnames)
672 634 if invalidnames:
673 635 raise error.Abort(_('invalid compression engine defined in %s: %s') %
674 636 (config, ', '.join(sorted(invalidnames))))
675 637
676 638 compengines = [e for e in compengines if e.name() in configengines]
677 639 compengines = sorted(compengines,
678 640 key=lambda e: configengines.index(e.name()))
679 641
680 642 if not compengines:
681 643 raise error.Abort(_('%s config option does not specify any known '
682 644 'compression engines') % config,
683 645 hint=_('usable compression engines: %s') %
684 646 ', '.sorted(validnames))
685 647
686 648 return compengines
687 649
688 650 # list of commands
689 651 commands = {}
690 652
691 653 def wireprotocommand(name, args=''):
692 654 """decorator for wire protocol command"""
693 655 def register(func):
694 656 commands[name] = (func, args)
695 657 return func
696 658 return register
697 659
698 660 @wireprotocommand('batch', 'cmds *')
699 661 def batch(repo, proto, cmds, others):
700 662 repo = repo.filtered("served")
701 663 res = []
702 664 for pair in cmds.split(';'):
703 665 op, args = pair.split(' ', 1)
704 666 vals = {}
705 667 for a in args.split(','):
706 668 if a:
707 669 n, v = a.split('=')
708 670 vals[unescapearg(n)] = unescapearg(v)
709 671 func, spec = commands[op]
710 672 if spec:
711 673 keys = spec.split()
712 674 data = {}
713 675 for k in keys:
714 676 if k == '*':
715 677 star = {}
716 678 for key in vals.keys():
717 679 if key not in keys:
718 680 star[key] = vals[key]
719 681 data['*'] = star
720 682 else:
721 683 data[k] = vals[k]
722 684 result = func(repo, proto, *[data[k] for k in keys])
723 685 else:
724 686 result = func(repo, proto)
725 687 if isinstance(result, ooberror):
726 688 return result
727 689 res.append(escapearg(result))
728 690 return ';'.join(res)
729 691
730 692 @wireprotocommand('between', 'pairs')
731 693 def between(repo, proto, pairs):
732 694 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
733 695 r = []
734 696 for b in repo.between(pairs):
735 697 r.append(encodelist(b) + "\n")
736 698 return "".join(r)
737 699
738 700 @wireprotocommand('branchmap')
739 701 def branchmap(repo, proto):
740 702 branchmap = repo.branchmap()
741 703 heads = []
742 704 for branch, nodes in branchmap.iteritems():
743 705 branchname = urlreq.quote(encoding.fromlocal(branch))
744 706 branchnodes = encodelist(nodes)
745 707 heads.append('%s %s' % (branchname, branchnodes))
746 708 return '\n'.join(heads)
747 709
748 710 @wireprotocommand('branches', 'nodes')
749 711 def branches(repo, proto, nodes):
750 712 nodes = decodelist(nodes)
751 713 r = []
752 714 for b in repo.branches(nodes):
753 715 r.append(encodelist(b) + "\n")
754 716 return "".join(r)
755 717
756 718 @wireprotocommand('clonebundles', '')
757 719 def clonebundles(repo, proto):
758 720 """Server command for returning info for available bundles to seed clones.
759 721
760 722 Clients will parse this response and determine what bundle to fetch.
761 723
762 724 Extensions may wrap this command to filter or dynamically emit data
763 725 depending on the request. e.g. you could advertise URLs for the closest
764 726 data center given the client's IP address.
765 727 """
766 728 return repo.vfs.tryread('clonebundles.manifest')
767 729
768 730 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
769 731 'known', 'getbundle', 'unbundlehash', 'batch']
770 732
771 733 def _capabilities(repo, proto):
772 734 """return a list of capabilities for a repo
773 735
774 736 This function exists to allow extensions to easily wrap capabilities
775 737 computation
776 738
777 739 - returns a lists: easy to alter
778 740 - change done here will be propagated to both `capabilities` and `hello`
779 741 command without any other action needed.
780 742 """
781 743 # copy to prevent modification of the global list
782 744 caps = list(wireprotocaps)
783 745 if streamclone.allowservergeneration(repo):
784 746 if repo.ui.configbool('server', 'preferuncompressed'):
785 747 caps.append('stream-preferred')
786 748 requiredformats = repo.requirements & repo.supportedformats
787 749 # if our local revlogs are just revlogv1, add 'stream' cap
788 750 if not requiredformats - {'revlogv1'}:
789 751 caps.append('stream')
790 752 # otherwise, add 'streamreqs' detailing our local revlog format
791 753 else:
792 754 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
793 755 if repo.ui.configbool('experimental', 'bundle2-advertise'):
794 756 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo))
795 757 caps.append('bundle2=' + urlreq.quote(capsblob))
796 758 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
797 759
798 760 if proto.name == 'http':
799 761 caps.append('httpheader=%d' %
800 762 repo.ui.configint('server', 'maxhttpheaderlen'))
801 763 if repo.ui.configbool('experimental', 'httppostargs'):
802 764 caps.append('httppostargs')
803 765
804 766 # FUTURE advertise 0.2rx once support is implemented
805 767 # FUTURE advertise minrx and mintx after consulting config option
806 768 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
807 769
808 770 compengines = supportedcompengines(repo.ui, proto, util.SERVERROLE)
809 771 if compengines:
810 772 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
811 773 for e in compengines)
812 774 caps.append('compression=%s' % comptypes)
813 775
814 776 return caps
815 777
816 778 # If you are writing an extension and consider wrapping this function. Wrap
817 779 # `_capabilities` instead.
818 780 @wireprotocommand('capabilities')
819 781 def capabilities(repo, proto):
820 782 return ' '.join(_capabilities(repo, proto))
821 783
822 784 @wireprotocommand('changegroup', 'roots')
823 785 def changegroup(repo, proto, roots):
824 786 nodes = decodelist(roots)
825 787 cg = changegroupmod.changegroup(repo, nodes, 'serve')
826 788 return streamres(reader=cg, v1compressible=True)
827 789
828 790 @wireprotocommand('changegroupsubset', 'bases heads')
829 791 def changegroupsubset(repo, proto, bases, heads):
830 792 bases = decodelist(bases)
831 793 heads = decodelist(heads)
832 794 cg = changegroupmod.changegroupsubset(repo, bases, heads, 'serve')
833 795 return streamres(reader=cg, v1compressible=True)
834 796
835 797 @wireprotocommand('debugwireargs', 'one two *')
836 798 def debugwireargs(repo, proto, one, two, others):
837 799 # only accept optional args from the known set
838 800 opts = options('debugwireargs', ['three', 'four'], others)
839 801 return repo.debugwireargs(one, two, **opts)
840 802
841 803 @wireprotocommand('getbundle', '*')
842 804 def getbundle(repo, proto, others):
843 805 opts = options('getbundle', gboptsmap.keys(), others)
844 806 for k, v in opts.iteritems():
845 807 keytype = gboptsmap[k]
846 808 if keytype == 'nodes':
847 809 opts[k] = decodelist(v)
848 810 elif keytype == 'csv':
849 811 opts[k] = list(v.split(','))
850 812 elif keytype == 'scsv':
851 813 opts[k] = set(v.split(','))
852 814 elif keytype == 'boolean':
853 815 # Client should serialize False as '0', which is a non-empty string
854 816 # so it evaluates as a True bool.
855 817 if v == '0':
856 818 opts[k] = False
857 819 else:
858 820 opts[k] = bool(v)
859 821 elif keytype != 'plain':
860 822 raise KeyError('unknown getbundle option type %s'
861 823 % keytype)
862 824
863 825 if not bundle1allowed(repo, 'pull'):
864 826 if not exchange.bundle2requested(opts.get('bundlecaps')):
865 827 if proto.name == 'http':
866 828 return ooberror(bundle2required)
867 829 raise error.Abort(bundle2requiredmain,
868 830 hint=bundle2requiredhint)
869 831
870 832 try:
871 833 if repo.ui.configbool('server', 'disablefullbundle'):
872 834 # Check to see if this is a full clone.
873 835 clheads = set(repo.changelog.heads())
874 836 heads = set(opts.get('heads', set()))
875 837 common = set(opts.get('common', set()))
876 838 common.discard(nullid)
877 839 if not common and clheads == heads:
878 840 raise error.Abort(
879 841 _('server has pull-based clones disabled'),
880 842 hint=_('remove --pull if specified or upgrade Mercurial'))
881 843
882 844 chunks = exchange.getbundlechunks(repo, 'serve', **opts)
883 845 except error.Abort as exc:
884 846 # cleanly forward Abort error to the client
885 847 if not exchange.bundle2requested(opts.get('bundlecaps')):
886 848 if proto.name == 'http':
887 849 return ooberror(str(exc) + '\n')
888 850 raise # cannot do better for bundle1 + ssh
889 851 # bundle2 request expect a bundle2 reply
890 852 bundler = bundle2.bundle20(repo.ui)
891 853 manargs = [('message', str(exc))]
892 854 advargs = []
893 855 if exc.hint is not None:
894 856 advargs.append(('hint', exc.hint))
895 857 bundler.addpart(bundle2.bundlepart('error:abort',
896 858 manargs, advargs))
897 859 return streamres(gen=bundler.getchunks(), v1compressible=True)
898 860 return streamres(gen=chunks, v1compressible=True)
899 861
900 862 @wireprotocommand('heads')
901 863 def heads(repo, proto):
902 864 h = repo.heads()
903 865 return encodelist(h) + "\n"
904 866
905 867 @wireprotocommand('hello')
906 868 def hello(repo, proto):
907 869 '''the hello command returns a set of lines describing various
908 870 interesting things about the server, in an RFC822-like format.
909 871 Currently the only one defined is "capabilities", which
910 872 consists of a line in the form:
911 873
912 874 capabilities: space separated list of tokens
913 875 '''
914 876 return "capabilities: %s\n" % (capabilities(repo, proto))
915 877
916 878 @wireprotocommand('listkeys', 'namespace')
917 879 def listkeys(repo, proto, namespace):
918 880 d = repo.listkeys(encoding.tolocal(namespace)).items()
919 881 return pushkeymod.encodekeys(d)
920 882
921 883 @wireprotocommand('lookup', 'key')
922 884 def lookup(repo, proto, key):
923 885 try:
924 886 k = encoding.tolocal(key)
925 887 c = repo[k]
926 888 r = c.hex()
927 889 success = 1
928 890 except Exception as inst:
929 891 r = str(inst)
930 892 success = 0
931 893 return "%s %s\n" % (success, r)
932 894
933 895 @wireprotocommand('known', 'nodes *')
934 896 def known(repo, proto, nodes, others):
935 897 return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
936 898
937 899 @wireprotocommand('pushkey', 'namespace key old new')
938 900 def pushkey(repo, proto, namespace, key, old, new):
939 901 # compatibility with pre-1.8 clients which were accidentally
940 902 # sending raw binary nodes rather than utf-8-encoded hex
941 903 if len(new) == 20 and util.escapestr(new) != new:
942 904 # looks like it could be a binary node
943 905 try:
944 906 new.decode('utf-8')
945 907 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
946 908 except UnicodeDecodeError:
947 909 pass # binary, leave unmodified
948 910 else:
949 911 new = encoding.tolocal(new) # normal path
950 912
951 913 if util.safehasattr(proto, 'restore'):
952 914
953 915 proto.redirect()
954 916
955 917 try:
956 918 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
957 919 encoding.tolocal(old), new) or False
958 920 except error.Abort:
959 921 r = False
960 922
961 923 output = proto.restore()
962 924
963 925 return '%s\n%s' % (int(r), output)
964 926
965 927 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
966 928 encoding.tolocal(old), new)
967 929 return '%s\n' % int(r)
968 930
969 931 @wireprotocommand('stream_out')
970 932 def stream(repo, proto):
971 933 '''If the server supports streaming clone, it advertises the "stream"
972 934 capability with a value representing the version and flags of the repo
973 935 it is serving. Client checks to see if it understands the format.
974 936 '''
975 937 if not streamclone.allowservergeneration(repo):
976 938 return '1\n'
977 939
978 940 def getstream(it):
979 941 yield '0\n'
980 942 for chunk in it:
981 943 yield chunk
982 944
983 945 try:
984 946 # LockError may be raised before the first result is yielded. Don't
985 947 # emit output until we're sure we got the lock successfully.
986 948 it = streamclone.generatev1wireproto(repo)
987 949 return streamres(gen=getstream(it))
988 950 except error.LockError:
989 951 return '2\n'
990 952
991 953 @wireprotocommand('unbundle', 'heads')
992 954 def unbundle(repo, proto, heads):
993 955 their_heads = decodelist(heads)
994 956
995 957 try:
996 958 proto.redirect()
997 959
998 960 exchange.check_heads(repo, their_heads, 'preparing changes')
999 961
1000 962 # write bundle data to temporary file because it can be big
1001 963 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
1002 964 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
1003 965 r = 0
1004 966 try:
1005 967 proto.getfile(fp)
1006 968 fp.seek(0)
1007 969 gen = exchange.readbundle(repo.ui, fp, None)
1008 970 if (isinstance(gen, changegroupmod.cg1unpacker)
1009 971 and not bundle1allowed(repo, 'push')):
1010 972 if proto.name == 'http':
1011 973 # need to special case http because stderr do not get to
1012 974 # the http client on failed push so we need to abuse some
1013 975 # other error type to make sure the message get to the
1014 976 # user.
1015 977 return ooberror(bundle2required)
1016 978 raise error.Abort(bundle2requiredmain,
1017 979 hint=bundle2requiredhint)
1018 980
1019 981 r = exchange.unbundle(repo, gen, their_heads, 'serve',
1020 982 proto._client())
1021 983 if util.safehasattr(r, 'addpart'):
1022 984 # The return looks streamable, we are in the bundle2 case and
1023 985 # should return a stream.
1024 986 return streamres(gen=r.getchunks())
1025 987 return pushres(r)
1026 988
1027 989 finally:
1028 990 fp.close()
1029 991 os.unlink(tempname)
1030 992
1031 993 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1032 994 # handle non-bundle2 case first
1033 995 if not getattr(exc, 'duringunbundle2', False):
1034 996 try:
1035 997 raise
1036 998 except error.Abort:
1037 999 # The old code we moved used util.stderr directly.
1038 1000 # We did not change it to minimise code change.
1039 1001 # This need to be moved to something proper.
1040 1002 # Feel free to do it.
1041 1003 util.stderr.write("abort: %s\n" % exc)
1042 1004 if exc.hint is not None:
1043 1005 util.stderr.write("(%s)\n" % exc.hint)
1044 1006 return pushres(0)
1045 1007 except error.PushRaced:
1046 1008 return pusherr(str(exc))
1047 1009
1048 1010 bundler = bundle2.bundle20(repo.ui)
1049 1011 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1050 1012 bundler.addpart(out)
1051 1013 try:
1052 1014 try:
1053 1015 raise
1054 1016 except error.PushkeyFailed as exc:
1055 1017 # check client caps
1056 1018 remotecaps = getattr(exc, '_replycaps', None)
1057 1019 if (remotecaps is not None
1058 1020 and 'pushkey' not in remotecaps.get('error', ())):
1059 1021 # no support remote side, fallback to Abort handler.
1060 1022 raise
1061 1023 part = bundler.newpart('error:pushkey')
1062 1024 part.addparam('in-reply-to', exc.partid)
1063 1025 if exc.namespace is not None:
1064 1026 part.addparam('namespace', exc.namespace, mandatory=False)
1065 1027 if exc.key is not None:
1066 1028 part.addparam('key', exc.key, mandatory=False)
1067 1029 if exc.new is not None:
1068 1030 part.addparam('new', exc.new, mandatory=False)
1069 1031 if exc.old is not None:
1070 1032 part.addparam('old', exc.old, mandatory=False)
1071 1033 if exc.ret is not None:
1072 1034 part.addparam('ret', exc.ret, mandatory=False)
1073 1035 except error.BundleValueError as exc:
1074 1036 errpart = bundler.newpart('error:unsupportedcontent')
1075 1037 if exc.parttype is not None:
1076 1038 errpart.addparam('parttype', exc.parttype)
1077 1039 if exc.params:
1078 1040 errpart.addparam('params', '\0'.join(exc.params))
1079 1041 except error.Abort as exc:
1080 1042 manargs = [('message', str(exc))]
1081 1043 advargs = []
1082 1044 if exc.hint is not None:
1083 1045 advargs.append(('hint', exc.hint))
1084 1046 bundler.addpart(bundle2.bundlepart('error:abort',
1085 1047 manargs, advargs))
1086 1048 except error.PushRaced as exc:
1087 1049 bundler.newpart('error:pushraced', [('message', str(exc))])
1088 1050 return streamres(gen=bundler.getchunks())
@@ -1,61 +1,61 b''
1 1 from __future__ import absolute_import, print_function
2 2
3 3 from mercurial import (
4 4 util,
5 5 wireproto,
6 6 )
7 7 stringio = util.stringio
8 8
9 9 class proto(object):
10 10 def __init__(self, args):
11 11 self.args = args
12 12 def getargs(self, spec):
13 13 args = self.args
14 14 args.setdefault('*', {})
15 15 names = spec.split()
16 16 return [args[n] for n in names]
17 17
18 18 class clientpeer(wireproto.wirepeer):
19 19 def __init__(self, serverrepo):
20 20 self.serverrepo = serverrepo
21 21
22 22 def _capabilities(self):
23 23 return ['batch']
24 24
25 25 def _call(self, cmd, **args):
26 26 return wireproto.dispatch(self.serverrepo, proto(args), cmd)
27 27
28 28 def _callstream(self, cmd, **args):
29 29 return stringio(self._call(cmd, **args))
30 30
31 31 @wireproto.batchable
32 32 def greet(self, name):
33 33 f = wireproto.future()
34 34 yield {'name': mangle(name)}, f
35 35 yield unmangle(f.value)
36 36
37 37 class serverrepo(object):
38 38 def greet(self, name):
39 39 return "Hello, " + name
40 40
41 41 def filtered(self, name):
42 42 return self
43 43
44 44 def mangle(s):
45 45 return ''.join(chr(ord(c) + 1) for c in s)
46 46 def unmangle(s):
47 47 return ''.join(chr(ord(c) - 1) for c in s)
48 48
49 49 def greet(repo, proto, name):
50 50 return mangle(repo.greet(unmangle(name)))
51 51
52 52 wireproto.commands['greet'] = (greet, 'name',)
53 53
54 54 srv = serverrepo()
55 55 clt = clientpeer(srv)
56 56
57 57 print(clt.greet("Foobar"))
58 b = clt.batch()
59 fs = [b.greet(s) for s in ["Fo, =;:<o", "Bar"]]
58 b = clt.iterbatch()
59 map(b.greet, ('Fo, =;:<o', 'Bar'))
60 60 b.submit()
61 print([f.value for f in fs])
61 print([r for r in b.results()])
General Comments 0
You need to be logged in to leave comments. Login now