##// END OF EJS Templates
wireproto: separate commands tables for version 1 and 2 commands...
Gregory Szorc -
r37311:45b39c69 default
parent child Browse files
Show More
@@ -175,6 +175,7 b' def uisetup(ui):'
175 175
176 176 # ... and wrap some existing ones
177 177 wireproto.commands['heads'].func = proto.heads
178 # TODO also wrap wireproto.commandsv2 once heads is implemented there.
178 179
179 180 extensions.wrapfunction(webcommands, 'decodepath', overrides.decodepath)
180 181
@@ -502,7 +502,11 b' def getdispatchrepo(repo, proto, command'
502 502
503 503 def dispatch(repo, proto, command):
504 504 repo = getdispatchrepo(repo, proto, command)
505 func, spec = commands[command]
505
506 transportversion = wireprototypes.TRANSPORTS[proto.name]['version']
507 commandtable = commandsv2 if transportversion == 2 else commands
508 func, spec = commandtable[command]
509
506 510 args = proto.getargs(spec)
507 511 return func(repo, proto, *args)
508 512
@@ -679,8 +683,12 b" POLICY_ALL = 'all'"
679 683 POLICY_V1_ONLY = 'v1-only'
680 684 POLICY_V2_ONLY = 'v2-only'
681 685
686 # For version 1 transports.
682 687 commands = commanddict()
683 688
689 # For version 2 transports.
690 commandsv2 = commanddict()
691
684 692 def wireprotocommand(name, args='', transportpolicy=POLICY_ALL,
685 693 permission='push'):
686 694 """Decorator to declare a wire protocol command.
@@ -702,12 +710,15 b" def wireprotocommand(name, args='', tran"
702 710 """
703 711 if transportpolicy == POLICY_ALL:
704 712 transports = set(wireprototypes.TRANSPORTS)
713 transportversions = {1, 2}
705 714 elif transportpolicy == POLICY_V1_ONLY:
706 715 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
707 716 if v['version'] == 1}
717 transportversions = {1}
708 718 elif transportpolicy == POLICY_V2_ONLY:
709 719 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
710 720 if v['version'] == 2}
721 transportversions = {2}
711 722 else:
712 723 raise error.ProgrammingError('invalid transport policy value: %s' %
713 724 transportpolicy)
@@ -724,8 +735,21 b" def wireprotocommand(name, args='', tran"
724 735 permission)
725 736
726 737 def register(func):
727 commands[name] = commandentry(func, args=args, transports=transports,
728 permission=permission)
738 if 1 in transportversions:
739 if name in commands:
740 raise error.ProgrammingError('%s command already registered '
741 'for version 1' % name)
742 commands[name] = commandentry(func, args=args,
743 transports=transports,
744 permission=permission)
745 if 2 in transportversions:
746 if name in commandsv2:
747 raise error.ProgrammingError('%s command already registered '
748 'for version 2' % name)
749 commandsv2[name] = commandentry(func, args=args,
750 transports=transports,
751 permission=permission)
752
729 753 return func
730 754 return register
731 755
@@ -335,7 +335,7 b' def _handlehttpv2request(rctx, req, res,'
335 335 # extension.
336 336 extracommands = {'multirequest'}
337 337
338 if command not in wireproto.commands and command not in extracommands:
338 if command not in wireproto.commandsv2 and command not in extracommands:
339 339 res.status = b'404 Not Found'
340 340 res.headers[b'Content-Type'] = b'text/plain'
341 341 res.setbodybytes(_('unknown wire protocol command: %s\n') % command)
@@ -346,7 +346,7 b' def _handlehttpv2request(rctx, req, res,'
346 346
347 347 proto = httpv2protocolhandler(req, ui)
348 348
349 if (not wireproto.commands.commandavailable(command, proto)
349 if (not wireproto.commandsv2.commandavailable(command, proto)
350 350 and command not in extracommands):
351 351 res.status = b'404 Not Found'
352 352 res.headers[b'Content-Type'] = b'text/plain'
@@ -502,7 +502,7 b' def _httpv2runcommand(ui, repo, req, res'
502 502 proto = httpv2protocolhandler(req, ui, args=command['args'])
503 503
504 504 if reqcommand == b'multirequest':
505 if not wireproto.commands.commandavailable(command['command'], proto):
505 if not wireproto.commandsv2.commandavailable(command['command'], proto):
506 506 # TODO proper error mechanism
507 507 res.status = b'200 OK'
508 508 res.headers[b'Content-Type'] = b'text/plain'
@@ -512,7 +512,7 b' def _httpv2runcommand(ui, repo, req, res'
512 512
513 513 # TODO don't use assert here, since it may be elided by -O.
514 514 assert authedperm in (b'ro', b'rw')
515 wirecommand = wireproto.commands[command['command']]
515 wirecommand = wireproto.commandsv2[command['command']]
516 516 assert wirecommand.permission in ('push', 'pull')
517 517
518 518 if authedperm == b'ro' and wirecommand.permission != 'pull':
@@ -13,6 +13,8 b' stringio = util.stringio'
13 13 class proto(object):
14 14 def __init__(self, args):
15 15 self.args = args
16 self.name = 'dummyproto'
17
16 18 def getargs(self, spec):
17 19 args = self.args
18 20 args.setdefault(b'*', {})
@@ -22,6 +24,11 b' class proto(object):'
22 24 def checkperm(self, perm):
23 25 pass
24 26
27 wireprototypes.TRANSPORTS['dummyproto'] = {
28 'transport': 'dummy',
29 'version': 1,
30 }
31
25 32 class clientpeer(wireproto.wirepeer):
26 33 def __init__(self, serverrepo, ui):
27 34 self.serverrepo = serverrepo
General Comments 0
You need to be logged in to leave comments. Login now