##// END OF EJS Templates
wireproto: move value encoding functions to wireprototypes (API)...
Gregory Szorc -
r37630:5e71dea7 default
parent child Browse files
Show More
@@ -1,1186 +1,1187
1 1 # Infinite push
2 2 #
3 3 # Copyright 2016 Facebook, Inc.
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7 """ store some pushes in a remote blob store on the server (EXPERIMENTAL)
8 8
9 9 [infinitepush]
10 10 # Server-side and client-side option. Pattern of the infinitepush bookmark
11 11 branchpattern = PATTERN
12 12
13 13 # Server or client
14 14 server = False
15 15
16 16 # Server-side option. Possible values: 'disk' or 'sql'. Fails if not set
17 17 indextype = disk
18 18
19 19 # Server-side option. Used only if indextype=sql.
20 20 # Format: 'IP:PORT:DB_NAME:USER:PASSWORD'
21 21 sqlhost = IP:PORT:DB_NAME:USER:PASSWORD
22 22
23 23 # Server-side option. Used only if indextype=disk.
24 24 # Filesystem path to the index store
25 25 indexpath = PATH
26 26
27 27 # Server-side option. Possible values: 'disk' or 'external'
28 28 # Fails if not set
29 29 storetype = disk
30 30
31 31 # Server-side option.
32 32 # Path to the binary that will save bundle to the bundlestore
33 33 # Formatted cmd line will be passed to it (see `put_args`)
34 34 put_binary = put
35 35
36 36 # Serser-side option. Used only if storetype=external.
37 37 # Format cmd-line string for put binary. Placeholder: {filename}
38 38 put_args = {filename}
39 39
40 40 # Server-side option.
41 41 # Path to the binary that get bundle from the bundlestore.
42 42 # Formatted cmd line will be passed to it (see `get_args`)
43 43 get_binary = get
44 44
45 45 # Serser-side option. Used only if storetype=external.
46 46 # Format cmd-line string for get binary. Placeholders: {filename} {handle}
47 47 get_args = {filename} {handle}
48 48
49 49 # Server-side option
50 50 logfile = FIlE
51 51
52 52 # Server-side option
53 53 loglevel = DEBUG
54 54
55 55 # Server-side option. Used only if indextype=sql.
56 56 # Sets mysql wait_timeout option.
57 57 waittimeout = 300
58 58
59 59 # Server-side option. Used only if indextype=sql.
60 60 # Sets mysql innodb_lock_wait_timeout option.
61 61 locktimeout = 120
62 62
63 63 # Server-side option. Used only if indextype=sql.
64 64 # Name of the repository
65 65 reponame = ''
66 66
67 67 # Client-side option. Used by --list-remote option. List of remote scratch
68 68 # patterns to list if no patterns are specified.
69 69 defaultremotepatterns = ['*']
70 70
71 71 # Instructs infinitepush to forward all received bundle2 parts to the
72 72 # bundle for storage. Defaults to False.
73 73 storeallparts = True
74 74
75 75 # routes each incoming push to the bundlestore. defaults to False
76 76 pushtobundlestore = True
77 77
78 78 [remotenames]
79 79 # Client-side option
80 80 # This option should be set only if remotenames extension is enabled.
81 81 # Whether remote bookmarks are tracked by remotenames extension.
82 82 bookmarks = True
83 83 """
84 84
85 85 from __future__ import absolute_import
86 86
87 87 import collections
88 88 import contextlib
89 89 import errno
90 90 import functools
91 91 import logging
92 92 import os
93 93 import random
94 94 import re
95 95 import socket
96 96 import subprocess
97 97 import tempfile
98 98 import time
99 99
100 100 from mercurial.node import (
101 101 bin,
102 102 hex,
103 103 )
104 104
105 105 from mercurial.i18n import _
106 106
107 107 from mercurial.utils import (
108 108 procutil,
109 109 stringutil,
110 110 )
111 111
112 112 from mercurial import (
113 113 bundle2,
114 114 changegroup,
115 115 commands,
116 116 discovery,
117 117 encoding,
118 118 error,
119 119 exchange,
120 120 extensions,
121 121 hg,
122 122 localrepo,
123 123 peer,
124 124 phases,
125 125 pushkey,
126 126 pycompat,
127 127 registrar,
128 128 util,
129 129 wireproto,
130 wireprototypes,
130 131 )
131 132
132 133 from . import (
133 134 bundleparts,
134 135 common,
135 136 )
136 137
137 138 # Note for extension authors: ONLY specify testedwith = 'ships-with-hg-core' for
138 139 # extensions which SHIP WITH MERCURIAL. Non-mainline extensions should
139 140 # be specifying the version(s) of Mercurial they are tested with, or
140 141 # leave the attribute unspecified.
141 142 testedwith = 'ships-with-hg-core'
142 143
143 144 configtable = {}
144 145 configitem = registrar.configitem(configtable)
145 146
146 147 configitem('infinitepush', 'server',
147 148 default=False,
148 149 )
149 150 configitem('infinitepush', 'storetype',
150 151 default='',
151 152 )
152 153 configitem('infinitepush', 'indextype',
153 154 default='',
154 155 )
155 156 configitem('infinitepush', 'indexpath',
156 157 default='',
157 158 )
158 159 configitem('infinitepush', 'storeallparts',
159 160 default=False,
160 161 )
161 162 configitem('infinitepush', 'reponame',
162 163 default='',
163 164 )
164 165 configitem('scratchbranch', 'storepath',
165 166 default='',
166 167 )
167 168 configitem('infinitepush', 'branchpattern',
168 169 default='',
169 170 )
170 171 configitem('infinitepush', 'pushtobundlestore',
171 172 default=False,
172 173 )
173 174 configitem('experimental', 'server-bundlestore-bookmark',
174 175 default='',
175 176 )
176 177 configitem('experimental', 'infinitepush-scratchpush',
177 178 default=False,
178 179 )
179 180
180 181 experimental = 'experimental'
181 182 configbookmark = 'server-bundlestore-bookmark'
182 183 configscratchpush = 'infinitepush-scratchpush'
183 184
184 185 scratchbranchparttype = bundleparts.scratchbranchparttype
185 186 revsetpredicate = registrar.revsetpredicate()
186 187 templatekeyword = registrar.templatekeyword()
187 188 _scratchbranchmatcher = lambda x: False
188 189 _maybehash = re.compile(r'^[a-f0-9]+$').search
189 190
190 191 def _buildexternalbundlestore(ui):
191 192 put_args = ui.configlist('infinitepush', 'put_args', [])
192 193 put_binary = ui.config('infinitepush', 'put_binary')
193 194 if not put_binary:
194 195 raise error.Abort('put binary is not specified')
195 196 get_args = ui.configlist('infinitepush', 'get_args', [])
196 197 get_binary = ui.config('infinitepush', 'get_binary')
197 198 if not get_binary:
198 199 raise error.Abort('get binary is not specified')
199 200 from . import store
200 201 return store.externalbundlestore(put_binary, put_args, get_binary, get_args)
201 202
202 203 def _buildsqlindex(ui):
203 204 sqlhost = ui.config('infinitepush', 'sqlhost')
204 205 if not sqlhost:
205 206 raise error.Abort(_('please set infinitepush.sqlhost'))
206 207 host, port, db, user, password = sqlhost.split(':')
207 208 reponame = ui.config('infinitepush', 'reponame')
208 209 if not reponame:
209 210 raise error.Abort(_('please set infinitepush.reponame'))
210 211
211 212 logfile = ui.config('infinitepush', 'logfile', '')
212 213 waittimeout = ui.configint('infinitepush', 'waittimeout', 300)
213 214 locktimeout = ui.configint('infinitepush', 'locktimeout', 120)
214 215 from . import sqlindexapi
215 216 return sqlindexapi.sqlindexapi(
216 217 reponame, host, port, db, user, password,
217 218 logfile, _getloglevel(ui), waittimeout=waittimeout,
218 219 locktimeout=locktimeout)
219 220
220 221 def _getloglevel(ui):
221 222 loglevel = ui.config('infinitepush', 'loglevel', 'DEBUG')
222 223 numeric_loglevel = getattr(logging, loglevel.upper(), None)
223 224 if not isinstance(numeric_loglevel, int):
224 225 raise error.Abort(_('invalid log level %s') % loglevel)
225 226 return numeric_loglevel
226 227
227 228 def _tryhoist(ui, remotebookmark):
228 229 '''returns a bookmarks with hoisted part removed
229 230
230 231 Remotenames extension has a 'hoist' config that allows to use remote
231 232 bookmarks without specifying remote path. For example, 'hg update master'
232 233 works as well as 'hg update remote/master'. We want to allow the same in
233 234 infinitepush.
234 235 '''
235 236
236 237 if common.isremotebooksenabled(ui):
237 238 hoist = ui.config('remotenames', 'hoistedpeer') + '/'
238 239 if remotebookmark.startswith(hoist):
239 240 return remotebookmark[len(hoist):]
240 241 return remotebookmark
241 242
242 243 class bundlestore(object):
243 244 def __init__(self, repo):
244 245 self._repo = repo
245 246 storetype = self._repo.ui.config('infinitepush', 'storetype')
246 247 if storetype == 'disk':
247 248 from . import store
248 249 self.store = store.filebundlestore(self._repo.ui, self._repo)
249 250 elif storetype == 'external':
250 251 self.store = _buildexternalbundlestore(self._repo.ui)
251 252 else:
252 253 raise error.Abort(
253 254 _('unknown infinitepush store type specified %s') % storetype)
254 255
255 256 indextype = self._repo.ui.config('infinitepush', 'indextype')
256 257 if indextype == 'disk':
257 258 from . import fileindexapi
258 259 self.index = fileindexapi.fileindexapi(self._repo)
259 260 elif indextype == 'sql':
260 261 self.index = _buildsqlindex(self._repo.ui)
261 262 else:
262 263 raise error.Abort(
263 264 _('unknown infinitepush index type specified %s') % indextype)
264 265
265 266 def _isserver(ui):
266 267 return ui.configbool('infinitepush', 'server')
267 268
268 269 def reposetup(ui, repo):
269 270 if _isserver(ui) and repo.local():
270 271 repo.bundlestore = bundlestore(repo)
271 272
272 273 def extsetup(ui):
273 274 commonsetup(ui)
274 275 if _isserver(ui):
275 276 serverextsetup(ui)
276 277 else:
277 278 clientextsetup(ui)
278 279
279 280 def commonsetup(ui):
280 281 wireproto.commands['listkeyspatterns'] = (
281 282 wireprotolistkeyspatterns, 'namespace patterns')
282 283 scratchbranchpat = ui.config('infinitepush', 'branchpattern')
283 284 if scratchbranchpat:
284 285 global _scratchbranchmatcher
285 286 kind, pat, _scratchbranchmatcher = \
286 287 stringutil.stringmatcher(scratchbranchpat)
287 288
288 289 def serverextsetup(ui):
289 290 origpushkeyhandler = bundle2.parthandlermapping['pushkey']
290 291
291 292 def newpushkeyhandler(*args, **kwargs):
292 293 bundle2pushkey(origpushkeyhandler, *args, **kwargs)
293 294 newpushkeyhandler.params = origpushkeyhandler.params
294 295 bundle2.parthandlermapping['pushkey'] = newpushkeyhandler
295 296
296 297 orighandlephasehandler = bundle2.parthandlermapping['phase-heads']
297 298 newphaseheadshandler = lambda *args, **kwargs: \
298 299 bundle2handlephases(orighandlephasehandler, *args, **kwargs)
299 300 newphaseheadshandler.params = orighandlephasehandler.params
300 301 bundle2.parthandlermapping['phase-heads'] = newphaseheadshandler
301 302
302 303 extensions.wrapfunction(localrepo.localrepository, 'listkeys',
303 304 localrepolistkeys)
304 305 wireproto.commands['lookup'] = (
305 306 _lookupwrap(wireproto.commands['lookup'][0]), 'key')
306 307 extensions.wrapfunction(exchange, 'getbundlechunks', getbundlechunks)
307 308
308 309 extensions.wrapfunction(bundle2, 'processparts', processparts)
309 310
310 311 def clientextsetup(ui):
311 312 entry = extensions.wrapcommand(commands.table, 'push', _push)
312 313
313 314 entry[1].append(
314 315 ('', 'bundle-store', None,
315 316 _('force push to go to bundle store (EXPERIMENTAL)')))
316 317
317 318 extensions.wrapcommand(commands.table, 'pull', _pull)
318 319
319 320 extensions.wrapfunction(discovery, 'checkheads', _checkheads)
320 321
321 322 wireproto.wirepeer.listkeyspatterns = listkeyspatterns
322 323
323 324 partorder = exchange.b2partsgenorder
324 325 index = partorder.index('changeset')
325 326 partorder.insert(
326 327 index, partorder.pop(partorder.index(scratchbranchparttype)))
327 328
328 329 def _checkheads(orig, pushop):
329 330 if pushop.ui.configbool(experimental, configscratchpush, False):
330 331 return
331 332 return orig(pushop)
332 333
333 334 def wireprotolistkeyspatterns(repo, proto, namespace, patterns):
334 patterns = wireproto.decodelist(patterns)
335 patterns = wireprototypes.decodelist(patterns)
335 336 d = repo.listkeys(encoding.tolocal(namespace), patterns).iteritems()
336 337 return pushkey.encodekeys(d)
337 338
338 339 def localrepolistkeys(orig, self, namespace, patterns=None):
339 340 if namespace == 'bookmarks' and patterns:
340 341 index = self.bundlestore.index
341 342 results = {}
342 343 bookmarks = orig(self, namespace)
343 344 for pattern in patterns:
344 345 results.update(index.getbookmarks(pattern))
345 346 if pattern.endswith('*'):
346 347 pattern = 're:^' + pattern[:-1] + '.*'
347 348 kind, pat, matcher = stringutil.stringmatcher(pattern)
348 349 for bookmark, node in bookmarks.iteritems():
349 350 if matcher(bookmark):
350 351 results[bookmark] = node
351 352 return results
352 353 else:
353 354 return orig(self, namespace)
354 355
355 356 @peer.batchable
356 357 def listkeyspatterns(self, namespace, patterns):
357 358 if not self.capable('pushkey'):
358 359 yield {}, None
359 360 f = peer.future()
360 361 self.ui.debug('preparing listkeys for "%s" with pattern "%s"\n' %
361 362 (namespace, patterns))
362 363 yield {
363 364 'namespace': encoding.fromlocal(namespace),
364 'patterns': wireproto.encodelist(patterns)
365 'patterns': wireprototypes.encodelist(patterns)
365 366 }, f
366 367 d = f.value
367 368 self.ui.debug('received listkey for "%s": %i bytes\n'
368 369 % (namespace, len(d)))
369 370 yield pushkey.decodekeys(d)
370 371
371 372 def _readbundlerevs(bundlerepo):
372 373 return list(bundlerepo.revs('bundle()'))
373 374
374 375 def _includefilelogstobundle(bundlecaps, bundlerepo, bundlerevs, ui):
375 376 '''Tells remotefilelog to include all changed files to the changegroup
376 377
377 378 By default remotefilelog doesn't include file content to the changegroup.
378 379 But we need to include it if we are fetching from bundlestore.
379 380 '''
380 381 changedfiles = set()
381 382 cl = bundlerepo.changelog
382 383 for r in bundlerevs:
383 384 # [3] means changed files
384 385 changedfiles.update(cl.read(r)[3])
385 386 if not changedfiles:
386 387 return bundlecaps
387 388
388 389 changedfiles = '\0'.join(changedfiles)
389 390 newcaps = []
390 391 appended = False
391 392 for cap in (bundlecaps or []):
392 393 if cap.startswith('excludepattern='):
393 394 newcaps.append('\0'.join((cap, changedfiles)))
394 395 appended = True
395 396 else:
396 397 newcaps.append(cap)
397 398 if not appended:
398 399 # Not found excludepattern cap. Just append it
399 400 newcaps.append('excludepattern=' + changedfiles)
400 401
401 402 return newcaps
402 403
403 404 def _rebundle(bundlerepo, bundleroots, unknownhead):
404 405 '''
405 406 Bundle may include more revision then user requested. For example,
406 407 if user asks for revision but bundle also consists its descendants.
407 408 This function will filter out all revision that user is not requested.
408 409 '''
409 410 parts = []
410 411
411 412 version = '02'
412 413 outgoing = discovery.outgoing(bundlerepo, commonheads=bundleroots,
413 414 missingheads=[unknownhead])
414 415 cgstream = changegroup.makestream(bundlerepo, outgoing, version, 'pull')
415 416 cgstream = util.chunkbuffer(cgstream).read()
416 417 cgpart = bundle2.bundlepart('changegroup', data=cgstream)
417 418 cgpart.addparam('version', version)
418 419 parts.append(cgpart)
419 420
420 421 return parts
421 422
422 423 def _getbundleroots(oldrepo, bundlerepo, bundlerevs):
423 424 cl = bundlerepo.changelog
424 425 bundleroots = []
425 426 for rev in bundlerevs:
426 427 node = cl.node(rev)
427 428 parents = cl.parents(node)
428 429 for parent in parents:
429 430 # include all revs that exist in the main repo
430 431 # to make sure that bundle may apply client-side
431 432 if parent in oldrepo:
432 433 bundleroots.append(parent)
433 434 return bundleroots
434 435
435 436 def _needsrebundling(head, bundlerepo):
436 437 bundleheads = list(bundlerepo.revs('heads(bundle())'))
437 438 return not (len(bundleheads) == 1 and
438 439 bundlerepo[bundleheads[0]].node() == head)
439 440
440 441 def _generateoutputparts(head, bundlerepo, bundleroots, bundlefile):
441 442 '''generates bundle that will be send to the user
442 443
443 444 returns tuple with raw bundle string and bundle type
444 445 '''
445 446 parts = []
446 447 if not _needsrebundling(head, bundlerepo):
447 448 with util.posixfile(bundlefile, "rb") as f:
448 449 unbundler = exchange.readbundle(bundlerepo.ui, f, bundlefile)
449 450 if isinstance(unbundler, changegroup.cg1unpacker):
450 451 part = bundle2.bundlepart('changegroup',
451 452 data=unbundler._stream.read())
452 453 part.addparam('version', '01')
453 454 parts.append(part)
454 455 elif isinstance(unbundler, bundle2.unbundle20):
455 456 haschangegroup = False
456 457 for part in unbundler.iterparts():
457 458 if part.type == 'changegroup':
458 459 haschangegroup = True
459 460 newpart = bundle2.bundlepart(part.type, data=part.read())
460 461 for key, value in part.params.iteritems():
461 462 newpart.addparam(key, value)
462 463 parts.append(newpart)
463 464
464 465 if not haschangegroup:
465 466 raise error.Abort(
466 467 'unexpected bundle without changegroup part, ' +
467 468 'head: %s' % hex(head),
468 469 hint='report to administrator')
469 470 else:
470 471 raise error.Abort('unknown bundle type')
471 472 else:
472 473 parts = _rebundle(bundlerepo, bundleroots, head)
473 474
474 475 return parts
475 476
476 477 def getbundlechunks(orig, repo, source, heads=None, bundlecaps=None, **kwargs):
477 478 heads = heads or []
478 479 # newheads are parents of roots of scratch bundles that were requested
479 480 newphases = {}
480 481 scratchbundles = []
481 482 newheads = []
482 483 scratchheads = []
483 484 nodestobundle = {}
484 485 allbundlestocleanup = []
485 486 try:
486 487 for head in heads:
487 488 if head not in repo.changelog.nodemap:
488 489 if head not in nodestobundle:
489 490 newbundlefile = common.downloadbundle(repo, head)
490 491 bundlepath = "bundle:%s+%s" % (repo.root, newbundlefile)
491 492 bundlerepo = hg.repository(repo.ui, bundlepath)
492 493
493 494 allbundlestocleanup.append((bundlerepo, newbundlefile))
494 495 bundlerevs = set(_readbundlerevs(bundlerepo))
495 496 bundlecaps = _includefilelogstobundle(
496 497 bundlecaps, bundlerepo, bundlerevs, repo.ui)
497 498 cl = bundlerepo.changelog
498 499 bundleroots = _getbundleroots(repo, bundlerepo, bundlerevs)
499 500 for rev in bundlerevs:
500 501 node = cl.node(rev)
501 502 newphases[hex(node)] = str(phases.draft)
502 503 nodestobundle[node] = (bundlerepo, bundleroots,
503 504 newbundlefile)
504 505
505 506 scratchbundles.append(
506 507 _generateoutputparts(head, *nodestobundle[head]))
507 508 newheads.extend(bundleroots)
508 509 scratchheads.append(head)
509 510 finally:
510 511 for bundlerepo, bundlefile in allbundlestocleanup:
511 512 bundlerepo.close()
512 513 try:
513 514 os.unlink(bundlefile)
514 515 except (IOError, OSError):
515 516 # if we can't cleanup the file then just ignore the error,
516 517 # no need to fail
517 518 pass
518 519
519 520 pullfrombundlestore = bool(scratchbundles)
520 521 wrappedchangegrouppart = False
521 522 wrappedlistkeys = False
522 523 oldchangegrouppart = exchange.getbundle2partsmapping['changegroup']
523 524 try:
524 525 def _changegrouppart(bundler, *args, **kwargs):
525 526 # Order is important here. First add non-scratch part
526 527 # and only then add parts with scratch bundles because
527 528 # non-scratch part contains parents of roots of scratch bundles.
528 529 result = oldchangegrouppart(bundler, *args, **kwargs)
529 530 for bundle in scratchbundles:
530 531 for part in bundle:
531 532 bundler.addpart(part)
532 533 return result
533 534
534 535 exchange.getbundle2partsmapping['changegroup'] = _changegrouppart
535 536 wrappedchangegrouppart = True
536 537
537 538 def _listkeys(orig, self, namespace):
538 539 origvalues = orig(self, namespace)
539 540 if namespace == 'phases' and pullfrombundlestore:
540 541 if origvalues.get('publishing') == 'True':
541 542 # Make repo non-publishing to preserve draft phase
542 543 del origvalues['publishing']
543 544 origvalues.update(newphases)
544 545 return origvalues
545 546
546 547 extensions.wrapfunction(localrepo.localrepository, 'listkeys',
547 548 _listkeys)
548 549 wrappedlistkeys = True
549 550 heads = list((set(newheads) | set(heads)) - set(scratchheads))
550 551 result = orig(repo, source, heads=heads,
551 552 bundlecaps=bundlecaps, **kwargs)
552 553 finally:
553 554 if wrappedchangegrouppart:
554 555 exchange.getbundle2partsmapping['changegroup'] = oldchangegrouppart
555 556 if wrappedlistkeys:
556 557 extensions.unwrapfunction(localrepo.localrepository, 'listkeys',
557 558 _listkeys)
558 559 return result
559 560
560 561 def _lookupwrap(orig):
561 562 def _lookup(repo, proto, key):
562 563 localkey = encoding.tolocal(key)
563 564
564 565 if isinstance(localkey, str) and _scratchbranchmatcher(localkey):
565 566 scratchnode = repo.bundlestore.index.getnode(localkey)
566 567 if scratchnode:
567 568 return "%s %s\n" % (1, scratchnode)
568 569 else:
569 570 return "%s %s\n" % (0, 'scratch branch %s not found' % localkey)
570 571 else:
571 572 try:
572 573 r = hex(repo.lookup(localkey))
573 574 return "%s %s\n" % (1, r)
574 575 except Exception as inst:
575 576 if repo.bundlestore.index.getbundle(localkey):
576 577 return "%s %s\n" % (1, localkey)
577 578 else:
578 579 r = str(inst)
579 580 return "%s %s\n" % (0, r)
580 581 return _lookup
581 582
582 583 def _pull(orig, ui, repo, source="default", **opts):
583 584 opts = pycompat.byteskwargs(opts)
584 585 # Copy paste from `pull` command
585 586 source, branches = hg.parseurl(ui.expandpath(source), opts.get('branch'))
586 587
587 588 scratchbookmarks = {}
588 589 unfi = repo.unfiltered()
589 590 unknownnodes = []
590 591 for rev in opts.get('rev', []):
591 592 if rev not in unfi:
592 593 unknownnodes.append(rev)
593 594 if opts.get('bookmark'):
594 595 bookmarks = []
595 596 revs = opts.get('rev') or []
596 597 for bookmark in opts.get('bookmark'):
597 598 if _scratchbranchmatcher(bookmark):
598 599 # rev is not known yet
599 600 # it will be fetched with listkeyspatterns next
600 601 scratchbookmarks[bookmark] = 'REVTOFETCH'
601 602 else:
602 603 bookmarks.append(bookmark)
603 604
604 605 if scratchbookmarks:
605 606 other = hg.peer(repo, opts, source)
606 607 fetchedbookmarks = other.listkeyspatterns(
607 608 'bookmarks', patterns=scratchbookmarks)
608 609 for bookmark in scratchbookmarks:
609 610 if bookmark not in fetchedbookmarks:
610 611 raise error.Abort('remote bookmark %s not found!' %
611 612 bookmark)
612 613 scratchbookmarks[bookmark] = fetchedbookmarks[bookmark]
613 614 revs.append(fetchedbookmarks[bookmark])
614 615 opts['bookmark'] = bookmarks
615 616 opts['rev'] = revs
616 617
617 618 if scratchbookmarks or unknownnodes:
618 619 # Set anyincoming to True
619 620 extensions.wrapfunction(discovery, 'findcommonincoming',
620 621 _findcommonincoming)
621 622 try:
622 623 # Remote scratch bookmarks will be deleted because remotenames doesn't
623 624 # know about them. Let's save it before pull and restore after
624 625 remotescratchbookmarks = _readscratchremotebookmarks(ui, repo, source)
625 626 result = orig(ui, repo, source, **pycompat.strkwargs(opts))
626 627 # TODO(stash): race condition is possible
627 628 # if scratch bookmarks was updated right after orig.
628 629 # But that's unlikely and shouldn't be harmful.
629 630 if common.isremotebooksenabled(ui):
630 631 remotescratchbookmarks.update(scratchbookmarks)
631 632 _saveremotebookmarks(repo, remotescratchbookmarks, source)
632 633 else:
633 634 _savelocalbookmarks(repo, scratchbookmarks)
634 635 return result
635 636 finally:
636 637 if scratchbookmarks:
637 638 extensions.unwrapfunction(discovery, 'findcommonincoming')
638 639
639 640 def _readscratchremotebookmarks(ui, repo, other):
640 641 if common.isremotebooksenabled(ui):
641 642 remotenamesext = extensions.find('remotenames')
642 643 remotepath = remotenamesext.activepath(repo.ui, other)
643 644 result = {}
644 645 # Let's refresh remotenames to make sure we have it up to date
645 646 # Seems that `repo.names['remotebookmarks']` may return stale bookmarks
646 647 # and it results in deleting scratch bookmarks. Our best guess how to
647 648 # fix it is to use `clearnames()`
648 649 repo._remotenames.clearnames()
649 650 for remotebookmark in repo.names['remotebookmarks'].listnames(repo):
650 651 path, bookname = remotenamesext.splitremotename(remotebookmark)
651 652 if path == remotepath and _scratchbranchmatcher(bookname):
652 653 nodes = repo.names['remotebookmarks'].nodes(repo,
653 654 remotebookmark)
654 655 if nodes:
655 656 result[bookname] = hex(nodes[0])
656 657 return result
657 658 else:
658 659 return {}
659 660
660 661 def _saveremotebookmarks(repo, newbookmarks, remote):
661 662 remotenamesext = extensions.find('remotenames')
662 663 remotepath = remotenamesext.activepath(repo.ui, remote)
663 664 branches = collections.defaultdict(list)
664 665 bookmarks = {}
665 666 remotenames = remotenamesext.readremotenames(repo)
666 667 for hexnode, nametype, remote, rname in remotenames:
667 668 if remote != remotepath:
668 669 continue
669 670 if nametype == 'bookmarks':
670 671 if rname in newbookmarks:
671 672 # It's possible if we have a normal bookmark that matches
672 673 # scratch branch pattern. In this case just use the current
673 674 # bookmark node
674 675 del newbookmarks[rname]
675 676 bookmarks[rname] = hexnode
676 677 elif nametype == 'branches':
677 678 # saveremotenames expects 20 byte binary nodes for branches
678 679 branches[rname].append(bin(hexnode))
679 680
680 681 for bookmark, hexnode in newbookmarks.iteritems():
681 682 bookmarks[bookmark] = hexnode
682 683 remotenamesext.saveremotenames(repo, remotepath, branches, bookmarks)
683 684
684 685 def _savelocalbookmarks(repo, bookmarks):
685 686 if not bookmarks:
686 687 return
687 688 with repo.wlock(), repo.lock(), repo.transaction('bookmark') as tr:
688 689 changes = []
689 690 for scratchbook, node in bookmarks.iteritems():
690 691 changectx = repo[node]
691 692 changes.append((scratchbook, changectx.node()))
692 693 repo._bookmarks.applychanges(repo, tr, changes)
693 694
694 695 def _findcommonincoming(orig, *args, **kwargs):
695 696 common, inc, remoteheads = orig(*args, **kwargs)
696 697 return common, True, remoteheads
697 698
698 699 def _push(orig, ui, repo, dest=None, *args, **opts):
699 700
700 701 bookmark = opts.get(r'bookmark')
701 702 # we only support pushing one infinitepush bookmark at once
702 703 if len(bookmark) == 1:
703 704 bookmark = bookmark[0]
704 705 else:
705 706 bookmark = ''
706 707
707 708 oldphasemove = None
708 709 overrides = {(experimental, configbookmark): bookmark}
709 710
710 711 with ui.configoverride(overrides, 'infinitepush'):
711 712 scratchpush = opts.get('bundle_store')
712 713 if _scratchbranchmatcher(bookmark):
713 714 scratchpush = True
714 715 # bundle2 can be sent back after push (for example, bundle2
715 716 # containing `pushkey` part to update bookmarks)
716 717 ui.setconfig(experimental, 'bundle2.pushback', True)
717 718
718 719 if scratchpush:
719 720 # this is an infinitepush, we don't want the bookmark to be applied
720 721 # rather that should be stored in the bundlestore
721 722 opts[r'bookmark'] = []
722 723 ui.setconfig(experimental, configscratchpush, True)
723 724 oldphasemove = extensions.wrapfunction(exchange,
724 725 '_localphasemove',
725 726 _phasemove)
726 727 # Copy-paste from `push` command
727 728 path = ui.paths.getpath(dest, default=('default-push', 'default'))
728 729 if not path:
729 730 raise error.Abort(_('default repository not configured!'),
730 731 hint=_("see 'hg help config.paths'"))
731 732 destpath = path.pushloc or path.loc
732 733 # Remote scratch bookmarks will be deleted because remotenames doesn't
733 734 # know about them. Let's save it before push and restore after
734 735 remotescratchbookmarks = _readscratchremotebookmarks(ui, repo, destpath)
735 736 result = orig(ui, repo, dest, *args, **opts)
736 737 if common.isremotebooksenabled(ui):
737 738 if bookmark and scratchpush:
738 739 other = hg.peer(repo, opts, destpath)
739 740 fetchedbookmarks = other.listkeyspatterns('bookmarks',
740 741 patterns=[bookmark])
741 742 remotescratchbookmarks.update(fetchedbookmarks)
742 743 _saveremotebookmarks(repo, remotescratchbookmarks, destpath)
743 744 if oldphasemove:
744 745 exchange._localphasemove = oldphasemove
745 746 return result
746 747
747 748 def _deleteinfinitepushbookmarks(ui, repo, path, names):
748 749 """Prune remote names by removing the bookmarks we don't want anymore,
749 750 then writing the result back to disk
750 751 """
751 752 remotenamesext = extensions.find('remotenames')
752 753
753 754 # remotename format is:
754 755 # (node, nametype ("branches" or "bookmarks"), remote, name)
755 756 nametype_idx = 1
756 757 remote_idx = 2
757 758 name_idx = 3
758 759 remotenames = [remotename for remotename in \
759 760 remotenamesext.readremotenames(repo) \
760 761 if remotename[remote_idx] == path]
761 762 remote_bm_names = [remotename[name_idx] for remotename in \
762 763 remotenames if remotename[nametype_idx] == "bookmarks"]
763 764
764 765 for name in names:
765 766 if name not in remote_bm_names:
766 767 raise error.Abort(_("infinitepush bookmark '{}' does not exist "
767 768 "in path '{}'").format(name, path))
768 769
769 770 bookmarks = {}
770 771 branches = collections.defaultdict(list)
771 772 for node, nametype, remote, name in remotenames:
772 773 if nametype == "bookmarks" and name not in names:
773 774 bookmarks[name] = node
774 775 elif nametype == "branches":
775 776 # saveremotenames wants binary nodes for branches
776 777 branches[name].append(bin(node))
777 778
778 779 remotenamesext.saveremotenames(repo, path, branches, bookmarks)
779 780
780 781 def _phasemove(orig, pushop, nodes, phase=phases.public):
781 782 """prevent commits from being marked public
782 783
783 784 Since these are going to a scratch branch, they aren't really being
784 785 published."""
785 786
786 787 if phase != phases.public:
787 788 orig(pushop, nodes, phase)
788 789
789 790 @exchange.b2partsgenerator(scratchbranchparttype)
790 791 def partgen(pushop, bundler):
791 792 bookmark = pushop.ui.config(experimental, configbookmark)
792 793 scratchpush = pushop.ui.configbool(experimental, configscratchpush)
793 794 if 'changesets' in pushop.stepsdone or not scratchpush:
794 795 return
795 796
796 797 if scratchbranchparttype not in bundle2.bundle2caps(pushop.remote):
797 798 return
798 799
799 800 pushop.stepsdone.add('changesets')
800 801 if not pushop.outgoing.missing:
801 802 pushop.ui.status(_('no changes found\n'))
802 803 pushop.cgresult = 0
803 804 return
804 805
805 806 # This parameter tells the server that the following bundle is an
806 807 # infinitepush. This let's it switch the part processing to our infinitepush
807 808 # code path.
808 809 bundler.addparam("infinitepush", "True")
809 810
810 811 scratchparts = bundleparts.getscratchbranchparts(pushop.repo,
811 812 pushop.remote,
812 813 pushop.outgoing,
813 814 pushop.ui,
814 815 bookmark)
815 816
816 817 for scratchpart in scratchparts:
817 818 bundler.addpart(scratchpart)
818 819
819 820 def handlereply(op):
820 821 # server either succeeds or aborts; no code to read
821 822 pushop.cgresult = 1
822 823
823 824 return handlereply
824 825
825 826 bundle2.capabilities[bundleparts.scratchbranchparttype] = ()
826 827
827 828 def _getrevs(bundle, oldnode, force, bookmark):
828 829 'extracts and validates the revs to be imported'
829 830 revs = [bundle[r] for r in bundle.revs('sort(bundle())')]
830 831
831 832 # new bookmark
832 833 if oldnode is None:
833 834 return revs
834 835
835 836 # Fast forward update
836 837 if oldnode in bundle and list(bundle.set('bundle() & %s::', oldnode)):
837 838 return revs
838 839
839 840 return revs
840 841
841 842 @contextlib.contextmanager
842 843 def logservicecall(logger, service, **kwargs):
843 844 start = time.time()
844 845 logger(service, eventtype='start', **kwargs)
845 846 try:
846 847 yield
847 848 logger(service, eventtype='success',
848 849 elapsedms=(time.time() - start) * 1000, **kwargs)
849 850 except Exception as e:
850 851 logger(service, eventtype='failure',
851 852 elapsedms=(time.time() - start) * 1000, errormsg=str(e),
852 853 **kwargs)
853 854 raise
854 855
855 856 def _getorcreateinfinitepushlogger(op):
856 857 logger = op.records['infinitepushlogger']
857 858 if not logger:
858 859 ui = op.repo.ui
859 860 try:
860 861 username = procutil.getuser()
861 862 except Exception:
862 863 username = 'unknown'
863 864 # Generate random request id to be able to find all logged entries
864 865 # for the same request. Since requestid is pseudo-generated it may
865 866 # not be unique, but we assume that (hostname, username, requestid)
866 867 # is unique.
867 868 random.seed()
868 869 requestid = random.randint(0, 2000000000)
869 870 hostname = socket.gethostname()
870 871 logger = functools.partial(ui.log, 'infinitepush', user=username,
871 872 requestid=requestid, hostname=hostname,
872 873 reponame=ui.config('infinitepush',
873 874 'reponame'))
874 875 op.records.add('infinitepushlogger', logger)
875 876 else:
876 877 logger = logger[0]
877 878 return logger
878 879
879 880 def storetobundlestore(orig, repo, op, unbundler):
880 881 """stores the incoming bundle coming from push command to the bundlestore
881 882 instead of applying on the revlogs"""
882 883
883 884 repo.ui.status(_("storing changesets on the bundlestore\n"))
884 885 bundler = bundle2.bundle20(repo.ui)
885 886
886 887 # processing each part and storing it in bundler
887 888 with bundle2.partiterator(repo, op, unbundler) as parts:
888 889 for part in parts:
889 890 bundlepart = None
890 891 if part.type == 'replycaps':
891 892 # This configures the current operation to allow reply parts.
892 893 bundle2._processpart(op, part)
893 894 else:
894 895 bundlepart = bundle2.bundlepart(part.type, data=part.read())
895 896 for key, value in part.params.iteritems():
896 897 bundlepart.addparam(key, value)
897 898
898 899 # Certain parts require a response
899 900 if part.type in ('pushkey', 'changegroup'):
900 901 if op.reply is not None:
901 902 rpart = op.reply.newpart('reply:%s' % part.type)
902 903 rpart.addparam('in-reply-to', str(part.id),
903 904 mandatory=False)
904 905 rpart.addparam('return', '1', mandatory=False)
905 906
906 907 op.records.add(part.type, {
907 908 'return': 1,
908 909 })
909 910 if bundlepart:
910 911 bundler.addpart(bundlepart)
911 912
912 913 # storing the bundle in the bundlestore
913 914 buf = util.chunkbuffer(bundler.getchunks())
914 915 fd, bundlefile = tempfile.mkstemp()
915 916 try:
916 917 try:
917 918 fp = os.fdopen(fd, r'wb')
918 919 fp.write(buf.read())
919 920 finally:
920 921 fp.close()
921 922 storebundle(op, {}, bundlefile)
922 923 finally:
923 924 try:
924 925 os.unlink(bundlefile)
925 926 except Exception:
926 927 # we would rather see the original exception
927 928 pass
928 929
929 930 def processparts(orig, repo, op, unbundler):
930 931
931 932 # make sure we don't wrap processparts in case of `hg unbundle`
932 933 if op.source == 'unbundle':
933 934 return orig(repo, op, unbundler)
934 935
935 936 # this server routes each push to bundle store
936 937 if repo.ui.configbool('infinitepush', 'pushtobundlestore'):
937 938 return storetobundlestore(orig, repo, op, unbundler)
938 939
939 940 if unbundler.params.get('infinitepush') != 'True':
940 941 return orig(repo, op, unbundler)
941 942
942 943 handleallparts = repo.ui.configbool('infinitepush', 'storeallparts')
943 944
944 945 bundler = bundle2.bundle20(repo.ui)
945 946 cgparams = None
946 947 with bundle2.partiterator(repo, op, unbundler) as parts:
947 948 for part in parts:
948 949 bundlepart = None
949 950 if part.type == 'replycaps':
950 951 # This configures the current operation to allow reply parts.
951 952 bundle2._processpart(op, part)
952 953 elif part.type == bundleparts.scratchbranchparttype:
953 954 # Scratch branch parts need to be converted to normal
954 955 # changegroup parts, and the extra parameters stored for later
955 956 # when we upload to the store. Eventually those parameters will
956 957 # be put on the actual bundle instead of this part, then we can
957 958 # send a vanilla changegroup instead of the scratchbranch part.
958 959 cgversion = part.params.get('cgversion', '01')
959 960 bundlepart = bundle2.bundlepart('changegroup', data=part.read())
960 961 bundlepart.addparam('version', cgversion)
961 962 cgparams = part.params
962 963
963 964 # If we're not dumping all parts into the new bundle, we need to
964 965 # alert the future pushkey and phase-heads handler to skip
965 966 # the part.
966 967 if not handleallparts:
967 968 op.records.add(scratchbranchparttype + '_skippushkey', True)
968 969 op.records.add(scratchbranchparttype + '_skipphaseheads',
969 970 True)
970 971 else:
971 972 if handleallparts:
972 973 # Ideally we would not process any parts, and instead just
973 974 # forward them to the bundle for storage, but since this
974 975 # differs from previous behavior, we need to put it behind a
975 976 # config flag for incremental rollout.
976 977 bundlepart = bundle2.bundlepart(part.type, data=part.read())
977 978 for key, value in part.params.iteritems():
978 979 bundlepart.addparam(key, value)
979 980
980 981 # Certain parts require a response
981 982 if part.type == 'pushkey':
982 983 if op.reply is not None:
983 984 rpart = op.reply.newpart('reply:pushkey')
984 985 rpart.addparam('in-reply-to', str(part.id),
985 986 mandatory=False)
986 987 rpart.addparam('return', '1', mandatory=False)
987 988 else:
988 989 bundle2._processpart(op, part)
989 990
990 991 if handleallparts:
991 992 op.records.add(part.type, {
992 993 'return': 1,
993 994 })
994 995 if bundlepart:
995 996 bundler.addpart(bundlepart)
996 997
997 998 # If commits were sent, store them
998 999 if cgparams:
999 1000 buf = util.chunkbuffer(bundler.getchunks())
1000 1001 fd, bundlefile = tempfile.mkstemp()
1001 1002 try:
1002 1003 try:
1003 1004 fp = os.fdopen(fd, r'wb')
1004 1005 fp.write(buf.read())
1005 1006 finally:
1006 1007 fp.close()
1007 1008 storebundle(op, cgparams, bundlefile)
1008 1009 finally:
1009 1010 try:
1010 1011 os.unlink(bundlefile)
1011 1012 except Exception:
1012 1013 # we would rather see the original exception
1013 1014 pass
1014 1015
1015 1016 def storebundle(op, params, bundlefile):
1016 1017 log = _getorcreateinfinitepushlogger(op)
1017 1018 parthandlerstart = time.time()
1018 1019 log(scratchbranchparttype, eventtype='start')
1019 1020 index = op.repo.bundlestore.index
1020 1021 store = op.repo.bundlestore.store
1021 1022 op.records.add(scratchbranchparttype + '_skippushkey', True)
1022 1023
1023 1024 bundle = None
1024 1025 try: # guards bundle
1025 1026 bundlepath = "bundle:%s+%s" % (op.repo.root, bundlefile)
1026 1027 bundle = hg.repository(op.repo.ui, bundlepath)
1027 1028
1028 1029 bookmark = params.get('bookmark')
1029 1030 bookprevnode = params.get('bookprevnode', '')
1030 1031 force = params.get('force')
1031 1032
1032 1033 if bookmark:
1033 1034 oldnode = index.getnode(bookmark)
1034 1035 else:
1035 1036 oldnode = None
1036 1037 bundleheads = bundle.revs('heads(bundle())')
1037 1038 if bookmark and len(bundleheads) > 1:
1038 1039 raise error.Abort(
1039 1040 _('cannot push more than one head to a scratch branch'))
1040 1041
1041 1042 revs = _getrevs(bundle, oldnode, force, bookmark)
1042 1043
1043 1044 # Notify the user of what is being pushed
1044 1045 plural = 's' if len(revs) > 1 else ''
1045 1046 op.repo.ui.warn(_("pushing %d commit%s:\n") % (len(revs), plural))
1046 1047 maxoutput = 10
1047 1048 for i in range(0, min(len(revs), maxoutput)):
1048 1049 firstline = bundle[revs[i]].description().split('\n')[0][:50]
1049 1050 op.repo.ui.warn((" %s %s\n") % (revs[i], firstline))
1050 1051
1051 1052 if len(revs) > maxoutput + 1:
1052 1053 op.repo.ui.warn((" ...\n"))
1053 1054 firstline = bundle[revs[-1]].description().split('\n')[0][:50]
1054 1055 op.repo.ui.warn((" %s %s\n") % (revs[-1], firstline))
1055 1056
1056 1057 nodesctx = [bundle[rev] for rev in revs]
1057 1058 inindex = lambda rev: bool(index.getbundle(bundle[rev].hex()))
1058 1059 if bundleheads:
1059 1060 newheadscount = sum(not inindex(rev) for rev in bundleheads)
1060 1061 else:
1061 1062 newheadscount = 0
1062 1063 # If there's a bookmark specified, there should be only one head,
1063 1064 # so we choose the last node, which will be that head.
1064 1065 # If a bug or malicious client allows there to be a bookmark
1065 1066 # with multiple heads, we will place the bookmark on the last head.
1066 1067 bookmarknode = nodesctx[-1].hex() if nodesctx else None
1067 1068 key = None
1068 1069 if newheadscount:
1069 1070 with open(bundlefile, 'r') as f:
1070 1071 bundledata = f.read()
1071 1072 with logservicecall(log, 'bundlestore',
1072 1073 bundlesize=len(bundledata)):
1073 1074 bundlesizelimit = 100 * 1024 * 1024 # 100 MB
1074 1075 if len(bundledata) > bundlesizelimit:
1075 1076 error_msg = ('bundle is too big: %d bytes. ' +
1076 1077 'max allowed size is 100 MB')
1077 1078 raise error.Abort(error_msg % (len(bundledata),))
1078 1079 key = store.write(bundledata)
1079 1080
1080 1081 with logservicecall(log, 'index', newheadscount=newheadscount), index:
1081 1082 if key:
1082 1083 index.addbundle(key, nodesctx)
1083 1084 if bookmark:
1084 1085 index.addbookmark(bookmark, bookmarknode)
1085 1086 _maybeaddpushbackpart(op, bookmark, bookmarknode,
1086 1087 bookprevnode, params)
1087 1088 log(scratchbranchparttype, eventtype='success',
1088 1089 elapsedms=(time.time() - parthandlerstart) * 1000)
1089 1090
1090 1091 except Exception as e:
1091 1092 log(scratchbranchparttype, eventtype='failure',
1092 1093 elapsedms=(time.time() - parthandlerstart) * 1000,
1093 1094 errormsg=str(e))
1094 1095 raise
1095 1096 finally:
1096 1097 if bundle:
1097 1098 bundle.close()
1098 1099
1099 1100 @bundle2.parthandler(scratchbranchparttype,
1100 1101 ('bookmark', 'bookprevnode', 'force',
1101 1102 'pushbackbookmarks', 'cgversion'))
1102 1103 def bundle2scratchbranch(op, part):
1103 1104 '''unbundle a bundle2 part containing a changegroup to store'''
1104 1105
1105 1106 bundler = bundle2.bundle20(op.repo.ui)
1106 1107 cgversion = part.params.get('cgversion', '01')
1107 1108 cgpart = bundle2.bundlepart('changegroup', data=part.read())
1108 1109 cgpart.addparam('version', cgversion)
1109 1110 bundler.addpart(cgpart)
1110 1111 buf = util.chunkbuffer(bundler.getchunks())
1111 1112
1112 1113 fd, bundlefile = tempfile.mkstemp()
1113 1114 try:
1114 1115 try:
1115 1116 fp = os.fdopen(fd, r'wb')
1116 1117 fp.write(buf.read())
1117 1118 finally:
1118 1119 fp.close()
1119 1120 storebundle(op, part.params, bundlefile)
1120 1121 finally:
1121 1122 try:
1122 1123 os.unlink(bundlefile)
1123 1124 except OSError as e:
1124 1125 if e.errno != errno.ENOENT:
1125 1126 raise
1126 1127
1127 1128 return 1
1128 1129
1129 1130 def _maybeaddpushbackpart(op, bookmark, newnode, oldnode, params):
1130 1131 if params.get('pushbackbookmarks'):
1131 1132 if op.reply and 'pushback' in op.reply.capabilities:
1132 1133 params = {
1133 1134 'namespace': 'bookmarks',
1134 1135 'key': bookmark,
1135 1136 'new': newnode,
1136 1137 'old': oldnode,
1137 1138 }
1138 1139 op.reply.newpart('pushkey', mandatoryparams=params.iteritems())
1139 1140
1140 1141 def bundle2pushkey(orig, op, part):
1141 1142 '''Wrapper of bundle2.handlepushkey()
1142 1143
1143 1144 The only goal is to skip calling the original function if flag is set.
1144 1145 It's set if infinitepush push is happening.
1145 1146 '''
1146 1147 if op.records[scratchbranchparttype + '_skippushkey']:
1147 1148 if op.reply is not None:
1148 1149 rpart = op.reply.newpart('reply:pushkey')
1149 1150 rpart.addparam('in-reply-to', str(part.id), mandatory=False)
1150 1151 rpart.addparam('return', '1', mandatory=False)
1151 1152 return 1
1152 1153
1153 1154 return orig(op, part)
1154 1155
1155 1156 def bundle2handlephases(orig, op, part):
1156 1157 '''Wrapper of bundle2.handlephases()
1157 1158
1158 1159 The only goal is to skip calling the original function if flag is set.
1159 1160 It's set if infinitepush push is happening.
1160 1161 '''
1161 1162
1162 1163 if op.records[scratchbranchparttype + '_skipphaseheads']:
1163 1164 return
1164 1165
1165 1166 return orig(op, part)
1166 1167
1167 1168 def _asyncsavemetadata(root, nodes):
1168 1169 '''starts a separate process that fills metadata for the nodes
1169 1170
1170 1171 This function creates a separate process and doesn't wait for it's
1171 1172 completion. This was done to avoid slowing down pushes
1172 1173 '''
1173 1174
1174 1175 maxnodes = 50
1175 1176 if len(nodes) > maxnodes:
1176 1177 return
1177 1178 nodesargs = []
1178 1179 for node in nodes:
1179 1180 nodesargs.append('--node')
1180 1181 nodesargs.append(node)
1181 1182 with open(os.devnull, 'w+b') as devnull:
1182 1183 cmdline = [util.hgexecutable(), 'debugfillinfinitepushmetadata',
1183 1184 '-R', root] + nodesargs
1184 1185 # Process will run in background. We don't care about the return code
1185 1186 subprocess.Popen(cmdline, close_fds=True, shell=False,
1186 1187 stdin=devnull, stdout=devnull, stderr=devnull)
@@ -1,1307 +1,1287
1 1 # wireproto.py - generic wire protocol support functions
2 2 #
3 3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import hashlib
11 11 import os
12 12 import tempfile
13 13
14 14 from .i18n import _
15 15 from .node import (
16 16 bin,
17 17 hex,
18 18 nullid,
19 19 )
20 20
21 21 from . import (
22 22 bundle2,
23 23 changegroup as changegroupmod,
24 24 discovery,
25 25 encoding,
26 26 error,
27 27 exchange,
28 28 peer,
29 29 pushkey as pushkeymod,
30 30 pycompat,
31 31 repository,
32 32 streamclone,
33 33 util,
34 34 wireprototypes,
35 35 )
36 36
37 37 from .utils import (
38 38 procutil,
39 39 stringutil,
40 40 )
41 41
42 42 urlerr = util.urlerr
43 43 urlreq = util.urlreq
44 44
45 45 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
46 46 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
47 47 'IncompatibleClient')
48 48 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
49 49
50 50 class remoteiterbatcher(peer.iterbatcher):
51 51 def __init__(self, remote):
52 52 super(remoteiterbatcher, self).__init__()
53 53 self._remote = remote
54 54
55 55 def __getattr__(self, name):
56 56 # Validate this method is batchable, since submit() only supports
57 57 # batchable methods.
58 58 fn = getattr(self._remote, name)
59 59 if not getattr(fn, 'batchable', None):
60 60 raise error.ProgrammingError('Attempted to batch a non-batchable '
61 61 'call to %r' % name)
62 62
63 63 return super(remoteiterbatcher, self).__getattr__(name)
64 64
65 65 def submit(self):
66 66 """Break the batch request into many patch calls and pipeline them.
67 67
68 68 This is mostly valuable over http where request sizes can be
69 69 limited, but can be used in other places as well.
70 70 """
71 71 # 2-tuple of (command, arguments) that represents what will be
72 72 # sent over the wire.
73 73 requests = []
74 74
75 75 # 4-tuple of (command, final future, @batchable generator, remote
76 76 # future).
77 77 results = []
78 78
79 79 for command, args, opts, finalfuture in self.calls:
80 80 mtd = getattr(self._remote, command)
81 81 batchable = mtd.batchable(mtd.__self__, *args, **opts)
82 82
83 83 commandargs, fremote = next(batchable)
84 84 assert fremote
85 85 requests.append((command, commandargs))
86 86 results.append((command, finalfuture, batchable, fremote))
87 87
88 88 if requests:
89 89 self._resultiter = self._remote._submitbatch(requests)
90 90
91 91 self._results = results
92 92
93 93 def results(self):
94 94 for command, finalfuture, batchable, remotefuture in self._results:
95 95 # Get the raw result, set it in the remote future, feed it
96 96 # back into the @batchable generator so it can be decoded, and
97 97 # set the result on the final future to this value.
98 98 remoteresult = next(self._resultiter)
99 99 remotefuture.set(remoteresult)
100 100 finalfuture.set(next(batchable))
101 101
102 102 # Verify our @batchable generators only emit 2 values.
103 103 try:
104 104 next(batchable)
105 105 except StopIteration:
106 106 pass
107 107 else:
108 108 raise error.ProgrammingError('%s @batchable generator emitted '
109 109 'unexpected value count' % command)
110 110
111 111 yield finalfuture.value
112 112
113 113 # Forward a couple of names from peer to make wireproto interactions
114 114 # slightly more sensible.
115 115 batchable = peer.batchable
116 116 future = peer.future
117 117
118 # list of nodes encoding / decoding
119
120 def decodelist(l, sep=' '):
121 if l:
122 return [bin(v) for v in l.split(sep)]
123 return []
124
125 def encodelist(l, sep=' '):
126 try:
127 return sep.join(map(hex, l))
128 except TypeError:
129 raise
130
131 # batched call argument encoding
132
133 def escapearg(plain):
134 return (plain
135 .replace(':', ':c')
136 .replace(',', ':o')
137 .replace(';', ':s')
138 .replace('=', ':e'))
139
140 def unescapearg(escaped):
141 return (escaped
142 .replace(':e', '=')
143 .replace(':s', ';')
144 .replace(':o', ',')
145 .replace(':c', ':'))
146 118
147 119 def encodebatchcmds(req):
148 120 """Return a ``cmds`` argument value for the ``batch`` command."""
121 escapearg = wireprototypes.escapebatcharg
122
149 123 cmds = []
150 124 for op, argsdict in req:
151 125 # Old servers didn't properly unescape argument names. So prevent
152 126 # the sending of argument names that may not be decoded properly by
153 127 # servers.
154 128 assert all(escapearg(k) == k for k in argsdict)
155 129
156 130 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
157 131 for k, v in argsdict.iteritems())
158 132 cmds.append('%s %s' % (op, args))
159 133
160 134 return ';'.join(cmds)
161 135
162 136 def clientcompressionsupport(proto):
163 137 """Returns a list of compression methods supported by the client.
164 138
165 139 Returns a list of the compression methods supported by the client
166 140 according to the protocol capabilities. If no such capability has
167 141 been announced, fallback to the default of zlib and uncompressed.
168 142 """
169 143 for cap in proto.getprotocaps():
170 144 if cap.startswith('comp='):
171 145 return cap[5:].split(',')
172 146 return ['zlib', 'none']
173 147
174 148 # mapping of options accepted by getbundle and their types
175 149 #
176 150 # Meant to be extended by extensions. It is extensions responsibility to ensure
177 151 # such options are properly processed in exchange.getbundle.
178 152 #
179 153 # supported types are:
180 154 #
181 155 # :nodes: list of binary nodes
182 156 # :csv: list of comma-separated values
183 157 # :scsv: list of comma-separated values return as set
184 158 # :plain: string with no transformation needed.
185 159 gboptsmap = {'heads': 'nodes',
186 160 'bookmarks': 'boolean',
187 161 'common': 'nodes',
188 162 'obsmarkers': 'boolean',
189 163 'phases': 'boolean',
190 164 'bundlecaps': 'scsv',
191 165 'listkeys': 'csv',
192 166 'cg': 'boolean',
193 167 'cbattempted': 'boolean',
194 168 'stream': 'boolean',
195 169 }
196 170
197 171 # client side
198 172
199 173 class wirepeer(repository.legacypeer):
200 174 """Client-side interface for communicating with a peer repository.
201 175
202 176 Methods commonly call wire protocol commands of the same name.
203 177
204 178 See also httppeer.py and sshpeer.py for protocol-specific
205 179 implementations of this interface.
206 180 """
207 181 # Begin of ipeercommands interface.
208 182
209 183 def iterbatch(self):
210 184 return remoteiterbatcher(self)
211 185
212 186 @batchable
213 187 def lookup(self, key):
214 188 self.requirecap('lookup', _('look up remote revision'))
215 189 f = future()
216 190 yield {'key': encoding.fromlocal(key)}, f
217 191 d = f.value
218 192 success, data = d[:-1].split(" ", 1)
219 193 if int(success):
220 194 yield bin(data)
221 195 else:
222 196 self._abort(error.RepoError(data))
223 197
224 198 @batchable
225 199 def heads(self):
226 200 f = future()
227 201 yield {}, f
228 202 d = f.value
229 203 try:
230 yield decodelist(d[:-1])
204 yield wireprototypes.decodelist(d[:-1])
231 205 except ValueError:
232 206 self._abort(error.ResponseError(_("unexpected response:"), d))
233 207
234 208 @batchable
235 209 def known(self, nodes):
236 210 f = future()
237 yield {'nodes': encodelist(nodes)}, f
211 yield {'nodes': wireprototypes.encodelist(nodes)}, f
238 212 d = f.value
239 213 try:
240 214 yield [bool(int(b)) for b in d]
241 215 except ValueError:
242 216 self._abort(error.ResponseError(_("unexpected response:"), d))
243 217
244 218 @batchable
245 219 def branchmap(self):
246 220 f = future()
247 221 yield {}, f
248 222 d = f.value
249 223 try:
250 224 branchmap = {}
251 225 for branchpart in d.splitlines():
252 226 branchname, branchheads = branchpart.split(' ', 1)
253 227 branchname = encoding.tolocal(urlreq.unquote(branchname))
254 branchheads = decodelist(branchheads)
228 branchheads = wireprototypes.decodelist(branchheads)
255 229 branchmap[branchname] = branchheads
256 230 yield branchmap
257 231 except TypeError:
258 232 self._abort(error.ResponseError(_("unexpected response:"), d))
259 233
260 234 @batchable
261 235 def listkeys(self, namespace):
262 236 if not self.capable('pushkey'):
263 237 yield {}, None
264 238 f = future()
265 239 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
266 240 yield {'namespace': encoding.fromlocal(namespace)}, f
267 241 d = f.value
268 242 self.ui.debug('received listkey for "%s": %i bytes\n'
269 243 % (namespace, len(d)))
270 244 yield pushkeymod.decodekeys(d)
271 245
272 246 @batchable
273 247 def pushkey(self, namespace, key, old, new):
274 248 if not self.capable('pushkey'):
275 249 yield False, None
276 250 f = future()
277 251 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
278 252 yield {'namespace': encoding.fromlocal(namespace),
279 253 'key': encoding.fromlocal(key),
280 254 'old': encoding.fromlocal(old),
281 255 'new': encoding.fromlocal(new)}, f
282 256 d = f.value
283 257 d, output = d.split('\n', 1)
284 258 try:
285 259 d = bool(int(d))
286 260 except ValueError:
287 261 raise error.ResponseError(
288 262 _('push failed (unexpected response):'), d)
289 263 for l in output.splitlines(True):
290 264 self.ui.status(_('remote: '), l)
291 265 yield d
292 266
293 267 def stream_out(self):
294 268 return self._callstream('stream_out')
295 269
296 270 def getbundle(self, source, **kwargs):
297 271 kwargs = pycompat.byteskwargs(kwargs)
298 272 self.requirecap('getbundle', _('look up remote changes'))
299 273 opts = {}
300 274 bundlecaps = kwargs.get('bundlecaps') or set()
301 275 for key, value in kwargs.iteritems():
302 276 if value is None:
303 277 continue
304 278 keytype = gboptsmap.get(key)
305 279 if keytype is None:
306 280 raise error.ProgrammingError(
307 281 'Unexpectedly None keytype for key %s' % key)
308 282 elif keytype == 'nodes':
309 value = encodelist(value)
283 value = wireprototypes.encodelist(value)
310 284 elif keytype == 'csv':
311 285 value = ','.join(value)
312 286 elif keytype == 'scsv':
313 287 value = ','.join(sorted(value))
314 288 elif keytype == 'boolean':
315 289 value = '%i' % bool(value)
316 290 elif keytype != 'plain':
317 291 raise KeyError('unknown getbundle option type %s'
318 292 % keytype)
319 293 opts[key] = value
320 294 f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
321 295 if any((cap.startswith('HG2') for cap in bundlecaps)):
322 296 return bundle2.getunbundler(self.ui, f)
323 297 else:
324 298 return changegroupmod.cg1unpacker(f, 'UN')
325 299
326 300 def unbundle(self, cg, heads, url):
327 301 '''Send cg (a readable file-like object representing the
328 302 changegroup to push, typically a chunkbuffer object) to the
329 303 remote server as a bundle.
330 304
331 305 When pushing a bundle10 stream, return an integer indicating the
332 306 result of the push (see changegroup.apply()).
333 307
334 308 When pushing a bundle20 stream, return a bundle20 stream.
335 309
336 310 `url` is the url the client thinks it's pushing to, which is
337 311 visible to hooks.
338 312 '''
339 313
340 314 if heads != ['force'] and self.capable('unbundlehash'):
341 heads = encodelist(['hashed',
342 hashlib.sha1(''.join(sorted(heads))).digest()])
315 heads = wireprototypes.encodelist(
316 ['hashed', hashlib.sha1(''.join(sorted(heads))).digest()])
343 317 else:
344 heads = encodelist(heads)
318 heads = wireprototypes.encodelist(heads)
345 319
346 320 if util.safehasattr(cg, 'deltaheader'):
347 321 # this a bundle10, do the old style call sequence
348 322 ret, output = self._callpush("unbundle", cg, heads=heads)
349 323 if ret == "":
350 324 raise error.ResponseError(
351 325 _('push failed:'), output)
352 326 try:
353 327 ret = int(ret)
354 328 except ValueError:
355 329 raise error.ResponseError(
356 330 _('push failed (unexpected response):'), ret)
357 331
358 332 for l in output.splitlines(True):
359 333 self.ui.status(_('remote: '), l)
360 334 else:
361 335 # bundle2 push. Send a stream, fetch a stream.
362 336 stream = self._calltwowaystream('unbundle', cg, heads=heads)
363 337 ret = bundle2.getunbundler(self.ui, stream)
364 338 return ret
365 339
366 340 # End of ipeercommands interface.
367 341
368 342 # Begin of ipeerlegacycommands interface.
369 343
370 344 def branches(self, nodes):
371 n = encodelist(nodes)
345 n = wireprototypes.encodelist(nodes)
372 346 d = self._call("branches", nodes=n)
373 347 try:
374 br = [tuple(decodelist(b)) for b in d.splitlines()]
348 br = [tuple(wireprototypes.decodelist(b)) for b in d.splitlines()]
375 349 return br
376 350 except ValueError:
377 351 self._abort(error.ResponseError(_("unexpected response:"), d))
378 352
379 353 def between(self, pairs):
380 354 batch = 8 # avoid giant requests
381 355 r = []
382 356 for i in xrange(0, len(pairs), batch):
383 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
357 n = " ".join([wireprototypes.encodelist(p, '-')
358 for p in pairs[i:i + batch]])
384 359 d = self._call("between", pairs=n)
385 360 try:
386 r.extend(l and decodelist(l) or [] for l in d.splitlines())
361 r.extend(l and wireprototypes.decodelist(l) or []
362 for l in d.splitlines())
387 363 except ValueError:
388 364 self._abort(error.ResponseError(_("unexpected response:"), d))
389 365 return r
390 366
391 367 def changegroup(self, nodes, kind):
392 n = encodelist(nodes)
368 n = wireprototypes.encodelist(nodes)
393 369 f = self._callcompressable("changegroup", roots=n)
394 370 return changegroupmod.cg1unpacker(f, 'UN')
395 371
396 372 def changegroupsubset(self, bases, heads, kind):
397 373 self.requirecap('changegroupsubset', _('look up remote changes'))
398 bases = encodelist(bases)
399 heads = encodelist(heads)
374 bases = wireprototypes.encodelist(bases)
375 heads = wireprototypes.encodelist(heads)
400 376 f = self._callcompressable("changegroupsubset",
401 377 bases=bases, heads=heads)
402 378 return changegroupmod.cg1unpacker(f, 'UN')
403 379
404 380 # End of ipeerlegacycommands interface.
405 381
406 382 def _submitbatch(self, req):
407 383 """run batch request <req> on the server
408 384
409 385 Returns an iterator of the raw responses from the server.
410 386 """
411 387 ui = self.ui
412 388 if ui.debugflag and ui.configbool('devel', 'debug.peer-request'):
413 389 ui.debug('devel-peer-request: batched-content\n')
414 390 for op, args in req:
415 391 msg = 'devel-peer-request: - %s (%d arguments)\n'
416 392 ui.debug(msg % (op, len(args)))
417 393
394 unescapearg = wireprototypes.unescapebatcharg
395
418 396 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
419 397 chunk = rsp.read(1024)
420 398 work = [chunk]
421 399 while chunk:
422 400 while ';' not in chunk and chunk:
423 401 chunk = rsp.read(1024)
424 402 work.append(chunk)
425 403 merged = ''.join(work)
426 404 while ';' in merged:
427 405 one, merged = merged.split(';', 1)
428 406 yield unescapearg(one)
429 407 chunk = rsp.read(1024)
430 408 work = [merged, chunk]
431 409 yield unescapearg(''.join(work))
432 410
433 411 def _submitone(self, op, args):
434 412 return self._call(op, **pycompat.strkwargs(args))
435 413
436 414 def debugwireargs(self, one, two, three=None, four=None, five=None):
437 415 # don't pass optional arguments left at their default value
438 416 opts = {}
439 417 if three is not None:
440 418 opts[r'three'] = three
441 419 if four is not None:
442 420 opts[r'four'] = four
443 421 return self._call('debugwireargs', one=one, two=two, **opts)
444 422
445 423 def _call(self, cmd, **args):
446 424 """execute <cmd> on the server
447 425
448 426 The command is expected to return a simple string.
449 427
450 428 returns the server reply as a string."""
451 429 raise NotImplementedError()
452 430
453 431 def _callstream(self, cmd, **args):
454 432 """execute <cmd> on the server
455 433
456 434 The command is expected to return a stream. Note that if the
457 435 command doesn't return a stream, _callstream behaves
458 436 differently for ssh and http peers.
459 437
460 438 returns the server reply as a file like object.
461 439 """
462 440 raise NotImplementedError()
463 441
464 442 def _callcompressable(self, cmd, **args):
465 443 """execute <cmd> on the server
466 444
467 445 The command is expected to return a stream.
468 446
469 447 The stream may have been compressed in some implementations. This
470 448 function takes care of the decompression. This is the only difference
471 449 with _callstream.
472 450
473 451 returns the server reply as a file like object.
474 452 """
475 453 raise NotImplementedError()
476 454
477 455 def _callpush(self, cmd, fp, **args):
478 456 """execute a <cmd> on server
479 457
480 458 The command is expected to be related to a push. Push has a special
481 459 return method.
482 460
483 461 returns the server reply as a (ret, output) tuple. ret is either
484 462 empty (error) or a stringified int.
485 463 """
486 464 raise NotImplementedError()
487 465
488 466 def _calltwowaystream(self, cmd, fp, **args):
489 467 """execute <cmd> on server
490 468
491 469 The command will send a stream to the server and get a stream in reply.
492 470 """
493 471 raise NotImplementedError()
494 472
495 473 def _abort(self, exception):
496 474 """clearly abort the wire protocol connection and raise the exception
497 475 """
498 476 raise NotImplementedError()
499 477
500 478 # server side
501 479
502 480 # wire protocol command can either return a string or one of these classes.
503 481
504 482 def getdispatchrepo(repo, proto, command):
505 483 """Obtain the repo used for processing wire protocol commands.
506 484
507 485 The intent of this function is to serve as a monkeypatch point for
508 486 extensions that need commands to operate on different repo views under
509 487 specialized circumstances.
510 488 """
511 489 return repo.filtered('served')
512 490
513 491 def dispatch(repo, proto, command):
514 492 repo = getdispatchrepo(repo, proto, command)
515 493
516 494 transportversion = wireprototypes.TRANSPORTS[proto.name]['version']
517 495 commandtable = commandsv2 if transportversion == 2 else commands
518 496 func, spec = commandtable[command]
519 497
520 498 args = proto.getargs(spec)
521 499
522 500 # Version 1 protocols define arguments as a list. Version 2 uses a dict.
523 501 if isinstance(args, list):
524 502 return func(repo, proto, *args)
525 503 elif isinstance(args, dict):
526 504 return func(repo, proto, **args)
527 505 else:
528 506 raise error.ProgrammingError('unexpected type returned from '
529 507 'proto.getargs(): %s' % type(args))
530 508
531 509 def options(cmd, keys, others):
532 510 opts = {}
533 511 for k in keys:
534 512 if k in others:
535 513 opts[k] = others[k]
536 514 del others[k]
537 515 if others:
538 516 procutil.stderr.write("warning: %s ignored unexpected arguments %s\n"
539 517 % (cmd, ",".join(others)))
540 518 return opts
541 519
542 520 def bundle1allowed(repo, action):
543 521 """Whether a bundle1 operation is allowed from the server.
544 522
545 523 Priority is:
546 524
547 525 1. server.bundle1gd.<action> (if generaldelta active)
548 526 2. server.bundle1.<action>
549 527 3. server.bundle1gd (if generaldelta active)
550 528 4. server.bundle1
551 529 """
552 530 ui = repo.ui
553 531 gd = 'generaldelta' in repo.requirements
554 532
555 533 if gd:
556 534 v = ui.configbool('server', 'bundle1gd.%s' % action)
557 535 if v is not None:
558 536 return v
559 537
560 538 v = ui.configbool('server', 'bundle1.%s' % action)
561 539 if v is not None:
562 540 return v
563 541
564 542 if gd:
565 543 v = ui.configbool('server', 'bundle1gd')
566 544 if v is not None:
567 545 return v
568 546
569 547 return ui.configbool('server', 'bundle1')
570 548
571 549 def supportedcompengines(ui, role):
572 550 """Obtain the list of supported compression engines for a request."""
573 551 assert role in (util.CLIENTROLE, util.SERVERROLE)
574 552
575 553 compengines = util.compengines.supportedwireengines(role)
576 554
577 555 # Allow config to override default list and ordering.
578 556 if role == util.SERVERROLE:
579 557 configengines = ui.configlist('server', 'compressionengines')
580 558 config = 'server.compressionengines'
581 559 else:
582 560 # This is currently implemented mainly to facilitate testing. In most
583 561 # cases, the server should be in charge of choosing a compression engine
584 562 # because a server has the most to lose from a sub-optimal choice. (e.g.
585 563 # CPU DoS due to an expensive engine or a network DoS due to poor
586 564 # compression ratio).
587 565 configengines = ui.configlist('experimental',
588 566 'clientcompressionengines')
589 567 config = 'experimental.clientcompressionengines'
590 568
591 569 # No explicit config. Filter out the ones that aren't supposed to be
592 570 # advertised and return default ordering.
593 571 if not configengines:
594 572 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
595 573 return [e for e in compengines
596 574 if getattr(e.wireprotosupport(), attr) > 0]
597 575
598 576 # If compression engines are listed in the config, assume there is a good
599 577 # reason for it (like server operators wanting to achieve specific
600 578 # performance characteristics). So fail fast if the config references
601 579 # unusable compression engines.
602 580 validnames = set(e.name() for e in compengines)
603 581 invalidnames = set(e for e in configengines if e not in validnames)
604 582 if invalidnames:
605 583 raise error.Abort(_('invalid compression engine defined in %s: %s') %
606 584 (config, ', '.join(sorted(invalidnames))))
607 585
608 586 compengines = [e for e in compengines if e.name() in configengines]
609 587 compengines = sorted(compengines,
610 588 key=lambda e: configengines.index(e.name()))
611 589
612 590 if not compengines:
613 591 raise error.Abort(_('%s config option does not specify any known '
614 592 'compression engines') % config,
615 593 hint=_('usable compression engines: %s') %
616 594 ', '.sorted(validnames))
617 595
618 596 return compengines
619 597
620 598 class commandentry(object):
621 599 """Represents a declared wire protocol command."""
622 600 def __init__(self, func, args='', transports=None,
623 601 permission='push'):
624 602 self.func = func
625 603 self.args = args
626 604 self.transports = transports or set()
627 605 self.permission = permission
628 606
629 607 def _merge(self, func, args):
630 608 """Merge this instance with an incoming 2-tuple.
631 609
632 610 This is called when a caller using the old 2-tuple API attempts
633 611 to replace an instance. The incoming values are merged with
634 612 data not captured by the 2-tuple and a new instance containing
635 613 the union of the two objects is returned.
636 614 """
637 615 return commandentry(func, args=args, transports=set(self.transports),
638 616 permission=self.permission)
639 617
640 618 # Old code treats instances as 2-tuples. So expose that interface.
641 619 def __iter__(self):
642 620 yield self.func
643 621 yield self.args
644 622
645 623 def __getitem__(self, i):
646 624 if i == 0:
647 625 return self.func
648 626 elif i == 1:
649 627 return self.args
650 628 else:
651 629 raise IndexError('can only access elements 0 and 1')
652 630
653 631 class commanddict(dict):
654 632 """Container for registered wire protocol commands.
655 633
656 634 It behaves like a dict. But __setitem__ is overwritten to allow silent
657 635 coercion of values from 2-tuples for API compatibility.
658 636 """
659 637 def __setitem__(self, k, v):
660 638 if isinstance(v, commandentry):
661 639 pass
662 640 # Cast 2-tuples to commandentry instances.
663 641 elif isinstance(v, tuple):
664 642 if len(v) != 2:
665 643 raise ValueError('command tuples must have exactly 2 elements')
666 644
667 645 # It is common for extensions to wrap wire protocol commands via
668 646 # e.g. ``wireproto.commands[x] = (newfn, args)``. Because callers
669 647 # doing this aren't aware of the new API that uses objects to store
670 648 # command entries, we automatically merge old state with new.
671 649 if k in self:
672 650 v = self[k]._merge(v[0], v[1])
673 651 else:
674 652 # Use default values from @wireprotocommand.
675 653 v = commandentry(v[0], args=v[1],
676 654 transports=set(wireprototypes.TRANSPORTS),
677 655 permission='push')
678 656 else:
679 657 raise ValueError('command entries must be commandentry instances '
680 658 'or 2-tuples')
681 659
682 660 return super(commanddict, self).__setitem__(k, v)
683 661
684 662 def commandavailable(self, command, proto):
685 663 """Determine if a command is available for the requested protocol."""
686 664 assert proto.name in wireprototypes.TRANSPORTS
687 665
688 666 entry = self.get(command)
689 667
690 668 if not entry:
691 669 return False
692 670
693 671 if proto.name not in entry.transports:
694 672 return False
695 673
696 674 return True
697 675
698 676 # Constants specifying which transports a wire protocol command should be
699 677 # available on. For use with @wireprotocommand.
700 678 POLICY_V1_ONLY = 'v1-only'
701 679 POLICY_V2_ONLY = 'v2-only'
702 680
703 681 # For version 1 transports.
704 682 commands = commanddict()
705 683
706 684 # For version 2 transports.
707 685 commandsv2 = commanddict()
708 686
709 687 def wireprotocommand(name, args=None, transportpolicy=POLICY_V1_ONLY,
710 688 permission='push'):
711 689 """Decorator to declare a wire protocol command.
712 690
713 691 ``name`` is the name of the wire protocol command being provided.
714 692
715 693 ``args`` defines the named arguments accepted by the command. It is
716 694 ideally a dict mapping argument names to their types. For backwards
717 695 compatibility, it can be a space-delimited list of argument names. For
718 696 version 1 transports, ``*`` denotes a special value that says to accept
719 697 all named arguments.
720 698
721 699 ``transportpolicy`` is a POLICY_* constant denoting which transports
722 700 this wire protocol command should be exposed to. By default, commands
723 701 are exposed to all wire protocol transports.
724 702
725 703 ``permission`` defines the permission type needed to run this command.
726 704 Can be ``push`` or ``pull``. These roughly map to read-write and read-only,
727 705 respectively. Default is to assume command requires ``push`` permissions
728 706 because otherwise commands not declaring their permissions could modify
729 707 a repository that is supposed to be read-only.
730 708 """
731 709 if transportpolicy == POLICY_V1_ONLY:
732 710 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
733 711 if v['version'] == 1}
734 712 transportversion = 1
735 713 elif transportpolicy == POLICY_V2_ONLY:
736 714 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
737 715 if v['version'] == 2}
738 716 transportversion = 2
739 717 else:
740 718 raise error.ProgrammingError('invalid transport policy value: %s' %
741 719 transportpolicy)
742 720
743 721 # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to
744 722 # SSHv2.
745 723 # TODO undo this hack when SSH is using the unified frame protocol.
746 724 if name == b'batch':
747 725 transports.add(wireprototypes.SSHV2)
748 726
749 727 if permission not in ('push', 'pull'):
750 728 raise error.ProgrammingError('invalid wire protocol permission; '
751 729 'got %s; expected "push" or "pull"' %
752 730 permission)
753 731
754 732 if transportversion == 1:
755 733 if args is None:
756 734 args = ''
757 735
758 736 if not isinstance(args, bytes):
759 737 raise error.ProgrammingError('arguments for version 1 commands '
760 738 'must be declared as bytes')
761 739 elif transportversion == 2:
762 740 if args is None:
763 741 args = {}
764 742
765 743 if not isinstance(args, dict):
766 744 raise error.ProgrammingError('arguments for version 2 commands '
767 745 'must be declared as dicts')
768 746
769 747 def register(func):
770 748 if transportversion == 1:
771 749 if name in commands:
772 750 raise error.ProgrammingError('%s command already registered '
773 751 'for version 1' % name)
774 752 commands[name] = commandentry(func, args=args,
775 753 transports=transports,
776 754 permission=permission)
777 755 elif transportversion == 2:
778 756 if name in commandsv2:
779 757 raise error.ProgrammingError('%s command already registered '
780 758 'for version 2' % name)
781 759
782 760 commandsv2[name] = commandentry(func, args=args,
783 761 transports=transports,
784 762 permission=permission)
785 763 else:
786 764 raise error.ProgrammingError('unhandled transport version: %d' %
787 765 transportversion)
788 766
789 767 return func
790 768 return register
791 769
792 770 # TODO define a more appropriate permissions type to use for this.
793 771 @wireprotocommand('batch', 'cmds *', permission='pull',
794 772 transportpolicy=POLICY_V1_ONLY)
795 773 def batch(repo, proto, cmds, others):
774 unescapearg = wireprototypes.unescapebatcharg
796 775 repo = repo.filtered("served")
797 776 res = []
798 777 for pair in cmds.split(';'):
799 778 op, args = pair.split(' ', 1)
800 779 vals = {}
801 780 for a in args.split(','):
802 781 if a:
803 782 n, v = a.split('=')
804 783 vals[unescapearg(n)] = unescapearg(v)
805 784 func, spec = commands[op]
806 785
807 786 # Validate that client has permissions to perform this command.
808 787 perm = commands[op].permission
809 788 assert perm in ('push', 'pull')
810 789 proto.checkperm(perm)
811 790
812 791 if spec:
813 792 keys = spec.split()
814 793 data = {}
815 794 for k in keys:
816 795 if k == '*':
817 796 star = {}
818 797 for key in vals.keys():
819 798 if key not in keys:
820 799 star[key] = vals[key]
821 800 data['*'] = star
822 801 else:
823 802 data[k] = vals[k]
824 803 result = func(repo, proto, *[data[k] for k in keys])
825 804 else:
826 805 result = func(repo, proto)
827 806 if isinstance(result, wireprototypes.ooberror):
828 807 return result
829 808
830 809 # For now, all batchable commands must return bytesresponse or
831 810 # raw bytes (for backwards compatibility).
832 811 assert isinstance(result, (wireprototypes.bytesresponse, bytes))
833 812 if isinstance(result, wireprototypes.bytesresponse):
834 813 result = result.data
835 res.append(escapearg(result))
814 res.append(wireprototypes.escapebatcharg(result))
836 815
837 816 return wireprototypes.bytesresponse(';'.join(res))
838 817
839 818 @wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY,
840 819 permission='pull')
841 820 def between(repo, proto, pairs):
842 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
821 pairs = [wireprototypes.decodelist(p, '-') for p in pairs.split(" ")]
843 822 r = []
844 823 for b in repo.between(pairs):
845 r.append(encodelist(b) + "\n")
824 r.append(wireprototypes.encodelist(b) + "\n")
846 825
847 826 return wireprototypes.bytesresponse(''.join(r))
848 827
849 828 @wireprotocommand('branchmap', permission='pull',
850 829 transportpolicy=POLICY_V1_ONLY)
851 830 def branchmap(repo, proto):
852 831 branchmap = repo.branchmap()
853 832 heads = []
854 833 for branch, nodes in branchmap.iteritems():
855 834 branchname = urlreq.quote(encoding.fromlocal(branch))
856 branchnodes = encodelist(nodes)
835 branchnodes = wireprototypes.encodelist(nodes)
857 836 heads.append('%s %s' % (branchname, branchnodes))
858 837
859 838 return wireprototypes.bytesresponse('\n'.join(heads))
860 839
861 840 @wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY,
862 841 permission='pull')
863 842 def branches(repo, proto, nodes):
864 nodes = decodelist(nodes)
843 nodes = wireprototypes.decodelist(nodes)
865 844 r = []
866 845 for b in repo.branches(nodes):
867 r.append(encodelist(b) + "\n")
846 r.append(wireprototypes.encodelist(b) + "\n")
868 847
869 848 return wireprototypes.bytesresponse(''.join(r))
870 849
871 850 @wireprotocommand('clonebundles', '', permission='pull',
872 851 transportpolicy=POLICY_V1_ONLY)
873 852 def clonebundles(repo, proto):
874 853 """Server command for returning info for available bundles to seed clones.
875 854
876 855 Clients will parse this response and determine what bundle to fetch.
877 856
878 857 Extensions may wrap this command to filter or dynamically emit data
879 858 depending on the request. e.g. you could advertise URLs for the closest
880 859 data center given the client's IP address.
881 860 """
882 861 return wireprototypes.bytesresponse(
883 862 repo.vfs.tryread('clonebundles.manifest'))
884 863
885 864 wireprotocaps = ['lookup', 'branchmap', 'pushkey',
886 865 'known', 'getbundle', 'unbundlehash']
887 866
888 867 def _capabilities(repo, proto):
889 868 """return a list of capabilities for a repo
890 869
891 870 This function exists to allow extensions to easily wrap capabilities
892 871 computation
893 872
894 873 - returns a lists: easy to alter
895 874 - change done here will be propagated to both `capabilities` and `hello`
896 875 command without any other action needed.
897 876 """
898 877 # copy to prevent modification of the global list
899 878 caps = list(wireprotocaps)
900 879
901 880 # Command of same name as capability isn't exposed to version 1 of
902 881 # transports. So conditionally add it.
903 882 if commands.commandavailable('changegroupsubset', proto):
904 883 caps.append('changegroupsubset')
905 884
906 885 if streamclone.allowservergeneration(repo):
907 886 if repo.ui.configbool('server', 'preferuncompressed'):
908 887 caps.append('stream-preferred')
909 888 requiredformats = repo.requirements & repo.supportedformats
910 889 # if our local revlogs are just revlogv1, add 'stream' cap
911 890 if not requiredformats - {'revlogv1'}:
912 891 caps.append('stream')
913 892 # otherwise, add 'streamreqs' detailing our local revlog format
914 893 else:
915 894 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
916 895 if repo.ui.configbool('experimental', 'bundle2-advertise'):
917 896 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
918 897 caps.append('bundle2=' + urlreq.quote(capsblob))
919 898 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
920 899
921 900 return proto.addcapabilities(repo, caps)
922 901
923 902 # If you are writing an extension and consider wrapping this function. Wrap
924 903 # `_capabilities` instead.
925 904 @wireprotocommand('capabilities', permission='pull',
926 905 transportpolicy=POLICY_V1_ONLY)
927 906 def capabilities(repo, proto):
928 907 caps = _capabilities(repo, proto)
929 908 return wireprototypes.bytesresponse(' '.join(sorted(caps)))
930 909
931 910 @wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY,
932 911 permission='pull')
933 912 def changegroup(repo, proto, roots):
934 nodes = decodelist(roots)
913 nodes = wireprototypes.decodelist(roots)
935 914 outgoing = discovery.outgoing(repo, missingroots=nodes,
936 915 missingheads=repo.heads())
937 916 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
938 917 gen = iter(lambda: cg.read(32768), '')
939 918 return wireprototypes.streamres(gen=gen)
940 919
941 920 @wireprotocommand('changegroupsubset', 'bases heads',
942 921 transportpolicy=POLICY_V1_ONLY,
943 922 permission='pull')
944 923 def changegroupsubset(repo, proto, bases, heads):
945 bases = decodelist(bases)
946 heads = decodelist(heads)
924 bases = wireprototypes.decodelist(bases)
925 heads = wireprototypes.decodelist(heads)
947 926 outgoing = discovery.outgoing(repo, missingroots=bases,
948 927 missingheads=heads)
949 928 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
950 929 gen = iter(lambda: cg.read(32768), '')
951 930 return wireprototypes.streamres(gen=gen)
952 931
953 932 @wireprotocommand('debugwireargs', 'one two *',
954 933 permission='pull', transportpolicy=POLICY_V1_ONLY)
955 934 def debugwireargs(repo, proto, one, two, others):
956 935 # only accept optional args from the known set
957 936 opts = options('debugwireargs', ['three', 'four'], others)
958 937 return wireprototypes.bytesresponse(repo.debugwireargs(
959 938 one, two, **pycompat.strkwargs(opts)))
960 939
961 940 def find_pullbundle(repo, proto, opts, clheads, heads, common):
962 941 """Return a file object for the first matching pullbundle.
963 942
964 943 Pullbundles are specified in .hg/pullbundles.manifest similar to
965 944 clonebundles.
966 945 For each entry, the bundle specification is checked for compatibility:
967 946 - Client features vs the BUNDLESPEC.
968 947 - Revisions shared with the clients vs base revisions of the bundle.
969 948 A bundle can be applied only if all its base revisions are known by
970 949 the client.
971 950 - At least one leaf of the bundle's DAG is missing on the client.
972 951 - Every leaf of the bundle's DAG is part of node set the client wants.
973 952 E.g. do not send a bundle of all changes if the client wants only
974 953 one specific branch of many.
975 954 """
976 955 def decodehexstring(s):
977 956 return set([h.decode('hex') for h in s.split(';')])
978 957
979 958 manifest = repo.vfs.tryread('pullbundles.manifest')
980 959 if not manifest:
981 960 return None
982 961 res = exchange.parseclonebundlesmanifest(repo, manifest)
983 962 res = exchange.filterclonebundleentries(repo, res)
984 963 if not res:
985 964 return None
986 965 cl = repo.changelog
987 966 heads_anc = cl.ancestors([cl.rev(rev) for rev in heads], inclusive=True)
988 967 common_anc = cl.ancestors([cl.rev(rev) for rev in common], inclusive=True)
989 968 compformats = clientcompressionsupport(proto)
990 969 for entry in res:
991 970 if 'COMPRESSION' in entry and entry['COMPRESSION'] not in compformats:
992 971 continue
993 972 # No test yet for VERSION, since V2 is supported by any client
994 973 # that advertises partial pulls
995 974 if 'heads' in entry:
996 975 try:
997 976 bundle_heads = decodehexstring(entry['heads'])
998 977 except TypeError:
999 978 # Bad heads entry
1000 979 continue
1001 980 if bundle_heads.issubset(common):
1002 981 continue # Nothing new
1003 982 if all(cl.rev(rev) in common_anc for rev in bundle_heads):
1004 983 continue # Still nothing new
1005 984 if any(cl.rev(rev) not in heads_anc and
1006 985 cl.rev(rev) not in common_anc for rev in bundle_heads):
1007 986 continue
1008 987 if 'bases' in entry:
1009 988 try:
1010 989 bundle_bases = decodehexstring(entry['bases'])
1011 990 except TypeError:
1012 991 # Bad bases entry
1013 992 continue
1014 993 if not all(cl.rev(rev) in common_anc for rev in bundle_bases):
1015 994 continue
1016 995 path = entry['URL']
1017 996 repo.ui.debug('sending pullbundle "%s"\n' % path)
1018 997 try:
1019 998 return repo.vfs.open(path)
1020 999 except IOError:
1021 1000 repo.ui.debug('pullbundle "%s" not accessible\n' % path)
1022 1001 continue
1023 1002 return None
1024 1003
1025 1004 @wireprotocommand('getbundle', '*', permission='pull',
1026 1005 transportpolicy=POLICY_V1_ONLY)
1027 1006 def getbundle(repo, proto, others):
1028 1007 opts = options('getbundle', gboptsmap.keys(), others)
1029 1008 for k, v in opts.iteritems():
1030 1009 keytype = gboptsmap[k]
1031 1010 if keytype == 'nodes':
1032 opts[k] = decodelist(v)
1011 opts[k] = wireprototypes.decodelist(v)
1033 1012 elif keytype == 'csv':
1034 1013 opts[k] = list(v.split(','))
1035 1014 elif keytype == 'scsv':
1036 1015 opts[k] = set(v.split(','))
1037 1016 elif keytype == 'boolean':
1038 1017 # Client should serialize False as '0', which is a non-empty string
1039 1018 # so it evaluates as a True bool.
1040 1019 if v == '0':
1041 1020 opts[k] = False
1042 1021 else:
1043 1022 opts[k] = bool(v)
1044 1023 elif keytype != 'plain':
1045 1024 raise KeyError('unknown getbundle option type %s'
1046 1025 % keytype)
1047 1026
1048 1027 if not bundle1allowed(repo, 'pull'):
1049 1028 if not exchange.bundle2requested(opts.get('bundlecaps')):
1050 1029 if proto.name == 'http-v1':
1051 1030 return wireprototypes.ooberror(bundle2required)
1052 1031 raise error.Abort(bundle2requiredmain,
1053 1032 hint=bundle2requiredhint)
1054 1033
1055 1034 prefercompressed = True
1056 1035
1057 1036 try:
1058 1037 clheads = set(repo.changelog.heads())
1059 1038 heads = set(opts.get('heads', set()))
1060 1039 common = set(opts.get('common', set()))
1061 1040 common.discard(nullid)
1062 1041 if (repo.ui.configbool('server', 'pullbundle') and
1063 1042 'partial-pull' in proto.getprotocaps()):
1064 1043 # Check if a pre-built bundle covers this request.
1065 1044 bundle = find_pullbundle(repo, proto, opts, clheads, heads, common)
1066 1045 if bundle:
1067 1046 return wireprototypes.streamres(gen=util.filechunkiter(bundle),
1068 1047 prefer_uncompressed=True)
1069 1048
1070 1049 if repo.ui.configbool('server', 'disablefullbundle'):
1071 1050 # Check to see if this is a full clone.
1072 1051 changegroup = opts.get('cg', True)
1073 1052 if changegroup and not common and clheads == heads:
1074 1053 raise error.Abort(
1075 1054 _('server has pull-based clones disabled'),
1076 1055 hint=_('remove --pull if specified or upgrade Mercurial'))
1077 1056
1078 1057 info, chunks = exchange.getbundlechunks(repo, 'serve',
1079 1058 **pycompat.strkwargs(opts))
1080 1059 prefercompressed = info.get('prefercompressed', True)
1081 1060 except error.Abort as exc:
1082 1061 # cleanly forward Abort error to the client
1083 1062 if not exchange.bundle2requested(opts.get('bundlecaps')):
1084 1063 if proto.name == 'http-v1':
1085 1064 return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n')
1086 1065 raise # cannot do better for bundle1 + ssh
1087 1066 # bundle2 request expect a bundle2 reply
1088 1067 bundler = bundle2.bundle20(repo.ui)
1089 1068 manargs = [('message', pycompat.bytestr(exc))]
1090 1069 advargs = []
1091 1070 if exc.hint is not None:
1092 1071 advargs.append(('hint', exc.hint))
1093 1072 bundler.addpart(bundle2.bundlepart('error:abort',
1094 1073 manargs, advargs))
1095 1074 chunks = bundler.getchunks()
1096 1075 prefercompressed = False
1097 1076
1098 1077 return wireprototypes.streamres(
1099 1078 gen=chunks, prefer_uncompressed=not prefercompressed)
1100 1079
1101 1080 @wireprotocommand('heads', permission='pull', transportpolicy=POLICY_V1_ONLY)
1102 1081 def heads(repo, proto):
1103 1082 h = repo.heads()
1104 return wireprototypes.bytesresponse(encodelist(h) + '\n')
1083 return wireprototypes.bytesresponse(wireprototypes.encodelist(h) + '\n')
1105 1084
1106 1085 @wireprotocommand('hello', permission='pull', transportpolicy=POLICY_V1_ONLY)
1107 1086 def hello(repo, proto):
1108 1087 """Called as part of SSH handshake to obtain server info.
1109 1088
1110 1089 Returns a list of lines describing interesting things about the
1111 1090 server, in an RFC822-like format.
1112 1091
1113 1092 Currently, the only one defined is ``capabilities``, which consists of a
1114 1093 line of space separated tokens describing server abilities:
1115 1094
1116 1095 capabilities: <token0> <token1> <token2>
1117 1096 """
1118 1097 caps = capabilities(repo, proto).data
1119 1098 return wireprototypes.bytesresponse('capabilities: %s\n' % caps)
1120 1099
1121 1100 @wireprotocommand('listkeys', 'namespace', permission='pull',
1122 1101 transportpolicy=POLICY_V1_ONLY)
1123 1102 def listkeys(repo, proto, namespace):
1124 1103 d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
1125 1104 return wireprototypes.bytesresponse(pushkeymod.encodekeys(d))
1126 1105
1127 1106 @wireprotocommand('lookup', 'key', permission='pull',
1128 1107 transportpolicy=POLICY_V1_ONLY)
1129 1108 def lookup(repo, proto, key):
1130 1109 try:
1131 1110 k = encoding.tolocal(key)
1132 1111 n = repo.lookup(k)
1133 1112 r = hex(n)
1134 1113 success = 1
1135 1114 except Exception as inst:
1136 1115 r = stringutil.forcebytestr(inst)
1137 1116 success = 0
1138 1117 return wireprototypes.bytesresponse('%d %s\n' % (success, r))
1139 1118
1140 1119 @wireprotocommand('known', 'nodes *', permission='pull',
1141 1120 transportpolicy=POLICY_V1_ONLY)
1142 1121 def known(repo, proto, nodes, others):
1143 v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
1122 v = ''.join(b and '1' or '0'
1123 for b in repo.known(wireprototypes.decodelist(nodes)))
1144 1124 return wireprototypes.bytesresponse(v)
1145 1125
1146 1126 @wireprotocommand('protocaps', 'caps', permission='pull',
1147 1127 transportpolicy=POLICY_V1_ONLY)
1148 1128 def protocaps(repo, proto, caps):
1149 1129 if proto.name == wireprototypes.SSHV1:
1150 1130 proto._protocaps = set(caps.split(' '))
1151 1131 return wireprototypes.bytesresponse('OK')
1152 1132
1153 1133 @wireprotocommand('pushkey', 'namespace key old new', permission='push',
1154 1134 transportpolicy=POLICY_V1_ONLY)
1155 1135 def pushkey(repo, proto, namespace, key, old, new):
1156 1136 # compatibility with pre-1.8 clients which were accidentally
1157 1137 # sending raw binary nodes rather than utf-8-encoded hex
1158 1138 if len(new) == 20 and stringutil.escapestr(new) != new:
1159 1139 # looks like it could be a binary node
1160 1140 try:
1161 1141 new.decode('utf-8')
1162 1142 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
1163 1143 except UnicodeDecodeError:
1164 1144 pass # binary, leave unmodified
1165 1145 else:
1166 1146 new = encoding.tolocal(new) # normal path
1167 1147
1168 1148 with proto.mayberedirectstdio() as output:
1169 1149 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
1170 1150 encoding.tolocal(old), new) or False
1171 1151
1172 1152 output = output.getvalue() if output else ''
1173 1153 return wireprototypes.bytesresponse('%d\n%s' % (int(r), output))
1174 1154
1175 1155 @wireprotocommand('stream_out', permission='pull',
1176 1156 transportpolicy=POLICY_V1_ONLY)
1177 1157 def stream(repo, proto):
1178 1158 '''If the server supports streaming clone, it advertises the "stream"
1179 1159 capability with a value representing the version and flags of the repo
1180 1160 it is serving. Client checks to see if it understands the format.
1181 1161 '''
1182 1162 return wireprototypes.streamreslegacy(
1183 1163 streamclone.generatev1wireproto(repo))
1184 1164
1185 1165 @wireprotocommand('unbundle', 'heads', permission='push',
1186 1166 transportpolicy=POLICY_V1_ONLY)
1187 1167 def unbundle(repo, proto, heads):
1188 their_heads = decodelist(heads)
1168 their_heads = wireprototypes.decodelist(heads)
1189 1169
1190 1170 with proto.mayberedirectstdio() as output:
1191 1171 try:
1192 1172 exchange.check_heads(repo, their_heads, 'preparing changes')
1193 1173 cleanup = lambda: None
1194 1174 try:
1195 1175 payload = proto.getpayload()
1196 1176 if repo.ui.configbool('server', 'streamunbundle'):
1197 1177 def cleanup():
1198 1178 # Ensure that the full payload is consumed, so
1199 1179 # that the connection doesn't contain trailing garbage.
1200 1180 for p in payload:
1201 1181 pass
1202 1182 fp = util.chunkbuffer(payload)
1203 1183 else:
1204 1184 # write bundle data to temporary file as it can be big
1205 1185 fp, tempname = None, None
1206 1186 def cleanup():
1207 1187 if fp:
1208 1188 fp.close()
1209 1189 if tempname:
1210 1190 os.unlink(tempname)
1211 1191 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
1212 1192 repo.ui.debug('redirecting incoming bundle to %s\n' %
1213 1193 tempname)
1214 1194 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
1215 1195 r = 0
1216 1196 for p in payload:
1217 1197 fp.write(p)
1218 1198 fp.seek(0)
1219 1199
1220 1200 gen = exchange.readbundle(repo.ui, fp, None)
1221 1201 if (isinstance(gen, changegroupmod.cg1unpacker)
1222 1202 and not bundle1allowed(repo, 'push')):
1223 1203 if proto.name == 'http-v1':
1224 1204 # need to special case http because stderr do not get to
1225 1205 # the http client on failed push so we need to abuse
1226 1206 # some other error type to make sure the message get to
1227 1207 # the user.
1228 1208 return wireprototypes.ooberror(bundle2required)
1229 1209 raise error.Abort(bundle2requiredmain,
1230 1210 hint=bundle2requiredhint)
1231 1211
1232 1212 r = exchange.unbundle(repo, gen, their_heads, 'serve',
1233 1213 proto.client())
1234 1214 if util.safehasattr(r, 'addpart'):
1235 1215 # The return looks streamable, we are in the bundle2 case
1236 1216 # and should return a stream.
1237 1217 return wireprototypes.streamreslegacy(gen=r.getchunks())
1238 1218 return wireprototypes.pushres(
1239 1219 r, output.getvalue() if output else '')
1240 1220
1241 1221 finally:
1242 1222 cleanup()
1243 1223
1244 1224 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1245 1225 # handle non-bundle2 case first
1246 1226 if not getattr(exc, 'duringunbundle2', False):
1247 1227 try:
1248 1228 raise
1249 1229 except error.Abort:
1250 1230 # The old code we moved used procutil.stderr directly.
1251 1231 # We did not change it to minimise code change.
1252 1232 # This need to be moved to something proper.
1253 1233 # Feel free to do it.
1254 1234 procutil.stderr.write("abort: %s\n" % exc)
1255 1235 if exc.hint is not None:
1256 1236 procutil.stderr.write("(%s)\n" % exc.hint)
1257 1237 procutil.stderr.flush()
1258 1238 return wireprototypes.pushres(
1259 1239 0, output.getvalue() if output else '')
1260 1240 except error.PushRaced:
1261 1241 return wireprototypes.pusherr(
1262 1242 pycompat.bytestr(exc),
1263 1243 output.getvalue() if output else '')
1264 1244
1265 1245 bundler = bundle2.bundle20(repo.ui)
1266 1246 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1267 1247 bundler.addpart(out)
1268 1248 try:
1269 1249 try:
1270 1250 raise
1271 1251 except error.PushkeyFailed as exc:
1272 1252 # check client caps
1273 1253 remotecaps = getattr(exc, '_replycaps', None)
1274 1254 if (remotecaps is not None
1275 1255 and 'pushkey' not in remotecaps.get('error', ())):
1276 1256 # no support remote side, fallback to Abort handler.
1277 1257 raise
1278 1258 part = bundler.newpart('error:pushkey')
1279 1259 part.addparam('in-reply-to', exc.partid)
1280 1260 if exc.namespace is not None:
1281 1261 part.addparam('namespace', exc.namespace,
1282 1262 mandatory=False)
1283 1263 if exc.key is not None:
1284 1264 part.addparam('key', exc.key, mandatory=False)
1285 1265 if exc.new is not None:
1286 1266 part.addparam('new', exc.new, mandatory=False)
1287 1267 if exc.old is not None:
1288 1268 part.addparam('old', exc.old, mandatory=False)
1289 1269 if exc.ret is not None:
1290 1270 part.addparam('ret', exc.ret, mandatory=False)
1291 1271 except error.BundleValueError as exc:
1292 1272 errpart = bundler.newpart('error:unsupportedcontent')
1293 1273 if exc.parttype is not None:
1294 1274 errpart.addparam('parttype', exc.parttype)
1295 1275 if exc.params:
1296 1276 errpart.addparam('params', '\0'.join(exc.params))
1297 1277 except error.Abort as exc:
1298 1278 manargs = [('message', stringutil.forcebytestr(exc))]
1299 1279 advargs = []
1300 1280 if exc.hint is not None:
1301 1281 advargs.append(('hint', exc.hint))
1302 1282 bundler.addpart(bundle2.bundlepart('error:abort',
1303 1283 manargs, advargs))
1304 1284 except error.PushRaced as exc:
1305 1285 bundler.newpart('error:pushraced',
1306 1286 [('message', stringutil.forcebytestr(exc))])
1307 1287 return wireprototypes.streamreslegacy(gen=bundler.getchunks())
@@ -1,171 +1,203
1 1 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
2 2 #
3 3 # This software may be used and distributed according to the terms of the
4 4 # GNU General Public License version 2 or any later version.
5 5
6 6 from __future__ import absolute_import
7 7
8 from .node import (
9 bin,
10 hex,
11 )
8 12 from .thirdparty.zope import (
9 13 interface as zi,
10 14 )
11 15
12 16 # Names of the SSH protocol implementations.
13 17 SSHV1 = 'ssh-v1'
14 18 # These are advertised over the wire. Increment the counters at the end
15 19 # to reflect BC breakages.
16 20 SSHV2 = 'exp-ssh-v2-0001'
17 21 HTTPV2 = 'exp-http-v2-0001'
18 22
19 23 # All available wire protocol transports.
20 24 TRANSPORTS = {
21 25 SSHV1: {
22 26 'transport': 'ssh',
23 27 'version': 1,
24 28 },
25 29 SSHV2: {
26 30 'transport': 'ssh',
27 31 # TODO mark as version 2 once all commands are implemented.
28 32 'version': 1,
29 33 },
30 34 'http-v1': {
31 35 'transport': 'http',
32 36 'version': 1,
33 37 },
34 38 HTTPV2: {
35 39 'transport': 'http',
36 40 'version': 2,
37 41 }
38 42 }
39 43
40 44 class bytesresponse(object):
41 45 """A wire protocol response consisting of raw bytes."""
42 46 def __init__(self, data):
43 47 self.data = data
44 48
45 49 class ooberror(object):
46 50 """wireproto reply: failure of a batch of operation
47 51
48 52 Something failed during a batch call. The error message is stored in
49 53 `self.message`.
50 54 """
51 55 def __init__(self, message):
52 56 self.message = message
53 57
54 58 class pushres(object):
55 59 """wireproto reply: success with simple integer return
56 60
57 61 The call was successful and returned an integer contained in `self.res`.
58 62 """
59 63 def __init__(self, res, output):
60 64 self.res = res
61 65 self.output = output
62 66
63 67 class pusherr(object):
64 68 """wireproto reply: failure
65 69
66 70 The call failed. The `self.res` attribute contains the error message.
67 71 """
68 72 def __init__(self, res, output):
69 73 self.res = res
70 74 self.output = output
71 75
72 76 class streamres(object):
73 77 """wireproto reply: binary stream
74 78
75 79 The call was successful and the result is a stream.
76 80
77 81 Accepts a generator containing chunks of data to be sent to the client.
78 82
79 83 ``prefer_uncompressed`` indicates that the data is expected to be
80 84 uncompressable and that the stream should therefore use the ``none``
81 85 engine.
82 86 """
83 87 def __init__(self, gen=None, prefer_uncompressed=False):
84 88 self.gen = gen
85 89 self.prefer_uncompressed = prefer_uncompressed
86 90
87 91 class streamreslegacy(object):
88 92 """wireproto reply: uncompressed binary stream
89 93
90 94 The call was successful and the result is a stream.
91 95
92 96 Accepts a generator containing chunks of data to be sent to the client.
93 97
94 98 Like ``streamres``, but sends an uncompressed data for "version 1" clients
95 99 using the application/mercurial-0.1 media type.
96 100 """
97 101 def __init__(self, gen=None):
98 102 self.gen = gen
99 103
100 104 class cborresponse(object):
101 105 """Encode the response value as CBOR."""
102 106 def __init__(self, v):
103 107 self.value = v
104 108
109 # list of nodes encoding / decoding
110 def decodelist(l, sep=' '):
111 if l:
112 return [bin(v) for v in l.split(sep)]
113 return []
114
115 def encodelist(l, sep=' '):
116 try:
117 return sep.join(map(hex, l))
118 except TypeError:
119 raise
120
121 # batched call argument encoding
122
123 def escapebatcharg(plain):
124 return (plain
125 .replace(':', ':c')
126 .replace(',', ':o')
127 .replace(';', ':s')
128 .replace('=', ':e'))
129
130 def unescapebatcharg(escaped):
131 return (escaped
132 .replace(':e', '=')
133 .replace(':s', ';')
134 .replace(':o', ',')
135 .replace(':c', ':'))
136
105 137 class baseprotocolhandler(zi.Interface):
106 138 """Abstract base class for wire protocol handlers.
107 139
108 140 A wire protocol handler serves as an interface between protocol command
109 141 handlers and the wire protocol transport layer. Protocol handlers provide
110 142 methods to read command arguments, redirect stdio for the duration of
111 143 the request, handle response types, etc.
112 144 """
113 145
114 146 name = zi.Attribute(
115 147 """The name of the protocol implementation.
116 148
117 149 Used for uniquely identifying the transport type.
118 150 """)
119 151
120 152 def getargs(args):
121 153 """return the value for arguments in <args>
122 154
123 155 For version 1 transports, returns a list of values in the same
124 156 order they appear in ``args``. For version 2 transports, returns
125 157 a dict mapping argument name to value.
126 158 """
127 159
128 160 def getprotocaps():
129 161 """Returns the list of protocol-level capabilities of client
130 162
131 163 Returns a list of capabilities as declared by the client for
132 164 the current request (or connection for stateful protocol handlers)."""
133 165
134 166 def getpayload():
135 167 """Provide a generator for the raw payload.
136 168
137 169 The caller is responsible for ensuring that the full payload is
138 170 processed.
139 171 """
140 172
141 173 def mayberedirectstdio():
142 174 """Context manager to possibly redirect stdio.
143 175
144 176 The context manager yields a file-object like object that receives
145 177 stdout and stderr output when the context manager is active. Or it
146 178 yields ``None`` if no I/O redirection occurs.
147 179
148 180 The intent of this context manager is to capture stdio output
149 181 so it may be sent in the response. Some transports support streaming
150 182 stdio to the client in real time. For these transports, stdio output
151 183 won't be captured.
152 184 """
153 185
154 186 def client():
155 187 """Returns a string representation of this client (as bytes)."""
156 188
157 189 def addcapabilities(repo, caps):
158 190 """Adds advertised capabilities specific to this protocol.
159 191
160 192 Receives the list of capabilities collected so far.
161 193
162 194 Returns a list of capabilities. The passed in argument can be returned.
163 195 """
164 196
165 197 def checkperm(perm):
166 198 """Validate that the client has permissions to perform a request.
167 199
168 200 The argument is the permission required to proceed. If the client
169 201 doesn't have that permission, the exception should raise or abort
170 202 in a protocol specific manner.
171 203 """
General Comments 0
You need to be logged in to leave comments. Login now