diff --git a/mercurial/changegroup.py b/mercurial/changegroup.py --- a/mercurial/changegroup.py +++ b/mercurial/changegroup.py @@ -108,22 +108,34 @@ def writebundle(cg, filename, bundletype if cleanup is not None: os.unlink(cleanup) -def readbundle(fh, fname): - header = fh.read(6) - if not header.startswith("HG"): - raise util.Abort(_("%s: not a Mercurial bundle file") % fname) - elif not header.startswith("HG10"): - raise util.Abort(_("%s: unknown bundle version") % fname) - - if header == "HG10BZ": +def unbundle(header, fh): + if header == 'HG10UN': + return fh + elif not header.startswith('HG'): + # old client with uncompressed bundle + def generator(f): + yield header + for chunk in f: + yield chunk + elif header == 'HG10GZ': + def generator(f): + zd = zlib.decompressobj() + for chunk in f: + yield zd.decompress(chunk) + elif header == 'HG10BZ': def generator(f): zd = bz2.BZ2Decompressor() zd.decompress("BZ") for chunk in util.filechunkiter(f, 4096): yield zd.decompress(chunk) - return util.chunkbuffer(generator(fh)) - elif header == "HG10UN": - return fh + return util.chunkbuffer(generator(fh)) - raise util.Abort(_("%s: unknown bundle compression type") - % fname) +def readbundle(fh, fname): + header = fh.read(6) + if not header.startswith('HG'): + raise util.Abort(_('%s: not a Mercurial bundle file') % fname) + if not header.startswith('HG10'): + raise util.Abort(_('%s: unknown bundle version') % fname) + elif header not in bundletypes: + raise util.Abort(_('%s: unknown bundle compression type') % fname) + return unbundle(header, fh) diff --git a/mercurial/hgweb/protocol.py b/mercurial/hgweb/protocol.py --- a/mercurial/hgweb/protocol.py +++ b/mercurial/hgweb/protocol.py @@ -9,6 +9,7 @@ import cStringIO, zlib, bz2, tempfile, e from mercurial import util, streamclone from mercurial.i18n import gettext as _ from mercurial.node import * +from mercurial import changegroup as changegroupmod from common import HTTP_OK, HTTP_NOT_FOUND, HTTP_SERVER_ERROR # __all__ is populated with the allowed commands. Be sure to add to it if @@ -167,36 +168,11 @@ def unbundle(web, req): fp.seek(0) header = fp.read(6) - if not header.startswith("HG"): - # old client with uncompressed bundle - def generator(f): - yield header - for chunk in f: - yield chunk - elif not header.startswith("HG10"): - req.write("0\n") - req.write(_("unknown bundle version\n")) - return - elif header == "HG10GZ": - def generator(f): - zd = zlib.decompressobj() - for chunk in f: - yield zd.decompress(chunk) - elif header == "HG10BZ": - def generator(f): - zd = bz2.BZ2Decompressor() - zd.decompress("BZ") - for chunk in f: - yield zd.decompress(chunk) - elif header == "HG10UN": - def generator(f): - for chunk in f: - yield chunk - else: - req.write("0\n") - req.write(_("unknown bundle compression type\n")) - return - gen = generator(util.filechunkiter(fp, 4096)) + if header.startswith('HG') and not header.startswith('HG10'): + raise ValueError('unknown bundle version') + elif header not in changegroupmod.bundletypes: + raise ValueError('unknown bundle compression type') + gen = changegroupmod.unbundle(header, fp) # send addchangegroup output to client @@ -207,8 +183,7 @@ def unbundle(web, req): url = 'remote:%s:%s' % (proto, req.env.get('REMOTE_HOST', '')) try: - ret = web.repo.addchangegroup( - util.chunkbuffer(gen), 'serve', url) + ret = web.repo.addchangegroup(gen, 'serve', url) except util.Abort, inst: sys.stdout.write("abort: %s\n" % inst) ret = 0 @@ -219,6 +194,9 @@ def unbundle(web, req): req.write(val) finally: del lock + except ValueError, inst: + req.write('0\n') + req.write(str(inst) + '\n') except (OSError, IOError), inst: req.write('0\n') filename = getattr(inst, 'filename', '')