diff --git a/mercurial/wireproto.py b/mercurial/wireproto.py --- a/mercurial/wireproto.py +++ b/mercurial/wireproto.py @@ -251,32 +251,20 @@ class commanddict(dict): return True -# Constants specifying which transports a wire protocol command should be -# available on. For use with @wireprotocommand. -POLICY_V1_ONLY = 'v1-only' -POLICY_V2_ONLY = 'v2-only' - # For version 1 transports. commands = commanddict() # For version 2 transports. commandsv2 = commanddict() -def wireprotocommand(name, args=None, transportpolicy=POLICY_V1_ONLY, - permission='push'): +def wireprotocommand(name, args=None, permission='push'): """Decorator to declare a wire protocol command. ``name`` is the name of the wire protocol command being provided. ``args`` defines the named arguments accepted by the command. It is - ideally a dict mapping argument names to their types. For backwards - compatibility, it can be a space-delimited list of argument names. For - version 1 transports, ``*`` denotes a special value that says to accept - all named arguments. - - ``transportpolicy`` is a POLICY_* constant denoting which transports - this wire protocol command should be exposed to. By default, commands - are exposed to all wire protocol transports. + a space-delimited list of argument names. ``*`` denotes a special value + that says to accept all named arguments. ``permission`` defines the permission type needed to run this command. Can be ``push`` or ``pull``. These roughly map to read-write and read-only, @@ -284,17 +272,8 @@ def wireprotocommand(name, args=None, tr because otherwise commands not declaring their permissions could modify a repository that is supposed to be read-only. """ - if transportpolicy == POLICY_V1_ONLY: - transports = {k for k, v in wireprototypes.TRANSPORTS.items() - if v['version'] == 1} - transportversion = 1 - elif transportpolicy == POLICY_V2_ONLY: - transports = {k for k, v in wireprototypes.TRANSPORTS.items() - if v['version'] == 2} - transportversion = 2 - else: - raise error.ProgrammingError('invalid transport policy value: %s' % - transportpolicy) + transports = {k for k, v in wireprototypes.TRANSPORTS.items() + if v['version'] == 1} # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to # SSHv2. @@ -307,40 +286,20 @@ def wireprotocommand(name, args=None, tr 'got %s; expected "push" or "pull"' % permission) - if transportversion == 1: - if args is None: - args = '' + if args is None: + args = '' - if not isinstance(args, bytes): - raise error.ProgrammingError('arguments for version 1 commands ' - 'must be declared as bytes') - elif transportversion == 2: - if args is None: - args = {} - - if not isinstance(args, dict): - raise error.ProgrammingError('arguments for version 2 commands ' - 'must be declared as dicts') + if not isinstance(args, bytes): + raise error.ProgrammingError('arguments for version 1 commands ' + 'must be declared as bytes') def register(func): - if transportversion == 1: - if name in commands: - raise error.ProgrammingError('%s command already registered ' - 'for version 1' % name) - commands[name] = commandentry(func, args=args, - transports=transports, - permission=permission) - elif transportversion == 2: - if name in commandsv2: - raise error.ProgrammingError('%s command already registered ' - 'for version 2' % name) - - commandsv2[name] = commandentry(func, args=args, - transports=transports, - permission=permission) - else: - raise error.ProgrammingError('unhandled transport version: %d' % - transportversion) + if name in commands: + raise error.ProgrammingError('%s command already registered ' + 'for version 1' % name) + commands[name] = commandentry(func, args=args, + transports=transports, + permission=permission) return func return register diff --git a/mercurial/wireprotov2server.py b/mercurial/wireprotov2server.py --- a/mercurial/wireprotov2server.py +++ b/mercurial/wireprotov2server.py @@ -405,10 +405,43 @@ def _capabilitiesv2(repo, proto): return proto.addcapabilities(repo, caps) -def wireprotocommand(*args, **kwargs): +def wireprotocommand(name, args=None, permission='push'): + """Decorator to declare a wire protocol command. + + ``name`` is the name of the wire protocol command being provided. + + ``args`` is a dict of argument names to example values. + + ``permission`` defines the permission type needed to run this command. + Can be ``push`` or ``pull``. These roughly map to read-write and read-only, + respectively. Default is to assume command requires ``push`` permissions + because otherwise commands not declaring their permissions could modify + a repository that is supposed to be read-only. + """ + transports = {k for k, v in wireprototypes.TRANSPORTS.items() + if v['version'] == 2} + + if permission not in ('push', 'pull'): + raise error.ProgrammingError('invalid wire protocol permission; ' + 'got %s; expected "push" or "pull"' % + permission) + + if args is None: + args = {} + + if not isinstance(args, dict): + raise error.ProgrammingError('arguments for version 2 commands ' + 'must be declared as dicts') + def register(func): - return wireproto.wireprotocommand( - *args, transportpolicy=wireproto.POLICY_V2_ONLY, **kwargs)(func) + if name in wireproto.commandsv2: + raise error.ProgrammingError('%s command already registered ' + 'for version 2' % name) + + wireproto.commandsv2[name] = wireproto.commandentry( + func, args=args, transports=transports, permission=permission) + + return func return register diff --git a/tests/wireprotohelpers.sh b/tests/wireprotohelpers.sh --- a/tests/wireprotohelpers.sh +++ b/tests/wireprotohelpers.sh @@ -16,6 +16,7 @@ sendhttpv2peerhandshake() { cat > dummycommands.py << EOF from mercurial import ( wireprototypes, + wireprotov2server, wireproto, ) @@ -23,8 +24,7 @@ from mercurial import ( def customreadonlyv1(repo, proto): return wireprototypes.bytesresponse(b'customreadonly bytes response') -@wireproto.wireprotocommand('customreadonly', permission='pull', - transportpolicy=wireproto.POLICY_V2_ONLY) +@wireprotov2server.wireprotocommand('customreadonly', permission='pull') def customreadonlyv2(repo, proto): return wireprototypes.cborresponse(b'customreadonly bytes response') @@ -32,8 +32,7 @@ def customreadonlyv2(repo, proto): def customreadwrite(repo, proto): return wireprototypes.bytesresponse(b'customreadwrite bytes response') -@wireproto.wireprotocommand('customreadwrite', permission='push', - transportpolicy=wireproto.POLICY_V2_ONLY) +@wireprotov2server.wireprotocommand('customreadwrite', permission='push') def customreadwritev2(repo, proto): return wireprototypes.cborresponse(b'customreadwrite bytes response') EOF