##// END OF EJS Templates
changegroup: don't accept odd chunk headers
Mads Kiilerich -
r13458:9f2c407c stable
parent child Browse files
Show More
@@ -1,199 +1,205 b''
1 # changegroup.py - Mercurial changegroup manipulation functions
1 # changegroup.py - Mercurial changegroup manipulation functions
2 #
2 #
3 # Copyright 2006 Matt Mackall <mpm@selenic.com>
3 # Copyright 2006 Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from i18n import _
8 from i18n import _
9 import util
9 import util
10 import struct, os, bz2, zlib, tempfile
10 import struct, os, bz2, zlib, tempfile
11
11
12 def readexactly(stream, n):
12 def readexactly(stream, n):
13 '''read n bytes from stream.read and abort if less was available'''
13 '''read n bytes from stream.read and abort if less was available'''
14 s = stream.read(n)
14 s = stream.read(n)
15 if len(s) < n:
15 if len(s) < n:
16 raise util.Abort(_("stream ended unexpectedly"
16 raise util.Abort(_("stream ended unexpectedly"
17 " (got %d bytes, expected %d)")
17 " (got %d bytes, expected %d)")
18 % (len(s), n))
18 % (len(s), n))
19 return s
19 return s
20
20
21 def getchunk(stream):
21 def getchunk(stream):
22 """return the next chunk from stream as a string"""
22 """return the next chunk from stream as a string"""
23 d = readexactly(stream, 4)
23 d = readexactly(stream, 4)
24 l = struct.unpack(">l", d)[0]
24 l = struct.unpack(">l", d)[0]
25 if l <= 4:
25 if l <= 4:
26 if l:
27 raise util.Abort(_("invalid chunk length %d") % l)
26 return ""
28 return ""
27 return readexactly(stream, l - 4)
29 return readexactly(stream, l - 4)
28
30
29 def chunkheader(length):
31 def chunkheader(length):
30 """return a changegroup chunk header (string)"""
32 """return a changegroup chunk header (string)"""
31 return struct.pack(">l", length + 4)
33 return struct.pack(">l", length + 4)
32
34
33 def closechunk():
35 def closechunk():
34 """return a changegroup chunk header (string) for a zero-length chunk"""
36 """return a changegroup chunk header (string) for a zero-length chunk"""
35 return struct.pack(">l", 0)
37 return struct.pack(">l", 0)
36
38
37 class nocompress(object):
39 class nocompress(object):
38 def compress(self, x):
40 def compress(self, x):
39 return x
41 return x
40 def flush(self):
42 def flush(self):
41 return ""
43 return ""
42
44
43 bundletypes = {
45 bundletypes = {
44 "": ("", nocompress),
46 "": ("", nocompress),
45 "HG10UN": ("HG10UN", nocompress),
47 "HG10UN": ("HG10UN", nocompress),
46 "HG10BZ": ("HG10", lambda: bz2.BZ2Compressor()),
48 "HG10BZ": ("HG10", lambda: bz2.BZ2Compressor()),
47 "HG10GZ": ("HG10GZ", lambda: zlib.compressobj()),
49 "HG10GZ": ("HG10GZ", lambda: zlib.compressobj()),
48 }
50 }
49
51
50 def collector(cl, mmfs, files):
52 def collector(cl, mmfs, files):
51 # Gather information about changeset nodes going out in a bundle.
53 # Gather information about changeset nodes going out in a bundle.
52 # We want to gather manifests needed and filelogs affected.
54 # We want to gather manifests needed and filelogs affected.
53 def collect(node):
55 def collect(node):
54 c = cl.read(node)
56 c = cl.read(node)
55 files.update(c[3])
57 files.update(c[3])
56 mmfs.setdefault(c[0], node)
58 mmfs.setdefault(c[0], node)
57 return collect
59 return collect
58
60
59 # hgweb uses this list to communicate its preferred type
61 # hgweb uses this list to communicate its preferred type
60 bundlepriority = ['HG10GZ', 'HG10BZ', 'HG10UN']
62 bundlepriority = ['HG10GZ', 'HG10BZ', 'HG10UN']
61
63
62 def writebundle(cg, filename, bundletype):
64 def writebundle(cg, filename, bundletype):
63 """Write a bundle file and return its filename.
65 """Write a bundle file and return its filename.
64
66
65 Existing files will not be overwritten.
67 Existing files will not be overwritten.
66 If no filename is specified, a temporary file is created.
68 If no filename is specified, a temporary file is created.
67 bz2 compression can be turned off.
69 bz2 compression can be turned off.
68 The bundle file will be deleted in case of errors.
70 The bundle file will be deleted in case of errors.
69 """
71 """
70
72
71 fh = None
73 fh = None
72 cleanup = None
74 cleanup = None
73 try:
75 try:
74 if filename:
76 if filename:
75 fh = open(filename, "wb")
77 fh = open(filename, "wb")
76 else:
78 else:
77 fd, filename = tempfile.mkstemp(prefix="hg-bundle-", suffix=".hg")
79 fd, filename = tempfile.mkstemp(prefix="hg-bundle-", suffix=".hg")
78 fh = os.fdopen(fd, "wb")
80 fh = os.fdopen(fd, "wb")
79 cleanup = filename
81 cleanup = filename
80
82
81 header, compressor = bundletypes[bundletype]
83 header, compressor = bundletypes[bundletype]
82 fh.write(header)
84 fh.write(header)
83 z = compressor()
85 z = compressor()
84
86
85 # parse the changegroup data, otherwise we will block
87 # parse the changegroup data, otherwise we will block
86 # in case of sshrepo because we don't know the end of the stream
88 # in case of sshrepo because we don't know the end of the stream
87
89
88 # an empty chunkgroup is the end of the changegroup
90 # an empty chunkgroup is the end of the changegroup
89 # a changegroup has at least 2 chunkgroups (changelog and manifest).
91 # a changegroup has at least 2 chunkgroups (changelog and manifest).
90 # after that, an empty chunkgroup is the end of the changegroup
92 # after that, an empty chunkgroup is the end of the changegroup
91 empty = False
93 empty = False
92 count = 0
94 count = 0
93 while not empty or count <= 2:
95 while not empty or count <= 2:
94 empty = True
96 empty = True
95 count += 1
97 count += 1
96 while 1:
98 while 1:
97 chunk = getchunk(cg)
99 chunk = getchunk(cg)
98 if not chunk:
100 if not chunk:
99 break
101 break
100 empty = False
102 empty = False
101 fh.write(z.compress(chunkheader(len(chunk))))
103 fh.write(z.compress(chunkheader(len(chunk))))
102 pos = 0
104 pos = 0
103 while pos < len(chunk):
105 while pos < len(chunk):
104 next = pos + 2**20
106 next = pos + 2**20
105 fh.write(z.compress(chunk[pos:next]))
107 fh.write(z.compress(chunk[pos:next]))
106 pos = next
108 pos = next
107 fh.write(z.compress(closechunk()))
109 fh.write(z.compress(closechunk()))
108 fh.write(z.flush())
110 fh.write(z.flush())
109 cleanup = None
111 cleanup = None
110 return filename
112 return filename
111 finally:
113 finally:
112 if fh is not None:
114 if fh is not None:
113 fh.close()
115 fh.close()
114 if cleanup is not None:
116 if cleanup is not None:
115 os.unlink(cleanup)
117 os.unlink(cleanup)
116
118
117 def decompressor(fh, alg):
119 def decompressor(fh, alg):
118 if alg == 'UN':
120 if alg == 'UN':
119 return fh
121 return fh
120 elif alg == 'GZ':
122 elif alg == 'GZ':
121 def generator(f):
123 def generator(f):
122 zd = zlib.decompressobj()
124 zd = zlib.decompressobj()
123 for chunk in f:
125 for chunk in f:
124 yield zd.decompress(chunk)
126 yield zd.decompress(chunk)
125 elif alg == 'BZ':
127 elif alg == 'BZ':
126 def generator(f):
128 def generator(f):
127 zd = bz2.BZ2Decompressor()
129 zd = bz2.BZ2Decompressor()
128 zd.decompress("BZ")
130 zd.decompress("BZ")
129 for chunk in util.filechunkiter(f, 4096):
131 for chunk in util.filechunkiter(f, 4096):
130 yield zd.decompress(chunk)
132 yield zd.decompress(chunk)
131 else:
133 else:
132 raise util.Abort("unknown bundle compression '%s'" % alg)
134 raise util.Abort("unknown bundle compression '%s'" % alg)
133 return util.chunkbuffer(generator(fh))
135 return util.chunkbuffer(generator(fh))
134
136
135 class unbundle10(object):
137 class unbundle10(object):
136 def __init__(self, fh, alg):
138 def __init__(self, fh, alg):
137 self._stream = decompressor(fh, alg)
139 self._stream = decompressor(fh, alg)
138 self._type = alg
140 self._type = alg
139 self.callback = None
141 self.callback = None
140 def compressed(self):
142 def compressed(self):
141 return self._type != 'UN'
143 return self._type != 'UN'
142 def read(self, l):
144 def read(self, l):
143 return self._stream.read(l)
145 return self._stream.read(l)
144 def seek(self, pos):
146 def seek(self, pos):
145 return self._stream.seek(pos)
147 return self._stream.seek(pos)
146 def tell(self):
148 def tell(self):
147 return self._stream.tell()
149 return self._stream.tell()
148 def close(self):
150 def close(self):
149 return self._stream.close()
151 return self._stream.close()
150
152
151 def chunklength(self):
153 def chunklength(self):
152 d = readexactly(self._stream, 4)
154 d = readexactly(stream, 4)
153 l = max(0, struct.unpack(">l", d)[0] - 4)
155 l = struct.unpack(">l", d)[0]
154 if l and self.callback:
156 if l <= 4:
157 if l:
158 raise util.Abort(_("invalid chunk length %d") % l)
159 return 0
160 if self.callback:
155 self.callback()
161 self.callback()
156 return l
162 return l - 4
157
163
158 def chunk(self):
164 def chunk(self):
159 """return the next chunk from changegroup 'source' as a string"""
165 """return the next chunk from changegroup 'source' as a string"""
160 l = self.chunklength()
166 l = self.chunklength()
161 return readexactly(self._stream, l)
167 return readexactly(self._stream, l)
162
168
163 def parsechunk(self):
169 def parsechunk(self):
164 l = self.chunklength()
170 l = self.chunklength()
165 if not l:
171 if not l:
166 return {}
172 return {}
167 h = readexactly(self._stream, 80)
173 h = readexactly(self._stream, 80)
168 node, p1, p2, cs = struct.unpack("20s20s20s20s", h)
174 node, p1, p2, cs = struct.unpack("20s20s20s20s", h)
169 data = readexactly(self._stream, l - 80)
175 data = readexactly(self._stream, l - 80)
170 return dict(node=node, p1=p1, p2=p2, cs=cs, data=data)
176 return dict(node=node, p1=p1, p2=p2, cs=cs, data=data)
171
177
172 class headerlessfixup(object):
178 class headerlessfixup(object):
173 def __init__(self, fh, h):
179 def __init__(self, fh, h):
174 self._h = h
180 self._h = h
175 self._fh = fh
181 self._fh = fh
176 def read(self, n):
182 def read(self, n):
177 if self._h:
183 if self._h:
178 d, self._h = self._h[:n], self._h[n:]
184 d, self._h = self._h[:n], self._h[n:]
179 if len(d) < n:
185 if len(d) < n:
180 d += readexactly(self._fh, n - len(d))
186 d += readexactly(self._fh, n - len(d))
181 return d
187 return d
182 return readexactly(self._fh, n)
188 return readexactly(self._fh, n)
183
189
184 def readbundle(fh, fname):
190 def readbundle(fh, fname):
185 header = readexactly(fh, 6)
191 header = readexactly(fh, 6)
186
192
187 if not fname:
193 if not fname:
188 fname = "stream"
194 fname = "stream"
189 if not header.startswith('HG') and header.startswith('\0'):
195 if not header.startswith('HG') and header.startswith('\0'):
190 fh = headerlessfixup(fh, header)
196 fh = headerlessfixup(fh, header)
191 header = "HG10UN"
197 header = "HG10UN"
192
198
193 magic, version, alg = header[0:2], header[2:4], header[4:6]
199 magic, version, alg = header[0:2], header[2:4], header[4:6]
194
200
195 if magic != 'HG':
201 if magic != 'HG':
196 raise util.Abort(_('%s: not a Mercurial bundle') % fname)
202 raise util.Abort(_('%s: not a Mercurial bundle') % fname)
197 if version != '10':
203 if version != '10':
198 raise util.Abort(_('%s: unknown bundle version %s') % (fname, version))
204 raise util.Abort(_('%s: unknown bundle version %s') % (fname, version))
199 return unbundle10(fh, alg)
205 return unbundle10(fh, alg)
General Comments 0
You need to be logged in to leave comments. Login now