##// END OF EJS Templates
wireproto: don't special case bundlecaps, but sort all scsv arguments...
Joerg Sonnenberger -
r37430:1d459f61 default
parent child Browse files
Show More
@@ -1,1182 +1,1180 b''
1 1 # wireproto.py - generic wire protocol support functions
2 2 #
3 3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import hashlib
11 11 import os
12 12 import tempfile
13 13
14 14 from .i18n import _
15 15 from .node import (
16 16 bin,
17 17 hex,
18 18 nullid,
19 19 )
20 20
21 21 from . import (
22 22 bundle2,
23 23 changegroup as changegroupmod,
24 24 discovery,
25 25 encoding,
26 26 error,
27 27 exchange,
28 28 peer,
29 29 pushkey as pushkeymod,
30 30 pycompat,
31 31 repository,
32 32 streamclone,
33 33 util,
34 34 wireprototypes,
35 35 )
36 36
37 37 from .utils import (
38 38 procutil,
39 39 stringutil,
40 40 )
41 41
42 42 urlerr = util.urlerr
43 43 urlreq = util.urlreq
44 44
45 45 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
46 46 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
47 47 'IncompatibleClient')
48 48 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
49 49
50 50 class remoteiterbatcher(peer.iterbatcher):
51 51 def __init__(self, remote):
52 52 super(remoteiterbatcher, self).__init__()
53 53 self._remote = remote
54 54
55 55 def __getattr__(self, name):
56 56 # Validate this method is batchable, since submit() only supports
57 57 # batchable methods.
58 58 fn = getattr(self._remote, name)
59 59 if not getattr(fn, 'batchable', None):
60 60 raise error.ProgrammingError('Attempted to batch a non-batchable '
61 61 'call to %r' % name)
62 62
63 63 return super(remoteiterbatcher, self).__getattr__(name)
64 64
65 65 def submit(self):
66 66 """Break the batch request into many patch calls and pipeline them.
67 67
68 68 This is mostly valuable over http where request sizes can be
69 69 limited, but can be used in other places as well.
70 70 """
71 71 # 2-tuple of (command, arguments) that represents what will be
72 72 # sent over the wire.
73 73 requests = []
74 74
75 75 # 4-tuple of (command, final future, @batchable generator, remote
76 76 # future).
77 77 results = []
78 78
79 79 for command, args, opts, finalfuture in self.calls:
80 80 mtd = getattr(self._remote, command)
81 81 batchable = mtd.batchable(mtd.__self__, *args, **opts)
82 82
83 83 commandargs, fremote = next(batchable)
84 84 assert fremote
85 85 requests.append((command, commandargs))
86 86 results.append((command, finalfuture, batchable, fremote))
87 87
88 88 if requests:
89 89 self._resultiter = self._remote._submitbatch(requests)
90 90
91 91 self._results = results
92 92
93 93 def results(self):
94 94 for command, finalfuture, batchable, remotefuture in self._results:
95 95 # Get the raw result, set it in the remote future, feed it
96 96 # back into the @batchable generator so it can be decoded, and
97 97 # set the result on the final future to this value.
98 98 remoteresult = next(self._resultiter)
99 99 remotefuture.set(remoteresult)
100 100 finalfuture.set(next(batchable))
101 101
102 102 # Verify our @batchable generators only emit 2 values.
103 103 try:
104 104 next(batchable)
105 105 except StopIteration:
106 106 pass
107 107 else:
108 108 raise error.ProgrammingError('%s @batchable generator emitted '
109 109 'unexpected value count' % command)
110 110
111 111 yield finalfuture.value
112 112
113 113 # Forward a couple of names from peer to make wireproto interactions
114 114 # slightly more sensible.
115 115 batchable = peer.batchable
116 116 future = peer.future
117 117
118 118 # list of nodes encoding / decoding
119 119
120 120 def decodelist(l, sep=' '):
121 121 if l:
122 122 return [bin(v) for v in l.split(sep)]
123 123 return []
124 124
125 125 def encodelist(l, sep=' '):
126 126 try:
127 127 return sep.join(map(hex, l))
128 128 except TypeError:
129 129 raise
130 130
131 131 # batched call argument encoding
132 132
133 133 def escapearg(plain):
134 134 return (plain
135 135 .replace(':', ':c')
136 136 .replace(',', ':o')
137 137 .replace(';', ':s')
138 138 .replace('=', ':e'))
139 139
140 140 def unescapearg(escaped):
141 141 return (escaped
142 142 .replace(':e', '=')
143 143 .replace(':s', ';')
144 144 .replace(':o', ',')
145 145 .replace(':c', ':'))
146 146
147 147 def encodebatchcmds(req):
148 148 """Return a ``cmds`` argument value for the ``batch`` command."""
149 149 cmds = []
150 150 for op, argsdict in req:
151 151 # Old servers didn't properly unescape argument names. So prevent
152 152 # the sending of argument names that may not be decoded properly by
153 153 # servers.
154 154 assert all(escapearg(k) == k for k in argsdict)
155 155
156 156 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
157 157 for k, v in argsdict.iteritems())
158 158 cmds.append('%s %s' % (op, args))
159 159
160 160 return ';'.join(cmds)
161 161
162 162 def clientcompressionsupport(proto):
163 163 """Returns a list of compression methods supported by the client.
164 164
165 165 Returns a list of the compression methods supported by the client
166 166 according to the protocol capabilities. If no such capability has
167 167 been announced, fallback to the default of zlib and uncompressed.
168 168 """
169 169 for cap in proto.getprotocaps():
170 170 if cap.startswith('comp='):
171 171 return cap[5:].split(',')
172 172 return ['zlib', 'none']
173 173
174 174 # mapping of options accepted by getbundle and their types
175 175 #
176 176 # Meant to be extended by extensions. It is extensions responsibility to ensure
177 177 # such options are properly processed in exchange.getbundle.
178 178 #
179 179 # supported types are:
180 180 #
181 181 # :nodes: list of binary nodes
182 182 # :csv: list of comma-separated values
183 183 # :scsv: list of comma-separated values return as set
184 184 # :plain: string with no transformation needed.
185 185 gboptsmap = {'heads': 'nodes',
186 186 'bookmarks': 'boolean',
187 187 'common': 'nodes',
188 188 'obsmarkers': 'boolean',
189 189 'phases': 'boolean',
190 190 'bundlecaps': 'scsv',
191 191 'listkeys': 'csv',
192 192 'cg': 'boolean',
193 193 'cbattempted': 'boolean',
194 194 'stream': 'boolean',
195 195 }
196 196
197 197 # client side
198 198
199 199 class wirepeer(repository.legacypeer):
200 200 """Client-side interface for communicating with a peer repository.
201 201
202 202 Methods commonly call wire protocol commands of the same name.
203 203
204 204 See also httppeer.py and sshpeer.py for protocol-specific
205 205 implementations of this interface.
206 206 """
207 207 # Begin of ipeercommands interface.
208 208
209 209 def iterbatch(self):
210 210 return remoteiterbatcher(self)
211 211
212 212 @batchable
213 213 def lookup(self, key):
214 214 self.requirecap('lookup', _('look up remote revision'))
215 215 f = future()
216 216 yield {'key': encoding.fromlocal(key)}, f
217 217 d = f.value
218 218 success, data = d[:-1].split(" ", 1)
219 219 if int(success):
220 220 yield bin(data)
221 221 else:
222 222 self._abort(error.RepoError(data))
223 223
224 224 @batchable
225 225 def heads(self):
226 226 f = future()
227 227 yield {}, f
228 228 d = f.value
229 229 try:
230 230 yield decodelist(d[:-1])
231 231 except ValueError:
232 232 self._abort(error.ResponseError(_("unexpected response:"), d))
233 233
234 234 @batchable
235 235 def known(self, nodes):
236 236 f = future()
237 237 yield {'nodes': encodelist(nodes)}, f
238 238 d = f.value
239 239 try:
240 240 yield [bool(int(b)) for b in d]
241 241 except ValueError:
242 242 self._abort(error.ResponseError(_("unexpected response:"), d))
243 243
244 244 @batchable
245 245 def branchmap(self):
246 246 f = future()
247 247 yield {}, f
248 248 d = f.value
249 249 try:
250 250 branchmap = {}
251 251 for branchpart in d.splitlines():
252 252 branchname, branchheads = branchpart.split(' ', 1)
253 253 branchname = encoding.tolocal(urlreq.unquote(branchname))
254 254 branchheads = decodelist(branchheads)
255 255 branchmap[branchname] = branchheads
256 256 yield branchmap
257 257 except TypeError:
258 258 self._abort(error.ResponseError(_("unexpected response:"), d))
259 259
260 260 @batchable
261 261 def listkeys(self, namespace):
262 262 if not self.capable('pushkey'):
263 263 yield {}, None
264 264 f = future()
265 265 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
266 266 yield {'namespace': encoding.fromlocal(namespace)}, f
267 267 d = f.value
268 268 self.ui.debug('received listkey for "%s": %i bytes\n'
269 269 % (namespace, len(d)))
270 270 yield pushkeymod.decodekeys(d)
271 271
272 272 @batchable
273 273 def pushkey(self, namespace, key, old, new):
274 274 if not self.capable('pushkey'):
275 275 yield False, None
276 276 f = future()
277 277 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
278 278 yield {'namespace': encoding.fromlocal(namespace),
279 279 'key': encoding.fromlocal(key),
280 280 'old': encoding.fromlocal(old),
281 281 'new': encoding.fromlocal(new)}, f
282 282 d = f.value
283 283 d, output = d.split('\n', 1)
284 284 try:
285 285 d = bool(int(d))
286 286 except ValueError:
287 287 raise error.ResponseError(
288 288 _('push failed (unexpected response):'), d)
289 289 for l in output.splitlines(True):
290 290 self.ui.status(_('remote: '), l)
291 291 yield d
292 292
293 293 def stream_out(self):
294 294 return self._callstream('stream_out')
295 295
296 296 def getbundle(self, source, **kwargs):
297 297 kwargs = pycompat.byteskwargs(kwargs)
298 298 self.requirecap('getbundle', _('look up remote changes'))
299 299 opts = {}
300 bundlecaps = kwargs.get('bundlecaps')
301 if bundlecaps is not None:
302 kwargs['bundlecaps'] = sorted(bundlecaps)
303 else:
304 bundlecaps = () # kwargs could have it to None
300 bundlecaps = kwargs.get('bundlecaps') or set()
305 301 for key, value in kwargs.iteritems():
306 302 if value is None:
307 303 continue
308 304 keytype = gboptsmap.get(key)
309 305 if keytype is None:
310 306 raise error.ProgrammingError(
311 307 'Unexpectedly None keytype for key %s' % key)
312 308 elif keytype == 'nodes':
313 309 value = encodelist(value)
314 elif keytype in ('csv', 'scsv'):
310 elif keytype == 'csv':
315 311 value = ','.join(value)
312 elif keytype == 'scsv':
313 value = ','.join(sorted(value))
316 314 elif keytype == 'boolean':
317 315 value = '%i' % bool(value)
318 316 elif keytype != 'plain':
319 317 raise KeyError('unknown getbundle option type %s'
320 318 % keytype)
321 319 opts[key] = value
322 320 f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
323 321 if any((cap.startswith('HG2') for cap in bundlecaps)):
324 322 return bundle2.getunbundler(self.ui, f)
325 323 else:
326 324 return changegroupmod.cg1unpacker(f, 'UN')
327 325
328 326 def unbundle(self, cg, heads, url):
329 327 '''Send cg (a readable file-like object representing the
330 328 changegroup to push, typically a chunkbuffer object) to the
331 329 remote server as a bundle.
332 330
333 331 When pushing a bundle10 stream, return an integer indicating the
334 332 result of the push (see changegroup.apply()).
335 333
336 334 When pushing a bundle20 stream, return a bundle20 stream.
337 335
338 336 `url` is the url the client thinks it's pushing to, which is
339 337 visible to hooks.
340 338 '''
341 339
342 340 if heads != ['force'] and self.capable('unbundlehash'):
343 341 heads = encodelist(['hashed',
344 342 hashlib.sha1(''.join(sorted(heads))).digest()])
345 343 else:
346 344 heads = encodelist(heads)
347 345
348 346 if util.safehasattr(cg, 'deltaheader'):
349 347 # this a bundle10, do the old style call sequence
350 348 ret, output = self._callpush("unbundle", cg, heads=heads)
351 349 if ret == "":
352 350 raise error.ResponseError(
353 351 _('push failed:'), output)
354 352 try:
355 353 ret = int(ret)
356 354 except ValueError:
357 355 raise error.ResponseError(
358 356 _('push failed (unexpected response):'), ret)
359 357
360 358 for l in output.splitlines(True):
361 359 self.ui.status(_('remote: '), l)
362 360 else:
363 361 # bundle2 push. Send a stream, fetch a stream.
364 362 stream = self._calltwowaystream('unbundle', cg, heads=heads)
365 363 ret = bundle2.getunbundler(self.ui, stream)
366 364 return ret
367 365
368 366 # End of ipeercommands interface.
369 367
370 368 # Begin of ipeerlegacycommands interface.
371 369
372 370 def branches(self, nodes):
373 371 n = encodelist(nodes)
374 372 d = self._call("branches", nodes=n)
375 373 try:
376 374 br = [tuple(decodelist(b)) for b in d.splitlines()]
377 375 return br
378 376 except ValueError:
379 377 self._abort(error.ResponseError(_("unexpected response:"), d))
380 378
381 379 def between(self, pairs):
382 380 batch = 8 # avoid giant requests
383 381 r = []
384 382 for i in xrange(0, len(pairs), batch):
385 383 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
386 384 d = self._call("between", pairs=n)
387 385 try:
388 386 r.extend(l and decodelist(l) or [] for l in d.splitlines())
389 387 except ValueError:
390 388 self._abort(error.ResponseError(_("unexpected response:"), d))
391 389 return r
392 390
393 391 def changegroup(self, nodes, kind):
394 392 n = encodelist(nodes)
395 393 f = self._callcompressable("changegroup", roots=n)
396 394 return changegroupmod.cg1unpacker(f, 'UN')
397 395
398 396 def changegroupsubset(self, bases, heads, kind):
399 397 self.requirecap('changegroupsubset', _('look up remote changes'))
400 398 bases = encodelist(bases)
401 399 heads = encodelist(heads)
402 400 f = self._callcompressable("changegroupsubset",
403 401 bases=bases, heads=heads)
404 402 return changegroupmod.cg1unpacker(f, 'UN')
405 403
406 404 # End of ipeerlegacycommands interface.
407 405
408 406 def _submitbatch(self, req):
409 407 """run batch request <req> on the server
410 408
411 409 Returns an iterator of the raw responses from the server.
412 410 """
413 411 ui = self.ui
414 412 if ui.debugflag and ui.configbool('devel', 'debug.peer-request'):
415 413 ui.debug('devel-peer-request: batched-content\n')
416 414 for op, args in req:
417 415 msg = 'devel-peer-request: - %s (%d arguments)\n'
418 416 ui.debug(msg % (op, len(args)))
419 417
420 418 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
421 419 chunk = rsp.read(1024)
422 420 work = [chunk]
423 421 while chunk:
424 422 while ';' not in chunk and chunk:
425 423 chunk = rsp.read(1024)
426 424 work.append(chunk)
427 425 merged = ''.join(work)
428 426 while ';' in merged:
429 427 one, merged = merged.split(';', 1)
430 428 yield unescapearg(one)
431 429 chunk = rsp.read(1024)
432 430 work = [merged, chunk]
433 431 yield unescapearg(''.join(work))
434 432
435 433 def _submitone(self, op, args):
436 434 return self._call(op, **pycompat.strkwargs(args))
437 435
438 436 def debugwireargs(self, one, two, three=None, four=None, five=None):
439 437 # don't pass optional arguments left at their default value
440 438 opts = {}
441 439 if three is not None:
442 440 opts[r'three'] = three
443 441 if four is not None:
444 442 opts[r'four'] = four
445 443 return self._call('debugwireargs', one=one, two=two, **opts)
446 444
447 445 def _call(self, cmd, **args):
448 446 """execute <cmd> on the server
449 447
450 448 The command is expected to return a simple string.
451 449
452 450 returns the server reply as a string."""
453 451 raise NotImplementedError()
454 452
455 453 def _callstream(self, cmd, **args):
456 454 """execute <cmd> on the server
457 455
458 456 The command is expected to return a stream. Note that if the
459 457 command doesn't return a stream, _callstream behaves
460 458 differently for ssh and http peers.
461 459
462 460 returns the server reply as a file like object.
463 461 """
464 462 raise NotImplementedError()
465 463
466 464 def _callcompressable(self, cmd, **args):
467 465 """execute <cmd> on the server
468 466
469 467 The command is expected to return a stream.
470 468
471 469 The stream may have been compressed in some implementations. This
472 470 function takes care of the decompression. This is the only difference
473 471 with _callstream.
474 472
475 473 returns the server reply as a file like object.
476 474 """
477 475 raise NotImplementedError()
478 476
479 477 def _callpush(self, cmd, fp, **args):
480 478 """execute a <cmd> on server
481 479
482 480 The command is expected to be related to a push. Push has a special
483 481 return method.
484 482
485 483 returns the server reply as a (ret, output) tuple. ret is either
486 484 empty (error) or a stringified int.
487 485 """
488 486 raise NotImplementedError()
489 487
490 488 def _calltwowaystream(self, cmd, fp, **args):
491 489 """execute <cmd> on server
492 490
493 491 The command will send a stream to the server and get a stream in reply.
494 492 """
495 493 raise NotImplementedError()
496 494
497 495 def _abort(self, exception):
498 496 """clearly abort the wire protocol connection and raise the exception
499 497 """
500 498 raise NotImplementedError()
501 499
502 500 # server side
503 501
504 502 # wire protocol command can either return a string or one of these classes.
505 503
506 504 def getdispatchrepo(repo, proto, command):
507 505 """Obtain the repo used for processing wire protocol commands.
508 506
509 507 The intent of this function is to serve as a monkeypatch point for
510 508 extensions that need commands to operate on different repo views under
511 509 specialized circumstances.
512 510 """
513 511 return repo.filtered('served')
514 512
515 513 def dispatch(repo, proto, command):
516 514 repo = getdispatchrepo(repo, proto, command)
517 515
518 516 transportversion = wireprototypes.TRANSPORTS[proto.name]['version']
519 517 commandtable = commandsv2 if transportversion == 2 else commands
520 518 func, spec = commandtable[command]
521 519
522 520 args = proto.getargs(spec)
523 521 return func(repo, proto, *args)
524 522
525 523 def options(cmd, keys, others):
526 524 opts = {}
527 525 for k in keys:
528 526 if k in others:
529 527 opts[k] = others[k]
530 528 del others[k]
531 529 if others:
532 530 procutil.stderr.write("warning: %s ignored unexpected arguments %s\n"
533 531 % (cmd, ",".join(others)))
534 532 return opts
535 533
536 534 def bundle1allowed(repo, action):
537 535 """Whether a bundle1 operation is allowed from the server.
538 536
539 537 Priority is:
540 538
541 539 1. server.bundle1gd.<action> (if generaldelta active)
542 540 2. server.bundle1.<action>
543 541 3. server.bundle1gd (if generaldelta active)
544 542 4. server.bundle1
545 543 """
546 544 ui = repo.ui
547 545 gd = 'generaldelta' in repo.requirements
548 546
549 547 if gd:
550 548 v = ui.configbool('server', 'bundle1gd.%s' % action)
551 549 if v is not None:
552 550 return v
553 551
554 552 v = ui.configbool('server', 'bundle1.%s' % action)
555 553 if v is not None:
556 554 return v
557 555
558 556 if gd:
559 557 v = ui.configbool('server', 'bundle1gd')
560 558 if v is not None:
561 559 return v
562 560
563 561 return ui.configbool('server', 'bundle1')
564 562
565 563 def supportedcompengines(ui, role):
566 564 """Obtain the list of supported compression engines for a request."""
567 565 assert role in (util.CLIENTROLE, util.SERVERROLE)
568 566
569 567 compengines = util.compengines.supportedwireengines(role)
570 568
571 569 # Allow config to override default list and ordering.
572 570 if role == util.SERVERROLE:
573 571 configengines = ui.configlist('server', 'compressionengines')
574 572 config = 'server.compressionengines'
575 573 else:
576 574 # This is currently implemented mainly to facilitate testing. In most
577 575 # cases, the server should be in charge of choosing a compression engine
578 576 # because a server has the most to lose from a sub-optimal choice. (e.g.
579 577 # CPU DoS due to an expensive engine or a network DoS due to poor
580 578 # compression ratio).
581 579 configengines = ui.configlist('experimental',
582 580 'clientcompressionengines')
583 581 config = 'experimental.clientcompressionengines'
584 582
585 583 # No explicit config. Filter out the ones that aren't supposed to be
586 584 # advertised and return default ordering.
587 585 if not configengines:
588 586 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
589 587 return [e for e in compengines
590 588 if getattr(e.wireprotosupport(), attr) > 0]
591 589
592 590 # If compression engines are listed in the config, assume there is a good
593 591 # reason for it (like server operators wanting to achieve specific
594 592 # performance characteristics). So fail fast if the config references
595 593 # unusable compression engines.
596 594 validnames = set(e.name() for e in compengines)
597 595 invalidnames = set(e for e in configengines if e not in validnames)
598 596 if invalidnames:
599 597 raise error.Abort(_('invalid compression engine defined in %s: %s') %
600 598 (config, ', '.join(sorted(invalidnames))))
601 599
602 600 compengines = [e for e in compengines if e.name() in configengines]
603 601 compengines = sorted(compengines,
604 602 key=lambda e: configengines.index(e.name()))
605 603
606 604 if not compengines:
607 605 raise error.Abort(_('%s config option does not specify any known '
608 606 'compression engines') % config,
609 607 hint=_('usable compression engines: %s') %
610 608 ', '.sorted(validnames))
611 609
612 610 return compengines
613 611
614 612 class commandentry(object):
615 613 """Represents a declared wire protocol command."""
616 614 def __init__(self, func, args='', transports=None,
617 615 permission='push'):
618 616 self.func = func
619 617 self.args = args
620 618 self.transports = transports or set()
621 619 self.permission = permission
622 620
623 621 def _merge(self, func, args):
624 622 """Merge this instance with an incoming 2-tuple.
625 623
626 624 This is called when a caller using the old 2-tuple API attempts
627 625 to replace an instance. The incoming values are merged with
628 626 data not captured by the 2-tuple and a new instance containing
629 627 the union of the two objects is returned.
630 628 """
631 629 return commandentry(func, args=args, transports=set(self.transports),
632 630 permission=self.permission)
633 631
634 632 # Old code treats instances as 2-tuples. So expose that interface.
635 633 def __iter__(self):
636 634 yield self.func
637 635 yield self.args
638 636
639 637 def __getitem__(self, i):
640 638 if i == 0:
641 639 return self.func
642 640 elif i == 1:
643 641 return self.args
644 642 else:
645 643 raise IndexError('can only access elements 0 and 1')
646 644
647 645 class commanddict(dict):
648 646 """Container for registered wire protocol commands.
649 647
650 648 It behaves like a dict. But __setitem__ is overwritten to allow silent
651 649 coercion of values from 2-tuples for API compatibility.
652 650 """
653 651 def __setitem__(self, k, v):
654 652 if isinstance(v, commandentry):
655 653 pass
656 654 # Cast 2-tuples to commandentry instances.
657 655 elif isinstance(v, tuple):
658 656 if len(v) != 2:
659 657 raise ValueError('command tuples must have exactly 2 elements')
660 658
661 659 # It is common for extensions to wrap wire protocol commands via
662 660 # e.g. ``wireproto.commands[x] = (newfn, args)``. Because callers
663 661 # doing this aren't aware of the new API that uses objects to store
664 662 # command entries, we automatically merge old state with new.
665 663 if k in self:
666 664 v = self[k]._merge(v[0], v[1])
667 665 else:
668 666 # Use default values from @wireprotocommand.
669 667 v = commandentry(v[0], args=v[1],
670 668 transports=set(wireprototypes.TRANSPORTS),
671 669 permission='push')
672 670 else:
673 671 raise ValueError('command entries must be commandentry instances '
674 672 'or 2-tuples')
675 673
676 674 return super(commanddict, self).__setitem__(k, v)
677 675
678 676 def commandavailable(self, command, proto):
679 677 """Determine if a command is available for the requested protocol."""
680 678 assert proto.name in wireprototypes.TRANSPORTS
681 679
682 680 entry = self.get(command)
683 681
684 682 if not entry:
685 683 return False
686 684
687 685 if proto.name not in entry.transports:
688 686 return False
689 687
690 688 return True
691 689
692 690 # Constants specifying which transports a wire protocol command should be
693 691 # available on. For use with @wireprotocommand.
694 692 POLICY_ALL = 'all'
695 693 POLICY_V1_ONLY = 'v1-only'
696 694 POLICY_V2_ONLY = 'v2-only'
697 695
698 696 # For version 1 transports.
699 697 commands = commanddict()
700 698
701 699 # For version 2 transports.
702 700 commandsv2 = commanddict()
703 701
704 702 def wireprotocommand(name, args='', transportpolicy=POLICY_ALL,
705 703 permission='push'):
706 704 """Decorator to declare a wire protocol command.
707 705
708 706 ``name`` is the name of the wire protocol command being provided.
709 707
710 708 ``args`` is a space-delimited list of named arguments that the command
711 709 accepts. ``*`` is a special value that says to accept all arguments.
712 710
713 711 ``transportpolicy`` is a POLICY_* constant denoting which transports
714 712 this wire protocol command should be exposed to. By default, commands
715 713 are exposed to all wire protocol transports.
716 714
717 715 ``permission`` defines the permission type needed to run this command.
718 716 Can be ``push`` or ``pull``. These roughly map to read-write and read-only,
719 717 respectively. Default is to assume command requires ``push`` permissions
720 718 because otherwise commands not declaring their permissions could modify
721 719 a repository that is supposed to be read-only.
722 720 """
723 721 if transportpolicy == POLICY_ALL:
724 722 transports = set(wireprototypes.TRANSPORTS)
725 723 transportversions = {1, 2}
726 724 elif transportpolicy == POLICY_V1_ONLY:
727 725 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
728 726 if v['version'] == 1}
729 727 transportversions = {1}
730 728 elif transportpolicy == POLICY_V2_ONLY:
731 729 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
732 730 if v['version'] == 2}
733 731 transportversions = {2}
734 732 else:
735 733 raise error.ProgrammingError('invalid transport policy value: %s' %
736 734 transportpolicy)
737 735
738 736 # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to
739 737 # SSHv2.
740 738 # TODO undo this hack when SSH is using the unified frame protocol.
741 739 if name == b'batch':
742 740 transports.add(wireprototypes.SSHV2)
743 741
744 742 if permission not in ('push', 'pull'):
745 743 raise error.ProgrammingError('invalid wire protocol permission; '
746 744 'got %s; expected "push" or "pull"' %
747 745 permission)
748 746
749 747 def register(func):
750 748 if 1 in transportversions:
751 749 if name in commands:
752 750 raise error.ProgrammingError('%s command already registered '
753 751 'for version 1' % name)
754 752 commands[name] = commandentry(func, args=args,
755 753 transports=transports,
756 754 permission=permission)
757 755 if 2 in transportversions:
758 756 if name in commandsv2:
759 757 raise error.ProgrammingError('%s command already registered '
760 758 'for version 2' % name)
761 759 commandsv2[name] = commandentry(func, args=args,
762 760 transports=transports,
763 761 permission=permission)
764 762
765 763 return func
766 764 return register
767 765
768 766 # TODO define a more appropriate permissions type to use for this.
769 767 @wireprotocommand('batch', 'cmds *', permission='pull',
770 768 transportpolicy=POLICY_V1_ONLY)
771 769 def batch(repo, proto, cmds, others):
772 770 repo = repo.filtered("served")
773 771 res = []
774 772 for pair in cmds.split(';'):
775 773 op, args = pair.split(' ', 1)
776 774 vals = {}
777 775 for a in args.split(','):
778 776 if a:
779 777 n, v = a.split('=')
780 778 vals[unescapearg(n)] = unescapearg(v)
781 779 func, spec = commands[op]
782 780
783 781 # Validate that client has permissions to perform this command.
784 782 perm = commands[op].permission
785 783 assert perm in ('push', 'pull')
786 784 proto.checkperm(perm)
787 785
788 786 if spec:
789 787 keys = spec.split()
790 788 data = {}
791 789 for k in keys:
792 790 if k == '*':
793 791 star = {}
794 792 for key in vals.keys():
795 793 if key not in keys:
796 794 star[key] = vals[key]
797 795 data['*'] = star
798 796 else:
799 797 data[k] = vals[k]
800 798 result = func(repo, proto, *[data[k] for k in keys])
801 799 else:
802 800 result = func(repo, proto)
803 801 if isinstance(result, wireprototypes.ooberror):
804 802 return result
805 803
806 804 # For now, all batchable commands must return bytesresponse or
807 805 # raw bytes (for backwards compatibility).
808 806 assert isinstance(result, (wireprototypes.bytesresponse, bytes))
809 807 if isinstance(result, wireprototypes.bytesresponse):
810 808 result = result.data
811 809 res.append(escapearg(result))
812 810
813 811 return wireprototypes.bytesresponse(';'.join(res))
814 812
815 813 @wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY,
816 814 permission='pull')
817 815 def between(repo, proto, pairs):
818 816 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
819 817 r = []
820 818 for b in repo.between(pairs):
821 819 r.append(encodelist(b) + "\n")
822 820
823 821 return wireprototypes.bytesresponse(''.join(r))
824 822
825 823 @wireprotocommand('branchmap', permission='pull')
826 824 def branchmap(repo, proto):
827 825 branchmap = repo.branchmap()
828 826 heads = []
829 827 for branch, nodes in branchmap.iteritems():
830 828 branchname = urlreq.quote(encoding.fromlocal(branch))
831 829 branchnodes = encodelist(nodes)
832 830 heads.append('%s %s' % (branchname, branchnodes))
833 831
834 832 return wireprototypes.bytesresponse('\n'.join(heads))
835 833
836 834 @wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY,
837 835 permission='pull')
838 836 def branches(repo, proto, nodes):
839 837 nodes = decodelist(nodes)
840 838 r = []
841 839 for b in repo.branches(nodes):
842 840 r.append(encodelist(b) + "\n")
843 841
844 842 return wireprototypes.bytesresponse(''.join(r))
845 843
846 844 @wireprotocommand('clonebundles', '', permission='pull')
847 845 def clonebundles(repo, proto):
848 846 """Server command for returning info for available bundles to seed clones.
849 847
850 848 Clients will parse this response and determine what bundle to fetch.
851 849
852 850 Extensions may wrap this command to filter or dynamically emit data
853 851 depending on the request. e.g. you could advertise URLs for the closest
854 852 data center given the client's IP address.
855 853 """
856 854 return wireprototypes.bytesresponse(
857 855 repo.vfs.tryread('clonebundles.manifest'))
858 856
859 857 wireprotocaps = ['lookup', 'branchmap', 'pushkey',
860 858 'known', 'getbundle', 'unbundlehash']
861 859
862 860 def _capabilities(repo, proto):
863 861 """return a list of capabilities for a repo
864 862
865 863 This function exists to allow extensions to easily wrap capabilities
866 864 computation
867 865
868 866 - returns a lists: easy to alter
869 867 - change done here will be propagated to both `capabilities` and `hello`
870 868 command without any other action needed.
871 869 """
872 870 # copy to prevent modification of the global list
873 871 caps = list(wireprotocaps)
874 872
875 873 # Command of same name as capability isn't exposed to version 1 of
876 874 # transports. So conditionally add it.
877 875 if commands.commandavailable('changegroupsubset', proto):
878 876 caps.append('changegroupsubset')
879 877
880 878 if streamclone.allowservergeneration(repo):
881 879 if repo.ui.configbool('server', 'preferuncompressed'):
882 880 caps.append('stream-preferred')
883 881 requiredformats = repo.requirements & repo.supportedformats
884 882 # if our local revlogs are just revlogv1, add 'stream' cap
885 883 if not requiredformats - {'revlogv1'}:
886 884 caps.append('stream')
887 885 # otherwise, add 'streamreqs' detailing our local revlog format
888 886 else:
889 887 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
890 888 if repo.ui.configbool('experimental', 'bundle2-advertise'):
891 889 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
892 890 caps.append('bundle2=' + urlreq.quote(capsblob))
893 891 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
894 892
895 893 return proto.addcapabilities(repo, caps)
896 894
897 895 # If you are writing an extension and consider wrapping this function. Wrap
898 896 # `_capabilities` instead.
899 897 @wireprotocommand('capabilities', permission='pull')
900 898 def capabilities(repo, proto):
901 899 return wireprototypes.bytesresponse(' '.join(_capabilities(repo, proto)))
902 900
903 901 @wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY,
904 902 permission='pull')
905 903 def changegroup(repo, proto, roots):
906 904 nodes = decodelist(roots)
907 905 outgoing = discovery.outgoing(repo, missingroots=nodes,
908 906 missingheads=repo.heads())
909 907 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
910 908 gen = iter(lambda: cg.read(32768), '')
911 909 return wireprototypes.streamres(gen=gen)
912 910
913 911 @wireprotocommand('changegroupsubset', 'bases heads',
914 912 transportpolicy=POLICY_V1_ONLY,
915 913 permission='pull')
916 914 def changegroupsubset(repo, proto, bases, heads):
917 915 bases = decodelist(bases)
918 916 heads = decodelist(heads)
919 917 outgoing = discovery.outgoing(repo, missingroots=bases,
920 918 missingheads=heads)
921 919 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
922 920 gen = iter(lambda: cg.read(32768), '')
923 921 return wireprototypes.streamres(gen=gen)
924 922
925 923 @wireprotocommand('debugwireargs', 'one two *',
926 924 permission='pull')
927 925 def debugwireargs(repo, proto, one, two, others):
928 926 # only accept optional args from the known set
929 927 opts = options('debugwireargs', ['three', 'four'], others)
930 928 return wireprototypes.bytesresponse(repo.debugwireargs(
931 929 one, two, **pycompat.strkwargs(opts)))
932 930
933 931 @wireprotocommand('getbundle', '*', permission='pull')
934 932 def getbundle(repo, proto, others):
935 933 opts = options('getbundle', gboptsmap.keys(), others)
936 934 for k, v in opts.iteritems():
937 935 keytype = gboptsmap[k]
938 936 if keytype == 'nodes':
939 937 opts[k] = decodelist(v)
940 938 elif keytype == 'csv':
941 939 opts[k] = list(v.split(','))
942 940 elif keytype == 'scsv':
943 941 opts[k] = set(v.split(','))
944 942 elif keytype == 'boolean':
945 943 # Client should serialize False as '0', which is a non-empty string
946 944 # so it evaluates as a True bool.
947 945 if v == '0':
948 946 opts[k] = False
949 947 else:
950 948 opts[k] = bool(v)
951 949 elif keytype != 'plain':
952 950 raise KeyError('unknown getbundle option type %s'
953 951 % keytype)
954 952
955 953 if not bundle1allowed(repo, 'pull'):
956 954 if not exchange.bundle2requested(opts.get('bundlecaps')):
957 955 if proto.name == 'http-v1':
958 956 return wireprototypes.ooberror(bundle2required)
959 957 raise error.Abort(bundle2requiredmain,
960 958 hint=bundle2requiredhint)
961 959
962 960 prefercompressed = True
963 961
964 962 try:
965 963 if repo.ui.configbool('server', 'disablefullbundle'):
966 964 # Check to see if this is a full clone.
967 965 clheads = set(repo.changelog.heads())
968 966 changegroup = opts.get('cg', True)
969 967 heads = set(opts.get('heads', set()))
970 968 common = set(opts.get('common', set()))
971 969 common.discard(nullid)
972 970 if changegroup and not common and clheads == heads:
973 971 raise error.Abort(
974 972 _('server has pull-based clones disabled'),
975 973 hint=_('remove --pull if specified or upgrade Mercurial'))
976 974
977 975 info, chunks = exchange.getbundlechunks(repo, 'serve',
978 976 **pycompat.strkwargs(opts))
979 977 prefercompressed = info.get('prefercompressed', True)
980 978 except error.Abort as exc:
981 979 # cleanly forward Abort error to the client
982 980 if not exchange.bundle2requested(opts.get('bundlecaps')):
983 981 if proto.name == 'http-v1':
984 982 return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n')
985 983 raise # cannot do better for bundle1 + ssh
986 984 # bundle2 request expect a bundle2 reply
987 985 bundler = bundle2.bundle20(repo.ui)
988 986 manargs = [('message', pycompat.bytestr(exc))]
989 987 advargs = []
990 988 if exc.hint is not None:
991 989 advargs.append(('hint', exc.hint))
992 990 bundler.addpart(bundle2.bundlepart('error:abort',
993 991 manargs, advargs))
994 992 chunks = bundler.getchunks()
995 993 prefercompressed = False
996 994
997 995 return wireprototypes.streamres(
998 996 gen=chunks, prefer_uncompressed=not prefercompressed)
999 997
1000 998 @wireprotocommand('heads', permission='pull')
1001 999 def heads(repo, proto):
1002 1000 h = repo.heads()
1003 1001 return wireprototypes.bytesresponse(encodelist(h) + '\n')
1004 1002
1005 1003 @wireprotocommand('hello', permission='pull')
1006 1004 def hello(repo, proto):
1007 1005 """Called as part of SSH handshake to obtain server info.
1008 1006
1009 1007 Returns a list of lines describing interesting things about the
1010 1008 server, in an RFC822-like format.
1011 1009
1012 1010 Currently, the only one defined is ``capabilities``, which consists of a
1013 1011 line of space separated tokens describing server abilities:
1014 1012
1015 1013 capabilities: <token0> <token1> <token2>
1016 1014 """
1017 1015 caps = capabilities(repo, proto).data
1018 1016 return wireprototypes.bytesresponse('capabilities: %s\n' % caps)
1019 1017
1020 1018 @wireprotocommand('listkeys', 'namespace', permission='pull')
1021 1019 def listkeys(repo, proto, namespace):
1022 1020 d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
1023 1021 return wireprototypes.bytesresponse(pushkeymod.encodekeys(d))
1024 1022
1025 1023 @wireprotocommand('lookup', 'key', permission='pull')
1026 1024 def lookup(repo, proto, key):
1027 1025 try:
1028 1026 k = encoding.tolocal(key)
1029 1027 n = repo.lookup(k)
1030 1028 r = hex(n)
1031 1029 success = 1
1032 1030 except Exception as inst:
1033 1031 r = stringutil.forcebytestr(inst)
1034 1032 success = 0
1035 1033 return wireprototypes.bytesresponse('%d %s\n' % (success, r))
1036 1034
1037 1035 @wireprotocommand('known', 'nodes *', permission='pull')
1038 1036 def known(repo, proto, nodes, others):
1039 1037 v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
1040 1038 return wireprototypes.bytesresponse(v)
1041 1039
1042 1040 @wireprotocommand('protocaps', 'caps', permission='pull',
1043 1041 transportpolicy=POLICY_V1_ONLY)
1044 1042 def protocaps(repo, proto, caps):
1045 1043 if proto.name == wireprototypes.SSHV1:
1046 1044 proto._protocaps = set(caps.split(' '))
1047 1045 return wireprototypes.bytesresponse('OK')
1048 1046
1049 1047 @wireprotocommand('pushkey', 'namespace key old new', permission='push')
1050 1048 def pushkey(repo, proto, namespace, key, old, new):
1051 1049 # compatibility with pre-1.8 clients which were accidentally
1052 1050 # sending raw binary nodes rather than utf-8-encoded hex
1053 1051 if len(new) == 20 and stringutil.escapestr(new) != new:
1054 1052 # looks like it could be a binary node
1055 1053 try:
1056 1054 new.decode('utf-8')
1057 1055 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
1058 1056 except UnicodeDecodeError:
1059 1057 pass # binary, leave unmodified
1060 1058 else:
1061 1059 new = encoding.tolocal(new) # normal path
1062 1060
1063 1061 with proto.mayberedirectstdio() as output:
1064 1062 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
1065 1063 encoding.tolocal(old), new) or False
1066 1064
1067 1065 output = output.getvalue() if output else ''
1068 1066 return wireprototypes.bytesresponse('%d\n%s' % (int(r), output))
1069 1067
1070 1068 @wireprotocommand('stream_out', permission='pull')
1071 1069 def stream(repo, proto):
1072 1070 '''If the server supports streaming clone, it advertises the "stream"
1073 1071 capability with a value representing the version and flags of the repo
1074 1072 it is serving. Client checks to see if it understands the format.
1075 1073 '''
1076 1074 return wireprototypes.streamreslegacy(
1077 1075 streamclone.generatev1wireproto(repo))
1078 1076
1079 1077 @wireprotocommand('unbundle', 'heads', permission='push')
1080 1078 def unbundle(repo, proto, heads):
1081 1079 their_heads = decodelist(heads)
1082 1080
1083 1081 with proto.mayberedirectstdio() as output:
1084 1082 try:
1085 1083 exchange.check_heads(repo, their_heads, 'preparing changes')
1086 1084
1087 1085 # write bundle data to temporary file because it can be big
1088 1086 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
1089 1087 fp = os.fdopen(fd, r'wb+')
1090 1088 r = 0
1091 1089 try:
1092 1090 proto.forwardpayload(fp)
1093 1091 fp.seek(0)
1094 1092 gen = exchange.readbundle(repo.ui, fp, None)
1095 1093 if (isinstance(gen, changegroupmod.cg1unpacker)
1096 1094 and not bundle1allowed(repo, 'push')):
1097 1095 if proto.name == 'http-v1':
1098 1096 # need to special case http because stderr do not get to
1099 1097 # the http client on failed push so we need to abuse
1100 1098 # some other error type to make sure the message get to
1101 1099 # the user.
1102 1100 return wireprototypes.ooberror(bundle2required)
1103 1101 raise error.Abort(bundle2requiredmain,
1104 1102 hint=bundle2requiredhint)
1105 1103
1106 1104 r = exchange.unbundle(repo, gen, their_heads, 'serve',
1107 1105 proto.client())
1108 1106 if util.safehasattr(r, 'addpart'):
1109 1107 # The return looks streamable, we are in the bundle2 case
1110 1108 # and should return a stream.
1111 1109 return wireprototypes.streamreslegacy(gen=r.getchunks())
1112 1110 return wireprototypes.pushres(
1113 1111 r, output.getvalue() if output else '')
1114 1112
1115 1113 finally:
1116 1114 fp.close()
1117 1115 os.unlink(tempname)
1118 1116
1119 1117 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1120 1118 # handle non-bundle2 case first
1121 1119 if not getattr(exc, 'duringunbundle2', False):
1122 1120 try:
1123 1121 raise
1124 1122 except error.Abort:
1125 1123 # The old code we moved used procutil.stderr directly.
1126 1124 # We did not change it to minimise code change.
1127 1125 # This need to be moved to something proper.
1128 1126 # Feel free to do it.
1129 1127 procutil.stderr.write("abort: %s\n" % exc)
1130 1128 if exc.hint is not None:
1131 1129 procutil.stderr.write("(%s)\n" % exc.hint)
1132 1130 procutil.stderr.flush()
1133 1131 return wireprototypes.pushres(
1134 1132 0, output.getvalue() if output else '')
1135 1133 except error.PushRaced:
1136 1134 return wireprototypes.pusherr(
1137 1135 pycompat.bytestr(exc),
1138 1136 output.getvalue() if output else '')
1139 1137
1140 1138 bundler = bundle2.bundle20(repo.ui)
1141 1139 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1142 1140 bundler.addpart(out)
1143 1141 try:
1144 1142 try:
1145 1143 raise
1146 1144 except error.PushkeyFailed as exc:
1147 1145 # check client caps
1148 1146 remotecaps = getattr(exc, '_replycaps', None)
1149 1147 if (remotecaps is not None
1150 1148 and 'pushkey' not in remotecaps.get('error', ())):
1151 1149 # no support remote side, fallback to Abort handler.
1152 1150 raise
1153 1151 part = bundler.newpart('error:pushkey')
1154 1152 part.addparam('in-reply-to', exc.partid)
1155 1153 if exc.namespace is not None:
1156 1154 part.addparam('namespace', exc.namespace,
1157 1155 mandatory=False)
1158 1156 if exc.key is not None:
1159 1157 part.addparam('key', exc.key, mandatory=False)
1160 1158 if exc.new is not None:
1161 1159 part.addparam('new', exc.new, mandatory=False)
1162 1160 if exc.old is not None:
1163 1161 part.addparam('old', exc.old, mandatory=False)
1164 1162 if exc.ret is not None:
1165 1163 part.addparam('ret', exc.ret, mandatory=False)
1166 1164 except error.BundleValueError as exc:
1167 1165 errpart = bundler.newpart('error:unsupportedcontent')
1168 1166 if exc.parttype is not None:
1169 1167 errpart.addparam('parttype', exc.parttype)
1170 1168 if exc.params:
1171 1169 errpart.addparam('params', '\0'.join(exc.params))
1172 1170 except error.Abort as exc:
1173 1171 manargs = [('message', stringutil.forcebytestr(exc))]
1174 1172 advargs = []
1175 1173 if exc.hint is not None:
1176 1174 advargs.append(('hint', exc.hint))
1177 1175 bundler.addpart(bundle2.bundlepart('error:abort',
1178 1176 manargs, advargs))
1179 1177 except error.PushRaced as exc:
1180 1178 bundler.newpart('error:pushraced',
1181 1179 [('message', stringutil.forcebytestr(exc))])
1182 1180 return wireprototypes.streamreslegacy(gen=bundler.getchunks())
General Comments 0
You need to be logged in to leave comments. Login now