diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -336,6 +336,24 @@ def _handlehttperror(e, req, cmd): return '' +def _sshv1respondbytes(fout, value): + """Send a bytes response for protocol version 1.""" + fout.write('%d\n' % len(value)) + fout.write(value) + fout.flush() + +def _sshv1respondstream(fout, source): + write = fout.write + for chunk in source.gen: + write(chunk) + fout.flush() + +def _sshv1respondooberror(fout, ferr, rsp): + ferr.write(b'%s\n-\n' % rsp) + ferr.flush() + fout.write(b'\n') + fout.flush() + class sshserver(baseprotocolhandler): def __init__(self, ui, repo): self._ui = ui @@ -376,7 +394,7 @@ class sshserver(baseprotocolhandler): return [data[k] for k in keys] def getfile(self, fpout): - self._sendresponse('') + _sshv1respondbytes(self._fout, b'') count = int(self._fin.readline()) while count: fpout.write(self._fin.read(count)) @@ -385,51 +403,34 @@ class sshserver(baseprotocolhandler): def redirect(self): pass - def _sendresponse(self, v): - self._fout.write("%d\n" % len(v)) - self._fout.write(v) - self._fout.flush() - - def _sendstream(self, source): - write = self._fout.write - for chunk in source.gen: - write(chunk) - self._fout.flush() - - def _sendpushresponse(self, rsp): - self._sendresponse('') - self._sendresponse(str(rsp.res)) - - def _sendpusherror(self, rsp): - self._sendresponse(rsp.res) - - def _sendooberror(self, rsp): - self._ui.ferr.write('%s\n-\n' % rsp.message) - self._ui.ferr.flush() - self._fout.write('\n') - self._fout.flush() - def serve_forever(self): while self.serve_one(): pass sys.exit(0) - _handlers = { - str: _sendresponse, - wireproto.streamres: _sendstream, - wireproto.streamres_legacy: _sendstream, - wireproto.pushres: _sendpushresponse, - wireproto.pusherr: _sendpusherror, - wireproto.ooberror: _sendooberror, - } - def serve_one(self): cmd = self._fin.readline()[:-1] if cmd and wireproto.commands.commandavailable(cmd, self): rsp = wireproto.dispatch(self._repo, self, cmd) - self._handlers[rsp.__class__](self, rsp) + + if isinstance(rsp, bytes): + _sshv1respondbytes(self._fout, rsp) + elif isinstance(rsp, wireproto.streamres): + _sshv1respondstream(self._fout, rsp) + elif isinstance(rsp, wireproto.streamres_legacy): + _sshv1respondstream(self._fout, rsp) + elif isinstance(rsp, wireproto.pushres): + _sshv1respondbytes(self._fout, b'') + _sshv1respondbytes(self._fout, bytes(rsp.res)) + elif isinstance(rsp, wireproto.pusherr): + _sshv1respondbytes(self._fout, rsp.res) + elif isinstance(rsp, wireproto.ooberror): + _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message) + else: + raise error.ProgrammingError('unhandled response type from ' + 'wire protocol command: %s' % rsp) elif cmd: - self._sendresponse("") + _sshv1respondbytes(self._fout, b'') return cmd != '' def _client(self): diff --git a/tests/sshprotoext.py b/tests/sshprotoext.py --- a/tests/sshprotoext.py +++ b/tests/sshprotoext.py @@ -45,11 +45,11 @@ class prehelloserver(wireprotoserver.ssh l = self._fin.readline() assert l == b'hello\n' # Respond to unknown commands with an empty reply. - self._sendresponse(b'') + wireprotoserver._sshv1respondbytes(self._fout, b'') l = self._fin.readline() assert l == b'between\n' rsp = wireproto.dispatch(self._repo, self, b'between') - self._handlers[rsp.__class__](self, rsp) + wireprotoserver._sshv1respondbytes(self._fout, rsp) super(prehelloserver, self).serve_forever()