##// END OF EJS Templates
wireprotoserver: add context manager mechanism for redirecting stdio...
Gregory Szorc -
r36083:2ad145fb default
parent child Browse files
Show More
@@ -1,1101 +1,1093 b''
1 1 # wireproto.py - generic wire protocol support functions
2 2 #
3 3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import hashlib
11 11 import os
12 12 import tempfile
13 13
14 14 from .i18n import _
15 15 from .node import (
16 16 bin,
17 17 hex,
18 18 nullid,
19 19 )
20 20
21 21 from . import (
22 22 bundle2,
23 23 changegroup as changegroupmod,
24 24 discovery,
25 25 encoding,
26 26 error,
27 27 exchange,
28 28 peer,
29 29 pushkey as pushkeymod,
30 30 pycompat,
31 31 repository,
32 32 streamclone,
33 33 util,
34 34 )
35 35
36 36 urlerr = util.urlerr
37 37 urlreq = util.urlreq
38 38
39 39 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
40 40 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
41 41 'IncompatibleClient')
42 42 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
43 43
44 44 class remoteiterbatcher(peer.iterbatcher):
45 45 def __init__(self, remote):
46 46 super(remoteiterbatcher, self).__init__()
47 47 self._remote = remote
48 48
49 49 def __getattr__(self, name):
50 50 # Validate this method is batchable, since submit() only supports
51 51 # batchable methods.
52 52 fn = getattr(self._remote, name)
53 53 if not getattr(fn, 'batchable', None):
54 54 raise error.ProgrammingError('Attempted to batch a non-batchable '
55 55 'call to %r' % name)
56 56
57 57 return super(remoteiterbatcher, self).__getattr__(name)
58 58
59 59 def submit(self):
60 60 """Break the batch request into many patch calls and pipeline them.
61 61
62 62 This is mostly valuable over http where request sizes can be
63 63 limited, but can be used in other places as well.
64 64 """
65 65 # 2-tuple of (command, arguments) that represents what will be
66 66 # sent over the wire.
67 67 requests = []
68 68
69 69 # 4-tuple of (command, final future, @batchable generator, remote
70 70 # future).
71 71 results = []
72 72
73 73 for command, args, opts, finalfuture in self.calls:
74 74 mtd = getattr(self._remote, command)
75 75 batchable = mtd.batchable(mtd.__self__, *args, **opts)
76 76
77 77 commandargs, fremote = next(batchable)
78 78 assert fremote
79 79 requests.append((command, commandargs))
80 80 results.append((command, finalfuture, batchable, fremote))
81 81
82 82 if requests:
83 83 self._resultiter = self._remote._submitbatch(requests)
84 84
85 85 self._results = results
86 86
87 87 def results(self):
88 88 for command, finalfuture, batchable, remotefuture in self._results:
89 89 # Get the raw result, set it in the remote future, feed it
90 90 # back into the @batchable generator so it can be decoded, and
91 91 # set the result on the final future to this value.
92 92 remoteresult = next(self._resultiter)
93 93 remotefuture.set(remoteresult)
94 94 finalfuture.set(next(batchable))
95 95
96 96 # Verify our @batchable generators only emit 2 values.
97 97 try:
98 98 next(batchable)
99 99 except StopIteration:
100 100 pass
101 101 else:
102 102 raise error.ProgrammingError('%s @batchable generator emitted '
103 103 'unexpected value count' % command)
104 104
105 105 yield finalfuture.value
106 106
107 107 # Forward a couple of names from peer to make wireproto interactions
108 108 # slightly more sensible.
109 109 batchable = peer.batchable
110 110 future = peer.future
111 111
112 112 # list of nodes encoding / decoding
113 113
114 114 def decodelist(l, sep=' '):
115 115 if l:
116 116 return [bin(v) for v in l.split(sep)]
117 117 return []
118 118
119 119 def encodelist(l, sep=' '):
120 120 try:
121 121 return sep.join(map(hex, l))
122 122 except TypeError:
123 123 raise
124 124
125 125 # batched call argument encoding
126 126
127 127 def escapearg(plain):
128 128 return (plain
129 129 .replace(':', ':c')
130 130 .replace(',', ':o')
131 131 .replace(';', ':s')
132 132 .replace('=', ':e'))
133 133
134 134 def unescapearg(escaped):
135 135 return (escaped
136 136 .replace(':e', '=')
137 137 .replace(':s', ';')
138 138 .replace(':o', ',')
139 139 .replace(':c', ':'))
140 140
141 141 def encodebatchcmds(req):
142 142 """Return a ``cmds`` argument value for the ``batch`` command."""
143 143 cmds = []
144 144 for op, argsdict in req:
145 145 # Old servers didn't properly unescape argument names. So prevent
146 146 # the sending of argument names that may not be decoded properly by
147 147 # servers.
148 148 assert all(escapearg(k) == k for k in argsdict)
149 149
150 150 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
151 151 for k, v in argsdict.iteritems())
152 152 cmds.append('%s %s' % (op, args))
153 153
154 154 return ';'.join(cmds)
155 155
156 156 # mapping of options accepted by getbundle and their types
157 157 #
158 158 # Meant to be extended by extensions. It is extensions responsibility to ensure
159 159 # such options are properly processed in exchange.getbundle.
160 160 #
161 161 # supported types are:
162 162 #
163 163 # :nodes: list of binary nodes
164 164 # :csv: list of comma-separated values
165 165 # :scsv: list of comma-separated values return as set
166 166 # :plain: string with no transformation needed.
167 167 gboptsmap = {'heads': 'nodes',
168 168 'bookmarks': 'boolean',
169 169 'common': 'nodes',
170 170 'obsmarkers': 'boolean',
171 171 'phases': 'boolean',
172 172 'bundlecaps': 'scsv',
173 173 'listkeys': 'csv',
174 174 'cg': 'boolean',
175 175 'cbattempted': 'boolean',
176 176 'stream': 'boolean',
177 177 }
178 178
179 179 # client side
180 180
181 181 class wirepeer(repository.legacypeer):
182 182 """Client-side interface for communicating with a peer repository.
183 183
184 184 Methods commonly call wire protocol commands of the same name.
185 185
186 186 See also httppeer.py and sshpeer.py for protocol-specific
187 187 implementations of this interface.
188 188 """
189 189 # Begin of basewirepeer interface.
190 190
191 191 def iterbatch(self):
192 192 return remoteiterbatcher(self)
193 193
194 194 @batchable
195 195 def lookup(self, key):
196 196 self.requirecap('lookup', _('look up remote revision'))
197 197 f = future()
198 198 yield {'key': encoding.fromlocal(key)}, f
199 199 d = f.value
200 200 success, data = d[:-1].split(" ", 1)
201 201 if int(success):
202 202 yield bin(data)
203 203 else:
204 204 self._abort(error.RepoError(data))
205 205
206 206 @batchable
207 207 def heads(self):
208 208 f = future()
209 209 yield {}, f
210 210 d = f.value
211 211 try:
212 212 yield decodelist(d[:-1])
213 213 except ValueError:
214 214 self._abort(error.ResponseError(_("unexpected response:"), d))
215 215
216 216 @batchable
217 217 def known(self, nodes):
218 218 f = future()
219 219 yield {'nodes': encodelist(nodes)}, f
220 220 d = f.value
221 221 try:
222 222 yield [bool(int(b)) for b in d]
223 223 except ValueError:
224 224 self._abort(error.ResponseError(_("unexpected response:"), d))
225 225
226 226 @batchable
227 227 def branchmap(self):
228 228 f = future()
229 229 yield {}, f
230 230 d = f.value
231 231 try:
232 232 branchmap = {}
233 233 for branchpart in d.splitlines():
234 234 branchname, branchheads = branchpart.split(' ', 1)
235 235 branchname = encoding.tolocal(urlreq.unquote(branchname))
236 236 branchheads = decodelist(branchheads)
237 237 branchmap[branchname] = branchheads
238 238 yield branchmap
239 239 except TypeError:
240 240 self._abort(error.ResponseError(_("unexpected response:"), d))
241 241
242 242 @batchable
243 243 def listkeys(self, namespace):
244 244 if not self.capable('pushkey'):
245 245 yield {}, None
246 246 f = future()
247 247 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
248 248 yield {'namespace': encoding.fromlocal(namespace)}, f
249 249 d = f.value
250 250 self.ui.debug('received listkey for "%s": %i bytes\n'
251 251 % (namespace, len(d)))
252 252 yield pushkeymod.decodekeys(d)
253 253
254 254 @batchable
255 255 def pushkey(self, namespace, key, old, new):
256 256 if not self.capable('pushkey'):
257 257 yield False, None
258 258 f = future()
259 259 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
260 260 yield {'namespace': encoding.fromlocal(namespace),
261 261 'key': encoding.fromlocal(key),
262 262 'old': encoding.fromlocal(old),
263 263 'new': encoding.fromlocal(new)}, f
264 264 d = f.value
265 265 d, output = d.split('\n', 1)
266 266 try:
267 267 d = bool(int(d))
268 268 except ValueError:
269 269 raise error.ResponseError(
270 270 _('push failed (unexpected response):'), d)
271 271 for l in output.splitlines(True):
272 272 self.ui.status(_('remote: '), l)
273 273 yield d
274 274
275 275 def stream_out(self):
276 276 return self._callstream('stream_out')
277 277
278 278 def getbundle(self, source, **kwargs):
279 279 kwargs = pycompat.byteskwargs(kwargs)
280 280 self.requirecap('getbundle', _('look up remote changes'))
281 281 opts = {}
282 282 bundlecaps = kwargs.get('bundlecaps')
283 283 if bundlecaps is not None:
284 284 kwargs['bundlecaps'] = sorted(bundlecaps)
285 285 else:
286 286 bundlecaps = () # kwargs could have it to None
287 287 for key, value in kwargs.iteritems():
288 288 if value is None:
289 289 continue
290 290 keytype = gboptsmap.get(key)
291 291 if keytype is None:
292 292 raise error.ProgrammingError(
293 293 'Unexpectedly None keytype for key %s' % key)
294 294 elif keytype == 'nodes':
295 295 value = encodelist(value)
296 296 elif keytype in ('csv', 'scsv'):
297 297 value = ','.join(value)
298 298 elif keytype == 'boolean':
299 299 value = '%i' % bool(value)
300 300 elif keytype != 'plain':
301 301 raise KeyError('unknown getbundle option type %s'
302 302 % keytype)
303 303 opts[key] = value
304 304 f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
305 305 if any((cap.startswith('HG2') for cap in bundlecaps)):
306 306 return bundle2.getunbundler(self.ui, f)
307 307 else:
308 308 return changegroupmod.cg1unpacker(f, 'UN')
309 309
310 310 def unbundle(self, cg, heads, url):
311 311 '''Send cg (a readable file-like object representing the
312 312 changegroup to push, typically a chunkbuffer object) to the
313 313 remote server as a bundle.
314 314
315 315 When pushing a bundle10 stream, return an integer indicating the
316 316 result of the push (see changegroup.apply()).
317 317
318 318 When pushing a bundle20 stream, return a bundle20 stream.
319 319
320 320 `url` is the url the client thinks it's pushing to, which is
321 321 visible to hooks.
322 322 '''
323 323
324 324 if heads != ['force'] and self.capable('unbundlehash'):
325 325 heads = encodelist(['hashed',
326 326 hashlib.sha1(''.join(sorted(heads))).digest()])
327 327 else:
328 328 heads = encodelist(heads)
329 329
330 330 if util.safehasattr(cg, 'deltaheader'):
331 331 # this a bundle10, do the old style call sequence
332 332 ret, output = self._callpush("unbundle", cg, heads=heads)
333 333 if ret == "":
334 334 raise error.ResponseError(
335 335 _('push failed:'), output)
336 336 try:
337 337 ret = int(ret)
338 338 except ValueError:
339 339 raise error.ResponseError(
340 340 _('push failed (unexpected response):'), ret)
341 341
342 342 for l in output.splitlines(True):
343 343 self.ui.status(_('remote: '), l)
344 344 else:
345 345 # bundle2 push. Send a stream, fetch a stream.
346 346 stream = self._calltwowaystream('unbundle', cg, heads=heads)
347 347 ret = bundle2.getunbundler(self.ui, stream)
348 348 return ret
349 349
350 350 # End of basewirepeer interface.
351 351
352 352 # Begin of baselegacywirepeer interface.
353 353
354 354 def branches(self, nodes):
355 355 n = encodelist(nodes)
356 356 d = self._call("branches", nodes=n)
357 357 try:
358 358 br = [tuple(decodelist(b)) for b in d.splitlines()]
359 359 return br
360 360 except ValueError:
361 361 self._abort(error.ResponseError(_("unexpected response:"), d))
362 362
363 363 def between(self, pairs):
364 364 batch = 8 # avoid giant requests
365 365 r = []
366 366 for i in xrange(0, len(pairs), batch):
367 367 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
368 368 d = self._call("between", pairs=n)
369 369 try:
370 370 r.extend(l and decodelist(l) or [] for l in d.splitlines())
371 371 except ValueError:
372 372 self._abort(error.ResponseError(_("unexpected response:"), d))
373 373 return r
374 374
375 375 def changegroup(self, nodes, kind):
376 376 n = encodelist(nodes)
377 377 f = self._callcompressable("changegroup", roots=n)
378 378 return changegroupmod.cg1unpacker(f, 'UN')
379 379
380 380 def changegroupsubset(self, bases, heads, kind):
381 381 self.requirecap('changegroupsubset', _('look up remote changes'))
382 382 bases = encodelist(bases)
383 383 heads = encodelist(heads)
384 384 f = self._callcompressable("changegroupsubset",
385 385 bases=bases, heads=heads)
386 386 return changegroupmod.cg1unpacker(f, 'UN')
387 387
388 388 # End of baselegacywirepeer interface.
389 389
390 390 def _submitbatch(self, req):
391 391 """run batch request <req> on the server
392 392
393 393 Returns an iterator of the raw responses from the server.
394 394 """
395 395 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
396 396 chunk = rsp.read(1024)
397 397 work = [chunk]
398 398 while chunk:
399 399 while ';' not in chunk and chunk:
400 400 chunk = rsp.read(1024)
401 401 work.append(chunk)
402 402 merged = ''.join(work)
403 403 while ';' in merged:
404 404 one, merged = merged.split(';', 1)
405 405 yield unescapearg(one)
406 406 chunk = rsp.read(1024)
407 407 work = [merged, chunk]
408 408 yield unescapearg(''.join(work))
409 409
410 410 def _submitone(self, op, args):
411 411 return self._call(op, **pycompat.strkwargs(args))
412 412
413 413 def debugwireargs(self, one, two, three=None, four=None, five=None):
414 414 # don't pass optional arguments left at their default value
415 415 opts = {}
416 416 if three is not None:
417 417 opts[r'three'] = three
418 418 if four is not None:
419 419 opts[r'four'] = four
420 420 return self._call('debugwireargs', one=one, two=two, **opts)
421 421
422 422 def _call(self, cmd, **args):
423 423 """execute <cmd> on the server
424 424
425 425 The command is expected to return a simple string.
426 426
427 427 returns the server reply as a string."""
428 428 raise NotImplementedError()
429 429
430 430 def _callstream(self, cmd, **args):
431 431 """execute <cmd> on the server
432 432
433 433 The command is expected to return a stream. Note that if the
434 434 command doesn't return a stream, _callstream behaves
435 435 differently for ssh and http peers.
436 436
437 437 returns the server reply as a file like object.
438 438 """
439 439 raise NotImplementedError()
440 440
441 441 def _callcompressable(self, cmd, **args):
442 442 """execute <cmd> on the server
443 443
444 444 The command is expected to return a stream.
445 445
446 446 The stream may have been compressed in some implementations. This
447 447 function takes care of the decompression. This is the only difference
448 448 with _callstream.
449 449
450 450 returns the server reply as a file like object.
451 451 """
452 452 raise NotImplementedError()
453 453
454 454 def _callpush(self, cmd, fp, **args):
455 455 """execute a <cmd> on server
456 456
457 457 The command is expected to be related to a push. Push has a special
458 458 return method.
459 459
460 460 returns the server reply as a (ret, output) tuple. ret is either
461 461 empty (error) or a stringified int.
462 462 """
463 463 raise NotImplementedError()
464 464
465 465 def _calltwowaystream(self, cmd, fp, **args):
466 466 """execute <cmd> on server
467 467
468 468 The command will send a stream to the server and get a stream in reply.
469 469 """
470 470 raise NotImplementedError()
471 471
472 472 def _abort(self, exception):
473 473 """clearly abort the wire protocol connection and raise the exception
474 474 """
475 475 raise NotImplementedError()
476 476
477 477 # server side
478 478
479 479 # wire protocol command can either return a string or one of these classes.
480 480 class streamres(object):
481 481 """wireproto reply: binary stream
482 482
483 483 The call was successful and the result is a stream.
484 484
485 485 Accepts a generator containing chunks of data to be sent to the client.
486 486
487 487 ``prefer_uncompressed`` indicates that the data is expected to be
488 488 uncompressable and that the stream should therefore use the ``none``
489 489 engine.
490 490 """
491 491 def __init__(self, gen=None, prefer_uncompressed=False):
492 492 self.gen = gen
493 493 self.prefer_uncompressed = prefer_uncompressed
494 494
495 495 class streamres_legacy(object):
496 496 """wireproto reply: uncompressed binary stream
497 497
498 498 The call was successful and the result is a stream.
499 499
500 500 Accepts a generator containing chunks of data to be sent to the client.
501 501
502 502 Like ``streamres``, but sends an uncompressed data for "version 1" clients
503 503 using the application/mercurial-0.1 media type.
504 504 """
505 505 def __init__(self, gen=None):
506 506 self.gen = gen
507 507
508 508 class pushres(object):
509 509 """wireproto reply: success with simple integer return
510 510
511 511 The call was successful and returned an integer contained in `self.res`.
512 512 """
513 513 def __init__(self, res):
514 514 self.res = res
515 515
516 516 class pusherr(object):
517 517 """wireproto reply: failure
518 518
519 519 The call failed. The `self.res` attribute contains the error message.
520 520 """
521 521 def __init__(self, res):
522 522 self.res = res
523 523
524 524 class ooberror(object):
525 525 """wireproto reply: failure of a batch of operation
526 526
527 527 Something failed during a batch call. The error message is stored in
528 528 `self.message`.
529 529 """
530 530 def __init__(self, message):
531 531 self.message = message
532 532
533 533 def getdispatchrepo(repo, proto, command):
534 534 """Obtain the repo used for processing wire protocol commands.
535 535
536 536 The intent of this function is to serve as a monkeypatch point for
537 537 extensions that need commands to operate on different repo views under
538 538 specialized circumstances.
539 539 """
540 540 return repo.filtered('served')
541 541
542 542 def dispatch(repo, proto, command):
543 543 repo = getdispatchrepo(repo, proto, command)
544 544 func, spec = commands[command]
545 545 args = proto.getargs(spec)
546 546 return func(repo, proto, *args)
547 547
548 548 def options(cmd, keys, others):
549 549 opts = {}
550 550 for k in keys:
551 551 if k in others:
552 552 opts[k] = others[k]
553 553 del others[k]
554 554 if others:
555 555 util.stderr.write("warning: %s ignored unexpected arguments %s\n"
556 556 % (cmd, ",".join(others)))
557 557 return opts
558 558
559 559 def bundle1allowed(repo, action):
560 560 """Whether a bundle1 operation is allowed from the server.
561 561
562 562 Priority is:
563 563
564 564 1. server.bundle1gd.<action> (if generaldelta active)
565 565 2. server.bundle1.<action>
566 566 3. server.bundle1gd (if generaldelta active)
567 567 4. server.bundle1
568 568 """
569 569 ui = repo.ui
570 570 gd = 'generaldelta' in repo.requirements
571 571
572 572 if gd:
573 573 v = ui.configbool('server', 'bundle1gd.%s' % action)
574 574 if v is not None:
575 575 return v
576 576
577 577 v = ui.configbool('server', 'bundle1.%s' % action)
578 578 if v is not None:
579 579 return v
580 580
581 581 if gd:
582 582 v = ui.configbool('server', 'bundle1gd')
583 583 if v is not None:
584 584 return v
585 585
586 586 return ui.configbool('server', 'bundle1')
587 587
588 588 def supportedcompengines(ui, proto, role):
589 589 """Obtain the list of supported compression engines for a request."""
590 590 assert role in (util.CLIENTROLE, util.SERVERROLE)
591 591
592 592 compengines = util.compengines.supportedwireengines(role)
593 593
594 594 # Allow config to override default list and ordering.
595 595 if role == util.SERVERROLE:
596 596 configengines = ui.configlist('server', 'compressionengines')
597 597 config = 'server.compressionengines'
598 598 else:
599 599 # This is currently implemented mainly to facilitate testing. In most
600 600 # cases, the server should be in charge of choosing a compression engine
601 601 # because a server has the most to lose from a sub-optimal choice. (e.g.
602 602 # CPU DoS due to an expensive engine or a network DoS due to poor
603 603 # compression ratio).
604 604 configengines = ui.configlist('experimental',
605 605 'clientcompressionengines')
606 606 config = 'experimental.clientcompressionengines'
607 607
608 608 # No explicit config. Filter out the ones that aren't supposed to be
609 609 # advertised and return default ordering.
610 610 if not configengines:
611 611 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
612 612 return [e for e in compengines
613 613 if getattr(e.wireprotosupport(), attr) > 0]
614 614
615 615 # If compression engines are listed in the config, assume there is a good
616 616 # reason for it (like server operators wanting to achieve specific
617 617 # performance characteristics). So fail fast if the config references
618 618 # unusable compression engines.
619 619 validnames = set(e.name() for e in compengines)
620 620 invalidnames = set(e for e in configengines if e not in validnames)
621 621 if invalidnames:
622 622 raise error.Abort(_('invalid compression engine defined in %s: %s') %
623 623 (config, ', '.join(sorted(invalidnames))))
624 624
625 625 compengines = [e for e in compengines if e.name() in configengines]
626 626 compengines = sorted(compengines,
627 627 key=lambda e: configengines.index(e.name()))
628 628
629 629 if not compengines:
630 630 raise error.Abort(_('%s config option does not specify any known '
631 631 'compression engines') % config,
632 632 hint=_('usable compression engines: %s') %
633 633 ', '.sorted(validnames))
634 634
635 635 return compengines
636 636
637 637 class commandentry(object):
638 638 """Represents a declared wire protocol command."""
639 639 def __init__(self, func, args=''):
640 640 self.func = func
641 641 self.args = args
642 642
643 643 def _merge(self, func, args):
644 644 """Merge this instance with an incoming 2-tuple.
645 645
646 646 This is called when a caller using the old 2-tuple API attempts
647 647 to replace an instance. The incoming values are merged with
648 648 data not captured by the 2-tuple and a new instance containing
649 649 the union of the two objects is returned.
650 650 """
651 651 return commandentry(func, args)
652 652
653 653 # Old code treats instances as 2-tuples. So expose that interface.
654 654 def __iter__(self):
655 655 yield self.func
656 656 yield self.args
657 657
658 658 def __getitem__(self, i):
659 659 if i == 0:
660 660 return self.func
661 661 elif i == 1:
662 662 return self.args
663 663 else:
664 664 raise IndexError('can only access elements 0 and 1')
665 665
666 666 class commanddict(dict):
667 667 """Container for registered wire protocol commands.
668 668
669 669 It behaves like a dict. But __setitem__ is overwritten to allow silent
670 670 coercion of values from 2-tuples for API compatibility.
671 671 """
672 672 def __setitem__(self, k, v):
673 673 if isinstance(v, commandentry):
674 674 pass
675 675 # Cast 2-tuples to commandentry instances.
676 676 elif isinstance(v, tuple):
677 677 if len(v) != 2:
678 678 raise ValueError('command tuples must have exactly 2 elements')
679 679
680 680 # It is common for extensions to wrap wire protocol commands via
681 681 # e.g. ``wireproto.commands[x] = (newfn, args)``. Because callers
682 682 # doing this aren't aware of the new API that uses objects to store
683 683 # command entries, we automatically merge old state with new.
684 684 if k in self:
685 685 v = self[k]._merge(v[0], v[1])
686 686 else:
687 687 v = commandentry(v[0], v[1])
688 688 else:
689 689 raise ValueError('command entries must be commandentry instances '
690 690 'or 2-tuples')
691 691
692 692 return super(commanddict, self).__setitem__(k, v)
693 693
694 694 def commandavailable(self, command, proto):
695 695 """Determine if a command is available for the requested protocol."""
696 696 # For now, commands are available for all protocols. So do a simple
697 697 # membership test.
698 698 return command in self
699 699
700 700 commands = commanddict()
701 701
702 702 def wireprotocommand(name, args=''):
703 703 """Decorator to declare a wire protocol command.
704 704
705 705 ``name`` is the name of the wire protocol command being provided.
706 706
707 707 ``args`` is a space-delimited list of named arguments that the command
708 708 accepts. ``*`` is a special value that says to accept all arguments.
709 709 """
710 710 def register(func):
711 711 commands[name] = commandentry(func, args)
712 712 return func
713 713 return register
714 714
715 715 @wireprotocommand('batch', 'cmds *')
716 716 def batch(repo, proto, cmds, others):
717 717 repo = repo.filtered("served")
718 718 res = []
719 719 for pair in cmds.split(';'):
720 720 op, args = pair.split(' ', 1)
721 721 vals = {}
722 722 for a in args.split(','):
723 723 if a:
724 724 n, v = a.split('=')
725 725 vals[unescapearg(n)] = unescapearg(v)
726 726 func, spec = commands[op]
727 727 if spec:
728 728 keys = spec.split()
729 729 data = {}
730 730 for k in keys:
731 731 if k == '*':
732 732 star = {}
733 733 for key in vals.keys():
734 734 if key not in keys:
735 735 star[key] = vals[key]
736 736 data['*'] = star
737 737 else:
738 738 data[k] = vals[k]
739 739 result = func(repo, proto, *[data[k] for k in keys])
740 740 else:
741 741 result = func(repo, proto)
742 742 if isinstance(result, ooberror):
743 743 return result
744 744 res.append(escapearg(result))
745 745 return ';'.join(res)
746 746
747 747 @wireprotocommand('between', 'pairs')
748 748 def between(repo, proto, pairs):
749 749 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
750 750 r = []
751 751 for b in repo.between(pairs):
752 752 r.append(encodelist(b) + "\n")
753 753 return "".join(r)
754 754
755 755 @wireprotocommand('branchmap')
756 756 def branchmap(repo, proto):
757 757 branchmap = repo.branchmap()
758 758 heads = []
759 759 for branch, nodes in branchmap.iteritems():
760 760 branchname = urlreq.quote(encoding.fromlocal(branch))
761 761 branchnodes = encodelist(nodes)
762 762 heads.append('%s %s' % (branchname, branchnodes))
763 763 return '\n'.join(heads)
764 764
765 765 @wireprotocommand('branches', 'nodes')
766 766 def branches(repo, proto, nodes):
767 767 nodes = decodelist(nodes)
768 768 r = []
769 769 for b in repo.branches(nodes):
770 770 r.append(encodelist(b) + "\n")
771 771 return "".join(r)
772 772
773 773 @wireprotocommand('clonebundles', '')
774 774 def clonebundles(repo, proto):
775 775 """Server command for returning info for available bundles to seed clones.
776 776
777 777 Clients will parse this response and determine what bundle to fetch.
778 778
779 779 Extensions may wrap this command to filter or dynamically emit data
780 780 depending on the request. e.g. you could advertise URLs for the closest
781 781 data center given the client's IP address.
782 782 """
783 783 return repo.vfs.tryread('clonebundles.manifest')
784 784
785 785 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
786 786 'known', 'getbundle', 'unbundlehash', 'batch']
787 787
788 788 def _capabilities(repo, proto):
789 789 """return a list of capabilities for a repo
790 790
791 791 This function exists to allow extensions to easily wrap capabilities
792 792 computation
793 793
794 794 - returns a lists: easy to alter
795 795 - change done here will be propagated to both `capabilities` and `hello`
796 796 command without any other action needed.
797 797 """
798 798 # copy to prevent modification of the global list
799 799 caps = list(wireprotocaps)
800 800 if streamclone.allowservergeneration(repo):
801 801 if repo.ui.configbool('server', 'preferuncompressed'):
802 802 caps.append('stream-preferred')
803 803 requiredformats = repo.requirements & repo.supportedformats
804 804 # if our local revlogs are just revlogv1, add 'stream' cap
805 805 if not requiredformats - {'revlogv1'}:
806 806 caps.append('stream')
807 807 # otherwise, add 'streamreqs' detailing our local revlog format
808 808 else:
809 809 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
810 810 if repo.ui.configbool('experimental', 'bundle2-advertise'):
811 811 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
812 812 caps.append('bundle2=' + urlreq.quote(capsblob))
813 813 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
814 814
815 815 if proto.name == 'http':
816 816 caps.append('httpheader=%d' %
817 817 repo.ui.configint('server', 'maxhttpheaderlen'))
818 818 if repo.ui.configbool('experimental', 'httppostargs'):
819 819 caps.append('httppostargs')
820 820
821 821 # FUTURE advertise 0.2rx once support is implemented
822 822 # FUTURE advertise minrx and mintx after consulting config option
823 823 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
824 824
825 825 compengines = supportedcompengines(repo.ui, proto, util.SERVERROLE)
826 826 if compengines:
827 827 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
828 828 for e in compengines)
829 829 caps.append('compression=%s' % comptypes)
830 830
831 831 return caps
832 832
833 833 # If you are writing an extension and consider wrapping this function. Wrap
834 834 # `_capabilities` instead.
835 835 @wireprotocommand('capabilities')
836 836 def capabilities(repo, proto):
837 837 return ' '.join(_capabilities(repo, proto))
838 838
839 839 @wireprotocommand('changegroup', 'roots')
840 840 def changegroup(repo, proto, roots):
841 841 nodes = decodelist(roots)
842 842 outgoing = discovery.outgoing(repo, missingroots=nodes,
843 843 missingheads=repo.heads())
844 844 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
845 845 gen = iter(lambda: cg.read(32768), '')
846 846 return streamres(gen=gen)
847 847
848 848 @wireprotocommand('changegroupsubset', 'bases heads')
849 849 def changegroupsubset(repo, proto, bases, heads):
850 850 bases = decodelist(bases)
851 851 heads = decodelist(heads)
852 852 outgoing = discovery.outgoing(repo, missingroots=bases,
853 853 missingheads=heads)
854 854 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
855 855 gen = iter(lambda: cg.read(32768), '')
856 856 return streamres(gen=gen)
857 857
858 858 @wireprotocommand('debugwireargs', 'one two *')
859 859 def debugwireargs(repo, proto, one, two, others):
860 860 # only accept optional args from the known set
861 861 opts = options('debugwireargs', ['three', 'four'], others)
862 862 return repo.debugwireargs(one, two, **pycompat.strkwargs(opts))
863 863
864 864 @wireprotocommand('getbundle', '*')
865 865 def getbundle(repo, proto, others):
866 866 opts = options('getbundle', gboptsmap.keys(), others)
867 867 for k, v in opts.iteritems():
868 868 keytype = gboptsmap[k]
869 869 if keytype == 'nodes':
870 870 opts[k] = decodelist(v)
871 871 elif keytype == 'csv':
872 872 opts[k] = list(v.split(','))
873 873 elif keytype == 'scsv':
874 874 opts[k] = set(v.split(','))
875 875 elif keytype == 'boolean':
876 876 # Client should serialize False as '0', which is a non-empty string
877 877 # so it evaluates as a True bool.
878 878 if v == '0':
879 879 opts[k] = False
880 880 else:
881 881 opts[k] = bool(v)
882 882 elif keytype != 'plain':
883 883 raise KeyError('unknown getbundle option type %s'
884 884 % keytype)
885 885
886 886 if not bundle1allowed(repo, 'pull'):
887 887 if not exchange.bundle2requested(opts.get('bundlecaps')):
888 888 if proto.name == 'http':
889 889 return ooberror(bundle2required)
890 890 raise error.Abort(bundle2requiredmain,
891 891 hint=bundle2requiredhint)
892 892
893 893 prefercompressed = True
894 894
895 895 try:
896 896 if repo.ui.configbool('server', 'disablefullbundle'):
897 897 # Check to see if this is a full clone.
898 898 clheads = set(repo.changelog.heads())
899 899 changegroup = opts.get('cg', True)
900 900 heads = set(opts.get('heads', set()))
901 901 common = set(opts.get('common', set()))
902 902 common.discard(nullid)
903 903 if changegroup and not common and clheads == heads:
904 904 raise error.Abort(
905 905 _('server has pull-based clones disabled'),
906 906 hint=_('remove --pull if specified or upgrade Mercurial'))
907 907
908 908 info, chunks = exchange.getbundlechunks(repo, 'serve',
909 909 **pycompat.strkwargs(opts))
910 910 prefercompressed = info.get('prefercompressed', True)
911 911 except error.Abort as exc:
912 912 # cleanly forward Abort error to the client
913 913 if not exchange.bundle2requested(opts.get('bundlecaps')):
914 914 if proto.name == 'http':
915 915 return ooberror(str(exc) + '\n')
916 916 raise # cannot do better for bundle1 + ssh
917 917 # bundle2 request expect a bundle2 reply
918 918 bundler = bundle2.bundle20(repo.ui)
919 919 manargs = [('message', str(exc))]
920 920 advargs = []
921 921 if exc.hint is not None:
922 922 advargs.append(('hint', exc.hint))
923 923 bundler.addpart(bundle2.bundlepart('error:abort',
924 924 manargs, advargs))
925 925 chunks = bundler.getchunks()
926 926 prefercompressed = False
927 927
928 928 return streamres(gen=chunks, prefer_uncompressed=not prefercompressed)
929 929
930 930 @wireprotocommand('heads')
931 931 def heads(repo, proto):
932 932 h = repo.heads()
933 933 return encodelist(h) + "\n"
934 934
935 935 @wireprotocommand('hello')
936 936 def hello(repo, proto):
937 937 '''the hello command returns a set of lines describing various
938 938 interesting things about the server, in an RFC822-like format.
939 939 Currently the only one defined is "capabilities", which
940 940 consists of a line in the form:
941 941
942 942 capabilities: space separated list of tokens
943 943 '''
944 944 return "capabilities: %s\n" % (capabilities(repo, proto))
945 945
946 946 @wireprotocommand('listkeys', 'namespace')
947 947 def listkeys(repo, proto, namespace):
948 948 d = repo.listkeys(encoding.tolocal(namespace)).items()
949 949 return pushkeymod.encodekeys(d)
950 950
951 951 @wireprotocommand('lookup', 'key')
952 952 def lookup(repo, proto, key):
953 953 try:
954 954 k = encoding.tolocal(key)
955 955 c = repo[k]
956 956 r = c.hex()
957 957 success = 1
958 958 except Exception as inst:
959 959 r = str(inst)
960 960 success = 0
961 961 return "%d %s\n" % (success, r)
962 962
963 963 @wireprotocommand('known', 'nodes *')
964 964 def known(repo, proto, nodes, others):
965 965 return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
966 966
967 967 @wireprotocommand('pushkey', 'namespace key old new')
968 968 def pushkey(repo, proto, namespace, key, old, new):
969 969 # compatibility with pre-1.8 clients which were accidentally
970 970 # sending raw binary nodes rather than utf-8-encoded hex
971 971 if len(new) == 20 and util.escapestr(new) != new:
972 972 # looks like it could be a binary node
973 973 try:
974 974 new.decode('utf-8')
975 975 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
976 976 except UnicodeDecodeError:
977 977 pass # binary, leave unmodified
978 978 else:
979 979 new = encoding.tolocal(new) # normal path
980 980
981 if util.safehasattr(proto, 'restore'):
982
983 proto.redirect()
984
981 with proto.mayberedirectstdio() as output:
985 982 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
986 983 encoding.tolocal(old), new) or False
987 984
988 output = proto.restore()
989
985 output = output.getvalue() if output else ''
990 986 return '%s\n%s' % (int(r), output)
991 987
992 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
993 encoding.tolocal(old), new)
994 return '%s\n' % int(r)
995
996 988 @wireprotocommand('stream_out')
997 989 def stream(repo, proto):
998 990 '''If the server supports streaming clone, it advertises the "stream"
999 991 capability with a value representing the version and flags of the repo
1000 992 it is serving. Client checks to see if it understands the format.
1001 993 '''
1002 994 return streamres_legacy(streamclone.generatev1wireproto(repo))
1003 995
1004 996 @wireprotocommand('unbundle', 'heads')
1005 997 def unbundle(repo, proto, heads):
1006 998 their_heads = decodelist(heads)
1007 999
1008 1000 try:
1009 1001 proto.redirect()
1010 1002
1011 1003 exchange.check_heads(repo, their_heads, 'preparing changes')
1012 1004
1013 1005 # write bundle data to temporary file because it can be big
1014 1006 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
1015 1007 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
1016 1008 r = 0
1017 1009 try:
1018 1010 proto.getfile(fp)
1019 1011 fp.seek(0)
1020 1012 gen = exchange.readbundle(repo.ui, fp, None)
1021 1013 if (isinstance(gen, changegroupmod.cg1unpacker)
1022 1014 and not bundle1allowed(repo, 'push')):
1023 1015 if proto.name == 'http':
1024 1016 # need to special case http because stderr do not get to
1025 1017 # the http client on failed push so we need to abuse some
1026 1018 # other error type to make sure the message get to the
1027 1019 # user.
1028 1020 return ooberror(bundle2required)
1029 1021 raise error.Abort(bundle2requiredmain,
1030 1022 hint=bundle2requiredhint)
1031 1023
1032 1024 r = exchange.unbundle(repo, gen, their_heads, 'serve',
1033 1025 proto._client())
1034 1026 if util.safehasattr(r, 'addpart'):
1035 1027 # The return looks streamable, we are in the bundle2 case and
1036 1028 # should return a stream.
1037 1029 return streamres_legacy(gen=r.getchunks())
1038 1030 return pushres(r)
1039 1031
1040 1032 finally:
1041 1033 fp.close()
1042 1034 os.unlink(tempname)
1043 1035
1044 1036 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1045 1037 # handle non-bundle2 case first
1046 1038 if not getattr(exc, 'duringunbundle2', False):
1047 1039 try:
1048 1040 raise
1049 1041 except error.Abort:
1050 1042 # The old code we moved used util.stderr directly.
1051 1043 # We did not change it to minimise code change.
1052 1044 # This need to be moved to something proper.
1053 1045 # Feel free to do it.
1054 1046 util.stderr.write("abort: %s\n" % exc)
1055 1047 if exc.hint is not None:
1056 1048 util.stderr.write("(%s)\n" % exc.hint)
1057 1049 return pushres(0)
1058 1050 except error.PushRaced:
1059 1051 return pusherr(str(exc))
1060 1052
1061 1053 bundler = bundle2.bundle20(repo.ui)
1062 1054 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1063 1055 bundler.addpart(out)
1064 1056 try:
1065 1057 try:
1066 1058 raise
1067 1059 except error.PushkeyFailed as exc:
1068 1060 # check client caps
1069 1061 remotecaps = getattr(exc, '_replycaps', None)
1070 1062 if (remotecaps is not None
1071 1063 and 'pushkey' not in remotecaps.get('error', ())):
1072 1064 # no support remote side, fallback to Abort handler.
1073 1065 raise
1074 1066 part = bundler.newpart('error:pushkey')
1075 1067 part.addparam('in-reply-to', exc.partid)
1076 1068 if exc.namespace is not None:
1077 1069 part.addparam('namespace', exc.namespace, mandatory=False)
1078 1070 if exc.key is not None:
1079 1071 part.addparam('key', exc.key, mandatory=False)
1080 1072 if exc.new is not None:
1081 1073 part.addparam('new', exc.new, mandatory=False)
1082 1074 if exc.old is not None:
1083 1075 part.addparam('old', exc.old, mandatory=False)
1084 1076 if exc.ret is not None:
1085 1077 part.addparam('ret', exc.ret, mandatory=False)
1086 1078 except error.BundleValueError as exc:
1087 1079 errpart = bundler.newpart('error:unsupportedcontent')
1088 1080 if exc.parttype is not None:
1089 1081 errpart.addparam('parttype', exc.parttype)
1090 1082 if exc.params:
1091 1083 errpart.addparam('params', '\0'.join(exc.params))
1092 1084 except error.Abort as exc:
1093 1085 manargs = [('message', str(exc))]
1094 1086 advargs = []
1095 1087 if exc.hint is not None:
1096 1088 advargs.append(('hint', exc.hint))
1097 1089 bundler.addpart(bundle2.bundlepart('error:abort',
1098 1090 manargs, advargs))
1099 1091 except error.PushRaced as exc:
1100 1092 bundler.newpart('error:pushraced', [('message', str(exc))])
1101 1093 return streamres_legacy(gen=bundler.getchunks())
@@ -1,447 +1,481 b''
1 1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 3 #
4 4 # This software may be used and distributed according to the terms of the
5 5 # GNU General Public License version 2 or any later version.
6 6
7 7 from __future__ import absolute_import
8 8
9 9 import abc
10 10 import cgi
11 import contextlib
11 12 import struct
12 13 import sys
13 14
14 15 from .i18n import _
15 16 from . import (
16 17 encoding,
17 18 error,
18 19 hook,
19 20 pycompat,
20 21 util,
21 22 wireproto,
22 23 )
23 24
24 25 stringio = util.stringio
25 26
26 27 urlerr = util.urlerr
27 28 urlreq = util.urlreq
28 29
29 30 HTTP_OK = 200
30 31
31 32 HGTYPE = 'application/mercurial-0.1'
32 33 HGTYPE2 = 'application/mercurial-0.2'
33 34 HGERRTYPE = 'application/hg-error'
34 35
35 36 # Names of the SSH protocol implementations.
36 37 SSHV1 = 'ssh-v1'
37 38 # This is advertised over the wire. Incremental the counter at the end
38 39 # to reflect BC breakages.
39 40 SSHV2 = 'exp-ssh-v2-0001'
40 41
41 42 class baseprotocolhandler(object):
42 43 """Abstract base class for wire protocol handlers.
43 44
44 45 A wire protocol handler serves as an interface between protocol command
45 46 handlers and the wire protocol transport layer. Protocol handlers provide
46 47 methods to read command arguments, redirect stdio for the duration of
47 48 the request, handle response types, etc.
48 49 """
49 50
50 51 __metaclass__ = abc.ABCMeta
51 52
52 53 @abc.abstractproperty
53 54 def name(self):
54 55 """The name of the protocol implementation.
55 56
56 57 Used for uniquely identifying the transport type.
57 58 """
58 59
59 60 @abc.abstractmethod
60 61 def getargs(self, args):
61 62 """return the value for arguments in <args>
62 63
63 64 returns a list of values (same order as <args>)"""
64 65
65 66 @abc.abstractmethod
66 67 def getfile(self, fp):
67 68 """write the whole content of a file into a file like object
68 69
69 70 The file is in the form::
70 71
71 72 (<chunk-size>\n<chunk>)+0\n
72 73
73 74 chunk size is the ascii version of the int.
74 75 """
75 76
76 77 @abc.abstractmethod
78 def mayberedirectstdio(self):
79 """Context manager to possibly redirect stdio.
80
81 The context manager yields a file-object like object that receives
82 stdout and stderr output when the context manager is active. Or it
83 yields ``None`` if no I/O redirection occurs.
84
85 The intent of this context manager is to capture stdio output
86 so it may be sent in the response. Some transports support streaming
87 stdio to the client in real time. For these transports, stdio output
88 won't be captured.
89 """
90
91 @abc.abstractmethod
77 92 def redirect(self):
78 93 """may setup interception for stdout and stderr
79 94
80 95 See also the `restore` method."""
81 96
82 97 # If the `redirect` function does install interception, the `restore`
83 98 # function MUST be defined. If interception is not used, this function
84 99 # MUST NOT be defined.
85 100 #
86 101 # left commented here on purpose
87 102 #
88 103 #def restore(self):
89 104 # """reinstall previous stdout and stderr and return intercepted stdout
90 105 # """
91 106 # raise NotImplementedError()
92 107
93 108 def decodevaluefromheaders(req, headerprefix):
94 109 """Decode a long value from multiple HTTP request headers.
95 110
96 111 Returns the value as a bytes, not a str.
97 112 """
98 113 chunks = []
99 114 i = 1
100 115 prefix = headerprefix.upper().replace(r'-', r'_')
101 116 while True:
102 117 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
103 118 if v is None:
104 119 break
105 120 chunks.append(pycompat.bytesurl(v))
106 121 i += 1
107 122
108 123 return ''.join(chunks)
109 124
110 125 class webproto(baseprotocolhandler):
111 126 def __init__(self, req, ui):
112 127 self._req = req
113 128 self._ui = ui
114 129
115 130 @property
116 131 def name(self):
117 132 return 'http'
118 133
119 134 def getargs(self, args):
120 135 knownargs = self._args()
121 136 data = {}
122 137 keys = args.split()
123 138 for k in keys:
124 139 if k == '*':
125 140 star = {}
126 141 for key in knownargs.keys():
127 142 if key != 'cmd' and key not in keys:
128 143 star[key] = knownargs[key][0]
129 144 data['*'] = star
130 145 else:
131 146 data[k] = knownargs[k][0]
132 147 return [data[k] for k in keys]
133 148
134 149 def _args(self):
135 150 args = util.rapply(pycompat.bytesurl, self._req.form.copy())
136 151 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
137 152 if postlen:
138 153 args.update(cgi.parse_qs(
139 154 self._req.read(postlen), keep_blank_values=True))
140 155 return args
141 156
142 157 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
143 158 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
144 159 return args
145 160
146 161 def getfile(self, fp):
147 162 length = int(self._req.env[r'CONTENT_LENGTH'])
148 163 # If httppostargs is used, we need to read Content-Length
149 164 # minus the amount that was consumed by args.
150 165 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
151 166 for s in util.filechunkiter(self._req, limit=length):
152 167 fp.write(s)
153 168
169 @contextlib.contextmanager
170 def mayberedirectstdio(self):
171 oldout = self._ui.fout
172 olderr = self._ui.ferr
173
174 out = util.stringio()
175
176 try:
177 self._ui.fout = out
178 self._ui.ferr = out
179 yield out
180 finally:
181 self._ui.fout = oldout
182 self._ui.ferr = olderr
183
154 184 def redirect(self):
155 185 self._oldio = self._ui.fout, self._ui.ferr
156 186 self._ui.ferr = self._ui.fout = stringio()
157 187
158 188 def restore(self):
159 189 val = self._ui.fout.getvalue()
160 190 self._ui.ferr, self._ui.fout = self._oldio
161 191 return val
162 192
163 193 def _client(self):
164 194 return 'remote:%s:%s:%s' % (
165 195 self._req.env.get('wsgi.url_scheme') or 'http',
166 196 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
167 197 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
168 198
169 199 def responsetype(self, prefer_uncompressed):
170 200 """Determine the appropriate response type and compression settings.
171 201
172 202 Returns a tuple of (mediatype, compengine, engineopts).
173 203 """
174 204 # Determine the response media type and compression engine based
175 205 # on the request parameters.
176 206 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
177 207
178 208 if '0.2' in protocaps:
179 209 # All clients are expected to support uncompressed data.
180 210 if prefer_uncompressed:
181 211 return HGTYPE2, util._noopengine(), {}
182 212
183 213 # Default as defined by wire protocol spec.
184 214 compformats = ['zlib', 'none']
185 215 for cap in protocaps:
186 216 if cap.startswith('comp='):
187 217 compformats = cap[5:].split(',')
188 218 break
189 219
190 220 # Now find an agreed upon compression format.
191 221 for engine in wireproto.supportedcompengines(self._ui, self,
192 222 util.SERVERROLE):
193 223 if engine.wireprotosupport().name in compformats:
194 224 opts = {}
195 225 level = self._ui.configint('server',
196 226 '%slevel' % engine.name())
197 227 if level is not None:
198 228 opts['level'] = level
199 229
200 230 return HGTYPE2, engine, opts
201 231
202 232 # No mutually supported compression format. Fall back to the
203 233 # legacy protocol.
204 234
205 235 # Don't allow untrusted settings because disabling compression or
206 236 # setting a very high compression level could lead to flooding
207 237 # the server's network or CPU.
208 238 opts = {'level': self._ui.configint('server', 'zliblevel')}
209 239 return HGTYPE, util.compengines['zlib'], opts
210 240
211 241 def iscmd(cmd):
212 242 return cmd in wireproto.commands
213 243
214 244 def parsehttprequest(repo, req, query):
215 245 """Parse the HTTP request for a wire protocol request.
216 246
217 247 If the current request appears to be a wire protocol request, this
218 248 function returns a dict with details about that request, including
219 249 an ``abstractprotocolserver`` instance suitable for handling the
220 250 request. Otherwise, ``None`` is returned.
221 251
222 252 ``req`` is a ``wsgirequest`` instance.
223 253 """
224 254 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
225 255 # string parameter. If it isn't present, this isn't a wire protocol
226 256 # request.
227 257 if r'cmd' not in req.form:
228 258 return None
229 259
230 260 cmd = pycompat.sysbytes(req.form[r'cmd'][0])
231 261
232 262 # The "cmd" request parameter is used by both the wire protocol and hgweb.
233 263 # While not all wire protocol commands are available for all transports,
234 264 # if we see a "cmd" value that resembles a known wire protocol command, we
235 265 # route it to a protocol handler. This is better than routing possible
236 266 # wire protocol requests to hgweb because it prevents hgweb from using
237 267 # known wire protocol commands and it is less confusing for machine
238 268 # clients.
239 269 if cmd not in wireproto.commands:
240 270 return None
241 271
242 272 proto = webproto(req, repo.ui)
243 273
244 274 return {
245 275 'cmd': cmd,
246 276 'proto': proto,
247 277 'dispatch': lambda: _callhttp(repo, req, proto, cmd),
248 278 'handleerror': lambda ex: _handlehttperror(ex, req, cmd),
249 279 }
250 280
251 281 def _callhttp(repo, req, proto, cmd):
252 282 def genversion2(gen, engine, engineopts):
253 283 # application/mercurial-0.2 always sends a payload header
254 284 # identifying the compression engine.
255 285 name = engine.wireprotosupport().name
256 286 assert 0 < len(name) < 256
257 287 yield struct.pack('B', len(name))
258 288 yield name
259 289
260 290 for chunk in gen:
261 291 yield chunk
262 292
263 293 rsp = wireproto.dispatch(repo, proto, cmd)
264 294
265 295 if not wireproto.commands.commandavailable(cmd, proto):
266 296 req.respond(HTTP_OK, HGERRTYPE,
267 297 body=_('requested wire protocol command is not available '
268 298 'over HTTP'))
269 299 return []
270 300
271 301 if isinstance(rsp, bytes):
272 302 req.respond(HTTP_OK, HGTYPE, body=rsp)
273 303 return []
274 304 elif isinstance(rsp, wireproto.streamres_legacy):
275 305 gen = rsp.gen
276 306 req.respond(HTTP_OK, HGTYPE)
277 307 return gen
278 308 elif isinstance(rsp, wireproto.streamres):
279 309 gen = rsp.gen
280 310
281 311 # This code for compression should not be streamres specific. It
282 312 # is here because we only compress streamres at the moment.
283 313 mediatype, engine, engineopts = proto.responsetype(
284 314 rsp.prefer_uncompressed)
285 315 gen = engine.compressstream(gen, engineopts)
286 316
287 317 if mediatype == HGTYPE2:
288 318 gen = genversion2(gen, engine, engineopts)
289 319
290 320 req.respond(HTTP_OK, mediatype)
291 321 return gen
292 322 elif isinstance(rsp, wireproto.pushres):
293 323 val = proto.restore()
294 324 rsp = '%d\n%s' % (rsp.res, val)
295 325 req.respond(HTTP_OK, HGTYPE, body=rsp)
296 326 return []
297 327 elif isinstance(rsp, wireproto.pusherr):
298 328 # This is the httplib workaround documented in _handlehttperror().
299 329 req.drain()
300 330
301 331 proto.restore()
302 332 rsp = '0\n%s\n' % rsp.res
303 333 req.respond(HTTP_OK, HGTYPE, body=rsp)
304 334 return []
305 335 elif isinstance(rsp, wireproto.ooberror):
306 336 rsp = rsp.message
307 337 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
308 338 return []
309 339 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
310 340
311 341 def _handlehttperror(e, req, cmd):
312 342 """Called when an ErrorResponse is raised during HTTP request processing."""
313 343
314 344 # Clients using Python's httplib are stateful: the HTTP client
315 345 # won't process an HTTP response until all request data is
316 346 # sent to the server. The intent of this code is to ensure
317 347 # we always read HTTP request data from the client, thus
318 348 # ensuring httplib transitions to a state that allows it to read
319 349 # the HTTP response. In other words, it helps prevent deadlocks
320 350 # on clients using httplib.
321 351
322 352 if (req.env[r'REQUEST_METHOD'] == r'POST' and
323 353 # But not if Expect: 100-continue is being used.
324 354 (req.env.get('HTTP_EXPECT',
325 355 '').lower() != '100-continue') or
326 356 # Or the non-httplib HTTP library is being advertised by
327 357 # the client.
328 358 req.env.get('X-HgHttp2', '')):
329 359 req.drain()
330 360 else:
331 361 req.headers.append((r'Connection', r'Close'))
332 362
333 363 # TODO This response body assumes the failed command was
334 364 # "unbundle." That assumption is not always valid.
335 365 req.respond(e, HGTYPE, body='0\n%s\n' % e)
336 366
337 367 return ''
338 368
339 369 def _sshv1respondbytes(fout, value):
340 370 """Send a bytes response for protocol version 1."""
341 371 fout.write('%d\n' % len(value))
342 372 fout.write(value)
343 373 fout.flush()
344 374
345 375 def _sshv1respondstream(fout, source):
346 376 write = fout.write
347 377 for chunk in source.gen:
348 378 write(chunk)
349 379 fout.flush()
350 380
351 381 def _sshv1respondooberror(fout, ferr, rsp):
352 382 ferr.write(b'%s\n-\n' % rsp)
353 383 ferr.flush()
354 384 fout.write(b'\n')
355 385 fout.flush()
356 386
357 387 class sshv1protocolhandler(baseprotocolhandler):
358 388 """Handler for requests services via version 1 of SSH protocol."""
359 389 def __init__(self, ui, fin, fout):
360 390 self._ui = ui
361 391 self._fin = fin
362 392 self._fout = fout
363 393
364 394 @property
365 395 def name(self):
366 396 return 'ssh'
367 397
368 398 def getargs(self, args):
369 399 data = {}
370 400 keys = args.split()
371 401 for n in xrange(len(keys)):
372 402 argline = self._fin.readline()[:-1]
373 403 arg, l = argline.split()
374 404 if arg not in keys:
375 405 raise error.Abort(_("unexpected parameter %r") % arg)
376 406 if arg == '*':
377 407 star = {}
378 408 for k in xrange(int(l)):
379 409 argline = self._fin.readline()[:-1]
380 410 arg, l = argline.split()
381 411 val = self._fin.read(int(l))
382 412 star[arg] = val
383 413 data['*'] = star
384 414 else:
385 415 val = self._fin.read(int(l))
386 416 data[arg] = val
387 417 return [data[k] for k in keys]
388 418
389 419 def getfile(self, fpout):
390 420 _sshv1respondbytes(self._fout, b'')
391 421 count = int(self._fin.readline())
392 422 while count:
393 423 fpout.write(self._fin.read(count))
394 424 count = int(self._fin.readline())
395 425
426 @contextlib.contextmanager
427 def mayberedirectstdio(self):
428 yield None
429
396 430 def redirect(self):
397 431 pass
398 432
399 433 def _client(self):
400 434 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
401 435 return 'remote:ssh:' + client
402 436
403 437 class sshserver(object):
404 438 def __init__(self, ui, repo):
405 439 self._ui = ui
406 440 self._repo = repo
407 441 self._fin = ui.fin
408 442 self._fout = ui.fout
409 443
410 444 hook.redirect(True)
411 445 ui.fout = repo.ui.fout = ui.ferr
412 446
413 447 # Prevent insertion/deletion of CRs
414 448 util.setbinary(self._fin)
415 449 util.setbinary(self._fout)
416 450
417 451 self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
418 452
419 453 def serve_forever(self):
420 454 while self.serve_one():
421 455 pass
422 456 sys.exit(0)
423 457
424 458 def serve_one(self):
425 459 cmd = self._fin.readline()[:-1]
426 460 if cmd and wireproto.commands.commandavailable(cmd, self._proto):
427 461 rsp = wireproto.dispatch(self._repo, self._proto, cmd)
428 462
429 463 if isinstance(rsp, bytes):
430 464 _sshv1respondbytes(self._fout, rsp)
431 465 elif isinstance(rsp, wireproto.streamres):
432 466 _sshv1respondstream(self._fout, rsp)
433 467 elif isinstance(rsp, wireproto.streamres_legacy):
434 468 _sshv1respondstream(self._fout, rsp)
435 469 elif isinstance(rsp, wireproto.pushres):
436 470 _sshv1respondbytes(self._fout, b'')
437 471 _sshv1respondbytes(self._fout, bytes(rsp.res))
438 472 elif isinstance(rsp, wireproto.pusherr):
439 473 _sshv1respondbytes(self._fout, rsp.res)
440 474 elif isinstance(rsp, wireproto.ooberror):
441 475 _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
442 476 else:
443 477 raise error.ProgrammingError('unhandled response type from '
444 478 'wire protocol command: %s' % rsp)
445 479 elif cmd:
446 480 _sshv1respondbytes(self._fout, b'')
447 481 return cmd != ''
General Comments 0
You need to be logged in to leave comments. Login now