##// END OF EJS Templates
py3: use int instead of pycompat.long...
Gregory Szorc -
r49787:176f1a0d default
parent child Browse files
Show More
@@ -1,410 +1,410 b''
1 # monotone.py - monotone support for the convert extension
1 # monotone.py - monotone support for the convert extension
2 #
2 #
3 # Copyright 2008, 2009 Mikkel Fahnoe Jorgensen <mikkel@dvide.com> and
3 # Copyright 2008, 2009 Mikkel Fahnoe Jorgensen <mikkel@dvide.com> and
4 # others
4 # others
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 import os
9 import os
10 import re
10 import re
11
11
12 from mercurial.i18n import _
12 from mercurial.i18n import _
13 from mercurial.pycompat import open
13 from mercurial.pycompat import open
14 from mercurial import (
14 from mercurial import (
15 error,
15 error,
16 pycompat,
16 pycompat,
17 )
17 )
18 from mercurial.utils import dateutil
18 from mercurial.utils import dateutil
19
19
20 from . import common
20 from . import common
21
21
22
22
23 class monotone_source(common.converter_source, common.commandline):
23 class monotone_source(common.converter_source, common.commandline):
24 def __init__(self, ui, repotype, path=None, revs=None):
24 def __init__(self, ui, repotype, path=None, revs=None):
25 common.converter_source.__init__(self, ui, repotype, path, revs)
25 common.converter_source.__init__(self, ui, repotype, path, revs)
26 if revs and len(revs) > 1:
26 if revs and len(revs) > 1:
27 raise error.Abort(
27 raise error.Abort(
28 _(
28 _(
29 b'monotone source does not support specifying '
29 b'monotone source does not support specifying '
30 b'multiple revs'
30 b'multiple revs'
31 )
31 )
32 )
32 )
33 common.commandline.__init__(self, ui, b'mtn')
33 common.commandline.__init__(self, ui, b'mtn')
34
34
35 self.ui = ui
35 self.ui = ui
36 self.path = path
36 self.path = path
37 self.automatestdio = False
37 self.automatestdio = False
38 self.revs = revs
38 self.revs = revs
39
39
40 norepo = common.NoRepo(
40 norepo = common.NoRepo(
41 _(b"%s does not look like a monotone repository") % path
41 _(b"%s does not look like a monotone repository") % path
42 )
42 )
43 if not os.path.exists(os.path.join(path, b'_MTN')):
43 if not os.path.exists(os.path.join(path, b'_MTN')):
44 # Could be a monotone repository (SQLite db file)
44 # Could be a monotone repository (SQLite db file)
45 try:
45 try:
46 f = open(path, b'rb')
46 f = open(path, b'rb')
47 header = f.read(16)
47 header = f.read(16)
48 f.close()
48 f.close()
49 except IOError:
49 except IOError:
50 header = b''
50 header = b''
51 if header != b'SQLite format 3\x00':
51 if header != b'SQLite format 3\x00':
52 raise norepo
52 raise norepo
53
53
54 # regular expressions for parsing monotone output
54 # regular expressions for parsing monotone output
55 space = br'\s*'
55 space = br'\s*'
56 name = br'\s+"((?:\\"|[^"])*)"\s*'
56 name = br'\s+"((?:\\"|[^"])*)"\s*'
57 value = name
57 value = name
58 revision = br'\s+\[(\w+)\]\s*'
58 revision = br'\s+\[(\w+)\]\s*'
59 lines = br'(?:.|\n)+'
59 lines = br'(?:.|\n)+'
60
60
61 self.dir_re = re.compile(space + b"dir" + name)
61 self.dir_re = re.compile(space + b"dir" + name)
62 self.file_re = re.compile(
62 self.file_re = re.compile(
63 space + b"file" + name + b"content" + revision
63 space + b"file" + name + b"content" + revision
64 )
64 )
65 self.add_file_re = re.compile(
65 self.add_file_re = re.compile(
66 space + b"add_file" + name + b"content" + revision
66 space + b"add_file" + name + b"content" + revision
67 )
67 )
68 self.patch_re = re.compile(
68 self.patch_re = re.compile(
69 space + b"patch" + name + b"from" + revision + b"to" + revision
69 space + b"patch" + name + b"from" + revision + b"to" + revision
70 )
70 )
71 self.rename_re = re.compile(space + b"rename" + name + b"to" + name)
71 self.rename_re = re.compile(space + b"rename" + name + b"to" + name)
72 self.delete_re = re.compile(space + b"delete" + name)
72 self.delete_re = re.compile(space + b"delete" + name)
73 self.tag_re = re.compile(space + b"tag" + name + b"revision" + revision)
73 self.tag_re = re.compile(space + b"tag" + name + b"revision" + revision)
74 self.cert_re = re.compile(
74 self.cert_re = re.compile(
75 lines + space + b"name" + name + b"value" + value
75 lines + space + b"name" + name + b"value" + value
76 )
76 )
77
77
78 attr = space + b"file" + lines + space + b"attr" + space
78 attr = space + b"file" + lines + space + b"attr" + space
79 self.attr_execute_re = re.compile(
79 self.attr_execute_re = re.compile(
80 attr + b'"mtn:execute"' + space + b'"true"'
80 attr + b'"mtn:execute"' + space + b'"true"'
81 )
81 )
82
82
83 # cached data
83 # cached data
84 self.manifest_rev = None
84 self.manifest_rev = None
85 self.manifest = None
85 self.manifest = None
86 self.files = None
86 self.files = None
87 self.dirs = None
87 self.dirs = None
88
88
89 common.checktool(b'mtn', abort=False)
89 common.checktool(b'mtn', abort=False)
90
90
91 def mtnrun(self, *args, **kwargs):
91 def mtnrun(self, *args, **kwargs):
92 if self.automatestdio:
92 if self.automatestdio:
93 return self.mtnrunstdio(*args, **kwargs)
93 return self.mtnrunstdio(*args, **kwargs)
94 else:
94 else:
95 return self.mtnrunsingle(*args, **kwargs)
95 return self.mtnrunsingle(*args, **kwargs)
96
96
97 def mtnrunsingle(self, *args, **kwargs):
97 def mtnrunsingle(self, *args, **kwargs):
98 kwargs['d'] = self.path
98 kwargs['d'] = self.path
99 return self.run0(b'automate', *args, **kwargs)
99 return self.run0(b'automate', *args, **kwargs)
100
100
101 def mtnrunstdio(self, *args, **kwargs):
101 def mtnrunstdio(self, *args, **kwargs):
102 # Prepare the command in automate stdio format
102 # Prepare the command in automate stdio format
103 kwargs = pycompat.byteskwargs(kwargs)
103 kwargs = pycompat.byteskwargs(kwargs)
104 command = []
104 command = []
105 for k, v in kwargs.items():
105 for k, v in kwargs.items():
106 command.append(b"%d:%s" % (len(k), k))
106 command.append(b"%d:%s" % (len(k), k))
107 if v:
107 if v:
108 command.append(b"%d:%s" % (len(v), v))
108 command.append(b"%d:%s" % (len(v), v))
109 if command:
109 if command:
110 command.insert(0, b'o')
110 command.insert(0, b'o')
111 command.append(b'e')
111 command.append(b'e')
112
112
113 command.append(b'l')
113 command.append(b'l')
114 for arg in args:
114 for arg in args:
115 command.append(b"%d:%s" % (len(arg), arg))
115 command.append(b"%d:%s" % (len(arg), arg))
116 command.append(b'e')
116 command.append(b'e')
117 command = b''.join(command)
117 command = b''.join(command)
118
118
119 self.ui.debug(b"mtn: sending '%s'\n" % command)
119 self.ui.debug(b"mtn: sending '%s'\n" % command)
120 self.mtnwritefp.write(command)
120 self.mtnwritefp.write(command)
121 self.mtnwritefp.flush()
121 self.mtnwritefp.flush()
122
122
123 return self.mtnstdioreadcommandoutput(command)
123 return self.mtnstdioreadcommandoutput(command)
124
124
125 def mtnstdioreadpacket(self):
125 def mtnstdioreadpacket(self):
126 read = None
126 read = None
127 commandnbr = b''
127 commandnbr = b''
128 while read != b':':
128 while read != b':':
129 read = self.mtnreadfp.read(1)
129 read = self.mtnreadfp.read(1)
130 if not read:
130 if not read:
131 raise error.Abort(_(b'bad mtn packet - no end of commandnbr'))
131 raise error.Abort(_(b'bad mtn packet - no end of commandnbr'))
132 commandnbr += read
132 commandnbr += read
133 commandnbr = commandnbr[:-1]
133 commandnbr = commandnbr[:-1]
134
134
135 stream = self.mtnreadfp.read(1)
135 stream = self.mtnreadfp.read(1)
136 if stream not in b'mewptl':
136 if stream not in b'mewptl':
137 raise error.Abort(
137 raise error.Abort(
138 _(b'bad mtn packet - bad stream type %s') % stream
138 _(b'bad mtn packet - bad stream type %s') % stream
139 )
139 )
140
140
141 read = self.mtnreadfp.read(1)
141 read = self.mtnreadfp.read(1)
142 if read != b':':
142 if read != b':':
143 raise error.Abort(_(b'bad mtn packet - no divider before size'))
143 raise error.Abort(_(b'bad mtn packet - no divider before size'))
144
144
145 read = None
145 read = None
146 lengthstr = b''
146 lengthstr = b''
147 while read != b':':
147 while read != b':':
148 read = self.mtnreadfp.read(1)
148 read = self.mtnreadfp.read(1)
149 if not read:
149 if not read:
150 raise error.Abort(_(b'bad mtn packet - no end of packet size'))
150 raise error.Abort(_(b'bad mtn packet - no end of packet size'))
151 lengthstr += read
151 lengthstr += read
152 try:
152 try:
153 length = pycompat.long(lengthstr[:-1])
153 length = int(lengthstr[:-1])
154 except TypeError:
154 except TypeError:
155 raise error.Abort(
155 raise error.Abort(
156 _(b'bad mtn packet - bad packet size %s') % lengthstr
156 _(b'bad mtn packet - bad packet size %s') % lengthstr
157 )
157 )
158
158
159 read = self.mtnreadfp.read(length)
159 read = self.mtnreadfp.read(length)
160 if len(read) != length:
160 if len(read) != length:
161 raise error.Abort(
161 raise error.Abort(
162 _(
162 _(
163 b"bad mtn packet - unable to read full packet "
163 b"bad mtn packet - unable to read full packet "
164 b"read %s of %s"
164 b"read %s of %s"
165 )
165 )
166 % (len(read), length)
166 % (len(read), length)
167 )
167 )
168
168
169 return (commandnbr, stream, length, read)
169 return (commandnbr, stream, length, read)
170
170
171 def mtnstdioreadcommandoutput(self, command):
171 def mtnstdioreadcommandoutput(self, command):
172 retval = []
172 retval = []
173 while True:
173 while True:
174 commandnbr, stream, length, output = self.mtnstdioreadpacket()
174 commandnbr, stream, length, output = self.mtnstdioreadpacket()
175 self.ui.debug(
175 self.ui.debug(
176 b'mtn: read packet %s:%s:%d\n' % (commandnbr, stream, length)
176 b'mtn: read packet %s:%s:%d\n' % (commandnbr, stream, length)
177 )
177 )
178
178
179 if stream == b'l':
179 if stream == b'l':
180 # End of command
180 # End of command
181 if output != b'0':
181 if output != b'0':
182 raise error.Abort(
182 raise error.Abort(
183 _(b"mtn command '%s' returned %s") % (command, output)
183 _(b"mtn command '%s' returned %s") % (command, output)
184 )
184 )
185 break
185 break
186 elif stream in b'ew':
186 elif stream in b'ew':
187 # Error, warning output
187 # Error, warning output
188 self.ui.warn(_(b'%s error:\n') % self.command)
188 self.ui.warn(_(b'%s error:\n') % self.command)
189 self.ui.warn(output)
189 self.ui.warn(output)
190 elif stream == b'p':
190 elif stream == b'p':
191 # Progress messages
191 # Progress messages
192 self.ui.debug(b'mtn: ' + output)
192 self.ui.debug(b'mtn: ' + output)
193 elif stream == b'm':
193 elif stream == b'm':
194 # Main stream - command output
194 # Main stream - command output
195 retval.append(output)
195 retval.append(output)
196
196
197 return b''.join(retval)
197 return b''.join(retval)
198
198
199 def mtnloadmanifest(self, rev):
199 def mtnloadmanifest(self, rev):
200 if self.manifest_rev == rev:
200 if self.manifest_rev == rev:
201 return
201 return
202 self.manifest = self.mtnrun(b"get_manifest_of", rev).split(b"\n\n")
202 self.manifest = self.mtnrun(b"get_manifest_of", rev).split(b"\n\n")
203 self.manifest_rev = rev
203 self.manifest_rev = rev
204 self.files = {}
204 self.files = {}
205 self.dirs = {}
205 self.dirs = {}
206
206
207 for e in self.manifest:
207 for e in self.manifest:
208 m = self.file_re.match(e)
208 m = self.file_re.match(e)
209 if m:
209 if m:
210 attr = b""
210 attr = b""
211 name = m.group(1)
211 name = m.group(1)
212 node = m.group(2)
212 node = m.group(2)
213 if self.attr_execute_re.match(e):
213 if self.attr_execute_re.match(e):
214 attr += b"x"
214 attr += b"x"
215 self.files[name] = (node, attr)
215 self.files[name] = (node, attr)
216 m = self.dir_re.match(e)
216 m = self.dir_re.match(e)
217 if m:
217 if m:
218 self.dirs[m.group(1)] = True
218 self.dirs[m.group(1)] = True
219
219
220 def mtnisfile(self, name, rev):
220 def mtnisfile(self, name, rev):
221 # a non-file could be a directory or a deleted or renamed file
221 # a non-file could be a directory or a deleted or renamed file
222 self.mtnloadmanifest(rev)
222 self.mtnloadmanifest(rev)
223 return name in self.files
223 return name in self.files
224
224
225 def mtnisdir(self, name, rev):
225 def mtnisdir(self, name, rev):
226 self.mtnloadmanifest(rev)
226 self.mtnloadmanifest(rev)
227 return name in self.dirs
227 return name in self.dirs
228
228
229 def mtngetcerts(self, rev):
229 def mtngetcerts(self, rev):
230 certs = {
230 certs = {
231 b"author": b"<missing>",
231 b"author": b"<missing>",
232 b"date": b"<missing>",
232 b"date": b"<missing>",
233 b"changelog": b"<missing>",
233 b"changelog": b"<missing>",
234 b"branch": b"<missing>",
234 b"branch": b"<missing>",
235 }
235 }
236 certlist = self.mtnrun(b"certs", rev)
236 certlist = self.mtnrun(b"certs", rev)
237 # mtn < 0.45:
237 # mtn < 0.45:
238 # key "test@selenic.com"
238 # key "test@selenic.com"
239 # mtn >= 0.45:
239 # mtn >= 0.45:
240 # key [ff58a7ffb771907c4ff68995eada1c4da068d328]
240 # key [ff58a7ffb771907c4ff68995eada1c4da068d328]
241 certlist = re.split(br'\n\n {6}key ["\[]', certlist)
241 certlist = re.split(br'\n\n {6}key ["\[]', certlist)
242 for e in certlist:
242 for e in certlist:
243 m = self.cert_re.match(e)
243 m = self.cert_re.match(e)
244 if m:
244 if m:
245 name, value = m.groups()
245 name, value = m.groups()
246 value = value.replace(br'\"', b'"')
246 value = value.replace(br'\"', b'"')
247 value = value.replace(br'\\', b'\\')
247 value = value.replace(br'\\', b'\\')
248 certs[name] = value
248 certs[name] = value
249 # Monotone may have subsecond dates: 2005-02-05T09:39:12.364306
249 # Monotone may have subsecond dates: 2005-02-05T09:39:12.364306
250 # and all times are stored in UTC
250 # and all times are stored in UTC
251 certs[b"date"] = certs[b"date"].split(b'.')[0] + b" UTC"
251 certs[b"date"] = certs[b"date"].split(b'.')[0] + b" UTC"
252 return certs
252 return certs
253
253
254 # implement the converter_source interface:
254 # implement the converter_source interface:
255
255
256 def getheads(self):
256 def getheads(self):
257 if not self.revs:
257 if not self.revs:
258 return self.mtnrun(b"leaves").splitlines()
258 return self.mtnrun(b"leaves").splitlines()
259 else:
259 else:
260 return self.revs
260 return self.revs
261
261
262 def getchanges(self, rev, full):
262 def getchanges(self, rev, full):
263 if full:
263 if full:
264 raise error.Abort(
264 raise error.Abort(
265 _(b"convert from monotone does not support --full")
265 _(b"convert from monotone does not support --full")
266 )
266 )
267 revision = self.mtnrun(b"get_revision", rev).split(b"\n\n")
267 revision = self.mtnrun(b"get_revision", rev).split(b"\n\n")
268 files = {}
268 files = {}
269 ignoremove = {}
269 ignoremove = {}
270 renameddirs = []
270 renameddirs = []
271 copies = {}
271 copies = {}
272 for e in revision:
272 for e in revision:
273 m = self.add_file_re.match(e)
273 m = self.add_file_re.match(e)
274 if m:
274 if m:
275 files[m.group(1)] = rev
275 files[m.group(1)] = rev
276 ignoremove[m.group(1)] = rev
276 ignoremove[m.group(1)] = rev
277 m = self.patch_re.match(e)
277 m = self.patch_re.match(e)
278 if m:
278 if m:
279 files[m.group(1)] = rev
279 files[m.group(1)] = rev
280 # Delete/rename is handled later when the convert engine
280 # Delete/rename is handled later when the convert engine
281 # discovers an IOError exception from getfile,
281 # discovers an IOError exception from getfile,
282 # but only if we add the "from" file to the list of changes.
282 # but only if we add the "from" file to the list of changes.
283 m = self.delete_re.match(e)
283 m = self.delete_re.match(e)
284 if m:
284 if m:
285 files[m.group(1)] = rev
285 files[m.group(1)] = rev
286 m = self.rename_re.match(e)
286 m = self.rename_re.match(e)
287 if m:
287 if m:
288 toname = m.group(2)
288 toname = m.group(2)
289 fromname = m.group(1)
289 fromname = m.group(1)
290 if self.mtnisfile(toname, rev):
290 if self.mtnisfile(toname, rev):
291 ignoremove[toname] = 1
291 ignoremove[toname] = 1
292 copies[toname] = fromname
292 copies[toname] = fromname
293 files[toname] = rev
293 files[toname] = rev
294 files[fromname] = rev
294 files[fromname] = rev
295 elif self.mtnisdir(toname, rev):
295 elif self.mtnisdir(toname, rev):
296 renameddirs.append((fromname, toname))
296 renameddirs.append((fromname, toname))
297
297
298 # Directory renames can be handled only once we have recorded
298 # Directory renames can be handled only once we have recorded
299 # all new files
299 # all new files
300 for fromdir, todir in renameddirs:
300 for fromdir, todir in renameddirs:
301 renamed = {}
301 renamed = {}
302 for tofile in self.files:
302 for tofile in self.files:
303 if tofile in ignoremove:
303 if tofile in ignoremove:
304 continue
304 continue
305 if tofile.startswith(todir + b'/'):
305 if tofile.startswith(todir + b'/'):
306 renamed[tofile] = fromdir + tofile[len(todir) :]
306 renamed[tofile] = fromdir + tofile[len(todir) :]
307 # Avoid chained moves like:
307 # Avoid chained moves like:
308 # d1(/a) => d3/d1(/a)
308 # d1(/a) => d3/d1(/a)
309 # d2 => d3
309 # d2 => d3
310 ignoremove[tofile] = 1
310 ignoremove[tofile] = 1
311 for tofile, fromfile in renamed.items():
311 for tofile, fromfile in renamed.items():
312 self.ui.debug(
312 self.ui.debug(
313 b"copying file in renamed directory from '%s' to '%s'"
313 b"copying file in renamed directory from '%s' to '%s'"
314 % (fromfile, tofile),
314 % (fromfile, tofile),
315 b'\n',
315 b'\n',
316 )
316 )
317 files[tofile] = rev
317 files[tofile] = rev
318 copies[tofile] = fromfile
318 copies[tofile] = fromfile
319 for fromfile in renamed.values():
319 for fromfile in renamed.values():
320 files[fromfile] = rev
320 files[fromfile] = rev
321
321
322 return (files.items(), copies, set())
322 return (files.items(), copies, set())
323
323
324 def getfile(self, name, rev):
324 def getfile(self, name, rev):
325 if not self.mtnisfile(name, rev):
325 if not self.mtnisfile(name, rev):
326 return None, None
326 return None, None
327 try:
327 try:
328 data = self.mtnrun(b"get_file_of", name, r=rev)
328 data = self.mtnrun(b"get_file_of", name, r=rev)
329 except Exception:
329 except Exception:
330 return None, None
330 return None, None
331 self.mtnloadmanifest(rev)
331 self.mtnloadmanifest(rev)
332 node, attr = self.files.get(name, (None, b""))
332 node, attr = self.files.get(name, (None, b""))
333 return data, attr
333 return data, attr
334
334
335 def getcommit(self, rev):
335 def getcommit(self, rev):
336 extra = {}
336 extra = {}
337 certs = self.mtngetcerts(rev)
337 certs = self.mtngetcerts(rev)
338 if certs.get(b'suspend') == certs[b"branch"]:
338 if certs.get(b'suspend') == certs[b"branch"]:
339 extra[b'close'] = b'1'
339 extra[b'close'] = b'1'
340 dateformat = b"%Y-%m-%dT%H:%M:%S"
340 dateformat = b"%Y-%m-%dT%H:%M:%S"
341 return common.commit(
341 return common.commit(
342 author=certs[b"author"],
342 author=certs[b"author"],
343 date=dateutil.datestr(dateutil.strdate(certs[b"date"], dateformat)),
343 date=dateutil.datestr(dateutil.strdate(certs[b"date"], dateformat)),
344 desc=certs[b"changelog"],
344 desc=certs[b"changelog"],
345 rev=rev,
345 rev=rev,
346 parents=self.mtnrun(b"parents", rev).splitlines(),
346 parents=self.mtnrun(b"parents", rev).splitlines(),
347 branch=certs[b"branch"],
347 branch=certs[b"branch"],
348 extra=extra,
348 extra=extra,
349 )
349 )
350
350
351 def gettags(self):
351 def gettags(self):
352 tags = {}
352 tags = {}
353 for e in self.mtnrun(b"tags").split(b"\n\n"):
353 for e in self.mtnrun(b"tags").split(b"\n\n"):
354 m = self.tag_re.match(e)
354 m = self.tag_re.match(e)
355 if m:
355 if m:
356 tags[m.group(1)] = m.group(2)
356 tags[m.group(1)] = m.group(2)
357 return tags
357 return tags
358
358
359 def getchangedfiles(self, rev, i):
359 def getchangedfiles(self, rev, i):
360 # This function is only needed to support --filemap
360 # This function is only needed to support --filemap
361 # ... and we don't support that
361 # ... and we don't support that
362 raise NotImplementedError
362 raise NotImplementedError
363
363
364 def before(self):
364 def before(self):
365 # Check if we have a new enough version to use automate stdio
365 # Check if we have a new enough version to use automate stdio
366 try:
366 try:
367 versionstr = self.mtnrunsingle(b"interface_version")
367 versionstr = self.mtnrunsingle(b"interface_version")
368 version = float(versionstr)
368 version = float(versionstr)
369 except Exception:
369 except Exception:
370 raise error.Abort(
370 raise error.Abort(
371 _(b"unable to determine mtn automate interface version")
371 _(b"unable to determine mtn automate interface version")
372 )
372 )
373
373
374 if version >= 12.0:
374 if version >= 12.0:
375 self.automatestdio = True
375 self.automatestdio = True
376 self.ui.debug(
376 self.ui.debug(
377 b"mtn automate version %f - using automate stdio\n" % version
377 b"mtn automate version %f - using automate stdio\n" % version
378 )
378 )
379
379
380 # launch the long-running automate stdio process
380 # launch the long-running automate stdio process
381 self.mtnwritefp, self.mtnreadfp = self._run2(
381 self.mtnwritefp, self.mtnreadfp = self._run2(
382 b'automate', b'stdio', b'-d', self.path
382 b'automate', b'stdio', b'-d', self.path
383 )
383 )
384 # read the headers
384 # read the headers
385 read = self.mtnreadfp.readline()
385 read = self.mtnreadfp.readline()
386 if read != b'format-version: 2\n':
386 if read != b'format-version: 2\n':
387 raise error.Abort(
387 raise error.Abort(
388 _(b'mtn automate stdio header unexpected: %s') % read
388 _(b'mtn automate stdio header unexpected: %s') % read
389 )
389 )
390 while read != b'\n':
390 while read != b'\n':
391 read = self.mtnreadfp.readline()
391 read = self.mtnreadfp.readline()
392 if not read:
392 if not read:
393 raise error.Abort(
393 raise error.Abort(
394 _(
394 _(
395 b"failed to reach end of mtn automate "
395 b"failed to reach end of mtn automate "
396 b"stdio headers"
396 b"stdio headers"
397 )
397 )
398 )
398 )
399 else:
399 else:
400 self.ui.debug(
400 self.ui.debug(
401 b"mtn automate version %s - not using automate stdio "
401 b"mtn automate version %s - not using automate stdio "
402 b"(automate >= 12.0 - mtn >= 0.46 is needed)\n" % version
402 b"(automate >= 12.0 - mtn >= 0.46 is needed)\n" % version
403 )
403 )
404
404
405 def after(self):
405 def after(self):
406 if self.automatestdio:
406 if self.automatestdio:
407 self.mtnwritefp.close()
407 self.mtnwritefp.close()
408 self.mtnwritefp = None
408 self.mtnwritefp = None
409 self.mtnreadfp.close()
409 self.mtnreadfp.close()
410 self.mtnreadfp = None
410 self.mtnreadfp = None
@@ -1,543 +1,543 b''
1 # shallowutil.py -- remotefilelog utilities
1 # shallowutil.py -- remotefilelog utilities
2 #
2 #
3 # Copyright 2014 Facebook, Inc.
3 # Copyright 2014 Facebook, Inc.
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 import collections
8 import collections
9 import errno
9 import errno
10 import os
10 import os
11 import stat
11 import stat
12 import struct
12 import struct
13 import tempfile
13 import tempfile
14
14
15 from mercurial.i18n import _
15 from mercurial.i18n import _
16 from mercurial.pycompat import open
16 from mercurial.pycompat import open
17 from mercurial.node import hex
17 from mercurial.node import hex
18 from mercurial import (
18 from mercurial import (
19 error,
19 error,
20 pycompat,
20 pycompat,
21 revlog,
21 revlog,
22 util,
22 util,
23 )
23 )
24 from mercurial.utils import (
24 from mercurial.utils import (
25 hashutil,
25 hashutil,
26 storageutil,
26 storageutil,
27 stringutil,
27 stringutil,
28 )
28 )
29 from . import constants
29 from . import constants
30
30
31 if not pycompat.iswindows:
31 if not pycompat.iswindows:
32 import grp
32 import grp
33
33
34
34
35 def isenabled(repo):
35 def isenabled(repo):
36 """returns whether the repository is remotefilelog enabled or not"""
36 """returns whether the repository is remotefilelog enabled or not"""
37 return constants.SHALLOWREPO_REQUIREMENT in repo.requirements
37 return constants.SHALLOWREPO_REQUIREMENT in repo.requirements
38
38
39
39
40 def getcachekey(reponame, file, id):
40 def getcachekey(reponame, file, id):
41 pathhash = hex(hashutil.sha1(file).digest())
41 pathhash = hex(hashutil.sha1(file).digest())
42 return os.path.join(reponame, pathhash[:2], pathhash[2:], id)
42 return os.path.join(reponame, pathhash[:2], pathhash[2:], id)
43
43
44
44
45 def getlocalkey(file, id):
45 def getlocalkey(file, id):
46 pathhash = hex(hashutil.sha1(file).digest())
46 pathhash = hex(hashutil.sha1(file).digest())
47 return os.path.join(pathhash, id)
47 return os.path.join(pathhash, id)
48
48
49
49
50 def getcachepath(ui, allowempty=False):
50 def getcachepath(ui, allowempty=False):
51 cachepath = ui.config(b"remotefilelog", b"cachepath")
51 cachepath = ui.config(b"remotefilelog", b"cachepath")
52 if not cachepath:
52 if not cachepath:
53 if allowempty:
53 if allowempty:
54 return None
54 return None
55 else:
55 else:
56 raise error.Abort(
56 raise error.Abort(
57 _(b"could not find config option remotefilelog.cachepath")
57 _(b"could not find config option remotefilelog.cachepath")
58 )
58 )
59 return util.expandpath(cachepath)
59 return util.expandpath(cachepath)
60
60
61
61
62 def getcachepackpath(repo, category):
62 def getcachepackpath(repo, category):
63 cachepath = getcachepath(repo.ui)
63 cachepath = getcachepath(repo.ui)
64 if category != constants.FILEPACK_CATEGORY:
64 if category != constants.FILEPACK_CATEGORY:
65 return os.path.join(cachepath, repo.name, b'packs', category)
65 return os.path.join(cachepath, repo.name, b'packs', category)
66 else:
66 else:
67 return os.path.join(cachepath, repo.name, b'packs')
67 return os.path.join(cachepath, repo.name, b'packs')
68
68
69
69
70 def getlocalpackpath(base, category):
70 def getlocalpackpath(base, category):
71 return os.path.join(base, b'packs', category)
71 return os.path.join(base, b'packs', category)
72
72
73
73
74 def createrevlogtext(text, copyfrom=None, copyrev=None):
74 def createrevlogtext(text, copyfrom=None, copyrev=None):
75 """returns a string that matches the revlog contents in a
75 """returns a string that matches the revlog contents in a
76 traditional revlog
76 traditional revlog
77 """
77 """
78 meta = {}
78 meta = {}
79 if copyfrom or text.startswith(b'\1\n'):
79 if copyfrom or text.startswith(b'\1\n'):
80 if copyfrom:
80 if copyfrom:
81 meta[b'copy'] = copyfrom
81 meta[b'copy'] = copyfrom
82 meta[b'copyrev'] = copyrev
82 meta[b'copyrev'] = copyrev
83 text = storageutil.packmeta(meta, text)
83 text = storageutil.packmeta(meta, text)
84
84
85 return text
85 return text
86
86
87
87
88 def parsemeta(text):
88 def parsemeta(text):
89 """parse mercurial filelog metadata"""
89 """parse mercurial filelog metadata"""
90 meta, size = storageutil.parsemeta(text)
90 meta, size = storageutil.parsemeta(text)
91 if text.startswith(b'\1\n'):
91 if text.startswith(b'\1\n'):
92 s = text.index(b'\1\n', 2)
92 s = text.index(b'\1\n', 2)
93 text = text[s + 2 :]
93 text = text[s + 2 :]
94 return meta or {}, text
94 return meta or {}, text
95
95
96
96
97 def sumdicts(*dicts):
97 def sumdicts(*dicts):
98 """Adds all the values of *dicts together into one dictionary. This assumes
98 """Adds all the values of *dicts together into one dictionary. This assumes
99 the values in *dicts are all summable.
99 the values in *dicts are all summable.
100
100
101 e.g. [{'a': 4', 'b': 2}, {'b': 3, 'c': 1}] -> {'a': 4, 'b': 5, 'c': 1}
101 e.g. [{'a': 4', 'b': 2}, {'b': 3, 'c': 1}] -> {'a': 4, 'b': 5, 'c': 1}
102 """
102 """
103 result = collections.defaultdict(lambda: 0)
103 result = collections.defaultdict(lambda: 0)
104 for dict in dicts:
104 for dict in dicts:
105 for k, v in dict.items():
105 for k, v in dict.items():
106 result[k] += v
106 result[k] += v
107 return result
107 return result
108
108
109
109
110 def prefixkeys(dict, prefix):
110 def prefixkeys(dict, prefix):
111 """Returns ``dict`` with ``prefix`` prepended to all its keys."""
111 """Returns ``dict`` with ``prefix`` prepended to all its keys."""
112 result = {}
112 result = {}
113 for k, v in dict.items():
113 for k, v in dict.items():
114 result[prefix + k] = v
114 result[prefix + k] = v
115 return result
115 return result
116
116
117
117
118 def reportpackmetrics(ui, prefix, *stores):
118 def reportpackmetrics(ui, prefix, *stores):
119 dicts = [s.getmetrics() for s in stores]
119 dicts = [s.getmetrics() for s in stores]
120 dict = prefixkeys(sumdicts(*dicts), prefix + b'_')
120 dict = prefixkeys(sumdicts(*dicts), prefix + b'_')
121 ui.log(prefix + b"_packsizes", b"\n", **pycompat.strkwargs(dict))
121 ui.log(prefix + b"_packsizes", b"\n", **pycompat.strkwargs(dict))
122
122
123
123
124 def _parsepackmeta(metabuf):
124 def _parsepackmeta(metabuf):
125 """parse datapack meta, bytes (<metadata-list>) -> dict
125 """parse datapack meta, bytes (<metadata-list>) -> dict
126
126
127 The dict contains raw content - both keys and values are strings.
127 The dict contains raw content - both keys and values are strings.
128 Upper-level business may want to convert some of them to other types like
128 Upper-level business may want to convert some of them to other types like
129 integers, on their own.
129 integers, on their own.
130
130
131 raise ValueError if the data is corrupted
131 raise ValueError if the data is corrupted
132 """
132 """
133 metadict = {}
133 metadict = {}
134 offset = 0
134 offset = 0
135 buflen = len(metabuf)
135 buflen = len(metabuf)
136 while buflen - offset >= 3:
136 while buflen - offset >= 3:
137 key = metabuf[offset : offset + 1]
137 key = metabuf[offset : offset + 1]
138 offset += 1
138 offset += 1
139 metalen = struct.unpack_from(b'!H', metabuf, offset)[0]
139 metalen = struct.unpack_from(b'!H', metabuf, offset)[0]
140 offset += 2
140 offset += 2
141 if offset + metalen > buflen:
141 if offset + metalen > buflen:
142 raise ValueError(b'corrupted metadata: incomplete buffer')
142 raise ValueError(b'corrupted metadata: incomplete buffer')
143 value = metabuf[offset : offset + metalen]
143 value = metabuf[offset : offset + metalen]
144 metadict[key] = value
144 metadict[key] = value
145 offset += metalen
145 offset += metalen
146 if offset != buflen:
146 if offset != buflen:
147 raise ValueError(b'corrupted metadata: redundant data')
147 raise ValueError(b'corrupted metadata: redundant data')
148 return metadict
148 return metadict
149
149
150
150
151 def _buildpackmeta(metadict):
151 def _buildpackmeta(metadict):
152 """reverse of _parsepackmeta, dict -> bytes (<metadata-list>)
152 """reverse of _parsepackmeta, dict -> bytes (<metadata-list>)
153
153
154 The dict contains raw content - both keys and values are strings.
154 The dict contains raw content - both keys and values are strings.
155 Upper-level business may want to serialize some of other types (like
155 Upper-level business may want to serialize some of other types (like
156 integers) to strings before calling this function.
156 integers) to strings before calling this function.
157
157
158 raise ProgrammingError when metadata key is illegal, or ValueError if
158 raise ProgrammingError when metadata key is illegal, or ValueError if
159 length limit is exceeded
159 length limit is exceeded
160 """
160 """
161 metabuf = b''
161 metabuf = b''
162 for k, v in sorted((metadict or {}).items()):
162 for k, v in sorted((metadict or {}).items()):
163 if len(k) != 1:
163 if len(k) != 1:
164 raise error.ProgrammingError(b'packmeta: illegal key: %s' % k)
164 raise error.ProgrammingError(b'packmeta: illegal key: %s' % k)
165 if len(v) > 0xFFFE:
165 if len(v) > 0xFFFE:
166 raise ValueError(
166 raise ValueError(
167 b'metadata value is too long: 0x%x > 0xfffe' % len(v)
167 b'metadata value is too long: 0x%x > 0xfffe' % len(v)
168 )
168 )
169 metabuf += k
169 metabuf += k
170 metabuf += struct.pack(b'!H', len(v))
170 metabuf += struct.pack(b'!H', len(v))
171 metabuf += v
171 metabuf += v
172 # len(metabuf) is guaranteed representable in 4 bytes, because there are
172 # len(metabuf) is guaranteed representable in 4 bytes, because there are
173 # only 256 keys, and for each value, len(value) <= 0xfffe.
173 # only 256 keys, and for each value, len(value) <= 0xfffe.
174 return metabuf
174 return metabuf
175
175
176
176
177 _metaitemtypes = {
177 _metaitemtypes = {
178 constants.METAKEYFLAG: (int, pycompat.long),
178 constants.METAKEYFLAG: (int, int),
179 constants.METAKEYSIZE: (int, pycompat.long),
179 constants.METAKEYSIZE: (int, int),
180 }
180 }
181
181
182
182
183 def buildpackmeta(metadict):
183 def buildpackmeta(metadict):
184 """like _buildpackmeta, but typechecks metadict and normalize it.
184 """like _buildpackmeta, but typechecks metadict and normalize it.
185
185
186 This means, METAKEYSIZE and METAKEYSIZE should have integers as values,
186 This means, METAKEYSIZE and METAKEYSIZE should have integers as values,
187 and METAKEYFLAG will be dropped if its value is 0.
187 and METAKEYFLAG will be dropped if its value is 0.
188 """
188 """
189 newmeta = {}
189 newmeta = {}
190 for k, v in (metadict or {}).items():
190 for k, v in (metadict or {}).items():
191 expectedtype = _metaitemtypes.get(k, (bytes,))
191 expectedtype = _metaitemtypes.get(k, (bytes,))
192 if not isinstance(v, expectedtype):
192 if not isinstance(v, expectedtype):
193 raise error.ProgrammingError(b'packmeta: wrong type of key %s' % k)
193 raise error.ProgrammingError(b'packmeta: wrong type of key %s' % k)
194 # normalize int to binary buffer
194 # normalize int to binary buffer
195 if int in expectedtype:
195 if int in expectedtype:
196 # optimization: remove flag if it's 0 to save space
196 # optimization: remove flag if it's 0 to save space
197 if k == constants.METAKEYFLAG and v == 0:
197 if k == constants.METAKEYFLAG and v == 0:
198 continue
198 continue
199 v = int2bin(v)
199 v = int2bin(v)
200 newmeta[k] = v
200 newmeta[k] = v
201 return _buildpackmeta(newmeta)
201 return _buildpackmeta(newmeta)
202
202
203
203
204 def parsepackmeta(metabuf):
204 def parsepackmeta(metabuf):
205 """like _parsepackmeta, but convert fields to desired types automatically.
205 """like _parsepackmeta, but convert fields to desired types automatically.
206
206
207 This means, METAKEYFLAG and METAKEYSIZE fields will be converted to
207 This means, METAKEYFLAG and METAKEYSIZE fields will be converted to
208 integers.
208 integers.
209 """
209 """
210 metadict = _parsepackmeta(metabuf)
210 metadict = _parsepackmeta(metabuf)
211 for k, v in metadict.items():
211 for k, v in metadict.items():
212 if k in _metaitemtypes and int in _metaitemtypes[k]:
212 if k in _metaitemtypes and int in _metaitemtypes[k]:
213 metadict[k] = bin2int(v)
213 metadict[k] = bin2int(v)
214 return metadict
214 return metadict
215
215
216
216
217 def int2bin(n):
217 def int2bin(n):
218 """convert a non-negative integer to raw binary buffer"""
218 """convert a non-negative integer to raw binary buffer"""
219 buf = bytearray()
219 buf = bytearray()
220 while n > 0:
220 while n > 0:
221 buf.insert(0, n & 0xFF)
221 buf.insert(0, n & 0xFF)
222 n >>= 8
222 n >>= 8
223 return bytes(buf)
223 return bytes(buf)
224
224
225
225
226 def bin2int(buf):
226 def bin2int(buf):
227 """the reverse of int2bin, convert a binary buffer to an integer"""
227 """the reverse of int2bin, convert a binary buffer to an integer"""
228 x = 0
228 x = 0
229 for b in bytearray(buf):
229 for b in bytearray(buf):
230 x <<= 8
230 x <<= 8
231 x |= b
231 x |= b
232 return x
232 return x
233
233
234
234
235 class BadRemotefilelogHeader(error.StorageError):
235 class BadRemotefilelogHeader(error.StorageError):
236 """Exception raised when parsing a remotefilelog blob header fails."""
236 """Exception raised when parsing a remotefilelog blob header fails."""
237
237
238
238
239 def parsesizeflags(raw):
239 def parsesizeflags(raw):
240 """given a remotefilelog blob, return (headersize, rawtextsize, flags)
240 """given a remotefilelog blob, return (headersize, rawtextsize, flags)
241
241
242 see remotefilelogserver.createfileblob for the format.
242 see remotefilelogserver.createfileblob for the format.
243 raise RuntimeError if the content is illformed.
243 raise RuntimeError if the content is illformed.
244 """
244 """
245 flags = revlog.REVIDX_DEFAULT_FLAGS
245 flags = revlog.REVIDX_DEFAULT_FLAGS
246 size = None
246 size = None
247 try:
247 try:
248 index = raw.index(b'\0')
248 index = raw.index(b'\0')
249 except ValueError:
249 except ValueError:
250 raise BadRemotefilelogHeader(
250 raise BadRemotefilelogHeader(
251 "unexpected remotefilelog header: illegal format"
251 "unexpected remotefilelog header: illegal format"
252 )
252 )
253 header = raw[:index]
253 header = raw[:index]
254 if header.startswith(b'v'):
254 if header.startswith(b'v'):
255 # v1 and above, header starts with 'v'
255 # v1 and above, header starts with 'v'
256 if header.startswith(b'v1\n'):
256 if header.startswith(b'v1\n'):
257 for s in header.split(b'\n'):
257 for s in header.split(b'\n'):
258 if s.startswith(constants.METAKEYSIZE):
258 if s.startswith(constants.METAKEYSIZE):
259 size = int(s[len(constants.METAKEYSIZE) :])
259 size = int(s[len(constants.METAKEYSIZE) :])
260 elif s.startswith(constants.METAKEYFLAG):
260 elif s.startswith(constants.METAKEYFLAG):
261 flags = int(s[len(constants.METAKEYFLAG) :])
261 flags = int(s[len(constants.METAKEYFLAG) :])
262 else:
262 else:
263 raise BadRemotefilelogHeader(
263 raise BadRemotefilelogHeader(
264 b'unsupported remotefilelog header: %s' % header
264 b'unsupported remotefilelog header: %s' % header
265 )
265 )
266 else:
266 else:
267 # v0, str(int(size)) is the header
267 # v0, str(int(size)) is the header
268 size = int(header)
268 size = int(header)
269 if size is None:
269 if size is None:
270 raise BadRemotefilelogHeader(
270 raise BadRemotefilelogHeader(
271 "unexpected remotefilelog header: no size found"
271 "unexpected remotefilelog header: no size found"
272 )
272 )
273 return index + 1, size, flags
273 return index + 1, size, flags
274
274
275
275
276 def buildfileblobheader(size, flags, version=None):
276 def buildfileblobheader(size, flags, version=None):
277 """return the header of a remotefilelog blob.
277 """return the header of a remotefilelog blob.
278
278
279 see remotefilelogserver.createfileblob for the format.
279 see remotefilelogserver.createfileblob for the format.
280 approximately the reverse of parsesizeflags.
280 approximately the reverse of parsesizeflags.
281
281
282 version could be 0 or 1, or None (auto decide).
282 version could be 0 or 1, or None (auto decide).
283 """
283 """
284 # choose v0 if flags is empty, otherwise v1
284 # choose v0 if flags is empty, otherwise v1
285 if version is None:
285 if version is None:
286 version = int(bool(flags))
286 version = int(bool(flags))
287 if version == 1:
287 if version == 1:
288 header = b'v1\n%s%d\n%s%d' % (
288 header = b'v1\n%s%d\n%s%d' % (
289 constants.METAKEYSIZE,
289 constants.METAKEYSIZE,
290 size,
290 size,
291 constants.METAKEYFLAG,
291 constants.METAKEYFLAG,
292 flags,
292 flags,
293 )
293 )
294 elif version == 0:
294 elif version == 0:
295 if flags:
295 if flags:
296 raise error.ProgrammingError(b'fileblob v0 does not support flag')
296 raise error.ProgrammingError(b'fileblob v0 does not support flag')
297 header = b'%d' % size
297 header = b'%d' % size
298 else:
298 else:
299 raise error.ProgrammingError(b'unknown fileblob version %d' % version)
299 raise error.ProgrammingError(b'unknown fileblob version %d' % version)
300 return header
300 return header
301
301
302
302
303 def ancestormap(raw):
303 def ancestormap(raw):
304 offset, size, flags = parsesizeflags(raw)
304 offset, size, flags = parsesizeflags(raw)
305 start = offset + size
305 start = offset + size
306
306
307 mapping = {}
307 mapping = {}
308 while start < len(raw):
308 while start < len(raw):
309 divider = raw.index(b'\0', start + 80)
309 divider = raw.index(b'\0', start + 80)
310
310
311 currentnode = raw[start : (start + 20)]
311 currentnode = raw[start : (start + 20)]
312 p1 = raw[(start + 20) : (start + 40)]
312 p1 = raw[(start + 20) : (start + 40)]
313 p2 = raw[(start + 40) : (start + 60)]
313 p2 = raw[(start + 40) : (start + 60)]
314 linknode = raw[(start + 60) : (start + 80)]
314 linknode = raw[(start + 60) : (start + 80)]
315 copyfrom = raw[(start + 80) : divider]
315 copyfrom = raw[(start + 80) : divider]
316
316
317 mapping[currentnode] = (p1, p2, linknode, copyfrom)
317 mapping[currentnode] = (p1, p2, linknode, copyfrom)
318 start = divider + 1
318 start = divider + 1
319
319
320 return mapping
320 return mapping
321
321
322
322
323 def readfile(path):
323 def readfile(path):
324 f = open(path, b'rb')
324 f = open(path, b'rb')
325 try:
325 try:
326 result = f.read()
326 result = f.read()
327
327
328 # we should never have empty files
328 # we should never have empty files
329 if not result:
329 if not result:
330 os.remove(path)
330 os.remove(path)
331 raise IOError(b"empty file: %s" % path)
331 raise IOError(b"empty file: %s" % path)
332
332
333 return result
333 return result
334 finally:
334 finally:
335 f.close()
335 f.close()
336
336
337
337
338 def unlinkfile(filepath):
338 def unlinkfile(filepath):
339 if pycompat.iswindows:
339 if pycompat.iswindows:
340 # On Windows, os.unlink cannnot delete readonly files
340 # On Windows, os.unlink cannnot delete readonly files
341 os.chmod(filepath, stat.S_IWUSR)
341 os.chmod(filepath, stat.S_IWUSR)
342 os.unlink(filepath)
342 os.unlink(filepath)
343
343
344
344
345 def renamefile(source, destination):
345 def renamefile(source, destination):
346 if pycompat.iswindows:
346 if pycompat.iswindows:
347 # On Windows, os.rename cannot rename readonly files
347 # On Windows, os.rename cannot rename readonly files
348 # and cannot overwrite destination if it exists
348 # and cannot overwrite destination if it exists
349 os.chmod(source, stat.S_IWUSR)
349 os.chmod(source, stat.S_IWUSR)
350 if os.path.isfile(destination):
350 if os.path.isfile(destination):
351 os.chmod(destination, stat.S_IWUSR)
351 os.chmod(destination, stat.S_IWUSR)
352 os.unlink(destination)
352 os.unlink(destination)
353
353
354 os.rename(source, destination)
354 os.rename(source, destination)
355
355
356
356
357 def writefile(path, content, readonly=False):
357 def writefile(path, content, readonly=False):
358 dirname, filename = os.path.split(path)
358 dirname, filename = os.path.split(path)
359 if not os.path.exists(dirname):
359 if not os.path.exists(dirname):
360 try:
360 try:
361 os.makedirs(dirname)
361 os.makedirs(dirname)
362 except OSError as ex:
362 except OSError as ex:
363 if ex.errno != errno.EEXIST:
363 if ex.errno != errno.EEXIST:
364 raise
364 raise
365
365
366 fd, temp = tempfile.mkstemp(prefix=b'.%s-' % filename, dir=dirname)
366 fd, temp = tempfile.mkstemp(prefix=b'.%s-' % filename, dir=dirname)
367 os.close(fd)
367 os.close(fd)
368
368
369 try:
369 try:
370 f = util.posixfile(temp, b'wb')
370 f = util.posixfile(temp, b'wb')
371 f.write(content)
371 f.write(content)
372 f.close()
372 f.close()
373
373
374 if readonly:
374 if readonly:
375 mode = 0o444
375 mode = 0o444
376 else:
376 else:
377 # tempfiles are created with 0o600, so we need to manually set the
377 # tempfiles are created with 0o600, so we need to manually set the
378 # mode.
378 # mode.
379 oldumask = os.umask(0)
379 oldumask = os.umask(0)
380 # there's no way to get the umask without modifying it, so set it
380 # there's no way to get the umask without modifying it, so set it
381 # back
381 # back
382 os.umask(oldumask)
382 os.umask(oldumask)
383 mode = ~oldumask
383 mode = ~oldumask
384
384
385 renamefile(temp, path)
385 renamefile(temp, path)
386 os.chmod(path, mode)
386 os.chmod(path, mode)
387 except Exception:
387 except Exception:
388 try:
388 try:
389 unlinkfile(temp)
389 unlinkfile(temp)
390 except OSError:
390 except OSError:
391 pass
391 pass
392 raise
392 raise
393
393
394
394
395 def sortnodes(nodes, parentfunc):
395 def sortnodes(nodes, parentfunc):
396 """Topologically sorts the nodes, using the parentfunc to find
396 """Topologically sorts the nodes, using the parentfunc to find
397 the parents of nodes."""
397 the parents of nodes."""
398 nodes = set(nodes)
398 nodes = set(nodes)
399 childmap = {}
399 childmap = {}
400 parentmap = {}
400 parentmap = {}
401 roots = []
401 roots = []
402
402
403 # Build a child and parent map
403 # Build a child and parent map
404 for n in nodes:
404 for n in nodes:
405 parents = [p for p in parentfunc(n) if p in nodes]
405 parents = [p for p in parentfunc(n) if p in nodes]
406 parentmap[n] = set(parents)
406 parentmap[n] = set(parents)
407 for p in parents:
407 for p in parents:
408 childmap.setdefault(p, set()).add(n)
408 childmap.setdefault(p, set()).add(n)
409 if not parents:
409 if not parents:
410 roots.append(n)
410 roots.append(n)
411
411
412 roots.sort()
412 roots.sort()
413 # Process roots, adding children to the queue as they become roots
413 # Process roots, adding children to the queue as they become roots
414 results = []
414 results = []
415 while roots:
415 while roots:
416 n = roots.pop(0)
416 n = roots.pop(0)
417 results.append(n)
417 results.append(n)
418 if n in childmap:
418 if n in childmap:
419 children = childmap[n]
419 children = childmap[n]
420 for c in children:
420 for c in children:
421 childparents = parentmap[c]
421 childparents = parentmap[c]
422 childparents.remove(n)
422 childparents.remove(n)
423 if len(childparents) == 0:
423 if len(childparents) == 0:
424 # insert at the beginning, that way child nodes
424 # insert at the beginning, that way child nodes
425 # are likely to be output immediately after their
425 # are likely to be output immediately after their
426 # parents. This gives better compression results.
426 # parents. This gives better compression results.
427 roots.insert(0, c)
427 roots.insert(0, c)
428
428
429 return results
429 return results
430
430
431
431
432 def readexactly(stream, n):
432 def readexactly(stream, n):
433 '''read n bytes from stream.read and abort if less was available'''
433 '''read n bytes from stream.read and abort if less was available'''
434 s = stream.read(n)
434 s = stream.read(n)
435 if len(s) < n:
435 if len(s) < n:
436 raise error.Abort(
436 raise error.Abort(
437 _(b"stream ended unexpectedly (got %d bytes, expected %d)")
437 _(b"stream ended unexpectedly (got %d bytes, expected %d)")
438 % (len(s), n)
438 % (len(s), n)
439 )
439 )
440 return s
440 return s
441
441
442
442
443 def readunpack(stream, fmt):
443 def readunpack(stream, fmt):
444 data = readexactly(stream, struct.calcsize(fmt))
444 data = readexactly(stream, struct.calcsize(fmt))
445 return struct.unpack(fmt, data)
445 return struct.unpack(fmt, data)
446
446
447
447
448 def readpath(stream):
448 def readpath(stream):
449 rawlen = readexactly(stream, constants.FILENAMESIZE)
449 rawlen = readexactly(stream, constants.FILENAMESIZE)
450 pathlen = struct.unpack(constants.FILENAMESTRUCT, rawlen)[0]
450 pathlen = struct.unpack(constants.FILENAMESTRUCT, rawlen)[0]
451 return readexactly(stream, pathlen)
451 return readexactly(stream, pathlen)
452
452
453
453
454 def readnodelist(stream):
454 def readnodelist(stream):
455 rawlen = readexactly(stream, constants.NODECOUNTSIZE)
455 rawlen = readexactly(stream, constants.NODECOUNTSIZE)
456 nodecount = struct.unpack(constants.NODECOUNTSTRUCT, rawlen)[0]
456 nodecount = struct.unpack(constants.NODECOUNTSTRUCT, rawlen)[0]
457 for i in pycompat.xrange(nodecount):
457 for i in pycompat.xrange(nodecount):
458 yield readexactly(stream, constants.NODESIZE)
458 yield readexactly(stream, constants.NODESIZE)
459
459
460
460
461 def readpathlist(stream):
461 def readpathlist(stream):
462 rawlen = readexactly(stream, constants.PATHCOUNTSIZE)
462 rawlen = readexactly(stream, constants.PATHCOUNTSIZE)
463 pathcount = struct.unpack(constants.PATHCOUNTSTRUCT, rawlen)[0]
463 pathcount = struct.unpack(constants.PATHCOUNTSTRUCT, rawlen)[0]
464 for i in pycompat.xrange(pathcount):
464 for i in pycompat.xrange(pathcount):
465 yield readpath(stream)
465 yield readpath(stream)
466
466
467
467
468 def getgid(groupname):
468 def getgid(groupname):
469 try:
469 try:
470 gid = grp.getgrnam(pycompat.fsdecode(groupname)).gr_gid
470 gid = grp.getgrnam(pycompat.fsdecode(groupname)).gr_gid
471 return gid
471 return gid
472 except KeyError:
472 except KeyError:
473 return None
473 return None
474
474
475
475
476 def setstickygroupdir(path, gid, warn=None):
476 def setstickygroupdir(path, gid, warn=None):
477 if gid is None:
477 if gid is None:
478 return
478 return
479 try:
479 try:
480 os.chown(path, -1, gid)
480 os.chown(path, -1, gid)
481 os.chmod(path, 0o2775)
481 os.chmod(path, 0o2775)
482 except (IOError, OSError) as ex:
482 except (IOError, OSError) as ex:
483 if warn:
483 if warn:
484 warn(_(b'unable to chown/chmod on %s: %s\n') % (path, ex))
484 warn(_(b'unable to chown/chmod on %s: %s\n') % (path, ex))
485
485
486
486
487 def mkstickygroupdir(ui, path):
487 def mkstickygroupdir(ui, path):
488 """Creates the given directory (if it doesn't exist) and give it a
488 """Creates the given directory (if it doesn't exist) and give it a
489 particular group with setgid enabled."""
489 particular group with setgid enabled."""
490 gid = None
490 gid = None
491 groupname = ui.config(b"remotefilelog", b"cachegroup")
491 groupname = ui.config(b"remotefilelog", b"cachegroup")
492 if groupname:
492 if groupname:
493 gid = getgid(groupname)
493 gid = getgid(groupname)
494 if gid is None:
494 if gid is None:
495 ui.warn(_(b'unable to resolve group name: %s\n') % groupname)
495 ui.warn(_(b'unable to resolve group name: %s\n') % groupname)
496
496
497 # we use a single stat syscall to test the existence and mode / group bit
497 # we use a single stat syscall to test the existence and mode / group bit
498 st = None
498 st = None
499 try:
499 try:
500 st = os.stat(path)
500 st = os.stat(path)
501 except OSError:
501 except OSError:
502 pass
502 pass
503
503
504 if st:
504 if st:
505 # exists
505 # exists
506 if (st.st_mode & 0o2775) != 0o2775 or st.st_gid != gid:
506 if (st.st_mode & 0o2775) != 0o2775 or st.st_gid != gid:
507 # permission needs to be fixed
507 # permission needs to be fixed
508 setstickygroupdir(path, gid, ui.warn)
508 setstickygroupdir(path, gid, ui.warn)
509 return
509 return
510
510
511 oldumask = os.umask(0o002)
511 oldumask = os.umask(0o002)
512 try:
512 try:
513 missingdirs = [path]
513 missingdirs = [path]
514 path = os.path.dirname(path)
514 path = os.path.dirname(path)
515 while path and not os.path.exists(path):
515 while path and not os.path.exists(path):
516 missingdirs.append(path)
516 missingdirs.append(path)
517 path = os.path.dirname(path)
517 path = os.path.dirname(path)
518
518
519 for path in reversed(missingdirs):
519 for path in reversed(missingdirs):
520 try:
520 try:
521 os.mkdir(path)
521 os.mkdir(path)
522 except OSError as ex:
522 except OSError as ex:
523 if ex.errno != errno.EEXIST:
523 if ex.errno != errno.EEXIST:
524 raise
524 raise
525
525
526 for path in missingdirs:
526 for path in missingdirs:
527 setstickygroupdir(path, gid, ui.warn)
527 setstickygroupdir(path, gid, ui.warn)
528 finally:
528 finally:
529 os.umask(oldumask)
529 os.umask(oldumask)
530
530
531
531
532 def getusername(ui):
532 def getusername(ui):
533 try:
533 try:
534 return stringutil.shortuser(ui.username())
534 return stringutil.shortuser(ui.username())
535 except Exception:
535 except Exception:
536 return b'unknown'
536 return b'unknown'
537
537
538
538
539 def getreponame(ui):
539 def getreponame(ui):
540 reponame = ui.config(b'paths', b'default')
540 reponame = ui.config(b'paths', b'default')
541 if reponame:
541 if reponame:
542 return os.path.basename(reponame)
542 return os.path.basename(reponame)
543 return b"unknown"
543 return b"unknown"
@@ -1,870 +1,870 b''
1 # formatter.py - generic output formatting for mercurial
1 # formatter.py - generic output formatting for mercurial
2 #
2 #
3 # Copyright 2012 Olivia Mackall <olivia@selenic.com>
3 # Copyright 2012 Olivia Mackall <olivia@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 """Generic output formatting for Mercurial
8 """Generic output formatting for Mercurial
9
9
10 The formatter provides API to show data in various ways. The following
10 The formatter provides API to show data in various ways. The following
11 functions should be used in place of ui.write():
11 functions should be used in place of ui.write():
12
12
13 - fm.write() for unconditional output
13 - fm.write() for unconditional output
14 - fm.condwrite() to show some extra data conditionally in plain output
14 - fm.condwrite() to show some extra data conditionally in plain output
15 - fm.context() to provide changectx to template output
15 - fm.context() to provide changectx to template output
16 - fm.data() to provide extra data to JSON or template output
16 - fm.data() to provide extra data to JSON or template output
17 - fm.plain() to show raw text that isn't provided to JSON or template output
17 - fm.plain() to show raw text that isn't provided to JSON or template output
18
18
19 To show structured data (e.g. date tuples, dicts, lists), apply fm.format*()
19 To show structured data (e.g. date tuples, dicts, lists), apply fm.format*()
20 beforehand so the data is converted to the appropriate data type. Use
20 beforehand so the data is converted to the appropriate data type. Use
21 fm.isplain() if you need to convert or format data conditionally which isn't
21 fm.isplain() if you need to convert or format data conditionally which isn't
22 supported by the formatter API.
22 supported by the formatter API.
23
23
24 To build nested structure (i.e. a list of dicts), use fm.nested().
24 To build nested structure (i.e. a list of dicts), use fm.nested().
25
25
26 See also https://www.mercurial-scm.org/wiki/GenericTemplatingPlan
26 See also https://www.mercurial-scm.org/wiki/GenericTemplatingPlan
27
27
28 fm.condwrite() vs 'if cond:':
28 fm.condwrite() vs 'if cond:':
29
29
30 In most cases, use fm.condwrite() so users can selectively show the data
30 In most cases, use fm.condwrite() so users can selectively show the data
31 in template output. If it's costly to build data, use plain 'if cond:' with
31 in template output. If it's costly to build data, use plain 'if cond:' with
32 fm.write().
32 fm.write().
33
33
34 fm.nested() vs fm.formatdict() (or fm.formatlist()):
34 fm.nested() vs fm.formatdict() (or fm.formatlist()):
35
35
36 fm.nested() should be used to form a tree structure (a list of dicts of
36 fm.nested() should be used to form a tree structure (a list of dicts of
37 lists of dicts...) which can be accessed through template keywords, e.g.
37 lists of dicts...) which can be accessed through template keywords, e.g.
38 "{foo % "{bar % {...}} {baz % {...}}"}". On the other hand, fm.formatdict()
38 "{foo % "{bar % {...}} {baz % {...}}"}". On the other hand, fm.formatdict()
39 exports a dict-type object to template, which can be accessed by e.g.
39 exports a dict-type object to template, which can be accessed by e.g.
40 "{get(foo, key)}" function.
40 "{get(foo, key)}" function.
41
41
42 Doctest helper:
42 Doctest helper:
43
43
44 >>> def show(fn, verbose=False, **opts):
44 >>> def show(fn, verbose=False, **opts):
45 ... import sys
45 ... import sys
46 ... from . import ui as uimod
46 ... from . import ui as uimod
47 ... ui = uimod.ui()
47 ... ui = uimod.ui()
48 ... ui.verbose = verbose
48 ... ui.verbose = verbose
49 ... ui.pushbuffer()
49 ... ui.pushbuffer()
50 ... try:
50 ... try:
51 ... return fn(ui, ui.formatter(pycompat.sysbytes(fn.__name__),
51 ... return fn(ui, ui.formatter(pycompat.sysbytes(fn.__name__),
52 ... pycompat.byteskwargs(opts)))
52 ... pycompat.byteskwargs(opts)))
53 ... finally:
53 ... finally:
54 ... print(pycompat.sysstr(ui.popbuffer()), end='')
54 ... print(pycompat.sysstr(ui.popbuffer()), end='')
55
55
56 Basic example:
56 Basic example:
57
57
58 >>> def files(ui, fm):
58 >>> def files(ui, fm):
59 ... files = [(b'foo', 123, (0, 0)), (b'bar', 456, (1, 0))]
59 ... files = [(b'foo', 123, (0, 0)), (b'bar', 456, (1, 0))]
60 ... for f in files:
60 ... for f in files:
61 ... fm.startitem()
61 ... fm.startitem()
62 ... fm.write(b'path', b'%s', f[0])
62 ... fm.write(b'path', b'%s', f[0])
63 ... fm.condwrite(ui.verbose, b'date', b' %s',
63 ... fm.condwrite(ui.verbose, b'date', b' %s',
64 ... fm.formatdate(f[2], b'%Y-%m-%d %H:%M:%S'))
64 ... fm.formatdate(f[2], b'%Y-%m-%d %H:%M:%S'))
65 ... fm.data(size=f[1])
65 ... fm.data(size=f[1])
66 ... fm.plain(b'\\n')
66 ... fm.plain(b'\\n')
67 ... fm.end()
67 ... fm.end()
68 >>> show(files)
68 >>> show(files)
69 foo
69 foo
70 bar
70 bar
71 >>> show(files, verbose=True)
71 >>> show(files, verbose=True)
72 foo 1970-01-01 00:00:00
72 foo 1970-01-01 00:00:00
73 bar 1970-01-01 00:00:01
73 bar 1970-01-01 00:00:01
74 >>> show(files, template=b'json')
74 >>> show(files, template=b'json')
75 [
75 [
76 {
76 {
77 "date": [0, 0],
77 "date": [0, 0],
78 "path": "foo",
78 "path": "foo",
79 "size": 123
79 "size": 123
80 },
80 },
81 {
81 {
82 "date": [1, 0],
82 "date": [1, 0],
83 "path": "bar",
83 "path": "bar",
84 "size": 456
84 "size": 456
85 }
85 }
86 ]
86 ]
87 >>> show(files, template=b'path: {path}\\ndate: {date|rfc3339date}\\n')
87 >>> show(files, template=b'path: {path}\\ndate: {date|rfc3339date}\\n')
88 path: foo
88 path: foo
89 date: 1970-01-01T00:00:00+00:00
89 date: 1970-01-01T00:00:00+00:00
90 path: bar
90 path: bar
91 date: 1970-01-01T00:00:01+00:00
91 date: 1970-01-01T00:00:01+00:00
92
92
93 Nested example:
93 Nested example:
94
94
95 >>> def subrepos(ui, fm):
95 >>> def subrepos(ui, fm):
96 ... fm.startitem()
96 ... fm.startitem()
97 ... fm.write(b'reponame', b'[%s]\\n', b'baz')
97 ... fm.write(b'reponame', b'[%s]\\n', b'baz')
98 ... files(ui, fm.nested(b'files', tmpl=b'{reponame}'))
98 ... files(ui, fm.nested(b'files', tmpl=b'{reponame}'))
99 ... fm.end()
99 ... fm.end()
100 >>> show(subrepos)
100 >>> show(subrepos)
101 [baz]
101 [baz]
102 foo
102 foo
103 bar
103 bar
104 >>> show(subrepos, template=b'{reponame}: {join(files % "{path}", ", ")}\\n')
104 >>> show(subrepos, template=b'{reponame}: {join(files % "{path}", ", ")}\\n')
105 baz: foo, bar
105 baz: foo, bar
106 """
106 """
107
107
108
108
109 import contextlib
109 import contextlib
110 import itertools
110 import itertools
111 import os
111 import os
112 import pickle
112 import pickle
113
113
114 from .i18n import _
114 from .i18n import _
115 from .node import (
115 from .node import (
116 hex,
116 hex,
117 short,
117 short,
118 )
118 )
119 from .thirdparty import attr
119 from .thirdparty import attr
120
120
121 from . import (
121 from . import (
122 error,
122 error,
123 pycompat,
123 pycompat,
124 templatefilters,
124 templatefilters,
125 templatekw,
125 templatekw,
126 templater,
126 templater,
127 templateutil,
127 templateutil,
128 util,
128 util,
129 )
129 )
130 from .utils import (
130 from .utils import (
131 cborutil,
131 cborutil,
132 dateutil,
132 dateutil,
133 stringutil,
133 stringutil,
134 )
134 )
135
135
136
136
137 def isprintable(obj):
137 def isprintable(obj):
138 """Check if the given object can be directly passed in to formatter's
138 """Check if the given object can be directly passed in to formatter's
139 write() and data() functions
139 write() and data() functions
140
140
141 Returns False if the object is unsupported or must be pre-processed by
141 Returns False if the object is unsupported or must be pre-processed by
142 formatdate(), formatdict(), or formatlist().
142 formatdate(), formatdict(), or formatlist().
143 """
143 """
144 return isinstance(obj, (type(None), bool, int, pycompat.long, float, bytes))
144 return isinstance(obj, (type(None), bool, int, int, float, bytes))
145
145
146
146
147 class _nullconverter(object):
147 class _nullconverter(object):
148 '''convert non-primitive data types to be processed by formatter'''
148 '''convert non-primitive data types to be processed by formatter'''
149
149
150 # set to True if context object should be stored as item
150 # set to True if context object should be stored as item
151 storecontext = False
151 storecontext = False
152
152
153 @staticmethod
153 @staticmethod
154 def wrapnested(data, tmpl, sep):
154 def wrapnested(data, tmpl, sep):
155 '''wrap nested data by appropriate type'''
155 '''wrap nested data by appropriate type'''
156 return data
156 return data
157
157
158 @staticmethod
158 @staticmethod
159 def formatdate(date, fmt):
159 def formatdate(date, fmt):
160 '''convert date tuple to appropriate format'''
160 '''convert date tuple to appropriate format'''
161 # timestamp can be float, but the canonical form should be int
161 # timestamp can be float, but the canonical form should be int
162 ts, tz = date
162 ts, tz = date
163 return (int(ts), tz)
163 return (int(ts), tz)
164
164
165 @staticmethod
165 @staticmethod
166 def formatdict(data, key, value, fmt, sep):
166 def formatdict(data, key, value, fmt, sep):
167 '''convert dict or key-value pairs to appropriate dict format'''
167 '''convert dict or key-value pairs to appropriate dict format'''
168 # use plain dict instead of util.sortdict so that data can be
168 # use plain dict instead of util.sortdict so that data can be
169 # serialized as a builtin dict in pickle output
169 # serialized as a builtin dict in pickle output
170 return dict(data)
170 return dict(data)
171
171
172 @staticmethod
172 @staticmethod
173 def formatlist(data, name, fmt, sep):
173 def formatlist(data, name, fmt, sep):
174 '''convert iterable to appropriate list format'''
174 '''convert iterable to appropriate list format'''
175 return list(data)
175 return list(data)
176
176
177
177
178 class baseformatter(object):
178 class baseformatter(object):
179
179
180 # set to True if the formater output a strict format that does not support
180 # set to True if the formater output a strict format that does not support
181 # arbitrary output in the stream.
181 # arbitrary output in the stream.
182 strict_format = False
182 strict_format = False
183
183
184 def __init__(self, ui, topic, opts, converter):
184 def __init__(self, ui, topic, opts, converter):
185 self._ui = ui
185 self._ui = ui
186 self._topic = topic
186 self._topic = topic
187 self._opts = opts
187 self._opts = opts
188 self._converter = converter
188 self._converter = converter
189 self._item = None
189 self._item = None
190 # function to convert node to string suitable for this output
190 # function to convert node to string suitable for this output
191 self.hexfunc = hex
191 self.hexfunc = hex
192
192
193 def __enter__(self):
193 def __enter__(self):
194 return self
194 return self
195
195
196 def __exit__(self, exctype, excvalue, traceback):
196 def __exit__(self, exctype, excvalue, traceback):
197 if exctype is None:
197 if exctype is None:
198 self.end()
198 self.end()
199
199
200 def _showitem(self):
200 def _showitem(self):
201 '''show a formatted item once all data is collected'''
201 '''show a formatted item once all data is collected'''
202
202
203 def startitem(self):
203 def startitem(self):
204 '''begin an item in the format list'''
204 '''begin an item in the format list'''
205 if self._item is not None:
205 if self._item is not None:
206 self._showitem()
206 self._showitem()
207 self._item = {}
207 self._item = {}
208
208
209 def formatdate(self, date, fmt=b'%a %b %d %H:%M:%S %Y %1%2'):
209 def formatdate(self, date, fmt=b'%a %b %d %H:%M:%S %Y %1%2'):
210 '''convert date tuple to appropriate format'''
210 '''convert date tuple to appropriate format'''
211 return self._converter.formatdate(date, fmt)
211 return self._converter.formatdate(date, fmt)
212
212
213 def formatdict(self, data, key=b'key', value=b'value', fmt=None, sep=b' '):
213 def formatdict(self, data, key=b'key', value=b'value', fmt=None, sep=b' '):
214 '''convert dict or key-value pairs to appropriate dict format'''
214 '''convert dict or key-value pairs to appropriate dict format'''
215 return self._converter.formatdict(data, key, value, fmt, sep)
215 return self._converter.formatdict(data, key, value, fmt, sep)
216
216
217 def formatlist(self, data, name, fmt=None, sep=b' '):
217 def formatlist(self, data, name, fmt=None, sep=b' '):
218 '''convert iterable to appropriate list format'''
218 '''convert iterable to appropriate list format'''
219 # name is mandatory argument for now, but it could be optional if
219 # name is mandatory argument for now, but it could be optional if
220 # we have default template keyword, e.g. {item}
220 # we have default template keyword, e.g. {item}
221 return self._converter.formatlist(data, name, fmt, sep)
221 return self._converter.formatlist(data, name, fmt, sep)
222
222
223 def context(self, **ctxs):
223 def context(self, **ctxs):
224 '''insert context objects to be used to render template keywords'''
224 '''insert context objects to be used to render template keywords'''
225 ctxs = pycompat.byteskwargs(ctxs)
225 ctxs = pycompat.byteskwargs(ctxs)
226 assert all(k in {b'repo', b'ctx', b'fctx'} for k in ctxs)
226 assert all(k in {b'repo', b'ctx', b'fctx'} for k in ctxs)
227 if self._converter.storecontext:
227 if self._converter.storecontext:
228 # populate missing resources in fctx -> ctx -> repo order
228 # populate missing resources in fctx -> ctx -> repo order
229 if b'fctx' in ctxs and b'ctx' not in ctxs:
229 if b'fctx' in ctxs and b'ctx' not in ctxs:
230 ctxs[b'ctx'] = ctxs[b'fctx'].changectx()
230 ctxs[b'ctx'] = ctxs[b'fctx'].changectx()
231 if b'ctx' in ctxs and b'repo' not in ctxs:
231 if b'ctx' in ctxs and b'repo' not in ctxs:
232 ctxs[b'repo'] = ctxs[b'ctx'].repo()
232 ctxs[b'repo'] = ctxs[b'ctx'].repo()
233 self._item.update(ctxs)
233 self._item.update(ctxs)
234
234
235 def datahint(self):
235 def datahint(self):
236 '''set of field names to be referenced'''
236 '''set of field names to be referenced'''
237 return set()
237 return set()
238
238
239 def data(self, **data):
239 def data(self, **data):
240 '''insert data into item that's not shown in default output'''
240 '''insert data into item that's not shown in default output'''
241 data = pycompat.byteskwargs(data)
241 data = pycompat.byteskwargs(data)
242 self._item.update(data)
242 self._item.update(data)
243
243
244 def write(self, fields, deftext, *fielddata, **opts):
244 def write(self, fields, deftext, *fielddata, **opts):
245 '''do default text output while assigning data to item'''
245 '''do default text output while assigning data to item'''
246 fieldkeys = fields.split()
246 fieldkeys = fields.split()
247 assert len(fieldkeys) == len(fielddata), (fieldkeys, fielddata)
247 assert len(fieldkeys) == len(fielddata), (fieldkeys, fielddata)
248 self._item.update(zip(fieldkeys, fielddata))
248 self._item.update(zip(fieldkeys, fielddata))
249
249
250 def condwrite(self, cond, fields, deftext, *fielddata, **opts):
250 def condwrite(self, cond, fields, deftext, *fielddata, **opts):
251 '''do conditional write (primarily for plain formatter)'''
251 '''do conditional write (primarily for plain formatter)'''
252 fieldkeys = fields.split()
252 fieldkeys = fields.split()
253 assert len(fieldkeys) == len(fielddata)
253 assert len(fieldkeys) == len(fielddata)
254 self._item.update(zip(fieldkeys, fielddata))
254 self._item.update(zip(fieldkeys, fielddata))
255
255
256 def plain(self, text, **opts):
256 def plain(self, text, **opts):
257 '''show raw text for non-templated mode'''
257 '''show raw text for non-templated mode'''
258
258
259 def isplain(self):
259 def isplain(self):
260 '''check for plain formatter usage'''
260 '''check for plain formatter usage'''
261 return False
261 return False
262
262
263 def nested(self, field, tmpl=None, sep=b''):
263 def nested(self, field, tmpl=None, sep=b''):
264 '''sub formatter to store nested data in the specified field'''
264 '''sub formatter to store nested data in the specified field'''
265 data = []
265 data = []
266 self._item[field] = self._converter.wrapnested(data, tmpl, sep)
266 self._item[field] = self._converter.wrapnested(data, tmpl, sep)
267 return _nestedformatter(self._ui, self._converter, data)
267 return _nestedformatter(self._ui, self._converter, data)
268
268
269 def end(self):
269 def end(self):
270 '''end output for the formatter'''
270 '''end output for the formatter'''
271 if self._item is not None:
271 if self._item is not None:
272 self._showitem()
272 self._showitem()
273
273
274
274
275 def nullformatter(ui, topic, opts):
275 def nullformatter(ui, topic, opts):
276 '''formatter that prints nothing'''
276 '''formatter that prints nothing'''
277 return baseformatter(ui, topic, opts, converter=_nullconverter)
277 return baseformatter(ui, topic, opts, converter=_nullconverter)
278
278
279
279
280 class _nestedformatter(baseformatter):
280 class _nestedformatter(baseformatter):
281 '''build sub items and store them in the parent formatter'''
281 '''build sub items and store them in the parent formatter'''
282
282
283 def __init__(self, ui, converter, data):
283 def __init__(self, ui, converter, data):
284 baseformatter.__init__(
284 baseformatter.__init__(
285 self, ui, topic=b'', opts={}, converter=converter
285 self, ui, topic=b'', opts={}, converter=converter
286 )
286 )
287 self._data = data
287 self._data = data
288
288
289 def _showitem(self):
289 def _showitem(self):
290 self._data.append(self._item)
290 self._data.append(self._item)
291
291
292
292
293 def _iteritems(data):
293 def _iteritems(data):
294 '''iterate key-value pairs in stable order'''
294 '''iterate key-value pairs in stable order'''
295 if isinstance(data, dict):
295 if isinstance(data, dict):
296 return sorted(data.items())
296 return sorted(data.items())
297 return data
297 return data
298
298
299
299
300 class _plainconverter(object):
300 class _plainconverter(object):
301 '''convert non-primitive data types to text'''
301 '''convert non-primitive data types to text'''
302
302
303 storecontext = False
303 storecontext = False
304
304
305 @staticmethod
305 @staticmethod
306 def wrapnested(data, tmpl, sep):
306 def wrapnested(data, tmpl, sep):
307 raise error.ProgrammingError(b'plainformatter should never be nested')
307 raise error.ProgrammingError(b'plainformatter should never be nested')
308
308
309 @staticmethod
309 @staticmethod
310 def formatdate(date, fmt):
310 def formatdate(date, fmt):
311 '''stringify date tuple in the given format'''
311 '''stringify date tuple in the given format'''
312 return dateutil.datestr(date, fmt)
312 return dateutil.datestr(date, fmt)
313
313
314 @staticmethod
314 @staticmethod
315 def formatdict(data, key, value, fmt, sep):
315 def formatdict(data, key, value, fmt, sep):
316 '''stringify key-value pairs separated by sep'''
316 '''stringify key-value pairs separated by sep'''
317 prefmt = pycompat.identity
317 prefmt = pycompat.identity
318 if fmt is None:
318 if fmt is None:
319 fmt = b'%s=%s'
319 fmt = b'%s=%s'
320 prefmt = pycompat.bytestr
320 prefmt = pycompat.bytestr
321 return sep.join(
321 return sep.join(
322 fmt % (prefmt(k), prefmt(v)) for k, v in _iteritems(data)
322 fmt % (prefmt(k), prefmt(v)) for k, v in _iteritems(data)
323 )
323 )
324
324
325 @staticmethod
325 @staticmethod
326 def formatlist(data, name, fmt, sep):
326 def formatlist(data, name, fmt, sep):
327 '''stringify iterable separated by sep'''
327 '''stringify iterable separated by sep'''
328 prefmt = pycompat.identity
328 prefmt = pycompat.identity
329 if fmt is None:
329 if fmt is None:
330 fmt = b'%s'
330 fmt = b'%s'
331 prefmt = pycompat.bytestr
331 prefmt = pycompat.bytestr
332 return sep.join(fmt % prefmt(e) for e in data)
332 return sep.join(fmt % prefmt(e) for e in data)
333
333
334
334
335 class plainformatter(baseformatter):
335 class plainformatter(baseformatter):
336 '''the default text output scheme'''
336 '''the default text output scheme'''
337
337
338 def __init__(self, ui, out, topic, opts):
338 def __init__(self, ui, out, topic, opts):
339 baseformatter.__init__(self, ui, topic, opts, _plainconverter)
339 baseformatter.__init__(self, ui, topic, opts, _plainconverter)
340 if ui.debugflag:
340 if ui.debugflag:
341 self.hexfunc = hex
341 self.hexfunc = hex
342 else:
342 else:
343 self.hexfunc = short
343 self.hexfunc = short
344 if ui is out:
344 if ui is out:
345 self._write = ui.write
345 self._write = ui.write
346 else:
346 else:
347 self._write = lambda s, **opts: out.write(s)
347 self._write = lambda s, **opts: out.write(s)
348
348
349 def startitem(self):
349 def startitem(self):
350 pass
350 pass
351
351
352 def data(self, **data):
352 def data(self, **data):
353 pass
353 pass
354
354
355 def write(self, fields, deftext, *fielddata, **opts):
355 def write(self, fields, deftext, *fielddata, **opts):
356 self._write(deftext % fielddata, **opts)
356 self._write(deftext % fielddata, **opts)
357
357
358 def condwrite(self, cond, fields, deftext, *fielddata, **opts):
358 def condwrite(self, cond, fields, deftext, *fielddata, **opts):
359 '''do conditional write'''
359 '''do conditional write'''
360 if cond:
360 if cond:
361 self._write(deftext % fielddata, **opts)
361 self._write(deftext % fielddata, **opts)
362
362
363 def plain(self, text, **opts):
363 def plain(self, text, **opts):
364 self._write(text, **opts)
364 self._write(text, **opts)
365
365
366 def isplain(self):
366 def isplain(self):
367 return True
367 return True
368
368
369 def nested(self, field, tmpl=None, sep=b''):
369 def nested(self, field, tmpl=None, sep=b''):
370 # nested data will be directly written to ui
370 # nested data will be directly written to ui
371 return self
371 return self
372
372
373 def end(self):
373 def end(self):
374 pass
374 pass
375
375
376
376
377 class debugformatter(baseformatter):
377 class debugformatter(baseformatter):
378 def __init__(self, ui, out, topic, opts):
378 def __init__(self, ui, out, topic, opts):
379 baseformatter.__init__(self, ui, topic, opts, _nullconverter)
379 baseformatter.__init__(self, ui, topic, opts, _nullconverter)
380 self._out = out
380 self._out = out
381 self._out.write(b"%s = [\n" % self._topic)
381 self._out.write(b"%s = [\n" % self._topic)
382
382
383 def _showitem(self):
383 def _showitem(self):
384 self._out.write(
384 self._out.write(
385 b' %s,\n' % stringutil.pprint(self._item, indent=4, level=1)
385 b' %s,\n' % stringutil.pprint(self._item, indent=4, level=1)
386 )
386 )
387
387
388 def end(self):
388 def end(self):
389 baseformatter.end(self)
389 baseformatter.end(self)
390 self._out.write(b"]\n")
390 self._out.write(b"]\n")
391
391
392
392
393 class pickleformatter(baseformatter):
393 class pickleformatter(baseformatter):
394 def __init__(self, ui, out, topic, opts):
394 def __init__(self, ui, out, topic, opts):
395 baseformatter.__init__(self, ui, topic, opts, _nullconverter)
395 baseformatter.__init__(self, ui, topic, opts, _nullconverter)
396 self._out = out
396 self._out = out
397 self._data = []
397 self._data = []
398
398
399 def _showitem(self):
399 def _showitem(self):
400 self._data.append(self._item)
400 self._data.append(self._item)
401
401
402 def end(self):
402 def end(self):
403 baseformatter.end(self)
403 baseformatter.end(self)
404 self._out.write(pickle.dumps(self._data))
404 self._out.write(pickle.dumps(self._data))
405
405
406
406
407 class cborformatter(baseformatter):
407 class cborformatter(baseformatter):
408 '''serialize items as an indefinite-length CBOR array'''
408 '''serialize items as an indefinite-length CBOR array'''
409
409
410 def __init__(self, ui, out, topic, opts):
410 def __init__(self, ui, out, topic, opts):
411 baseformatter.__init__(self, ui, topic, opts, _nullconverter)
411 baseformatter.__init__(self, ui, topic, opts, _nullconverter)
412 self._out = out
412 self._out = out
413 self._out.write(cborutil.BEGIN_INDEFINITE_ARRAY)
413 self._out.write(cborutil.BEGIN_INDEFINITE_ARRAY)
414
414
415 def _showitem(self):
415 def _showitem(self):
416 self._out.write(b''.join(cborutil.streamencode(self._item)))
416 self._out.write(b''.join(cborutil.streamencode(self._item)))
417
417
418 def end(self):
418 def end(self):
419 baseformatter.end(self)
419 baseformatter.end(self)
420 self._out.write(cborutil.BREAK)
420 self._out.write(cborutil.BREAK)
421
421
422
422
423 class jsonformatter(baseformatter):
423 class jsonformatter(baseformatter):
424
424
425 strict_format = True
425 strict_format = True
426
426
427 def __init__(self, ui, out, topic, opts):
427 def __init__(self, ui, out, topic, opts):
428 baseformatter.__init__(self, ui, topic, opts, _nullconverter)
428 baseformatter.__init__(self, ui, topic, opts, _nullconverter)
429 self._out = out
429 self._out = out
430 self._out.write(b"[")
430 self._out.write(b"[")
431 self._first = True
431 self._first = True
432
432
433 def _showitem(self):
433 def _showitem(self):
434 if self._first:
434 if self._first:
435 self._first = False
435 self._first = False
436 else:
436 else:
437 self._out.write(b",")
437 self._out.write(b",")
438
438
439 self._out.write(b"\n {\n")
439 self._out.write(b"\n {\n")
440 first = True
440 first = True
441 for k, v in sorted(self._item.items()):
441 for k, v in sorted(self._item.items()):
442 if first:
442 if first:
443 first = False
443 first = False
444 else:
444 else:
445 self._out.write(b",\n")
445 self._out.write(b",\n")
446 u = templatefilters.json(v, paranoid=False)
446 u = templatefilters.json(v, paranoid=False)
447 self._out.write(b' "%s": %s' % (k, u))
447 self._out.write(b' "%s": %s' % (k, u))
448 self._out.write(b"\n }")
448 self._out.write(b"\n }")
449
449
450 def end(self):
450 def end(self):
451 baseformatter.end(self)
451 baseformatter.end(self)
452 self._out.write(b"\n]\n")
452 self._out.write(b"\n]\n")
453
453
454
454
455 class _templateconverter(object):
455 class _templateconverter(object):
456 '''convert non-primitive data types to be processed by templater'''
456 '''convert non-primitive data types to be processed by templater'''
457
457
458 storecontext = True
458 storecontext = True
459
459
460 @staticmethod
460 @staticmethod
461 def wrapnested(data, tmpl, sep):
461 def wrapnested(data, tmpl, sep):
462 '''wrap nested data by templatable type'''
462 '''wrap nested data by templatable type'''
463 return templateutil.mappinglist(data, tmpl=tmpl, sep=sep)
463 return templateutil.mappinglist(data, tmpl=tmpl, sep=sep)
464
464
465 @staticmethod
465 @staticmethod
466 def formatdate(date, fmt):
466 def formatdate(date, fmt):
467 '''return date tuple'''
467 '''return date tuple'''
468 return templateutil.date(date)
468 return templateutil.date(date)
469
469
470 @staticmethod
470 @staticmethod
471 def formatdict(data, key, value, fmt, sep):
471 def formatdict(data, key, value, fmt, sep):
472 '''build object that can be evaluated as either plain string or dict'''
472 '''build object that can be evaluated as either plain string or dict'''
473 data = util.sortdict(_iteritems(data))
473 data = util.sortdict(_iteritems(data))
474
474
475 def f():
475 def f():
476 yield _plainconverter.formatdict(data, key, value, fmt, sep)
476 yield _plainconverter.formatdict(data, key, value, fmt, sep)
477
477
478 return templateutil.hybriddict(
478 return templateutil.hybriddict(
479 data, key=key, value=value, fmt=fmt, gen=f
479 data, key=key, value=value, fmt=fmt, gen=f
480 )
480 )
481
481
482 @staticmethod
482 @staticmethod
483 def formatlist(data, name, fmt, sep):
483 def formatlist(data, name, fmt, sep):
484 '''build object that can be evaluated as either plain string or list'''
484 '''build object that can be evaluated as either plain string or list'''
485 data = list(data)
485 data = list(data)
486
486
487 def f():
487 def f():
488 yield _plainconverter.formatlist(data, name, fmt, sep)
488 yield _plainconverter.formatlist(data, name, fmt, sep)
489
489
490 return templateutil.hybridlist(data, name=name, fmt=fmt, gen=f)
490 return templateutil.hybridlist(data, name=name, fmt=fmt, gen=f)
491
491
492
492
493 class templateformatter(baseformatter):
493 class templateformatter(baseformatter):
494 def __init__(self, ui, out, topic, opts, spec, overridetemplates=None):
494 def __init__(self, ui, out, topic, opts, spec, overridetemplates=None):
495 baseformatter.__init__(self, ui, topic, opts, _templateconverter)
495 baseformatter.__init__(self, ui, topic, opts, _templateconverter)
496 self._out = out
496 self._out = out
497 self._tref = spec.ref
497 self._tref = spec.ref
498 self._t = loadtemplater(
498 self._t = loadtemplater(
499 ui,
499 ui,
500 spec,
500 spec,
501 defaults=templatekw.keywords,
501 defaults=templatekw.keywords,
502 resources=templateresources(ui),
502 resources=templateresources(ui),
503 cache=templatekw.defaulttempl,
503 cache=templatekw.defaulttempl,
504 )
504 )
505 if overridetemplates:
505 if overridetemplates:
506 self._t.cache.update(overridetemplates)
506 self._t.cache.update(overridetemplates)
507 self._parts = templatepartsmap(
507 self._parts = templatepartsmap(
508 spec, self._t, [b'docheader', b'docfooter', b'separator']
508 spec, self._t, [b'docheader', b'docfooter', b'separator']
509 )
509 )
510 self._counter = itertools.count()
510 self._counter = itertools.count()
511 self._renderitem(b'docheader', {})
511 self._renderitem(b'docheader', {})
512
512
513 def _showitem(self):
513 def _showitem(self):
514 item = self._item.copy()
514 item = self._item.copy()
515 item[b'index'] = index = next(self._counter)
515 item[b'index'] = index = next(self._counter)
516 if index > 0:
516 if index > 0:
517 self._renderitem(b'separator', {})
517 self._renderitem(b'separator', {})
518 self._renderitem(self._tref, item)
518 self._renderitem(self._tref, item)
519
519
520 def _renderitem(self, part, item):
520 def _renderitem(self, part, item):
521 if part not in self._parts:
521 if part not in self._parts:
522 return
522 return
523 ref = self._parts[part]
523 ref = self._parts[part]
524 # None can't be put in the mapping dict since it means <unset>
524 # None can't be put in the mapping dict since it means <unset>
525 for k, v in item.items():
525 for k, v in item.items():
526 if v is None:
526 if v is None:
527 item[k] = templateutil.wrappedvalue(v)
527 item[k] = templateutil.wrappedvalue(v)
528 self._out.write(self._t.render(ref, item))
528 self._out.write(self._t.render(ref, item))
529
529
530 @util.propertycache
530 @util.propertycache
531 def _symbolsused(self):
531 def _symbolsused(self):
532 return self._t.symbolsused(self._tref)
532 return self._t.symbolsused(self._tref)
533
533
534 def datahint(self):
534 def datahint(self):
535 '''set of field names to be referenced from the template'''
535 '''set of field names to be referenced from the template'''
536 return self._symbolsused[0]
536 return self._symbolsused[0]
537
537
538 def end(self):
538 def end(self):
539 baseformatter.end(self)
539 baseformatter.end(self)
540 self._renderitem(b'docfooter', {})
540 self._renderitem(b'docfooter', {})
541
541
542
542
543 @attr.s(frozen=True)
543 @attr.s(frozen=True)
544 class templatespec(object):
544 class templatespec(object):
545 ref = attr.ib()
545 ref = attr.ib()
546 tmpl = attr.ib()
546 tmpl = attr.ib()
547 mapfile = attr.ib()
547 mapfile = attr.ib()
548 refargs = attr.ib(default=None)
548 refargs = attr.ib(default=None)
549 fp = attr.ib(default=None)
549 fp = attr.ib(default=None)
550
550
551
551
552 def empty_templatespec():
552 def empty_templatespec():
553 return templatespec(None, None, None)
553 return templatespec(None, None, None)
554
554
555
555
556 def reference_templatespec(ref, refargs=None):
556 def reference_templatespec(ref, refargs=None):
557 return templatespec(ref, None, None, refargs)
557 return templatespec(ref, None, None, refargs)
558
558
559
559
560 def literal_templatespec(tmpl):
560 def literal_templatespec(tmpl):
561 assert not isinstance(tmpl, str), b'tmpl must not be a str'
561 assert not isinstance(tmpl, str), b'tmpl must not be a str'
562 return templatespec(b'', tmpl, None)
562 return templatespec(b'', tmpl, None)
563
563
564
564
565 def mapfile_templatespec(topic, mapfile, fp=None):
565 def mapfile_templatespec(topic, mapfile, fp=None):
566 return templatespec(topic, None, mapfile, fp=fp)
566 return templatespec(topic, None, mapfile, fp=fp)
567
567
568
568
569 def lookuptemplate(ui, topic, tmpl):
569 def lookuptemplate(ui, topic, tmpl):
570 """Find the template matching the given -T/--template spec 'tmpl'
570 """Find the template matching the given -T/--template spec 'tmpl'
571
571
572 'tmpl' can be any of the following:
572 'tmpl' can be any of the following:
573
573
574 - a literal template (e.g. '{rev}')
574 - a literal template (e.g. '{rev}')
575 - a reference to built-in template (i.e. formatter)
575 - a reference to built-in template (i.e. formatter)
576 - a map-file name or path (e.g. 'changelog')
576 - a map-file name or path (e.g. 'changelog')
577 - a reference to [templates] in config file
577 - a reference to [templates] in config file
578 - a path to raw template file
578 - a path to raw template file
579
579
580 A map file defines a stand-alone template environment. If a map file
580 A map file defines a stand-alone template environment. If a map file
581 selected, all templates defined in the file will be loaded, and the
581 selected, all templates defined in the file will be loaded, and the
582 template matching the given topic will be rendered. Aliases won't be
582 template matching the given topic will be rendered. Aliases won't be
583 loaded from user config, but from the map file.
583 loaded from user config, but from the map file.
584
584
585 If no map file selected, all templates in [templates] section will be
585 If no map file selected, all templates in [templates] section will be
586 available as well as aliases in [templatealias].
586 available as well as aliases in [templatealias].
587 """
587 """
588
588
589 if not tmpl:
589 if not tmpl:
590 return empty_templatespec()
590 return empty_templatespec()
591
591
592 # looks like a literal template?
592 # looks like a literal template?
593 if b'{' in tmpl:
593 if b'{' in tmpl:
594 return literal_templatespec(tmpl)
594 return literal_templatespec(tmpl)
595
595
596 # a reference to built-in (formatter) template
596 # a reference to built-in (formatter) template
597 if tmpl in {b'cbor', b'json', b'pickle', b'debug'}:
597 if tmpl in {b'cbor', b'json', b'pickle', b'debug'}:
598 return reference_templatespec(tmpl)
598 return reference_templatespec(tmpl)
599
599
600 # a function-style reference to built-in template
600 # a function-style reference to built-in template
601 func, fsep, ftail = tmpl.partition(b'(')
601 func, fsep, ftail = tmpl.partition(b'(')
602 if func in {b'cbor', b'json'} and fsep and ftail.endswith(b')'):
602 if func in {b'cbor', b'json'} and fsep and ftail.endswith(b')'):
603 templater.parseexpr(tmpl) # make sure syntax errors are confined
603 templater.parseexpr(tmpl) # make sure syntax errors are confined
604 return reference_templatespec(func, refargs=ftail[:-1])
604 return reference_templatespec(func, refargs=ftail[:-1])
605
605
606 # perhaps a stock style?
606 # perhaps a stock style?
607 if not os.path.split(tmpl)[0]:
607 if not os.path.split(tmpl)[0]:
608 (mapname, fp) = templater.try_open_template(
608 (mapname, fp) = templater.try_open_template(
609 b'map-cmdline.' + tmpl
609 b'map-cmdline.' + tmpl
610 ) or templater.try_open_template(tmpl)
610 ) or templater.try_open_template(tmpl)
611 if mapname:
611 if mapname:
612 return mapfile_templatespec(topic, mapname, fp)
612 return mapfile_templatespec(topic, mapname, fp)
613
613
614 # perhaps it's a reference to [templates]
614 # perhaps it's a reference to [templates]
615 if ui.config(b'templates', tmpl):
615 if ui.config(b'templates', tmpl):
616 return reference_templatespec(tmpl)
616 return reference_templatespec(tmpl)
617
617
618 if tmpl == b'list':
618 if tmpl == b'list':
619 ui.write(_(b"available styles: %s\n") % templater.stylelist())
619 ui.write(_(b"available styles: %s\n") % templater.stylelist())
620 raise error.Abort(_(b"specify a template"))
620 raise error.Abort(_(b"specify a template"))
621
621
622 # perhaps it's a path to a map or a template
622 # perhaps it's a path to a map or a template
623 if (b'/' in tmpl or b'\\' in tmpl) and os.path.isfile(tmpl):
623 if (b'/' in tmpl or b'\\' in tmpl) and os.path.isfile(tmpl):
624 # is it a mapfile for a style?
624 # is it a mapfile for a style?
625 if os.path.basename(tmpl).startswith(b"map-"):
625 if os.path.basename(tmpl).startswith(b"map-"):
626 return mapfile_templatespec(topic, os.path.realpath(tmpl))
626 return mapfile_templatespec(topic, os.path.realpath(tmpl))
627 with util.posixfile(tmpl, b'rb') as f:
627 with util.posixfile(tmpl, b'rb') as f:
628 tmpl = f.read()
628 tmpl = f.read()
629 return literal_templatespec(tmpl)
629 return literal_templatespec(tmpl)
630
630
631 # constant string?
631 # constant string?
632 return literal_templatespec(tmpl)
632 return literal_templatespec(tmpl)
633
633
634
634
635 def templatepartsmap(spec, t, partnames):
635 def templatepartsmap(spec, t, partnames):
636 """Create a mapping of {part: ref}"""
636 """Create a mapping of {part: ref}"""
637 partsmap = {spec.ref: spec.ref} # initial ref must exist in t
637 partsmap = {spec.ref: spec.ref} # initial ref must exist in t
638 if spec.mapfile:
638 if spec.mapfile:
639 partsmap.update((p, p) for p in partnames if p in t)
639 partsmap.update((p, p) for p in partnames if p in t)
640 elif spec.ref:
640 elif spec.ref:
641 for part in partnames:
641 for part in partnames:
642 ref = b'%s:%s' % (spec.ref, part) # select config sub-section
642 ref = b'%s:%s' % (spec.ref, part) # select config sub-section
643 if ref in t:
643 if ref in t:
644 partsmap[part] = ref
644 partsmap[part] = ref
645 return partsmap
645 return partsmap
646
646
647
647
648 def loadtemplater(ui, spec, defaults=None, resources=None, cache=None):
648 def loadtemplater(ui, spec, defaults=None, resources=None, cache=None):
649 """Create a templater from either a literal template or loading from
649 """Create a templater from either a literal template or loading from
650 a map file"""
650 a map file"""
651 assert not (spec.tmpl and spec.mapfile)
651 assert not (spec.tmpl and spec.mapfile)
652 if spec.mapfile:
652 if spec.mapfile:
653 return templater.templater.frommapfile(
653 return templater.templater.frommapfile(
654 spec.mapfile,
654 spec.mapfile,
655 spec.fp,
655 spec.fp,
656 defaults=defaults,
656 defaults=defaults,
657 resources=resources,
657 resources=resources,
658 cache=cache,
658 cache=cache,
659 )
659 )
660 return maketemplater(
660 return maketemplater(
661 ui, spec.tmpl, defaults=defaults, resources=resources, cache=cache
661 ui, spec.tmpl, defaults=defaults, resources=resources, cache=cache
662 )
662 )
663
663
664
664
665 def maketemplater(ui, tmpl, defaults=None, resources=None, cache=None):
665 def maketemplater(ui, tmpl, defaults=None, resources=None, cache=None):
666 """Create a templater from a string template 'tmpl'"""
666 """Create a templater from a string template 'tmpl'"""
667 aliases = ui.configitems(b'templatealias')
667 aliases = ui.configitems(b'templatealias')
668 t = templater.templater(
668 t = templater.templater(
669 defaults=defaults, resources=resources, cache=cache, aliases=aliases
669 defaults=defaults, resources=resources, cache=cache, aliases=aliases
670 )
670 )
671 t.cache.update(
671 t.cache.update(
672 (k, templater.unquotestring(v)) for k, v in ui.configitems(b'templates')
672 (k, templater.unquotestring(v)) for k, v in ui.configitems(b'templates')
673 )
673 )
674 if tmpl:
674 if tmpl:
675 t.cache[b''] = tmpl
675 t.cache[b''] = tmpl
676 return t
676 return t
677
677
678
678
679 # marker to denote a resource to be loaded on demand based on mapping values
679 # marker to denote a resource to be loaded on demand based on mapping values
680 # (e.g. (ctx, path) -> fctx)
680 # (e.g. (ctx, path) -> fctx)
681 _placeholder = object()
681 _placeholder = object()
682
682
683
683
684 class templateresources(templater.resourcemapper):
684 class templateresources(templater.resourcemapper):
685 """Resource mapper designed for the default templatekw and function"""
685 """Resource mapper designed for the default templatekw and function"""
686
686
687 def __init__(self, ui, repo=None):
687 def __init__(self, ui, repo=None):
688 self._resmap = {
688 self._resmap = {
689 b'cache': {}, # for templatekw/funcs to store reusable data
689 b'cache': {}, # for templatekw/funcs to store reusable data
690 b'repo': repo,
690 b'repo': repo,
691 b'ui': ui,
691 b'ui': ui,
692 }
692 }
693
693
694 def availablekeys(self, mapping):
694 def availablekeys(self, mapping):
695 return {
695 return {
696 k for k in self.knownkeys() if self._getsome(mapping, k) is not None
696 k for k in self.knownkeys() if self._getsome(mapping, k) is not None
697 }
697 }
698
698
699 def knownkeys(self):
699 def knownkeys(self):
700 return {b'cache', b'ctx', b'fctx', b'repo', b'revcache', b'ui'}
700 return {b'cache', b'ctx', b'fctx', b'repo', b'revcache', b'ui'}
701
701
702 def lookup(self, mapping, key):
702 def lookup(self, mapping, key):
703 if key not in self.knownkeys():
703 if key not in self.knownkeys():
704 return None
704 return None
705 v = self._getsome(mapping, key)
705 v = self._getsome(mapping, key)
706 if v is _placeholder:
706 if v is _placeholder:
707 v = mapping[key] = self._loadermap[key](self, mapping)
707 v = mapping[key] = self._loadermap[key](self, mapping)
708 return v
708 return v
709
709
710 def populatemap(self, context, origmapping, newmapping):
710 def populatemap(self, context, origmapping, newmapping):
711 mapping = {}
711 mapping = {}
712 if self._hasnodespec(newmapping):
712 if self._hasnodespec(newmapping):
713 mapping[b'revcache'] = {} # per-ctx cache
713 mapping[b'revcache'] = {} # per-ctx cache
714 if self._hasnodespec(origmapping) and self._hasnodespec(newmapping):
714 if self._hasnodespec(origmapping) and self._hasnodespec(newmapping):
715 orignode = templateutil.runsymbol(context, origmapping, b'node')
715 orignode = templateutil.runsymbol(context, origmapping, b'node')
716 mapping[b'originalnode'] = orignode
716 mapping[b'originalnode'] = orignode
717 # put marker to override 'ctx'/'fctx' in mapping if any, and flag
717 # put marker to override 'ctx'/'fctx' in mapping if any, and flag
718 # its existence to be reported by availablekeys()
718 # its existence to be reported by availablekeys()
719 if b'ctx' not in newmapping and self._hasliteral(newmapping, b'node'):
719 if b'ctx' not in newmapping and self._hasliteral(newmapping, b'node'):
720 mapping[b'ctx'] = _placeholder
720 mapping[b'ctx'] = _placeholder
721 if b'fctx' not in newmapping and self._hasliteral(newmapping, b'path'):
721 if b'fctx' not in newmapping and self._hasliteral(newmapping, b'path'):
722 mapping[b'fctx'] = _placeholder
722 mapping[b'fctx'] = _placeholder
723 return mapping
723 return mapping
724
724
725 def _getsome(self, mapping, key):
725 def _getsome(self, mapping, key):
726 v = mapping.get(key)
726 v = mapping.get(key)
727 if v is not None:
727 if v is not None:
728 return v
728 return v
729 return self._resmap.get(key)
729 return self._resmap.get(key)
730
730
731 def _hasliteral(self, mapping, key):
731 def _hasliteral(self, mapping, key):
732 """Test if a literal value is set or unset in the given mapping"""
732 """Test if a literal value is set or unset in the given mapping"""
733 return key in mapping and not callable(mapping[key])
733 return key in mapping and not callable(mapping[key])
734
734
735 def _getliteral(self, mapping, key):
735 def _getliteral(self, mapping, key):
736 """Return value of the given name if it is a literal"""
736 """Return value of the given name if it is a literal"""
737 v = mapping.get(key)
737 v = mapping.get(key)
738 if callable(v):
738 if callable(v):
739 return None
739 return None
740 return v
740 return v
741
741
742 def _hasnodespec(self, mapping):
742 def _hasnodespec(self, mapping):
743 """Test if context revision is set or unset in the given mapping"""
743 """Test if context revision is set or unset in the given mapping"""
744 return b'node' in mapping or b'ctx' in mapping
744 return b'node' in mapping or b'ctx' in mapping
745
745
746 def _loadctx(self, mapping):
746 def _loadctx(self, mapping):
747 repo = self._getsome(mapping, b'repo')
747 repo = self._getsome(mapping, b'repo')
748 node = self._getliteral(mapping, b'node')
748 node = self._getliteral(mapping, b'node')
749 if repo is None or node is None:
749 if repo is None or node is None:
750 return
750 return
751 try:
751 try:
752 return repo[node]
752 return repo[node]
753 except error.RepoLookupError:
753 except error.RepoLookupError:
754 return None # maybe hidden/non-existent node
754 return None # maybe hidden/non-existent node
755
755
756 def _loadfctx(self, mapping):
756 def _loadfctx(self, mapping):
757 ctx = self._getsome(mapping, b'ctx')
757 ctx = self._getsome(mapping, b'ctx')
758 path = self._getliteral(mapping, b'path')
758 path = self._getliteral(mapping, b'path')
759 if ctx is None or path is None:
759 if ctx is None or path is None:
760 return None
760 return None
761 try:
761 try:
762 return ctx[path]
762 return ctx[path]
763 except error.LookupError:
763 except error.LookupError:
764 return None # maybe removed file?
764 return None # maybe removed file?
765
765
766 _loadermap = {
766 _loadermap = {
767 b'ctx': _loadctx,
767 b'ctx': _loadctx,
768 b'fctx': _loadfctx,
768 b'fctx': _loadfctx,
769 }
769 }
770
770
771
771
772 def _internaltemplateformatter(
772 def _internaltemplateformatter(
773 ui,
773 ui,
774 out,
774 out,
775 topic,
775 topic,
776 opts,
776 opts,
777 spec,
777 spec,
778 tmpl,
778 tmpl,
779 docheader=b'',
779 docheader=b'',
780 docfooter=b'',
780 docfooter=b'',
781 separator=b'',
781 separator=b'',
782 ):
782 ):
783 """Build template formatter that handles customizable built-in templates
783 """Build template formatter that handles customizable built-in templates
784 such as -Tjson(...)"""
784 such as -Tjson(...)"""
785 templates = {spec.ref: tmpl}
785 templates = {spec.ref: tmpl}
786 if docheader:
786 if docheader:
787 templates[b'%s:docheader' % spec.ref] = docheader
787 templates[b'%s:docheader' % spec.ref] = docheader
788 if docfooter:
788 if docfooter:
789 templates[b'%s:docfooter' % spec.ref] = docfooter
789 templates[b'%s:docfooter' % spec.ref] = docfooter
790 if separator:
790 if separator:
791 templates[b'%s:separator' % spec.ref] = separator
791 templates[b'%s:separator' % spec.ref] = separator
792 return templateformatter(
792 return templateformatter(
793 ui, out, topic, opts, spec, overridetemplates=templates
793 ui, out, topic, opts, spec, overridetemplates=templates
794 )
794 )
795
795
796
796
797 def formatter(ui, out, topic, opts):
797 def formatter(ui, out, topic, opts):
798 spec = lookuptemplate(ui, topic, opts.get(b'template', b''))
798 spec = lookuptemplate(ui, topic, opts.get(b'template', b''))
799 if spec.ref == b"cbor" and spec.refargs is not None:
799 if spec.ref == b"cbor" and spec.refargs is not None:
800 return _internaltemplateformatter(
800 return _internaltemplateformatter(
801 ui,
801 ui,
802 out,
802 out,
803 topic,
803 topic,
804 opts,
804 opts,
805 spec,
805 spec,
806 tmpl=b'{dict(%s)|cbor}' % spec.refargs,
806 tmpl=b'{dict(%s)|cbor}' % spec.refargs,
807 docheader=cborutil.BEGIN_INDEFINITE_ARRAY,
807 docheader=cborutil.BEGIN_INDEFINITE_ARRAY,
808 docfooter=cborutil.BREAK,
808 docfooter=cborutil.BREAK,
809 )
809 )
810 elif spec.ref == b"cbor":
810 elif spec.ref == b"cbor":
811 return cborformatter(ui, out, topic, opts)
811 return cborformatter(ui, out, topic, opts)
812 elif spec.ref == b"json" and spec.refargs is not None:
812 elif spec.ref == b"json" and spec.refargs is not None:
813 return _internaltemplateformatter(
813 return _internaltemplateformatter(
814 ui,
814 ui,
815 out,
815 out,
816 topic,
816 topic,
817 opts,
817 opts,
818 spec,
818 spec,
819 tmpl=b'{dict(%s)|json}' % spec.refargs,
819 tmpl=b'{dict(%s)|json}' % spec.refargs,
820 docheader=b'[\n ',
820 docheader=b'[\n ',
821 docfooter=b'\n]\n',
821 docfooter=b'\n]\n',
822 separator=b',\n ',
822 separator=b',\n ',
823 )
823 )
824 elif spec.ref == b"json":
824 elif spec.ref == b"json":
825 return jsonformatter(ui, out, topic, opts)
825 return jsonformatter(ui, out, topic, opts)
826 elif spec.ref == b"pickle":
826 elif spec.ref == b"pickle":
827 assert spec.refargs is None, r'function-style not supported'
827 assert spec.refargs is None, r'function-style not supported'
828 return pickleformatter(ui, out, topic, opts)
828 return pickleformatter(ui, out, topic, opts)
829 elif spec.ref == b"debug":
829 elif spec.ref == b"debug":
830 assert spec.refargs is None, r'function-style not supported'
830 assert spec.refargs is None, r'function-style not supported'
831 return debugformatter(ui, out, topic, opts)
831 return debugformatter(ui, out, topic, opts)
832 elif spec.ref or spec.tmpl or spec.mapfile:
832 elif spec.ref or spec.tmpl or spec.mapfile:
833 assert spec.refargs is None, r'function-style not supported'
833 assert spec.refargs is None, r'function-style not supported'
834 return templateformatter(ui, out, topic, opts, spec)
834 return templateformatter(ui, out, topic, opts, spec)
835 # developer config: ui.formatdebug
835 # developer config: ui.formatdebug
836 elif ui.configbool(b'ui', b'formatdebug'):
836 elif ui.configbool(b'ui', b'formatdebug'):
837 return debugformatter(ui, out, topic, opts)
837 return debugformatter(ui, out, topic, opts)
838 # deprecated config: ui.formatjson
838 # deprecated config: ui.formatjson
839 elif ui.configbool(b'ui', b'formatjson'):
839 elif ui.configbool(b'ui', b'formatjson'):
840 return jsonformatter(ui, out, topic, opts)
840 return jsonformatter(ui, out, topic, opts)
841 return plainformatter(ui, out, topic, opts)
841 return plainformatter(ui, out, topic, opts)
842
842
843
843
844 @contextlib.contextmanager
844 @contextlib.contextmanager
845 def openformatter(ui, filename, topic, opts):
845 def openformatter(ui, filename, topic, opts):
846 """Create a formatter that writes outputs to the specified file
846 """Create a formatter that writes outputs to the specified file
847
847
848 Must be invoked using the 'with' statement.
848 Must be invoked using the 'with' statement.
849 """
849 """
850 with util.posixfile(filename, b'wb') as out:
850 with util.posixfile(filename, b'wb') as out:
851 with formatter(ui, out, topic, opts) as fm:
851 with formatter(ui, out, topic, opts) as fm:
852 yield fm
852 yield fm
853
853
854
854
855 @contextlib.contextmanager
855 @contextlib.contextmanager
856 def _neverending(fm):
856 def _neverending(fm):
857 yield fm
857 yield fm
858
858
859
859
860 def maybereopen(fm, filename):
860 def maybereopen(fm, filename):
861 """Create a formatter backed by file if filename specified, else return
861 """Create a formatter backed by file if filename specified, else return
862 the given formatter
862 the given formatter
863
863
864 Must be invoked using the 'with' statement. This will never call fm.end()
864 Must be invoked using the 'with' statement. This will never call fm.end()
865 of the given formatter.
865 of the given formatter.
866 """
866 """
867 if filename:
867 if filename:
868 return openformatter(fm._ui, filename, fm._topic, fm._opts)
868 return openformatter(fm._ui, filename, fm._topic, fm._opts)
869 else:
869 else:
870 return _neverending(fm)
870 return _neverending(fm)
@@ -1,556 +1,556 b''
1 # templatefilters.py - common template expansion filters
1 # templatefilters.py - common template expansion filters
2 #
2 #
3 # Copyright 2005-2008 Olivia Mackall <olivia@selenic.com>
3 # Copyright 2005-2008 Olivia Mackall <olivia@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8
8
9 import os
9 import os
10 import re
10 import re
11 import time
11 import time
12
12
13 from .i18n import _
13 from .i18n import _
14 from .node import hex
14 from .node import hex
15 from . import (
15 from . import (
16 encoding,
16 encoding,
17 error,
17 error,
18 pycompat,
18 pycompat,
19 registrar,
19 registrar,
20 smartset,
20 smartset,
21 templateutil,
21 templateutil,
22 url,
22 url,
23 util,
23 util,
24 )
24 )
25 from .utils import (
25 from .utils import (
26 cborutil,
26 cborutil,
27 dateutil,
27 dateutil,
28 stringutil,
28 stringutil,
29 )
29 )
30
30
31 urlerr = util.urlerr
31 urlerr = util.urlerr
32 urlreq = util.urlreq
32 urlreq = util.urlreq
33
33
34 # filters are callables like:
34 # filters are callables like:
35 # fn(obj)
35 # fn(obj)
36 # with:
36 # with:
37 # obj - object to be filtered (text, date, list and so on)
37 # obj - object to be filtered (text, date, list and so on)
38 filters = {}
38 filters = {}
39
39
40 templatefilter = registrar.templatefilter(filters)
40 templatefilter = registrar.templatefilter(filters)
41
41
42
42
43 @templatefilter(b'addbreaks', intype=bytes)
43 @templatefilter(b'addbreaks', intype=bytes)
44 def addbreaks(text):
44 def addbreaks(text):
45 """Any text. Add an XHTML "<br />" tag before the end of
45 """Any text. Add an XHTML "<br />" tag before the end of
46 every line except the last.
46 every line except the last.
47 """
47 """
48 return text.replace(b'\n', b'<br/>\n')
48 return text.replace(b'\n', b'<br/>\n')
49
49
50
50
51 agescales = [
51 agescales = [
52 (b"year", 3600 * 24 * 365, b'Y'),
52 (b"year", 3600 * 24 * 365, b'Y'),
53 (b"month", 3600 * 24 * 30, b'M'),
53 (b"month", 3600 * 24 * 30, b'M'),
54 (b"week", 3600 * 24 * 7, b'W'),
54 (b"week", 3600 * 24 * 7, b'W'),
55 (b"day", 3600 * 24, b'd'),
55 (b"day", 3600 * 24, b'd'),
56 (b"hour", 3600, b'h'),
56 (b"hour", 3600, b'h'),
57 (b"minute", 60, b'm'),
57 (b"minute", 60, b'm'),
58 (b"second", 1, b's'),
58 (b"second", 1, b's'),
59 ]
59 ]
60
60
61
61
62 @templatefilter(b'age', intype=templateutil.date)
62 @templatefilter(b'age', intype=templateutil.date)
63 def age(date, abbrev=False):
63 def age(date, abbrev=False):
64 """Date. Returns a human-readable date/time difference between the
64 """Date. Returns a human-readable date/time difference between the
65 given date/time and the current date/time.
65 given date/time and the current date/time.
66 """
66 """
67
67
68 def plural(t, c):
68 def plural(t, c):
69 if c == 1:
69 if c == 1:
70 return t
70 return t
71 return t + b"s"
71 return t + b"s"
72
72
73 def fmt(t, c, a):
73 def fmt(t, c, a):
74 if abbrev:
74 if abbrev:
75 return b"%d%s" % (c, a)
75 return b"%d%s" % (c, a)
76 return b"%d %s" % (c, plural(t, c))
76 return b"%d %s" % (c, plural(t, c))
77
77
78 now = time.time()
78 now = time.time()
79 then = date[0]
79 then = date[0]
80 future = False
80 future = False
81 if then > now:
81 if then > now:
82 future = True
82 future = True
83 delta = max(1, int(then - now))
83 delta = max(1, int(then - now))
84 if delta > agescales[0][1] * 30:
84 if delta > agescales[0][1] * 30:
85 return b'in the distant future'
85 return b'in the distant future'
86 else:
86 else:
87 delta = max(1, int(now - then))
87 delta = max(1, int(now - then))
88 if delta > agescales[0][1] * 2:
88 if delta > agescales[0][1] * 2:
89 return dateutil.shortdate(date)
89 return dateutil.shortdate(date)
90
90
91 for t, s, a in agescales:
91 for t, s, a in agescales:
92 n = delta // s
92 n = delta // s
93 if n >= 2 or s == 1:
93 if n >= 2 or s == 1:
94 if future:
94 if future:
95 return b'%s from now' % fmt(t, n, a)
95 return b'%s from now' % fmt(t, n, a)
96 return b'%s ago' % fmt(t, n, a)
96 return b'%s ago' % fmt(t, n, a)
97
97
98
98
99 @templatefilter(b'basename', intype=bytes)
99 @templatefilter(b'basename', intype=bytes)
100 def basename(path):
100 def basename(path):
101 """Any text. Treats the text as a path, and returns the last
101 """Any text. Treats the text as a path, and returns the last
102 component of the path after splitting by the path separator.
102 component of the path after splitting by the path separator.
103 For example, "foo/bar/baz" becomes "baz" and "foo/bar//" becomes "".
103 For example, "foo/bar/baz" becomes "baz" and "foo/bar//" becomes "".
104 """
104 """
105 return os.path.basename(path)
105 return os.path.basename(path)
106
106
107
107
108 def _tocborencodable(obj):
108 def _tocborencodable(obj):
109 if isinstance(obj, smartset.abstractsmartset):
109 if isinstance(obj, smartset.abstractsmartset):
110 return list(obj)
110 return list(obj)
111 return obj
111 return obj
112
112
113
113
114 @templatefilter(b'cbor')
114 @templatefilter(b'cbor')
115 def cbor(obj):
115 def cbor(obj):
116 """Any object. Serializes the object to CBOR bytes."""
116 """Any object. Serializes the object to CBOR bytes."""
117 # cborutil is stricter about type than json() filter
117 # cborutil is stricter about type than json() filter
118 obj = pycompat.rapply(_tocborencodable, obj)
118 obj = pycompat.rapply(_tocborencodable, obj)
119 return b''.join(cborutil.streamencode(obj))
119 return b''.join(cborutil.streamencode(obj))
120
120
121
121
122 @templatefilter(b'commondir')
122 @templatefilter(b'commondir')
123 def commondir(filelist):
123 def commondir(filelist):
124 """List of text. Treats each list item as file name with /
124 """List of text. Treats each list item as file name with /
125 as path separator and returns the longest common directory
125 as path separator and returns the longest common directory
126 prefix shared by all list items.
126 prefix shared by all list items.
127 Returns the empty string if no common prefix exists.
127 Returns the empty string if no common prefix exists.
128
128
129 The list items are not normalized, i.e. "foo/../bar" is handled as
129 The list items are not normalized, i.e. "foo/../bar" is handled as
130 file "bar" in the directory "foo/..". Leading slashes are ignored.
130 file "bar" in the directory "foo/..". Leading slashes are ignored.
131
131
132 For example, ["foo/bar/baz", "foo/baz/bar"] becomes "foo" and
132 For example, ["foo/bar/baz", "foo/baz/bar"] becomes "foo" and
133 ["foo/bar", "baz"] becomes "".
133 ["foo/bar", "baz"] becomes "".
134 """
134 """
135
135
136 def common(a, b):
136 def common(a, b):
137 if len(a) > len(b):
137 if len(a) > len(b):
138 a = b[: len(a)]
138 a = b[: len(a)]
139 elif len(b) > len(a):
139 elif len(b) > len(a):
140 b = b[: len(a)]
140 b = b[: len(a)]
141 if a == b:
141 if a == b:
142 return a
142 return a
143 for i in pycompat.xrange(len(a)):
143 for i in pycompat.xrange(len(a)):
144 if a[i] != b[i]:
144 if a[i] != b[i]:
145 return a[:i]
145 return a[:i]
146 return a
146 return a
147
147
148 try:
148 try:
149 if not filelist:
149 if not filelist:
150 return b""
150 return b""
151 dirlist = [f.lstrip(b'/').split(b'/')[:-1] for f in filelist]
151 dirlist = [f.lstrip(b'/').split(b'/')[:-1] for f in filelist]
152 if len(dirlist) == 1:
152 if len(dirlist) == 1:
153 return b'/'.join(dirlist[0])
153 return b'/'.join(dirlist[0])
154 a = min(dirlist)
154 a = min(dirlist)
155 b = max(dirlist)
155 b = max(dirlist)
156 # The common prefix of a and b is shared with all
156 # The common prefix of a and b is shared with all
157 # elements of the list since Python sorts lexicographical
157 # elements of the list since Python sorts lexicographical
158 # and [1, x] after [1].
158 # and [1, x] after [1].
159 return b'/'.join(common(a, b))
159 return b'/'.join(common(a, b))
160 except TypeError:
160 except TypeError:
161 raise error.ParseError(_(b'argument is not a list of text'))
161 raise error.ParseError(_(b'argument is not a list of text'))
162
162
163
163
164 @templatefilter(b'count')
164 @templatefilter(b'count')
165 def count(i):
165 def count(i):
166 """List or text. Returns the length as an integer."""
166 """List or text. Returns the length as an integer."""
167 try:
167 try:
168 return len(i)
168 return len(i)
169 except TypeError:
169 except TypeError:
170 raise error.ParseError(_(b'not countable'))
170 raise error.ParseError(_(b'not countable'))
171
171
172
172
173 @templatefilter(b'dirname', intype=bytes)
173 @templatefilter(b'dirname', intype=bytes)
174 def dirname(path):
174 def dirname(path):
175 """Any text. Treats the text as a path, and strips the last
175 """Any text. Treats the text as a path, and strips the last
176 component of the path after splitting by the path separator.
176 component of the path after splitting by the path separator.
177 """
177 """
178 return os.path.dirname(path)
178 return os.path.dirname(path)
179
179
180
180
181 @templatefilter(b'domain', intype=bytes)
181 @templatefilter(b'domain', intype=bytes)
182 def domain(author):
182 def domain(author):
183 """Any text. Finds the first string that looks like an email
183 """Any text. Finds the first string that looks like an email
184 address, and extracts just the domain component. Example: ``User
184 address, and extracts just the domain component. Example: ``User
185 <user@example.com>`` becomes ``example.com``.
185 <user@example.com>`` becomes ``example.com``.
186 """
186 """
187 f = author.find(b'@')
187 f = author.find(b'@')
188 if f == -1:
188 if f == -1:
189 return b''
189 return b''
190 author = author[f + 1 :]
190 author = author[f + 1 :]
191 f = author.find(b'>')
191 f = author.find(b'>')
192 if f >= 0:
192 if f >= 0:
193 author = author[:f]
193 author = author[:f]
194 return author
194 return author
195
195
196
196
197 @templatefilter(b'email', intype=bytes)
197 @templatefilter(b'email', intype=bytes)
198 def email(text):
198 def email(text):
199 """Any text. Extracts the first string that looks like an email
199 """Any text. Extracts the first string that looks like an email
200 address. Example: ``User <user@example.com>`` becomes
200 address. Example: ``User <user@example.com>`` becomes
201 ``user@example.com``.
201 ``user@example.com``.
202 """
202 """
203 return stringutil.email(text)
203 return stringutil.email(text)
204
204
205
205
206 @templatefilter(b'escape', intype=bytes)
206 @templatefilter(b'escape', intype=bytes)
207 def escape(text):
207 def escape(text):
208 """Any text. Replaces the special XML/XHTML characters "&", "<"
208 """Any text. Replaces the special XML/XHTML characters "&", "<"
209 and ">" with XML entities, and filters out NUL characters.
209 and ">" with XML entities, and filters out NUL characters.
210 """
210 """
211 return url.escape(text.replace(b'\0', b''), True)
211 return url.escape(text.replace(b'\0', b''), True)
212
212
213
213
214 para_re = None
214 para_re = None
215 space_re = None
215 space_re = None
216
216
217
217
218 def fill(text, width, initindent=b'', hangindent=b''):
218 def fill(text, width, initindent=b'', hangindent=b''):
219 '''fill many paragraphs with optional indentation.'''
219 '''fill many paragraphs with optional indentation.'''
220 global para_re, space_re
220 global para_re, space_re
221 if para_re is None:
221 if para_re is None:
222 para_re = re.compile(b'(\n\n|\n\\s*[-*]\\s*)', re.M)
222 para_re = re.compile(b'(\n\n|\n\\s*[-*]\\s*)', re.M)
223 space_re = re.compile(br' +')
223 space_re = re.compile(br' +')
224
224
225 def findparas():
225 def findparas():
226 start = 0
226 start = 0
227 while True:
227 while True:
228 m = para_re.search(text, start)
228 m = para_re.search(text, start)
229 if not m:
229 if not m:
230 uctext = encoding.unifromlocal(text[start:])
230 uctext = encoding.unifromlocal(text[start:])
231 w = len(uctext)
231 w = len(uctext)
232 while w > 0 and uctext[w - 1].isspace():
232 while w > 0 and uctext[w - 1].isspace():
233 w -= 1
233 w -= 1
234 yield (
234 yield (
235 encoding.unitolocal(uctext[:w]),
235 encoding.unitolocal(uctext[:w]),
236 encoding.unitolocal(uctext[w:]),
236 encoding.unitolocal(uctext[w:]),
237 )
237 )
238 break
238 break
239 yield text[start : m.start(0)], m.group(1)
239 yield text[start : m.start(0)], m.group(1)
240 start = m.end(1)
240 start = m.end(1)
241
241
242 return b"".join(
242 return b"".join(
243 [
243 [
244 stringutil.wrap(
244 stringutil.wrap(
245 space_re.sub(b' ', stringutil.wrap(para, width)),
245 space_re.sub(b' ', stringutil.wrap(para, width)),
246 width,
246 width,
247 initindent,
247 initindent,
248 hangindent,
248 hangindent,
249 )
249 )
250 + rest
250 + rest
251 for para, rest in findparas()
251 for para, rest in findparas()
252 ]
252 ]
253 )
253 )
254
254
255
255
256 @templatefilter(b'fill68', intype=bytes)
256 @templatefilter(b'fill68', intype=bytes)
257 def fill68(text):
257 def fill68(text):
258 """Any text. Wraps the text to fit in 68 columns."""
258 """Any text. Wraps the text to fit in 68 columns."""
259 return fill(text, 68)
259 return fill(text, 68)
260
260
261
261
262 @templatefilter(b'fill76', intype=bytes)
262 @templatefilter(b'fill76', intype=bytes)
263 def fill76(text):
263 def fill76(text):
264 """Any text. Wraps the text to fit in 76 columns."""
264 """Any text. Wraps the text to fit in 76 columns."""
265 return fill(text, 76)
265 return fill(text, 76)
266
266
267
267
268 @templatefilter(b'firstline', intype=bytes)
268 @templatefilter(b'firstline', intype=bytes)
269 def firstline(text):
269 def firstline(text):
270 """Any text. Returns the first line of text."""
270 """Any text. Returns the first line of text."""
271 try:
271 try:
272 return text.splitlines(True)[0].rstrip(b'\r\n')
272 return text.splitlines(True)[0].rstrip(b'\r\n')
273 except IndexError:
273 except IndexError:
274 return b''
274 return b''
275
275
276
276
277 @templatefilter(b'hex', intype=bytes)
277 @templatefilter(b'hex', intype=bytes)
278 def hexfilter(text):
278 def hexfilter(text):
279 """Any text. Convert a binary Mercurial node identifier into
279 """Any text. Convert a binary Mercurial node identifier into
280 its long hexadecimal representation.
280 its long hexadecimal representation.
281 """
281 """
282 return hex(text)
282 return hex(text)
283
283
284
284
285 @templatefilter(b'hgdate', intype=templateutil.date)
285 @templatefilter(b'hgdate', intype=templateutil.date)
286 def hgdate(text):
286 def hgdate(text):
287 """Date. Returns the date as a pair of numbers: "1157407993
287 """Date. Returns the date as a pair of numbers: "1157407993
288 25200" (Unix timestamp, timezone offset).
288 25200" (Unix timestamp, timezone offset).
289 """
289 """
290 return b"%d %d" % text
290 return b"%d %d" % text
291
291
292
292
293 @templatefilter(b'isodate', intype=templateutil.date)
293 @templatefilter(b'isodate', intype=templateutil.date)
294 def isodate(text):
294 def isodate(text):
295 """Date. Returns the date in ISO 8601 format: "2009-08-18 13:00
295 """Date. Returns the date in ISO 8601 format: "2009-08-18 13:00
296 +0200".
296 +0200".
297 """
297 """
298 return dateutil.datestr(text, b'%Y-%m-%d %H:%M %1%2')
298 return dateutil.datestr(text, b'%Y-%m-%d %H:%M %1%2')
299
299
300
300
301 @templatefilter(b'isodatesec', intype=templateutil.date)
301 @templatefilter(b'isodatesec', intype=templateutil.date)
302 def isodatesec(text):
302 def isodatesec(text):
303 """Date. Returns the date in ISO 8601 format, including
303 """Date. Returns the date in ISO 8601 format, including
304 seconds: "2009-08-18 13:00:13 +0200". See also the rfc3339date
304 seconds: "2009-08-18 13:00:13 +0200". See also the rfc3339date
305 filter.
305 filter.
306 """
306 """
307 return dateutil.datestr(text, b'%Y-%m-%d %H:%M:%S %1%2')
307 return dateutil.datestr(text, b'%Y-%m-%d %H:%M:%S %1%2')
308
308
309
309
310 def indent(text, prefix, firstline=b''):
310 def indent(text, prefix, firstline=b''):
311 '''indent each non-empty line of text after first with prefix.'''
311 '''indent each non-empty line of text after first with prefix.'''
312 lines = text.splitlines()
312 lines = text.splitlines()
313 num_lines = len(lines)
313 num_lines = len(lines)
314 endswithnewline = text[-1:] == b'\n'
314 endswithnewline = text[-1:] == b'\n'
315
315
316 def indenter():
316 def indenter():
317 for i in pycompat.xrange(num_lines):
317 for i in pycompat.xrange(num_lines):
318 l = lines[i]
318 l = lines[i]
319 if l.strip():
319 if l.strip():
320 yield prefix if i else firstline
320 yield prefix if i else firstline
321 yield l
321 yield l
322 if i < num_lines - 1 or endswithnewline:
322 if i < num_lines - 1 or endswithnewline:
323 yield b'\n'
323 yield b'\n'
324
324
325 return b"".join(indenter())
325 return b"".join(indenter())
326
326
327
327
328 @templatefilter(b'json')
328 @templatefilter(b'json')
329 def json(obj, paranoid=True):
329 def json(obj, paranoid=True):
330 """Any object. Serializes the object to a JSON formatted text."""
330 """Any object. Serializes the object to a JSON formatted text."""
331 if obj is None:
331 if obj is None:
332 return b'null'
332 return b'null'
333 elif obj is False:
333 elif obj is False:
334 return b'false'
334 return b'false'
335 elif obj is True:
335 elif obj is True:
336 return b'true'
336 return b'true'
337 elif isinstance(obj, (int, pycompat.long, float)):
337 elif isinstance(obj, (int, int, float)):
338 return pycompat.bytestr(obj)
338 return pycompat.bytestr(obj)
339 elif isinstance(obj, bytes):
339 elif isinstance(obj, bytes):
340 return b'"%s"' % encoding.jsonescape(obj, paranoid=paranoid)
340 return b'"%s"' % encoding.jsonescape(obj, paranoid=paranoid)
341 elif isinstance(obj, type(u'')):
341 elif isinstance(obj, type(u'')):
342 raise error.ProgrammingError(
342 raise error.ProgrammingError(
343 b'Mercurial only does output with bytes: %r' % obj
343 b'Mercurial only does output with bytes: %r' % obj
344 )
344 )
345 elif util.safehasattr(obj, b'keys'):
345 elif util.safehasattr(obj, b'keys'):
346 out = [
346 out = [
347 b'"%s": %s'
347 b'"%s": %s'
348 % (encoding.jsonescape(k, paranoid=paranoid), json(v, paranoid))
348 % (encoding.jsonescape(k, paranoid=paranoid), json(v, paranoid))
349 for k, v in sorted(obj.items())
349 for k, v in sorted(obj.items())
350 ]
350 ]
351 return b'{' + b', '.join(out) + b'}'
351 return b'{' + b', '.join(out) + b'}'
352 elif util.safehasattr(obj, b'__iter__'):
352 elif util.safehasattr(obj, b'__iter__'):
353 out = [json(i, paranoid) for i in obj]
353 out = [json(i, paranoid) for i in obj]
354 return b'[' + b', '.join(out) + b']'
354 return b'[' + b', '.join(out) + b']'
355 raise error.ProgrammingError(b'cannot encode %r' % obj)
355 raise error.ProgrammingError(b'cannot encode %r' % obj)
356
356
357
357
358 @templatefilter(b'lower', intype=bytes)
358 @templatefilter(b'lower', intype=bytes)
359 def lower(text):
359 def lower(text):
360 """Any text. Converts the text to lowercase."""
360 """Any text. Converts the text to lowercase."""
361 return encoding.lower(text)
361 return encoding.lower(text)
362
362
363
363
364 @templatefilter(b'nonempty', intype=bytes)
364 @templatefilter(b'nonempty', intype=bytes)
365 def nonempty(text):
365 def nonempty(text):
366 """Any text. Returns '(none)' if the string is empty."""
366 """Any text. Returns '(none)' if the string is empty."""
367 return text or b"(none)"
367 return text or b"(none)"
368
368
369
369
370 @templatefilter(b'obfuscate', intype=bytes)
370 @templatefilter(b'obfuscate', intype=bytes)
371 def obfuscate(text):
371 def obfuscate(text):
372 """Any text. Returns the input text rendered as a sequence of
372 """Any text. Returns the input text rendered as a sequence of
373 XML entities.
373 XML entities.
374 """
374 """
375 text = pycompat.unicode(
375 text = pycompat.unicode(
376 text, pycompat.sysstr(encoding.encoding), r'replace'
376 text, pycompat.sysstr(encoding.encoding), r'replace'
377 )
377 )
378 return b''.join([b'&#%d;' % ord(c) for c in text])
378 return b''.join([b'&#%d;' % ord(c) for c in text])
379
379
380
380
381 @templatefilter(b'permissions', intype=bytes)
381 @templatefilter(b'permissions', intype=bytes)
382 def permissions(flags):
382 def permissions(flags):
383 if b"l" in flags:
383 if b"l" in flags:
384 return b"lrwxrwxrwx"
384 return b"lrwxrwxrwx"
385 if b"x" in flags:
385 if b"x" in flags:
386 return b"-rwxr-xr-x"
386 return b"-rwxr-xr-x"
387 return b"-rw-r--r--"
387 return b"-rw-r--r--"
388
388
389
389
390 @templatefilter(b'person', intype=bytes)
390 @templatefilter(b'person', intype=bytes)
391 def person(author):
391 def person(author):
392 """Any text. Returns the name before an email address,
392 """Any text. Returns the name before an email address,
393 interpreting it as per RFC 5322.
393 interpreting it as per RFC 5322.
394 """
394 """
395 return stringutil.person(author)
395 return stringutil.person(author)
396
396
397
397
398 @templatefilter(b'revescape', intype=bytes)
398 @templatefilter(b'revescape', intype=bytes)
399 def revescape(text):
399 def revescape(text):
400 """Any text. Escapes all "special" characters, except @.
400 """Any text. Escapes all "special" characters, except @.
401 Forward slashes are escaped twice to prevent web servers from prematurely
401 Forward slashes are escaped twice to prevent web servers from prematurely
402 unescaping them. For example, "@foo bar/baz" becomes "@foo%20bar%252Fbaz".
402 unescaping them. For example, "@foo bar/baz" becomes "@foo%20bar%252Fbaz".
403 """
403 """
404 return urlreq.quote(text, safe=b'/@').replace(b'/', b'%252F')
404 return urlreq.quote(text, safe=b'/@').replace(b'/', b'%252F')
405
405
406
406
407 @templatefilter(b'rfc3339date', intype=templateutil.date)
407 @templatefilter(b'rfc3339date', intype=templateutil.date)
408 def rfc3339date(text):
408 def rfc3339date(text):
409 """Date. Returns a date using the Internet date format
409 """Date. Returns a date using the Internet date format
410 specified in RFC 3339: "2009-08-18T13:00:13+02:00".
410 specified in RFC 3339: "2009-08-18T13:00:13+02:00".
411 """
411 """
412 return dateutil.datestr(text, b"%Y-%m-%dT%H:%M:%S%1:%2")
412 return dateutil.datestr(text, b"%Y-%m-%dT%H:%M:%S%1:%2")
413
413
414
414
415 @templatefilter(b'rfc822date', intype=templateutil.date)
415 @templatefilter(b'rfc822date', intype=templateutil.date)
416 def rfc822date(text):
416 def rfc822date(text):
417 """Date. Returns a date using the same format used in email
417 """Date. Returns a date using the same format used in email
418 headers: "Tue, 18 Aug 2009 13:00:13 +0200".
418 headers: "Tue, 18 Aug 2009 13:00:13 +0200".
419 """
419 """
420 return dateutil.datestr(text, b"%a, %d %b %Y %H:%M:%S %1%2")
420 return dateutil.datestr(text, b"%a, %d %b %Y %H:%M:%S %1%2")
421
421
422
422
423 @templatefilter(b'short', intype=bytes)
423 @templatefilter(b'short', intype=bytes)
424 def short(text):
424 def short(text):
425 """Changeset hash. Returns the short form of a changeset hash,
425 """Changeset hash. Returns the short form of a changeset hash,
426 i.e. a 12 hexadecimal digit string.
426 i.e. a 12 hexadecimal digit string.
427 """
427 """
428 return text[:12]
428 return text[:12]
429
429
430
430
431 @templatefilter(b'shortbisect', intype=bytes)
431 @templatefilter(b'shortbisect', intype=bytes)
432 def shortbisect(label):
432 def shortbisect(label):
433 """Any text. Treats `label` as a bisection status, and
433 """Any text. Treats `label` as a bisection status, and
434 returns a single-character representing the status (G: good, B: bad,
434 returns a single-character representing the status (G: good, B: bad,
435 S: skipped, U: untested, I: ignored). Returns single space if `text`
435 S: skipped, U: untested, I: ignored). Returns single space if `text`
436 is not a valid bisection status.
436 is not a valid bisection status.
437 """
437 """
438 if label:
438 if label:
439 return label[0:1].upper()
439 return label[0:1].upper()
440 return b' '
440 return b' '
441
441
442
442
443 @templatefilter(b'shortdate', intype=templateutil.date)
443 @templatefilter(b'shortdate', intype=templateutil.date)
444 def shortdate(text):
444 def shortdate(text):
445 """Date. Returns a date like "2006-09-18"."""
445 """Date. Returns a date like "2006-09-18"."""
446 return dateutil.shortdate(text)
446 return dateutil.shortdate(text)
447
447
448
448
449 @templatefilter(b'slashpath', intype=bytes)
449 @templatefilter(b'slashpath', intype=bytes)
450 def slashpath(path):
450 def slashpath(path):
451 """Any text. Replaces the native path separator with slash."""
451 """Any text. Replaces the native path separator with slash."""
452 return util.pconvert(path)
452 return util.pconvert(path)
453
453
454
454
455 @templatefilter(b'splitlines', intype=bytes)
455 @templatefilter(b'splitlines', intype=bytes)
456 def splitlines(text):
456 def splitlines(text):
457 """Any text. Split text into a list of lines."""
457 """Any text. Split text into a list of lines."""
458 return templateutil.hybridlist(text.splitlines(), name=b'line')
458 return templateutil.hybridlist(text.splitlines(), name=b'line')
459
459
460
460
461 @templatefilter(b'stringescape', intype=bytes)
461 @templatefilter(b'stringescape', intype=bytes)
462 def stringescape(text):
462 def stringescape(text):
463 return stringutil.escapestr(text)
463 return stringutil.escapestr(text)
464
464
465
465
466 @templatefilter(b'stringify', intype=bytes)
466 @templatefilter(b'stringify', intype=bytes)
467 def stringify(thing):
467 def stringify(thing):
468 """Any type. Turns the value into text by converting values into
468 """Any type. Turns the value into text by converting values into
469 text and concatenating them.
469 text and concatenating them.
470 """
470 """
471 return thing # coerced by the intype
471 return thing # coerced by the intype
472
472
473
473
474 @templatefilter(b'stripdir', intype=bytes)
474 @templatefilter(b'stripdir', intype=bytes)
475 def stripdir(text):
475 def stripdir(text):
476 """Treat the text as path and strip a directory level, if
476 """Treat the text as path and strip a directory level, if
477 possible. For example, "foo" and "foo/bar" becomes "foo".
477 possible. For example, "foo" and "foo/bar" becomes "foo".
478 """
478 """
479 dir = os.path.dirname(text)
479 dir = os.path.dirname(text)
480 if dir == b"":
480 if dir == b"":
481 return os.path.basename(text)
481 return os.path.basename(text)
482 else:
482 else:
483 return dir
483 return dir
484
484
485
485
486 @templatefilter(b'tabindent', intype=bytes)
486 @templatefilter(b'tabindent', intype=bytes)
487 def tabindent(text):
487 def tabindent(text):
488 """Any text. Returns the text, with every non-empty line
488 """Any text. Returns the text, with every non-empty line
489 except the first starting with a tab character.
489 except the first starting with a tab character.
490 """
490 """
491 return indent(text, b'\t')
491 return indent(text, b'\t')
492
492
493
493
494 @templatefilter(b'upper', intype=bytes)
494 @templatefilter(b'upper', intype=bytes)
495 def upper(text):
495 def upper(text):
496 """Any text. Converts the text to uppercase."""
496 """Any text. Converts the text to uppercase."""
497 return encoding.upper(text)
497 return encoding.upper(text)
498
498
499
499
500 @templatefilter(b'urlescape', intype=bytes)
500 @templatefilter(b'urlescape', intype=bytes)
501 def urlescape(text):
501 def urlescape(text):
502 """Any text. Escapes all "special" characters. For example,
502 """Any text. Escapes all "special" characters. For example,
503 "foo bar" becomes "foo%20bar".
503 "foo bar" becomes "foo%20bar".
504 """
504 """
505 return urlreq.quote(text)
505 return urlreq.quote(text)
506
506
507
507
508 @templatefilter(b'user', intype=bytes)
508 @templatefilter(b'user', intype=bytes)
509 def userfilter(text):
509 def userfilter(text):
510 """Any text. Returns a short representation of a user name or email
510 """Any text. Returns a short representation of a user name or email
511 address."""
511 address."""
512 return stringutil.shortuser(text)
512 return stringutil.shortuser(text)
513
513
514
514
515 @templatefilter(b'emailuser', intype=bytes)
515 @templatefilter(b'emailuser', intype=bytes)
516 def emailuser(text):
516 def emailuser(text):
517 """Any text. Returns the user portion of an email address."""
517 """Any text. Returns the user portion of an email address."""
518 return stringutil.emailuser(text)
518 return stringutil.emailuser(text)
519
519
520
520
521 @templatefilter(b'utf8', intype=bytes)
521 @templatefilter(b'utf8', intype=bytes)
522 def utf8(text):
522 def utf8(text):
523 """Any text. Converts from the local character encoding to UTF-8."""
523 """Any text. Converts from the local character encoding to UTF-8."""
524 return encoding.fromlocal(text)
524 return encoding.fromlocal(text)
525
525
526
526
527 @templatefilter(b'xmlescape', intype=bytes)
527 @templatefilter(b'xmlescape', intype=bytes)
528 def xmlescape(text):
528 def xmlescape(text):
529 text = (
529 text = (
530 text.replace(b'&', b'&amp;')
530 text.replace(b'&', b'&amp;')
531 .replace(b'<', b'&lt;')
531 .replace(b'<', b'&lt;')
532 .replace(b'>', b'&gt;')
532 .replace(b'>', b'&gt;')
533 .replace(b'"', b'&quot;')
533 .replace(b'"', b'&quot;')
534 .replace(b"'", b'&#39;')
534 .replace(b"'", b'&#39;')
535 ) # &apos; invalid in HTML
535 ) # &apos; invalid in HTML
536 return re.sub(b'[\x00-\x08\x0B\x0C\x0E-\x1F]', b' ', text)
536 return re.sub(b'[\x00-\x08\x0B\x0C\x0E-\x1F]', b' ', text)
537
537
538
538
539 def websub(text, websubtable):
539 def websub(text, websubtable):
540 """:websub: Any text. Only applies to hgweb. Applies the regular
540 """:websub: Any text. Only applies to hgweb. Applies the regular
541 expression replacements defined in the websub section.
541 expression replacements defined in the websub section.
542 """
542 """
543 if websubtable:
543 if websubtable:
544 for regexp, format in websubtable:
544 for regexp, format in websubtable:
545 text = regexp.sub(format, text)
545 text = regexp.sub(format, text)
546 return text
546 return text
547
547
548
548
549 def loadfilter(ui, extname, registrarobj):
549 def loadfilter(ui, extname, registrarobj):
550 """Load template filter from specified registrarobj"""
550 """Load template filter from specified registrarobj"""
551 for name, func in registrarobj._table.items():
551 for name, func in registrarobj._table.items():
552 filters[name] = func
552 filters[name] = func
553
553
554
554
555 # tell hggettext to extract docstrings from these functions:
555 # tell hggettext to extract docstrings from these functions:
556 i18nfunctions = filters.values()
556 i18nfunctions = filters.values()
@@ -1,1082 +1,1081 b''
1 # cborutil.py - CBOR extensions
1 # cborutil.py - CBOR extensions
2 #
2 #
3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8
8
9 import struct
9 import struct
10 import sys
10 import sys
11
11
12 from .. import pycompat
13
12
14 # Very short very of RFC 7049...
13 # Very short very of RFC 7049...
15 #
14 #
16 # Each item begins with a byte. The 3 high bits of that byte denote the
15 # Each item begins with a byte. The 3 high bits of that byte denote the
17 # "major type." The lower 5 bits denote the "subtype." Each major type
16 # "major type." The lower 5 bits denote the "subtype." Each major type
18 # has its own encoding mechanism.
17 # has its own encoding mechanism.
19 #
18 #
20 # Most types have lengths. However, bytestring, string, array, and map
19 # Most types have lengths. However, bytestring, string, array, and map
21 # can be indefinite length. These are denotes by a subtype with value 31.
20 # can be indefinite length. These are denotes by a subtype with value 31.
22 # Sub-components of those types then come afterwards and are terminated
21 # Sub-components of those types then come afterwards and are terminated
23 # by a "break" byte.
22 # by a "break" byte.
24
23
25 MAJOR_TYPE_UINT = 0
24 MAJOR_TYPE_UINT = 0
26 MAJOR_TYPE_NEGINT = 1
25 MAJOR_TYPE_NEGINT = 1
27 MAJOR_TYPE_BYTESTRING = 2
26 MAJOR_TYPE_BYTESTRING = 2
28 MAJOR_TYPE_STRING = 3
27 MAJOR_TYPE_STRING = 3
29 MAJOR_TYPE_ARRAY = 4
28 MAJOR_TYPE_ARRAY = 4
30 MAJOR_TYPE_MAP = 5
29 MAJOR_TYPE_MAP = 5
31 MAJOR_TYPE_SEMANTIC = 6
30 MAJOR_TYPE_SEMANTIC = 6
32 MAJOR_TYPE_SPECIAL = 7
31 MAJOR_TYPE_SPECIAL = 7
33
32
34 SUBTYPE_MASK = 0b00011111
33 SUBTYPE_MASK = 0b00011111
35
34
36 SUBTYPE_FALSE = 20
35 SUBTYPE_FALSE = 20
37 SUBTYPE_TRUE = 21
36 SUBTYPE_TRUE = 21
38 SUBTYPE_NULL = 22
37 SUBTYPE_NULL = 22
39 SUBTYPE_HALF_FLOAT = 25
38 SUBTYPE_HALF_FLOAT = 25
40 SUBTYPE_SINGLE_FLOAT = 26
39 SUBTYPE_SINGLE_FLOAT = 26
41 SUBTYPE_DOUBLE_FLOAT = 27
40 SUBTYPE_DOUBLE_FLOAT = 27
42 SUBTYPE_INDEFINITE = 31
41 SUBTYPE_INDEFINITE = 31
43
42
44 SEMANTIC_TAG_FINITE_SET = 258
43 SEMANTIC_TAG_FINITE_SET = 258
45
44
46 # Indefinite types begin with their major type ORd with information value 31.
45 # Indefinite types begin with their major type ORd with information value 31.
47 BEGIN_INDEFINITE_BYTESTRING = struct.pack(
46 BEGIN_INDEFINITE_BYTESTRING = struct.pack(
48 '>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE
47 '>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE
49 )
48 )
50 BEGIN_INDEFINITE_ARRAY = struct.pack(
49 BEGIN_INDEFINITE_ARRAY = struct.pack(
51 '>B', MAJOR_TYPE_ARRAY << 5 | SUBTYPE_INDEFINITE
50 '>B', MAJOR_TYPE_ARRAY << 5 | SUBTYPE_INDEFINITE
52 )
51 )
53 BEGIN_INDEFINITE_MAP = struct.pack(
52 BEGIN_INDEFINITE_MAP = struct.pack(
54 '>B', MAJOR_TYPE_MAP << 5 | SUBTYPE_INDEFINITE
53 '>B', MAJOR_TYPE_MAP << 5 | SUBTYPE_INDEFINITE
55 )
54 )
56
55
57 ENCODED_LENGTH_1 = struct.Struct('>B')
56 ENCODED_LENGTH_1 = struct.Struct('>B')
58 ENCODED_LENGTH_2 = struct.Struct('>BB')
57 ENCODED_LENGTH_2 = struct.Struct('>BB')
59 ENCODED_LENGTH_3 = struct.Struct('>BH')
58 ENCODED_LENGTH_3 = struct.Struct('>BH')
60 ENCODED_LENGTH_4 = struct.Struct('>BL')
59 ENCODED_LENGTH_4 = struct.Struct('>BL')
61 ENCODED_LENGTH_5 = struct.Struct('>BQ')
60 ENCODED_LENGTH_5 = struct.Struct('>BQ')
62
61
63 # The break ends an indefinite length item.
62 # The break ends an indefinite length item.
64 BREAK = b'\xff'
63 BREAK = b'\xff'
65 BREAK_INT = 255
64 BREAK_INT = 255
66
65
67
66
68 def encodelength(majortype, length):
67 def encodelength(majortype, length):
69 """Obtain a value encoding the major type and its length."""
68 """Obtain a value encoding the major type and its length."""
70 if length < 24:
69 if length < 24:
71 return ENCODED_LENGTH_1.pack(majortype << 5 | length)
70 return ENCODED_LENGTH_1.pack(majortype << 5 | length)
72 elif length < 256:
71 elif length < 256:
73 return ENCODED_LENGTH_2.pack(majortype << 5 | 24, length)
72 return ENCODED_LENGTH_2.pack(majortype << 5 | 24, length)
74 elif length < 65536:
73 elif length < 65536:
75 return ENCODED_LENGTH_3.pack(majortype << 5 | 25, length)
74 return ENCODED_LENGTH_3.pack(majortype << 5 | 25, length)
76 elif length < 4294967296:
75 elif length < 4294967296:
77 return ENCODED_LENGTH_4.pack(majortype << 5 | 26, length)
76 return ENCODED_LENGTH_4.pack(majortype << 5 | 26, length)
78 else:
77 else:
79 return ENCODED_LENGTH_5.pack(majortype << 5 | 27, length)
78 return ENCODED_LENGTH_5.pack(majortype << 5 | 27, length)
80
79
81
80
82 def streamencodebytestring(v):
81 def streamencodebytestring(v):
83 yield encodelength(MAJOR_TYPE_BYTESTRING, len(v))
82 yield encodelength(MAJOR_TYPE_BYTESTRING, len(v))
84 yield v
83 yield v
85
84
86
85
87 def streamencodebytestringfromiter(it):
86 def streamencodebytestringfromiter(it):
88 """Convert an iterator of chunks to an indefinite bytestring.
87 """Convert an iterator of chunks to an indefinite bytestring.
89
88
90 Given an input that is iterable and each element in the iterator is
89 Given an input that is iterable and each element in the iterator is
91 representable as bytes, emit an indefinite length bytestring.
90 representable as bytes, emit an indefinite length bytestring.
92 """
91 """
93 yield BEGIN_INDEFINITE_BYTESTRING
92 yield BEGIN_INDEFINITE_BYTESTRING
94
93
95 for chunk in it:
94 for chunk in it:
96 yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
95 yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
97 yield chunk
96 yield chunk
98
97
99 yield BREAK
98 yield BREAK
100
99
101
100
102 def streamencodeindefinitebytestring(source, chunksize=65536):
101 def streamencodeindefinitebytestring(source, chunksize=65536):
103 """Given a large source buffer, emit as an indefinite length bytestring.
102 """Given a large source buffer, emit as an indefinite length bytestring.
104
103
105 This is a generator of chunks constituting the encoded CBOR data.
104 This is a generator of chunks constituting the encoded CBOR data.
106 """
105 """
107 yield BEGIN_INDEFINITE_BYTESTRING
106 yield BEGIN_INDEFINITE_BYTESTRING
108
107
109 i = 0
108 i = 0
110 l = len(source)
109 l = len(source)
111
110
112 while True:
111 while True:
113 chunk = source[i : i + chunksize]
112 chunk = source[i : i + chunksize]
114 i += len(chunk)
113 i += len(chunk)
115
114
116 yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
115 yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
117 yield chunk
116 yield chunk
118
117
119 if i >= l:
118 if i >= l:
120 break
119 break
121
120
122 yield BREAK
121 yield BREAK
123
122
124
123
125 def streamencodeint(v):
124 def streamencodeint(v):
126 if v >= 18446744073709551616 or v < -18446744073709551616:
125 if v >= 18446744073709551616 or v < -18446744073709551616:
127 raise ValueError(b'big integers not supported')
126 raise ValueError(b'big integers not supported')
128
127
129 if v >= 0:
128 if v >= 0:
130 yield encodelength(MAJOR_TYPE_UINT, v)
129 yield encodelength(MAJOR_TYPE_UINT, v)
131 else:
130 else:
132 yield encodelength(MAJOR_TYPE_NEGINT, abs(v) - 1)
131 yield encodelength(MAJOR_TYPE_NEGINT, abs(v) - 1)
133
132
134
133
135 def streamencodearray(l):
134 def streamencodearray(l):
136 """Encode a known size iterable to an array."""
135 """Encode a known size iterable to an array."""
137
136
138 yield encodelength(MAJOR_TYPE_ARRAY, len(l))
137 yield encodelength(MAJOR_TYPE_ARRAY, len(l))
139
138
140 for i in l:
139 for i in l:
141 for chunk in streamencode(i):
140 for chunk in streamencode(i):
142 yield chunk
141 yield chunk
143
142
144
143
145 def streamencodearrayfromiter(it):
144 def streamencodearrayfromiter(it):
146 """Encode an iterator of items to an indefinite length array."""
145 """Encode an iterator of items to an indefinite length array."""
147
146
148 yield BEGIN_INDEFINITE_ARRAY
147 yield BEGIN_INDEFINITE_ARRAY
149
148
150 for i in it:
149 for i in it:
151 for chunk in streamencode(i):
150 for chunk in streamencode(i):
152 yield chunk
151 yield chunk
153
152
154 yield BREAK
153 yield BREAK
155
154
156
155
157 def _mixedtypesortkey(v):
156 def _mixedtypesortkey(v):
158 return type(v).__name__, v
157 return type(v).__name__, v
159
158
160
159
161 def streamencodeset(s):
160 def streamencodeset(s):
162 # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
161 # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
163 # semantic tag 258 for finite sets.
162 # semantic tag 258 for finite sets.
164 yield encodelength(MAJOR_TYPE_SEMANTIC, SEMANTIC_TAG_FINITE_SET)
163 yield encodelength(MAJOR_TYPE_SEMANTIC, SEMANTIC_TAG_FINITE_SET)
165
164
166 for chunk in streamencodearray(sorted(s, key=_mixedtypesortkey)):
165 for chunk in streamencodearray(sorted(s, key=_mixedtypesortkey)):
167 yield chunk
166 yield chunk
168
167
169
168
170 def streamencodemap(d):
169 def streamencodemap(d):
171 """Encode dictionary to a generator.
170 """Encode dictionary to a generator.
172
171
173 Does not supporting indefinite length dictionaries.
172 Does not supporting indefinite length dictionaries.
174 """
173 """
175 yield encodelength(MAJOR_TYPE_MAP, len(d))
174 yield encodelength(MAJOR_TYPE_MAP, len(d))
176
175
177 for key, value in sorted(d.items(), key=lambda x: _mixedtypesortkey(x[0])):
176 for key, value in sorted(d.items(), key=lambda x: _mixedtypesortkey(x[0])):
178 for chunk in streamencode(key):
177 for chunk in streamencode(key):
179 yield chunk
178 yield chunk
180 for chunk in streamencode(value):
179 for chunk in streamencode(value):
181 yield chunk
180 yield chunk
182
181
183
182
184 def streamencodemapfromiter(it):
183 def streamencodemapfromiter(it):
185 """Given an iterable of (key, value), encode to an indefinite length map."""
184 """Given an iterable of (key, value), encode to an indefinite length map."""
186 yield BEGIN_INDEFINITE_MAP
185 yield BEGIN_INDEFINITE_MAP
187
186
188 for key, value in it:
187 for key, value in it:
189 for chunk in streamencode(key):
188 for chunk in streamencode(key):
190 yield chunk
189 yield chunk
191 for chunk in streamencode(value):
190 for chunk in streamencode(value):
192 yield chunk
191 yield chunk
193
192
194 yield BREAK
193 yield BREAK
195
194
196
195
197 def streamencodebool(b):
196 def streamencodebool(b):
198 # major type 7, simple value 20 and 21.
197 # major type 7, simple value 20 and 21.
199 yield b'\xf5' if b else b'\xf4'
198 yield b'\xf5' if b else b'\xf4'
200
199
201
200
202 def streamencodenone(v):
201 def streamencodenone(v):
203 # major type 7, simple value 22.
202 # major type 7, simple value 22.
204 yield b'\xf6'
203 yield b'\xf6'
205
204
206
205
207 STREAM_ENCODERS = {
206 STREAM_ENCODERS = {
208 bytes: streamencodebytestring,
207 bytes: streamencodebytestring,
209 int: streamencodeint,
208 int: streamencodeint,
210 pycompat.long: streamencodeint,
209 int: streamencodeint,
211 list: streamencodearray,
210 list: streamencodearray,
212 tuple: streamencodearray,
211 tuple: streamencodearray,
213 dict: streamencodemap,
212 dict: streamencodemap,
214 set: streamencodeset,
213 set: streamencodeset,
215 bool: streamencodebool,
214 bool: streamencodebool,
216 type(None): streamencodenone,
215 type(None): streamencodenone,
217 }
216 }
218
217
219
218
220 def streamencode(v):
219 def streamencode(v):
221 """Encode a value in a streaming manner.
220 """Encode a value in a streaming manner.
222
221
223 Given an input object, encode it to CBOR recursively.
222 Given an input object, encode it to CBOR recursively.
224
223
225 Returns a generator of CBOR encoded bytes. There is no guarantee
224 Returns a generator of CBOR encoded bytes. There is no guarantee
226 that each emitted chunk fully decodes to a value or sub-value.
225 that each emitted chunk fully decodes to a value or sub-value.
227
226
228 Encoding is deterministic - unordered collections are sorted.
227 Encoding is deterministic - unordered collections are sorted.
229 """
228 """
230 fn = STREAM_ENCODERS.get(v.__class__)
229 fn = STREAM_ENCODERS.get(v.__class__)
231
230
232 if not fn:
231 if not fn:
233 # handle subtypes such as encoding.localstr and util.sortdict
232 # handle subtypes such as encoding.localstr and util.sortdict
234 for ty in STREAM_ENCODERS:
233 for ty in STREAM_ENCODERS:
235 if not isinstance(v, ty):
234 if not isinstance(v, ty):
236 continue
235 continue
237 fn = STREAM_ENCODERS[ty]
236 fn = STREAM_ENCODERS[ty]
238 break
237 break
239
238
240 if not fn:
239 if not fn:
241 raise ValueError(b'do not know how to encode %s' % type(v))
240 raise ValueError(b'do not know how to encode %s' % type(v))
242
241
243 return fn(v)
242 return fn(v)
244
243
245
244
246 class CBORDecodeError(Exception):
245 class CBORDecodeError(Exception):
247 """Represents an error decoding CBOR."""
246 """Represents an error decoding CBOR."""
248
247
249
248
250 if sys.version_info.major >= 3:
249 if sys.version_info.major >= 3:
251
250
252 def _elementtointeger(b, i):
251 def _elementtointeger(b, i):
253 return b[i]
252 return b[i]
254
253
255
254
256 else:
255 else:
257
256
258 def _elementtointeger(b, i):
257 def _elementtointeger(b, i):
259 return ord(b[i])
258 return ord(b[i])
260
259
261
260
262 STRUCT_BIG_UBYTE = struct.Struct('>B')
261 STRUCT_BIG_UBYTE = struct.Struct('>B')
263 STRUCT_BIG_USHORT = struct.Struct(b'>H')
262 STRUCT_BIG_USHORT = struct.Struct(b'>H')
264 STRUCT_BIG_ULONG = struct.Struct(b'>L')
263 STRUCT_BIG_ULONG = struct.Struct(b'>L')
265 STRUCT_BIG_ULONGLONG = struct.Struct(b'>Q')
264 STRUCT_BIG_ULONGLONG = struct.Struct(b'>Q')
266
265
267 SPECIAL_NONE = 0
266 SPECIAL_NONE = 0
268 SPECIAL_START_INDEFINITE_BYTESTRING = 1
267 SPECIAL_START_INDEFINITE_BYTESTRING = 1
269 SPECIAL_START_ARRAY = 2
268 SPECIAL_START_ARRAY = 2
270 SPECIAL_START_MAP = 3
269 SPECIAL_START_MAP = 3
271 SPECIAL_START_SET = 4
270 SPECIAL_START_SET = 4
272 SPECIAL_INDEFINITE_BREAK = 5
271 SPECIAL_INDEFINITE_BREAK = 5
273
272
274
273
275 def decodeitem(b, offset=0):
274 def decodeitem(b, offset=0):
276 """Decode a new CBOR value from a buffer at offset.
275 """Decode a new CBOR value from a buffer at offset.
277
276
278 This function attempts to decode up to one complete CBOR value
277 This function attempts to decode up to one complete CBOR value
279 from ``b`` starting at offset ``offset``.
278 from ``b`` starting at offset ``offset``.
280
279
281 The beginning of a collection (such as an array, map, set, or
280 The beginning of a collection (such as an array, map, set, or
282 indefinite length bytestring) counts as a single value. For these
281 indefinite length bytestring) counts as a single value. For these
283 special cases, a state flag will indicate that a special value was seen.
282 special cases, a state flag will indicate that a special value was seen.
284
283
285 When called, the function either returns a decoded value or gives
284 When called, the function either returns a decoded value or gives
286 a hint as to how many more bytes are needed to do so. By calling
285 a hint as to how many more bytes are needed to do so. By calling
287 the function repeatedly given a stream of bytes, the caller can
286 the function repeatedly given a stream of bytes, the caller can
288 build up the original values.
287 build up the original values.
289
288
290 Returns a tuple with the following elements:
289 Returns a tuple with the following elements:
291
290
292 * Bool indicating whether a complete value was decoded.
291 * Bool indicating whether a complete value was decoded.
293 * A decoded value if first value is True otherwise None
292 * A decoded value if first value is True otherwise None
294 * Integer number of bytes. If positive, the number of bytes
293 * Integer number of bytes. If positive, the number of bytes
295 read. If negative, the number of bytes we need to read to
294 read. If negative, the number of bytes we need to read to
296 decode this value or the next chunk in this value.
295 decode this value or the next chunk in this value.
297 * One of the ``SPECIAL_*`` constants indicating special treatment
296 * One of the ``SPECIAL_*`` constants indicating special treatment
298 for this value. ``SPECIAL_NONE`` means this is a fully decoded
297 for this value. ``SPECIAL_NONE`` means this is a fully decoded
299 simple value (such as an integer or bool).
298 simple value (such as an integer or bool).
300 """
299 """
301
300
302 initial = _elementtointeger(b, offset)
301 initial = _elementtointeger(b, offset)
303 offset += 1
302 offset += 1
304
303
305 majortype = initial >> 5
304 majortype = initial >> 5
306 subtype = initial & SUBTYPE_MASK
305 subtype = initial & SUBTYPE_MASK
307
306
308 if majortype == MAJOR_TYPE_UINT:
307 if majortype == MAJOR_TYPE_UINT:
309 complete, value, readcount = decodeuint(subtype, b, offset)
308 complete, value, readcount = decodeuint(subtype, b, offset)
310
309
311 if complete:
310 if complete:
312 return True, value, readcount + 1, SPECIAL_NONE
311 return True, value, readcount + 1, SPECIAL_NONE
313 else:
312 else:
314 return False, None, readcount, SPECIAL_NONE
313 return False, None, readcount, SPECIAL_NONE
315
314
316 elif majortype == MAJOR_TYPE_NEGINT:
315 elif majortype == MAJOR_TYPE_NEGINT:
317 # Negative integers are the same as UINT except inverted minus 1.
316 # Negative integers are the same as UINT except inverted minus 1.
318 complete, value, readcount = decodeuint(subtype, b, offset)
317 complete, value, readcount = decodeuint(subtype, b, offset)
319
318
320 if complete:
319 if complete:
321 return True, -value - 1, readcount + 1, SPECIAL_NONE
320 return True, -value - 1, readcount + 1, SPECIAL_NONE
322 else:
321 else:
323 return False, None, readcount, SPECIAL_NONE
322 return False, None, readcount, SPECIAL_NONE
324
323
325 elif majortype == MAJOR_TYPE_BYTESTRING:
324 elif majortype == MAJOR_TYPE_BYTESTRING:
326 # Beginning of bytestrings are treated as uints in order to
325 # Beginning of bytestrings are treated as uints in order to
327 # decode their length, which may be indefinite.
326 # decode their length, which may be indefinite.
328 complete, size, readcount = decodeuint(
327 complete, size, readcount = decodeuint(
329 subtype, b, offset, allowindefinite=True
328 subtype, b, offset, allowindefinite=True
330 )
329 )
331
330
332 # We don't know the size of the bytestring. It must be a definitive
331 # We don't know the size of the bytestring. It must be a definitive
333 # length since the indefinite subtype would be encoded in the initial
332 # length since the indefinite subtype would be encoded in the initial
334 # byte.
333 # byte.
335 if not complete:
334 if not complete:
336 return False, None, readcount, SPECIAL_NONE
335 return False, None, readcount, SPECIAL_NONE
337
336
338 # We know the length of the bytestring.
337 # We know the length of the bytestring.
339 if size is not None:
338 if size is not None:
340 # And the data is available in the buffer.
339 # And the data is available in the buffer.
341 if offset + readcount + size <= len(b):
340 if offset + readcount + size <= len(b):
342 value = b[offset + readcount : offset + readcount + size]
341 value = b[offset + readcount : offset + readcount + size]
343 return True, value, readcount + size + 1, SPECIAL_NONE
342 return True, value, readcount + size + 1, SPECIAL_NONE
344
343
345 # And we need more data in order to return the bytestring.
344 # And we need more data in order to return the bytestring.
346 else:
345 else:
347 wanted = len(b) - offset - readcount - size
346 wanted = len(b) - offset - readcount - size
348 return False, None, wanted, SPECIAL_NONE
347 return False, None, wanted, SPECIAL_NONE
349
348
350 # It is an indefinite length bytestring.
349 # It is an indefinite length bytestring.
351 else:
350 else:
352 return True, None, 1, SPECIAL_START_INDEFINITE_BYTESTRING
351 return True, None, 1, SPECIAL_START_INDEFINITE_BYTESTRING
353
352
354 elif majortype == MAJOR_TYPE_STRING:
353 elif majortype == MAJOR_TYPE_STRING:
355 raise CBORDecodeError(b'string major type not supported')
354 raise CBORDecodeError(b'string major type not supported')
356
355
357 elif majortype == MAJOR_TYPE_ARRAY:
356 elif majortype == MAJOR_TYPE_ARRAY:
358 # Beginning of arrays are treated as uints in order to decode their
357 # Beginning of arrays are treated as uints in order to decode their
359 # length. We don't allow indefinite length arrays.
358 # length. We don't allow indefinite length arrays.
360 complete, size, readcount = decodeuint(subtype, b, offset)
359 complete, size, readcount = decodeuint(subtype, b, offset)
361
360
362 if complete:
361 if complete:
363 return True, size, readcount + 1, SPECIAL_START_ARRAY
362 return True, size, readcount + 1, SPECIAL_START_ARRAY
364 else:
363 else:
365 return False, None, readcount, SPECIAL_NONE
364 return False, None, readcount, SPECIAL_NONE
366
365
367 elif majortype == MAJOR_TYPE_MAP:
366 elif majortype == MAJOR_TYPE_MAP:
368 # Beginning of maps are treated as uints in order to decode their
367 # Beginning of maps are treated as uints in order to decode their
369 # number of elements. We don't allow indefinite length arrays.
368 # number of elements. We don't allow indefinite length arrays.
370 complete, size, readcount = decodeuint(subtype, b, offset)
369 complete, size, readcount = decodeuint(subtype, b, offset)
371
370
372 if complete:
371 if complete:
373 return True, size, readcount + 1, SPECIAL_START_MAP
372 return True, size, readcount + 1, SPECIAL_START_MAP
374 else:
373 else:
375 return False, None, readcount, SPECIAL_NONE
374 return False, None, readcount, SPECIAL_NONE
376
375
377 elif majortype == MAJOR_TYPE_SEMANTIC:
376 elif majortype == MAJOR_TYPE_SEMANTIC:
378 # Semantic tag value is read the same as a uint.
377 # Semantic tag value is read the same as a uint.
379 complete, tagvalue, readcount = decodeuint(subtype, b, offset)
378 complete, tagvalue, readcount = decodeuint(subtype, b, offset)
380
379
381 if not complete:
380 if not complete:
382 return False, None, readcount, SPECIAL_NONE
381 return False, None, readcount, SPECIAL_NONE
383
382
384 # This behavior here is a little wonky. The main type being "decorated"
383 # This behavior here is a little wonky. The main type being "decorated"
385 # by this semantic tag follows. A more robust parser would probably emit
384 # by this semantic tag follows. A more robust parser would probably emit
386 # a special flag indicating this as a semantic tag and let the caller
385 # a special flag indicating this as a semantic tag and let the caller
387 # deal with the types that follow. But since we don't support many
386 # deal with the types that follow. But since we don't support many
388 # semantic tags, it is easier to deal with the special cases here and
387 # semantic tags, it is easier to deal with the special cases here and
389 # hide complexity from the caller. If we add support for more semantic
388 # hide complexity from the caller. If we add support for more semantic
390 # tags, we should probably move semantic tag handling into the caller.
389 # tags, we should probably move semantic tag handling into the caller.
391 if tagvalue == SEMANTIC_TAG_FINITE_SET:
390 if tagvalue == SEMANTIC_TAG_FINITE_SET:
392 if offset + readcount >= len(b):
391 if offset + readcount >= len(b):
393 return False, None, -1, SPECIAL_NONE
392 return False, None, -1, SPECIAL_NONE
394
393
395 complete, size, readcount2, special = decodeitem(
394 complete, size, readcount2, special = decodeitem(
396 b, offset + readcount
395 b, offset + readcount
397 )
396 )
398
397
399 if not complete:
398 if not complete:
400 return False, None, readcount2, SPECIAL_NONE
399 return False, None, readcount2, SPECIAL_NONE
401
400
402 if special != SPECIAL_START_ARRAY:
401 if special != SPECIAL_START_ARRAY:
403 raise CBORDecodeError(
402 raise CBORDecodeError(
404 b'expected array after finite set semantic tag'
403 b'expected array after finite set semantic tag'
405 )
404 )
406
405
407 return True, size, readcount + readcount2 + 1, SPECIAL_START_SET
406 return True, size, readcount + readcount2 + 1, SPECIAL_START_SET
408
407
409 else:
408 else:
410 raise CBORDecodeError(b'semantic tag %d not allowed' % tagvalue)
409 raise CBORDecodeError(b'semantic tag %d not allowed' % tagvalue)
411
410
412 elif majortype == MAJOR_TYPE_SPECIAL:
411 elif majortype == MAJOR_TYPE_SPECIAL:
413 # Only specific values for the information field are allowed.
412 # Only specific values for the information field are allowed.
414 if subtype == SUBTYPE_FALSE:
413 if subtype == SUBTYPE_FALSE:
415 return True, False, 1, SPECIAL_NONE
414 return True, False, 1, SPECIAL_NONE
416 elif subtype == SUBTYPE_TRUE:
415 elif subtype == SUBTYPE_TRUE:
417 return True, True, 1, SPECIAL_NONE
416 return True, True, 1, SPECIAL_NONE
418 elif subtype == SUBTYPE_NULL:
417 elif subtype == SUBTYPE_NULL:
419 return True, None, 1, SPECIAL_NONE
418 return True, None, 1, SPECIAL_NONE
420 elif subtype == SUBTYPE_INDEFINITE:
419 elif subtype == SUBTYPE_INDEFINITE:
421 return True, None, 1, SPECIAL_INDEFINITE_BREAK
420 return True, None, 1, SPECIAL_INDEFINITE_BREAK
422 # If value is 24, subtype is in next byte.
421 # If value is 24, subtype is in next byte.
423 else:
422 else:
424 raise CBORDecodeError(b'special type %d not allowed' % subtype)
423 raise CBORDecodeError(b'special type %d not allowed' % subtype)
425 else:
424 else:
426 assert False
425 assert False
427
426
428
427
429 def decodeuint(subtype, b, offset=0, allowindefinite=False):
428 def decodeuint(subtype, b, offset=0, allowindefinite=False):
430 """Decode an unsigned integer.
429 """Decode an unsigned integer.
431
430
432 ``subtype`` is the lower 5 bits from the initial byte CBOR item
431 ``subtype`` is the lower 5 bits from the initial byte CBOR item
433 "header." ``b`` is a buffer containing bytes. ``offset`` points to
432 "header." ``b`` is a buffer containing bytes. ``offset`` points to
434 the index of the first byte after the byte that ``subtype`` was
433 the index of the first byte after the byte that ``subtype`` was
435 derived from.
434 derived from.
436
435
437 ``allowindefinite`` allows the special indefinite length value
436 ``allowindefinite`` allows the special indefinite length value
438 indicator.
437 indicator.
439
438
440 Returns a 3-tuple of (successful, value, count).
439 Returns a 3-tuple of (successful, value, count).
441
440
442 The first element is a bool indicating if decoding completed. The 2nd
441 The first element is a bool indicating if decoding completed. The 2nd
443 is the decoded integer value or None if not fully decoded or the subtype
442 is the decoded integer value or None if not fully decoded or the subtype
444 is 31 and ``allowindefinite`` is True. The 3rd value is the count of bytes.
443 is 31 and ``allowindefinite`` is True. The 3rd value is the count of bytes.
445 If positive, it is the number of additional bytes decoded. If negative,
444 If positive, it is the number of additional bytes decoded. If negative,
446 it is the number of additional bytes needed to decode this value.
445 it is the number of additional bytes needed to decode this value.
447 """
446 """
448
447
449 # Small values are inline.
448 # Small values are inline.
450 if subtype < 24:
449 if subtype < 24:
451 return True, subtype, 0
450 return True, subtype, 0
452 # Indefinite length specifier.
451 # Indefinite length specifier.
453 elif subtype == 31:
452 elif subtype == 31:
454 if allowindefinite:
453 if allowindefinite:
455 return True, None, 0
454 return True, None, 0
456 else:
455 else:
457 raise CBORDecodeError(b'indefinite length uint not allowed here')
456 raise CBORDecodeError(b'indefinite length uint not allowed here')
458 elif subtype >= 28:
457 elif subtype >= 28:
459 raise CBORDecodeError(
458 raise CBORDecodeError(
460 b'unsupported subtype on integer type: %d' % subtype
459 b'unsupported subtype on integer type: %d' % subtype
461 )
460 )
462
461
463 if subtype == 24:
462 if subtype == 24:
464 s = STRUCT_BIG_UBYTE
463 s = STRUCT_BIG_UBYTE
465 elif subtype == 25:
464 elif subtype == 25:
466 s = STRUCT_BIG_USHORT
465 s = STRUCT_BIG_USHORT
467 elif subtype == 26:
466 elif subtype == 26:
468 s = STRUCT_BIG_ULONG
467 s = STRUCT_BIG_ULONG
469 elif subtype == 27:
468 elif subtype == 27:
470 s = STRUCT_BIG_ULONGLONG
469 s = STRUCT_BIG_ULONGLONG
471 else:
470 else:
472 raise CBORDecodeError(b'bounds condition checking violation')
471 raise CBORDecodeError(b'bounds condition checking violation')
473
472
474 if len(b) - offset >= s.size:
473 if len(b) - offset >= s.size:
475 return True, s.unpack_from(b, offset)[0], s.size
474 return True, s.unpack_from(b, offset)[0], s.size
476 else:
475 else:
477 return False, None, len(b) - offset - s.size
476 return False, None, len(b) - offset - s.size
478
477
479
478
480 class bytestringchunk(bytes):
479 class bytestringchunk(bytes):
481 """Represents a chunk/segment in an indefinite length bytestring.
480 """Represents a chunk/segment in an indefinite length bytestring.
482
481
483 This behaves like a ``bytes`` but in addition has the ``isfirst``
482 This behaves like a ``bytes`` but in addition has the ``isfirst``
484 and ``islast`` attributes indicating whether this chunk is the first
483 and ``islast`` attributes indicating whether this chunk is the first
485 or last in an indefinite length bytestring.
484 or last in an indefinite length bytestring.
486 """
485 """
487
486
488 def __new__(cls, v, first=False, last=False):
487 def __new__(cls, v, first=False, last=False):
489 self = bytes.__new__(cls, v)
488 self = bytes.__new__(cls, v)
490 self.isfirst = first
489 self.isfirst = first
491 self.islast = last
490 self.islast = last
492
491
493 return self
492 return self
494
493
495
494
496 class sansiodecoder(object):
495 class sansiodecoder(object):
497 """A CBOR decoder that doesn't perform its own I/O.
496 """A CBOR decoder that doesn't perform its own I/O.
498
497
499 To use, construct an instance and feed it segments containing
498 To use, construct an instance and feed it segments containing
500 CBOR-encoded bytes via ``decode()``. The return value from ``decode()``
499 CBOR-encoded bytes via ``decode()``. The return value from ``decode()``
501 indicates whether a fully-decoded value is available, how many bytes
500 indicates whether a fully-decoded value is available, how many bytes
502 were consumed, and offers a hint as to how many bytes should be fed
501 were consumed, and offers a hint as to how many bytes should be fed
503 in next time to decode the next value.
502 in next time to decode the next value.
504
503
505 The decoder assumes it will decode N discrete CBOR values, not just
504 The decoder assumes it will decode N discrete CBOR values, not just
506 a single value. i.e. if the bytestream contains uints packed one after
505 a single value. i.e. if the bytestream contains uints packed one after
507 the other, the decoder will decode them all, rather than just the initial
506 the other, the decoder will decode them all, rather than just the initial
508 one.
507 one.
509
508
510 When ``decode()`` indicates a value is available, call ``getavailable()``
509 When ``decode()`` indicates a value is available, call ``getavailable()``
511 to return all fully decoded values.
510 to return all fully decoded values.
512
511
513 ``decode()`` can partially decode input. It is up to the caller to keep
512 ``decode()`` can partially decode input. It is up to the caller to keep
514 track of what data was consumed and to pass unconsumed data in on the
513 track of what data was consumed and to pass unconsumed data in on the
515 next invocation.
514 next invocation.
516
515
517 The decoder decodes atomically at the *item* level. See ``decodeitem()``.
516 The decoder decodes atomically at the *item* level. See ``decodeitem()``.
518 If an *item* cannot be fully decoded, the decoder won't record it as
517 If an *item* cannot be fully decoded, the decoder won't record it as
519 partially consumed. Instead, the caller will be instructed to pass in
518 partially consumed. Instead, the caller will be instructed to pass in
520 the initial bytes of this item on the next invocation. This does result
519 the initial bytes of this item on the next invocation. This does result
521 in some redundant parsing. But the overhead should be minimal.
520 in some redundant parsing. But the overhead should be minimal.
522
521
523 This decoder only supports a subset of CBOR as required by Mercurial.
522 This decoder only supports a subset of CBOR as required by Mercurial.
524 It lacks support for:
523 It lacks support for:
525
524
526 * Indefinite length arrays
525 * Indefinite length arrays
527 * Indefinite length maps
526 * Indefinite length maps
528 * Use of indefinite length bytestrings as keys or values within
527 * Use of indefinite length bytestrings as keys or values within
529 arrays, maps, or sets.
528 arrays, maps, or sets.
530 * Nested arrays, maps, or sets within sets
529 * Nested arrays, maps, or sets within sets
531 * Any semantic tag that isn't a mathematical finite set
530 * Any semantic tag that isn't a mathematical finite set
532 * Floating point numbers
531 * Floating point numbers
533 * Undefined special value
532 * Undefined special value
534
533
535 CBOR types are decoded to Python types as follows:
534 CBOR types are decoded to Python types as follows:
536
535
537 uint -> int
536 uint -> int
538 negint -> int
537 negint -> int
539 bytestring -> bytes
538 bytestring -> bytes
540 map -> dict
539 map -> dict
541 array -> list
540 array -> list
542 True -> bool
541 True -> bool
543 False -> bool
542 False -> bool
544 null -> None
543 null -> None
545 indefinite length bytestring chunk -> [bytestringchunk]
544 indefinite length bytestring chunk -> [bytestringchunk]
546
545
547 The only non-obvious mapping here is an indefinite length bytestring
546 The only non-obvious mapping here is an indefinite length bytestring
548 to the ``bytestringchunk`` type. This is to facilitate streaming
547 to the ``bytestringchunk`` type. This is to facilitate streaming
549 indefinite length bytestrings out of the decoder and to differentiate
548 indefinite length bytestrings out of the decoder and to differentiate
550 a regular bytestring from an indefinite length bytestring.
549 a regular bytestring from an indefinite length bytestring.
551 """
550 """
552
551
553 _STATE_NONE = 0
552 _STATE_NONE = 0
554 _STATE_WANT_MAP_KEY = 1
553 _STATE_WANT_MAP_KEY = 1
555 _STATE_WANT_MAP_VALUE = 2
554 _STATE_WANT_MAP_VALUE = 2
556 _STATE_WANT_ARRAY_VALUE = 3
555 _STATE_WANT_ARRAY_VALUE = 3
557 _STATE_WANT_SET_VALUE = 4
556 _STATE_WANT_SET_VALUE = 4
558 _STATE_WANT_BYTESTRING_CHUNK_FIRST = 5
557 _STATE_WANT_BYTESTRING_CHUNK_FIRST = 5
559 _STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT = 6
558 _STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT = 6
560
559
561 def __init__(self):
560 def __init__(self):
562 # TODO add support for limiting size of bytestrings
561 # TODO add support for limiting size of bytestrings
563 # TODO add support for limiting number of keys / values in collections
562 # TODO add support for limiting number of keys / values in collections
564 # TODO add support for limiting size of buffered partial values
563 # TODO add support for limiting size of buffered partial values
565
564
566 self.decodedbytecount = 0
565 self.decodedbytecount = 0
567
566
568 self._state = self._STATE_NONE
567 self._state = self._STATE_NONE
569
568
570 # Stack of active nested collections. Each entry is a dict describing
569 # Stack of active nested collections. Each entry is a dict describing
571 # the collection.
570 # the collection.
572 self._collectionstack = []
571 self._collectionstack = []
573
572
574 # Fully decoded key to use for the current map.
573 # Fully decoded key to use for the current map.
575 self._currentmapkey = None
574 self._currentmapkey = None
576
575
577 # Fully decoded values available for retrieval.
576 # Fully decoded values available for retrieval.
578 self._decodedvalues = []
577 self._decodedvalues = []
579
578
580 @property
579 @property
581 def inprogress(self):
580 def inprogress(self):
582 """Whether the decoder has partially decoded a value."""
581 """Whether the decoder has partially decoded a value."""
583 return self._state != self._STATE_NONE
582 return self._state != self._STATE_NONE
584
583
585 def decode(self, b, offset=0):
584 def decode(self, b, offset=0):
586 """Attempt to decode bytes from an input buffer.
585 """Attempt to decode bytes from an input buffer.
587
586
588 ``b`` is a collection of bytes and ``offset`` is the byte
587 ``b`` is a collection of bytes and ``offset`` is the byte
589 offset within that buffer from which to begin reading data.
588 offset within that buffer from which to begin reading data.
590
589
591 ``b`` must support ``len()`` and accessing bytes slices via
590 ``b`` must support ``len()`` and accessing bytes slices via
592 ``__slice__``. Typically ``bytes`` instances are used.
591 ``__slice__``. Typically ``bytes`` instances are used.
593
592
594 Returns a tuple with the following fields:
593 Returns a tuple with the following fields:
595
594
596 * Bool indicating whether values are available for retrieval.
595 * Bool indicating whether values are available for retrieval.
597 * Integer indicating the number of bytes that were fully consumed,
596 * Integer indicating the number of bytes that were fully consumed,
598 starting from ``offset``.
597 starting from ``offset``.
599 * Integer indicating the number of bytes that are desired for the
598 * Integer indicating the number of bytes that are desired for the
600 next call in order to decode an item.
599 next call in order to decode an item.
601 """
600 """
602 if not b:
601 if not b:
603 return bool(self._decodedvalues), 0, 0
602 return bool(self._decodedvalues), 0, 0
604
603
605 initialoffset = offset
604 initialoffset = offset
606
605
607 # We could easily split the body of this loop into a function. But
606 # We could easily split the body of this loop into a function. But
608 # Python performance is sensitive to function calls and collections
607 # Python performance is sensitive to function calls and collections
609 # are composed of many items. So leaving as a while loop could help
608 # are composed of many items. So leaving as a while loop could help
610 # with performance. One thing that may not help is the use of
609 # with performance. One thing that may not help is the use of
611 # if..elif versus a lookup/dispatch table. There may be value
610 # if..elif versus a lookup/dispatch table. There may be value
612 # in switching that.
611 # in switching that.
613 while offset < len(b):
612 while offset < len(b):
614 # Attempt to decode an item. This could be a whole value or a
613 # Attempt to decode an item. This could be a whole value or a
615 # special value indicating an event, such as start or end of a
614 # special value indicating an event, such as start or end of a
616 # collection or indefinite length type.
615 # collection or indefinite length type.
617 complete, value, readcount, special = decodeitem(b, offset)
616 complete, value, readcount, special = decodeitem(b, offset)
618
617
619 if readcount > 0:
618 if readcount > 0:
620 self.decodedbytecount += readcount
619 self.decodedbytecount += readcount
621
620
622 if not complete:
621 if not complete:
623 assert readcount < 0
622 assert readcount < 0
624 return (
623 return (
625 bool(self._decodedvalues),
624 bool(self._decodedvalues),
626 offset - initialoffset,
625 offset - initialoffset,
627 -readcount,
626 -readcount,
628 )
627 )
629
628
630 offset += readcount
629 offset += readcount
631
630
632 # No nested state. We either have a full value or beginning of a
631 # No nested state. We either have a full value or beginning of a
633 # complex value to deal with.
632 # complex value to deal with.
634 if self._state == self._STATE_NONE:
633 if self._state == self._STATE_NONE:
635 # A normal value.
634 # A normal value.
636 if special == SPECIAL_NONE:
635 if special == SPECIAL_NONE:
637 self._decodedvalues.append(value)
636 self._decodedvalues.append(value)
638
637
639 elif special == SPECIAL_START_ARRAY:
638 elif special == SPECIAL_START_ARRAY:
640 self._collectionstack.append(
639 self._collectionstack.append(
641 {
640 {
642 b'remaining': value,
641 b'remaining': value,
643 b'v': [],
642 b'v': [],
644 }
643 }
645 )
644 )
646 self._state = self._STATE_WANT_ARRAY_VALUE
645 self._state = self._STATE_WANT_ARRAY_VALUE
647
646
648 elif special == SPECIAL_START_MAP:
647 elif special == SPECIAL_START_MAP:
649 self._collectionstack.append(
648 self._collectionstack.append(
650 {
649 {
651 b'remaining': value,
650 b'remaining': value,
652 b'v': {},
651 b'v': {},
653 }
652 }
654 )
653 )
655 self._state = self._STATE_WANT_MAP_KEY
654 self._state = self._STATE_WANT_MAP_KEY
656
655
657 elif special == SPECIAL_START_SET:
656 elif special == SPECIAL_START_SET:
658 self._collectionstack.append(
657 self._collectionstack.append(
659 {
658 {
660 b'remaining': value,
659 b'remaining': value,
661 b'v': set(),
660 b'v': set(),
662 }
661 }
663 )
662 )
664 self._state = self._STATE_WANT_SET_VALUE
663 self._state = self._STATE_WANT_SET_VALUE
665
664
666 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
665 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
667 self._state = self._STATE_WANT_BYTESTRING_CHUNK_FIRST
666 self._state = self._STATE_WANT_BYTESTRING_CHUNK_FIRST
668
667
669 else:
668 else:
670 raise CBORDecodeError(
669 raise CBORDecodeError(
671 b'unhandled special state: %d' % special
670 b'unhandled special state: %d' % special
672 )
671 )
673
672
674 # This value becomes an element of the current array.
673 # This value becomes an element of the current array.
675 elif self._state == self._STATE_WANT_ARRAY_VALUE:
674 elif self._state == self._STATE_WANT_ARRAY_VALUE:
676 # Simple values get appended.
675 # Simple values get appended.
677 if special == SPECIAL_NONE:
676 if special == SPECIAL_NONE:
678 c = self._collectionstack[-1]
677 c = self._collectionstack[-1]
679 c[b'v'].append(value)
678 c[b'v'].append(value)
680 c[b'remaining'] -= 1
679 c[b'remaining'] -= 1
681
680
682 # self._state doesn't need changed.
681 # self._state doesn't need changed.
683
682
684 # An array nested within an array.
683 # An array nested within an array.
685 elif special == SPECIAL_START_ARRAY:
684 elif special == SPECIAL_START_ARRAY:
686 lastc = self._collectionstack[-1]
685 lastc = self._collectionstack[-1]
687 newvalue = []
686 newvalue = []
688
687
689 lastc[b'v'].append(newvalue)
688 lastc[b'v'].append(newvalue)
690 lastc[b'remaining'] -= 1
689 lastc[b'remaining'] -= 1
691
690
692 self._collectionstack.append(
691 self._collectionstack.append(
693 {
692 {
694 b'remaining': value,
693 b'remaining': value,
695 b'v': newvalue,
694 b'v': newvalue,
696 }
695 }
697 )
696 )
698
697
699 # self._state doesn't need changed.
698 # self._state doesn't need changed.
700
699
701 # A map nested within an array.
700 # A map nested within an array.
702 elif special == SPECIAL_START_MAP:
701 elif special == SPECIAL_START_MAP:
703 lastc = self._collectionstack[-1]
702 lastc = self._collectionstack[-1]
704 newvalue = {}
703 newvalue = {}
705
704
706 lastc[b'v'].append(newvalue)
705 lastc[b'v'].append(newvalue)
707 lastc[b'remaining'] -= 1
706 lastc[b'remaining'] -= 1
708
707
709 self._collectionstack.append(
708 self._collectionstack.append(
710 {b'remaining': value, b'v': newvalue}
709 {b'remaining': value, b'v': newvalue}
711 )
710 )
712
711
713 self._state = self._STATE_WANT_MAP_KEY
712 self._state = self._STATE_WANT_MAP_KEY
714
713
715 elif special == SPECIAL_START_SET:
714 elif special == SPECIAL_START_SET:
716 lastc = self._collectionstack[-1]
715 lastc = self._collectionstack[-1]
717 newvalue = set()
716 newvalue = set()
718
717
719 lastc[b'v'].append(newvalue)
718 lastc[b'v'].append(newvalue)
720 lastc[b'remaining'] -= 1
719 lastc[b'remaining'] -= 1
721
720
722 self._collectionstack.append(
721 self._collectionstack.append(
723 {
722 {
724 b'remaining': value,
723 b'remaining': value,
725 b'v': newvalue,
724 b'v': newvalue,
726 }
725 }
727 )
726 )
728
727
729 self._state = self._STATE_WANT_SET_VALUE
728 self._state = self._STATE_WANT_SET_VALUE
730
729
731 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
730 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
732 raise CBORDecodeError(
731 raise CBORDecodeError(
733 b'indefinite length bytestrings '
732 b'indefinite length bytestrings '
734 b'not allowed as array values'
733 b'not allowed as array values'
735 )
734 )
736
735
737 else:
736 else:
738 raise CBORDecodeError(
737 raise CBORDecodeError(
739 b'unhandled special item when '
738 b'unhandled special item when '
740 b'expecting array value: %d' % special
739 b'expecting array value: %d' % special
741 )
740 )
742
741
743 # This value becomes the key of the current map instance.
742 # This value becomes the key of the current map instance.
744 elif self._state == self._STATE_WANT_MAP_KEY:
743 elif self._state == self._STATE_WANT_MAP_KEY:
745 if special == SPECIAL_NONE:
744 if special == SPECIAL_NONE:
746 self._currentmapkey = value
745 self._currentmapkey = value
747 self._state = self._STATE_WANT_MAP_VALUE
746 self._state = self._STATE_WANT_MAP_VALUE
748
747
749 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
748 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
750 raise CBORDecodeError(
749 raise CBORDecodeError(
751 b'indefinite length bytestrings '
750 b'indefinite length bytestrings '
752 b'not allowed as map keys'
751 b'not allowed as map keys'
753 )
752 )
754
753
755 elif special in (
754 elif special in (
756 SPECIAL_START_ARRAY,
755 SPECIAL_START_ARRAY,
757 SPECIAL_START_MAP,
756 SPECIAL_START_MAP,
758 SPECIAL_START_SET,
757 SPECIAL_START_SET,
759 ):
758 ):
760 raise CBORDecodeError(
759 raise CBORDecodeError(
761 b'collections not supported as map keys'
760 b'collections not supported as map keys'
762 )
761 )
763
762
764 # We do not allow special values to be used as map keys.
763 # We do not allow special values to be used as map keys.
765 else:
764 else:
766 raise CBORDecodeError(
765 raise CBORDecodeError(
767 b'unhandled special item when '
766 b'unhandled special item when '
768 b'expecting map key: %d' % special
767 b'expecting map key: %d' % special
769 )
768 )
770
769
771 # This value becomes the value of the current map key.
770 # This value becomes the value of the current map key.
772 elif self._state == self._STATE_WANT_MAP_VALUE:
771 elif self._state == self._STATE_WANT_MAP_VALUE:
773 # Simple values simply get inserted into the map.
772 # Simple values simply get inserted into the map.
774 if special == SPECIAL_NONE:
773 if special == SPECIAL_NONE:
775 lastc = self._collectionstack[-1]
774 lastc = self._collectionstack[-1]
776 lastc[b'v'][self._currentmapkey] = value
775 lastc[b'v'][self._currentmapkey] = value
777 lastc[b'remaining'] -= 1
776 lastc[b'remaining'] -= 1
778
777
779 self._state = self._STATE_WANT_MAP_KEY
778 self._state = self._STATE_WANT_MAP_KEY
780
779
781 # A new array is used as the map value.
780 # A new array is used as the map value.
782 elif special == SPECIAL_START_ARRAY:
781 elif special == SPECIAL_START_ARRAY:
783 lastc = self._collectionstack[-1]
782 lastc = self._collectionstack[-1]
784 newvalue = []
783 newvalue = []
785
784
786 lastc[b'v'][self._currentmapkey] = newvalue
785 lastc[b'v'][self._currentmapkey] = newvalue
787 lastc[b'remaining'] -= 1
786 lastc[b'remaining'] -= 1
788
787
789 self._collectionstack.append(
788 self._collectionstack.append(
790 {
789 {
791 b'remaining': value,
790 b'remaining': value,
792 b'v': newvalue,
791 b'v': newvalue,
793 }
792 }
794 )
793 )
795
794
796 self._state = self._STATE_WANT_ARRAY_VALUE
795 self._state = self._STATE_WANT_ARRAY_VALUE
797
796
798 # A new map is used as the map value.
797 # A new map is used as the map value.
799 elif special == SPECIAL_START_MAP:
798 elif special == SPECIAL_START_MAP:
800 lastc = self._collectionstack[-1]
799 lastc = self._collectionstack[-1]
801 newvalue = {}
800 newvalue = {}
802
801
803 lastc[b'v'][self._currentmapkey] = newvalue
802 lastc[b'v'][self._currentmapkey] = newvalue
804 lastc[b'remaining'] -= 1
803 lastc[b'remaining'] -= 1
805
804
806 self._collectionstack.append(
805 self._collectionstack.append(
807 {
806 {
808 b'remaining': value,
807 b'remaining': value,
809 b'v': newvalue,
808 b'v': newvalue,
810 }
809 }
811 )
810 )
812
811
813 self._state = self._STATE_WANT_MAP_KEY
812 self._state = self._STATE_WANT_MAP_KEY
814
813
815 # A new set is used as the map value.
814 # A new set is used as the map value.
816 elif special == SPECIAL_START_SET:
815 elif special == SPECIAL_START_SET:
817 lastc = self._collectionstack[-1]
816 lastc = self._collectionstack[-1]
818 newvalue = set()
817 newvalue = set()
819
818
820 lastc[b'v'][self._currentmapkey] = newvalue
819 lastc[b'v'][self._currentmapkey] = newvalue
821 lastc[b'remaining'] -= 1
820 lastc[b'remaining'] -= 1
822
821
823 self._collectionstack.append(
822 self._collectionstack.append(
824 {
823 {
825 b'remaining': value,
824 b'remaining': value,
826 b'v': newvalue,
825 b'v': newvalue,
827 }
826 }
828 )
827 )
829
828
830 self._state = self._STATE_WANT_SET_VALUE
829 self._state = self._STATE_WANT_SET_VALUE
831
830
832 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
831 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
833 raise CBORDecodeError(
832 raise CBORDecodeError(
834 b'indefinite length bytestrings not '
833 b'indefinite length bytestrings not '
835 b'allowed as map values'
834 b'allowed as map values'
836 )
835 )
837
836
838 else:
837 else:
839 raise CBORDecodeError(
838 raise CBORDecodeError(
840 b'unhandled special item when '
839 b'unhandled special item when '
841 b'expecting map value: %d' % special
840 b'expecting map value: %d' % special
842 )
841 )
843
842
844 self._currentmapkey = None
843 self._currentmapkey = None
845
844
846 # This value is added to the current set.
845 # This value is added to the current set.
847 elif self._state == self._STATE_WANT_SET_VALUE:
846 elif self._state == self._STATE_WANT_SET_VALUE:
848 if special == SPECIAL_NONE:
847 if special == SPECIAL_NONE:
849 lastc = self._collectionstack[-1]
848 lastc = self._collectionstack[-1]
850 lastc[b'v'].add(value)
849 lastc[b'v'].add(value)
851 lastc[b'remaining'] -= 1
850 lastc[b'remaining'] -= 1
852
851
853 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
852 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
854 raise CBORDecodeError(
853 raise CBORDecodeError(
855 b'indefinite length bytestrings not '
854 b'indefinite length bytestrings not '
856 b'allowed as set values'
855 b'allowed as set values'
857 )
856 )
858
857
859 elif special in (
858 elif special in (
860 SPECIAL_START_ARRAY,
859 SPECIAL_START_ARRAY,
861 SPECIAL_START_MAP,
860 SPECIAL_START_MAP,
862 SPECIAL_START_SET,
861 SPECIAL_START_SET,
863 ):
862 ):
864 raise CBORDecodeError(
863 raise CBORDecodeError(
865 b'collections not allowed as set values'
864 b'collections not allowed as set values'
866 )
865 )
867
866
868 # We don't allow non-trivial types to exist as set values.
867 # We don't allow non-trivial types to exist as set values.
869 else:
868 else:
870 raise CBORDecodeError(
869 raise CBORDecodeError(
871 b'unhandled special item when '
870 b'unhandled special item when '
872 b'expecting set value: %d' % special
871 b'expecting set value: %d' % special
873 )
872 )
874
873
875 # This value represents the first chunk in an indefinite length
874 # This value represents the first chunk in an indefinite length
876 # bytestring.
875 # bytestring.
877 elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_FIRST:
876 elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_FIRST:
878 # We received a full chunk.
877 # We received a full chunk.
879 if special == SPECIAL_NONE:
878 if special == SPECIAL_NONE:
880 self._decodedvalues.append(
879 self._decodedvalues.append(
881 bytestringchunk(value, first=True)
880 bytestringchunk(value, first=True)
882 )
881 )
883
882
884 self._state = self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT
883 self._state = self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT
885
884
886 # The end of stream marker. This means it is an empty
885 # The end of stream marker. This means it is an empty
887 # indefinite length bytestring.
886 # indefinite length bytestring.
888 elif special == SPECIAL_INDEFINITE_BREAK:
887 elif special == SPECIAL_INDEFINITE_BREAK:
889 # We /could/ convert this to a b''. But we want to preserve
888 # We /could/ convert this to a b''. But we want to preserve
890 # the nature of the underlying data so consumers expecting
889 # the nature of the underlying data so consumers expecting
891 # an indefinite length bytestring get one.
890 # an indefinite length bytestring get one.
892 self._decodedvalues.append(
891 self._decodedvalues.append(
893 bytestringchunk(b'', first=True, last=True)
892 bytestringchunk(b'', first=True, last=True)
894 )
893 )
895
894
896 # Since indefinite length bytestrings can't be used in
895 # Since indefinite length bytestrings can't be used in
897 # collections, we must be at the root level.
896 # collections, we must be at the root level.
898 assert not self._collectionstack
897 assert not self._collectionstack
899 self._state = self._STATE_NONE
898 self._state = self._STATE_NONE
900
899
901 else:
900 else:
902 raise CBORDecodeError(
901 raise CBORDecodeError(
903 b'unexpected special value when '
902 b'unexpected special value when '
904 b'expecting bytestring chunk: %d' % special
903 b'expecting bytestring chunk: %d' % special
905 )
904 )
906
905
907 # This value represents the non-initial chunk in an indefinite
906 # This value represents the non-initial chunk in an indefinite
908 # length bytestring.
907 # length bytestring.
909 elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT:
908 elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT:
910 # We received a full chunk.
909 # We received a full chunk.
911 if special == SPECIAL_NONE:
910 if special == SPECIAL_NONE:
912 self._decodedvalues.append(bytestringchunk(value))
911 self._decodedvalues.append(bytestringchunk(value))
913
912
914 # The end of stream marker.
913 # The end of stream marker.
915 elif special == SPECIAL_INDEFINITE_BREAK:
914 elif special == SPECIAL_INDEFINITE_BREAK:
916 self._decodedvalues.append(bytestringchunk(b'', last=True))
915 self._decodedvalues.append(bytestringchunk(b'', last=True))
917
916
918 # Since indefinite length bytestrings can't be used in
917 # Since indefinite length bytestrings can't be used in
919 # collections, we must be at the root level.
918 # collections, we must be at the root level.
920 assert not self._collectionstack
919 assert not self._collectionstack
921 self._state = self._STATE_NONE
920 self._state = self._STATE_NONE
922
921
923 else:
922 else:
924 raise CBORDecodeError(
923 raise CBORDecodeError(
925 b'unexpected special value when '
924 b'unexpected special value when '
926 b'expecting bytestring chunk: %d' % special
925 b'expecting bytestring chunk: %d' % special
927 )
926 )
928
927
929 else:
928 else:
930 raise CBORDecodeError(
929 raise CBORDecodeError(
931 b'unhandled decoder state: %d' % self._state
930 b'unhandled decoder state: %d' % self._state
932 )
931 )
933
932
934 # We could have just added the final value in a collection. End
933 # We could have just added the final value in a collection. End
935 # all complete collections at the top of the stack.
934 # all complete collections at the top of the stack.
936 while True:
935 while True:
937 # Bail if we're not waiting on a new collection item.
936 # Bail if we're not waiting on a new collection item.
938 if self._state not in (
937 if self._state not in (
939 self._STATE_WANT_ARRAY_VALUE,
938 self._STATE_WANT_ARRAY_VALUE,
940 self._STATE_WANT_MAP_KEY,
939 self._STATE_WANT_MAP_KEY,
941 self._STATE_WANT_SET_VALUE,
940 self._STATE_WANT_SET_VALUE,
942 ):
941 ):
943 break
942 break
944
943
945 # Or we are expecting more items for this collection.
944 # Or we are expecting more items for this collection.
946 lastc = self._collectionstack[-1]
945 lastc = self._collectionstack[-1]
947
946
948 if lastc[b'remaining']:
947 if lastc[b'remaining']:
949 break
948 break
950
949
951 # The collection at the top of the stack is complete.
950 # The collection at the top of the stack is complete.
952
951
953 # Discard it, as it isn't needed for future items.
952 # Discard it, as it isn't needed for future items.
954 self._collectionstack.pop()
953 self._collectionstack.pop()
955
954
956 # If this is a nested collection, we don't emit it, since it
955 # If this is a nested collection, we don't emit it, since it
957 # will be emitted by its parent collection. But we do need to
956 # will be emitted by its parent collection. But we do need to
958 # update state to reflect what the new top-most collection
957 # update state to reflect what the new top-most collection
959 # on the stack is.
958 # on the stack is.
960 if self._collectionstack:
959 if self._collectionstack:
961 self._state = {
960 self._state = {
962 list: self._STATE_WANT_ARRAY_VALUE,
961 list: self._STATE_WANT_ARRAY_VALUE,
963 dict: self._STATE_WANT_MAP_KEY,
962 dict: self._STATE_WANT_MAP_KEY,
964 set: self._STATE_WANT_SET_VALUE,
963 set: self._STATE_WANT_SET_VALUE,
965 }[type(self._collectionstack[-1][b'v'])]
964 }[type(self._collectionstack[-1][b'v'])]
966
965
967 # If this is the root collection, emit it.
966 # If this is the root collection, emit it.
968 else:
967 else:
969 self._decodedvalues.append(lastc[b'v'])
968 self._decodedvalues.append(lastc[b'v'])
970 self._state = self._STATE_NONE
969 self._state = self._STATE_NONE
971
970
972 return (
971 return (
973 bool(self._decodedvalues),
972 bool(self._decodedvalues),
974 offset - initialoffset,
973 offset - initialoffset,
975 0,
974 0,
976 )
975 )
977
976
978 def getavailable(self):
977 def getavailable(self):
979 """Returns an iterator over fully decoded values.
978 """Returns an iterator over fully decoded values.
980
979
981 Once values are retrieved, they won't be available on the next call.
980 Once values are retrieved, they won't be available on the next call.
982 """
981 """
983
982
984 l = list(self._decodedvalues)
983 l = list(self._decodedvalues)
985 self._decodedvalues = []
984 self._decodedvalues = []
986 return l
985 return l
987
986
988
987
989 class bufferingdecoder(object):
988 class bufferingdecoder(object):
990 """A CBOR decoder that buffers undecoded input.
989 """A CBOR decoder that buffers undecoded input.
991
990
992 This is a glorified wrapper around ``sansiodecoder`` that adds a buffering
991 This is a glorified wrapper around ``sansiodecoder`` that adds a buffering
993 layer. All input that isn't consumed by ``sansiodecoder`` will be buffered
992 layer. All input that isn't consumed by ``sansiodecoder`` will be buffered
994 and concatenated with any new input that arrives later.
993 and concatenated with any new input that arrives later.
995
994
996 TODO consider adding limits as to the maximum amount of data that can
995 TODO consider adding limits as to the maximum amount of data that can
997 be buffered.
996 be buffered.
998 """
997 """
999
998
1000 def __init__(self):
999 def __init__(self):
1001 self._decoder = sansiodecoder()
1000 self._decoder = sansiodecoder()
1002 self._chunks = []
1001 self._chunks = []
1003 self._wanted = 0
1002 self._wanted = 0
1004
1003
1005 def decode(self, b):
1004 def decode(self, b):
1006 """Attempt to decode bytes to CBOR values.
1005 """Attempt to decode bytes to CBOR values.
1007
1006
1008 Returns a tuple with the following fields:
1007 Returns a tuple with the following fields:
1009
1008
1010 * Bool indicating whether new values are available for retrieval.
1009 * Bool indicating whether new values are available for retrieval.
1011 * Integer number of bytes decoded from the new input.
1010 * Integer number of bytes decoded from the new input.
1012 * Integer number of bytes wanted to decode the next value.
1011 * Integer number of bytes wanted to decode the next value.
1013 """
1012 """
1014 # We /might/ be able to support passing a bytearray all the
1013 # We /might/ be able to support passing a bytearray all the
1015 # way through. For now, let's cheat.
1014 # way through. For now, let's cheat.
1016 if isinstance(b, bytearray):
1015 if isinstance(b, bytearray):
1017 b = bytes(b)
1016 b = bytes(b)
1018
1017
1019 # Our strategy for buffering is to aggregate the incoming chunks in a
1018 # Our strategy for buffering is to aggregate the incoming chunks in a
1020 # list until we've received enough data to decode the next item.
1019 # list until we've received enough data to decode the next item.
1021 # This is slightly more complicated than using an ``io.BytesIO``
1020 # This is slightly more complicated than using an ``io.BytesIO``
1022 # or continuously concatenating incoming data. However, because it
1021 # or continuously concatenating incoming data. However, because it
1023 # isn't constantly reallocating backing memory for a growing buffer,
1022 # isn't constantly reallocating backing memory for a growing buffer,
1024 # it prevents excessive memory thrashing and is significantly faster,
1023 # it prevents excessive memory thrashing and is significantly faster,
1025 # especially in cases where the percentage of input chunks that don't
1024 # especially in cases where the percentage of input chunks that don't
1026 # decode into a full item is high.
1025 # decode into a full item is high.
1027
1026
1028 if self._chunks:
1027 if self._chunks:
1029 # A previous call said we needed N bytes to decode the next item.
1028 # A previous call said we needed N bytes to decode the next item.
1030 # But this call doesn't provide enough data. We buffer the incoming
1029 # But this call doesn't provide enough data. We buffer the incoming
1031 # chunk without attempting to decode.
1030 # chunk without attempting to decode.
1032 if len(b) < self._wanted:
1031 if len(b) < self._wanted:
1033 self._chunks.append(b)
1032 self._chunks.append(b)
1034 self._wanted -= len(b)
1033 self._wanted -= len(b)
1035 return False, 0, self._wanted
1034 return False, 0, self._wanted
1036
1035
1037 # Else we may have enough data to decode the next item. Aggregate
1036 # Else we may have enough data to decode the next item. Aggregate
1038 # old data with new and reset the buffer.
1037 # old data with new and reset the buffer.
1039 newlen = len(b)
1038 newlen = len(b)
1040 self._chunks.append(b)
1039 self._chunks.append(b)
1041 b = b''.join(self._chunks)
1040 b = b''.join(self._chunks)
1042 self._chunks = []
1041 self._chunks = []
1043 oldlen = len(b) - newlen
1042 oldlen = len(b) - newlen
1044
1043
1045 else:
1044 else:
1046 oldlen = 0
1045 oldlen = 0
1047
1046
1048 available, readcount, wanted = self._decoder.decode(b)
1047 available, readcount, wanted = self._decoder.decode(b)
1049 self._wanted = wanted
1048 self._wanted = wanted
1050
1049
1051 if readcount < len(b):
1050 if readcount < len(b):
1052 self._chunks.append(b[readcount:])
1051 self._chunks.append(b[readcount:])
1053
1052
1054 return available, readcount - oldlen, wanted
1053 return available, readcount - oldlen, wanted
1055
1054
1056 def getavailable(self):
1055 def getavailable(self):
1057 return self._decoder.getavailable()
1056 return self._decoder.getavailable()
1058
1057
1059
1058
1060 def decodeall(b):
1059 def decodeall(b):
1061 """Decode all CBOR items present in an iterable of bytes.
1060 """Decode all CBOR items present in an iterable of bytes.
1062
1061
1063 In addition to regular decode errors, raises CBORDecodeError if the
1062 In addition to regular decode errors, raises CBORDecodeError if the
1064 entirety of the passed buffer does not fully decode to complete CBOR
1063 entirety of the passed buffer does not fully decode to complete CBOR
1065 values. This includes failure to decode any value, incomplete collection
1064 values. This includes failure to decode any value, incomplete collection
1066 types, incomplete indefinite length items, and extra data at the end of
1065 types, incomplete indefinite length items, and extra data at the end of
1067 the buffer.
1066 the buffer.
1068 """
1067 """
1069 if not b:
1068 if not b:
1070 return []
1069 return []
1071
1070
1072 decoder = sansiodecoder()
1071 decoder = sansiodecoder()
1073
1072
1074 havevalues, readcount, wantbytes = decoder.decode(b)
1073 havevalues, readcount, wantbytes = decoder.decode(b)
1075
1074
1076 if readcount != len(b):
1075 if readcount != len(b):
1077 raise CBORDecodeError(b'input data not fully consumed')
1076 raise CBORDecodeError(b'input data not fully consumed')
1078
1077
1079 if decoder.inprogress:
1078 if decoder.inprogress:
1080 raise CBORDecodeError(b'input data not complete')
1079 raise CBORDecodeError(b'input data not complete')
1081
1080
1082 return decoder.getavailable()
1081 return decoder.getavailable()
General Comments 0
You need to be logged in to leave comments. Login now