diff --git a/mercurial/hgweb/protocol.py b/mercurial/hgweb/protocol.py --- a/mercurial/hgweb/protocol.py +++ b/mercurial/hgweb/protocol.py @@ -48,13 +48,20 @@ class webproto(object): self.response = s def sendstream(self, source): self.req.respond(HTTP_OK, HGTYPE) - for chunk in source: - self.req.write(str(chunk)) - def sendpushresponse(self, ret): + for chunk in source.gen: + self.req.write(chunk) + def sendpushresponse(self, rsp): val = sys.stdout.getvalue() sys.stdout, sys.stderr = self.oldio self.req.respond(HTTP_OK, HGTYPE) - self.response = '%d\n%s' % (ret, val) + self.response = '%d\n%s' % (rsp.res, val) + + handlers = { + str: sendresponse, + wireproto.streamres: sendstream, + wireproto.pushres: sendpushresponse, + } + def _client(self): return 'remote:%s:%s:%s' % ( self.req.env.get('wsgi.url_scheme') or 'http', @@ -66,5 +73,6 @@ def iscmd(cmd): def call(repo, req, cmd): p = webproto(req) - wireproto.dispatch(repo, p, cmd) - yield p.response + rsp = wireproto.dispatch(repo, p, cmd) + webproto.handlers[rsp.__class__](p, rsp) + return [p.response] diff --git a/mercurial/sshserver.py b/mercurial/sshserver.py --- a/mercurial/sshserver.py +++ b/mercurial/sshserver.py @@ -72,13 +72,13 @@ class sshserver(object): self.fout.flush() def sendstream(self, source): - for chunk in source: + for chunk in source.gen: self.fout.write(chunk) self.fout.flush() - def sendpushresponse(self, ret): + def sendpushresponse(self, rsp): self.sendresponse('') - self.sendresponse(str(ret)) + self.sendresponse(str(rsp.res)) def serve_forever(self): try: @@ -89,10 +89,17 @@ class sshserver(object): self.lock.release() sys.exit(0) + handlers = { + str: sendresponse, + wireproto.streamres: sendstream, + wireproto.pushres: sendpushresponse, + } + def serve_one(self): cmd = self.fin.readline()[:-1] if cmd and cmd in wireproto.commands: - wireproto.dispatch(self.repo, self, cmd) + rsp = wireproto.dispatch(self.repo, self, cmd) + self.handlers[rsp.__class__](self, rsp) elif cmd: impl = getattr(self, 'do_' + cmd, None) if impl: diff --git a/mercurial/wireproto.py b/mercurial/wireproto.py --- a/mercurial/wireproto.py +++ b/mercurial/wireproto.py @@ -133,12 +133,18 @@ class wirerepository(repo.repository): # server side +class streamres(object): + def __init__(self, gen): + self.gen = gen + +class pushres(object): + def __init__(self, res): + self.res = res + def dispatch(repo, proto, command): func, spec = commands[command] args = proto.getargs(spec) - r = func(repo, proto, *args) - if r != None: - proto.sendresponse(r) + return func(repo, proto, *args) def between(repo, proto, pairs): pairs = [decodelist(p, '-') for p in pairs.split(" ")] @@ -173,13 +179,13 @@ def capabilities(repo, proto): def changegroup(repo, proto, roots): nodes = decodelist(roots) cg = repo.changegroup(nodes, 'serve') - proto.sendstream(proto.groupchunks(cg)) + return streamres(proto.groupchunks(cg)) def changegroupsubset(repo, proto, bases, heads): bases = decodelist(bases) heads = decodelist(heads) cg = repo.changegroupsubset(bases, heads, 'serve') - proto.sendstream(proto.groupchunks(cg)) + return streamres(proto.groupchunks(cg)) def heads(repo, proto): h = repo.heads() @@ -215,7 +221,7 @@ def pushkey(repo, proto, namespace, key, return '%s\n' % int(r) def stream(repo, proto): - proto.sendstream(streamclone.stream_out(repo)) + return streamres(streamclone.stream_out(repo)) def unbundle(repo, proto, heads): their_heads = decodelist(heads) @@ -259,7 +265,7 @@ def unbundle(repo, proto, heads): sys.stderr.write("abort: %s\n" % inst) finally: lock.release() - proto.sendpushresponse(r) + return pushres(r) finally: fp.close()