diff --git a/mercurial/utils/cborutil.py b/mercurial/utils/cborutil.py --- a/mercurial/utils/cborutil.py +++ b/mercurial/utils/cborutil.py @@ -8,6 +8,7 @@ from __future__ import absolute_import import struct +import sys from ..thirdparty.cbor.cbor2 import ( decoder as decodermod, @@ -35,11 +36,16 @@ MAJOR_TYPE_SPECIAL = 7 SUBTYPE_MASK = 0b00011111 +SUBTYPE_FALSE = 20 +SUBTYPE_TRUE = 21 +SUBTYPE_NULL = 22 SUBTYPE_HALF_FLOAT = 25 SUBTYPE_SINGLE_FLOAT = 26 SUBTYPE_DOUBLE_FLOAT = 27 SUBTYPE_INDEFINITE = 31 +SEMANTIC_TAG_FINITE_SET = 258 + # Indefinite types begin with their major type ORd with information value 31. BEGIN_INDEFINITE_BYTESTRING = struct.pack( r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE) @@ -146,7 +152,7 @@ def _mixedtypesortkey(v): def streamencodeset(s): # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines # semantic tag 258 for finite sets. - yield encodelength(MAJOR_TYPE_SEMANTIC, 258) + yield encodelength(MAJOR_TYPE_SEMANTIC, SEMANTIC_TAG_FINITE_SET) for chunk in streamencodearray(sorted(s, key=_mixedtypesortkey)): yield chunk @@ -260,3 +266,710 @@ def readindefinitebytestringtoiter(fh, e len(chunk), length)) yield chunk + +class CBORDecodeError(Exception): + """Represents an error decoding CBOR.""" + +if sys.version_info.major >= 3: + def _elementtointeger(b, i): + return b[i] +else: + def _elementtointeger(b, i): + return ord(b[i]) + +STRUCT_BIG_UBYTE = struct.Struct(r'>B') +STRUCT_BIG_USHORT = struct.Struct('>H') +STRUCT_BIG_ULONG = struct.Struct('>L') +STRUCT_BIG_ULONGLONG = struct.Struct('>Q') + +SPECIAL_NONE = 0 +SPECIAL_START_INDEFINITE_BYTESTRING = 1 +SPECIAL_START_ARRAY = 2 +SPECIAL_START_MAP = 3 +SPECIAL_START_SET = 4 +SPECIAL_INDEFINITE_BREAK = 5 + +def decodeitem(b, offset=0): + """Decode a new CBOR value from a buffer at offset. + + This function attempts to decode up to one complete CBOR value + from ``b`` starting at offset ``offset``. + + The beginning of a collection (such as an array, map, set, or + indefinite length bytestring) counts as a single value. For these + special cases, a state flag will indicate that a special value was seen. + + When called, the function either returns a decoded value or gives + a hint as to how many more bytes are needed to do so. By calling + the function repeatedly given a stream of bytes, the caller can + build up the original values. + + Returns a tuple with the following elements: + + * Bool indicating whether a complete value was decoded. + * A decoded value if first value is True otherwise None + * Integer number of bytes. If positive, the number of bytes + read. If negative, the number of bytes we need to read to + decode this value or the next chunk in this value. + * One of the ``SPECIAL_*`` constants indicating special treatment + for this value. ``SPECIAL_NONE`` means this is a fully decoded + simple value (such as an integer or bool). + """ + + initial = _elementtointeger(b, offset) + offset += 1 + + majortype = initial >> 5 + subtype = initial & SUBTYPE_MASK + + if majortype == MAJOR_TYPE_UINT: + complete, value, readcount = decodeuint(subtype, b, offset) + + if complete: + return True, value, readcount + 1, SPECIAL_NONE + else: + return False, None, readcount, SPECIAL_NONE + + elif majortype == MAJOR_TYPE_NEGINT: + # Negative integers are the same as UINT except inverted minus 1. + complete, value, readcount = decodeuint(subtype, b, offset) + + if complete: + return True, -value - 1, readcount + 1, SPECIAL_NONE + else: + return False, None, readcount, SPECIAL_NONE + + elif majortype == MAJOR_TYPE_BYTESTRING: + # Beginning of bytestrings are treated as uints in order to + # decode their length, which may be indefinite. + complete, size, readcount = decodeuint(subtype, b, offset, + allowindefinite=True) + + # We don't know the size of the bytestring. It must be a definitive + # length since the indefinite subtype would be encoded in the initial + # byte. + if not complete: + return False, None, readcount, SPECIAL_NONE + + # We know the length of the bytestring. + if size is not None: + # And the data is available in the buffer. + if offset + readcount + size <= len(b): + value = b[offset + readcount:offset + readcount + size] + return True, value, readcount + size + 1, SPECIAL_NONE + + # And we need more data in order to return the bytestring. + else: + wanted = len(b) - offset - readcount - size + return False, None, wanted, SPECIAL_NONE + + # It is an indefinite length bytestring. + else: + return True, None, 1, SPECIAL_START_INDEFINITE_BYTESTRING + + elif majortype == MAJOR_TYPE_STRING: + raise CBORDecodeError('string major type not supported') + + elif majortype == MAJOR_TYPE_ARRAY: + # Beginning of arrays are treated as uints in order to decode their + # length. We don't allow indefinite length arrays. + complete, size, readcount = decodeuint(subtype, b, offset) + + if complete: + return True, size, readcount + 1, SPECIAL_START_ARRAY + else: + return False, None, readcount, SPECIAL_NONE + + elif majortype == MAJOR_TYPE_MAP: + # Beginning of maps are treated as uints in order to decode their + # number of elements. We don't allow indefinite length arrays. + complete, size, readcount = decodeuint(subtype, b, offset) + + if complete: + return True, size, readcount + 1, SPECIAL_START_MAP + else: + return False, None, readcount, SPECIAL_NONE + + elif majortype == MAJOR_TYPE_SEMANTIC: + # Semantic tag value is read the same as a uint. + complete, tagvalue, readcount = decodeuint(subtype, b, offset) + + if not complete: + return False, None, readcount, SPECIAL_NONE + + # This behavior here is a little wonky. The main type being "decorated" + # by this semantic tag follows. A more robust parser would probably emit + # a special flag indicating this as a semantic tag and let the caller + # deal with the types that follow. But since we don't support many + # semantic tags, it is easier to deal with the special cases here and + # hide complexity from the caller. If we add support for more semantic + # tags, we should probably move semantic tag handling into the caller. + if tagvalue == SEMANTIC_TAG_FINITE_SET: + if offset + readcount >= len(b): + return False, None, -1, SPECIAL_NONE + + complete, size, readcount2, special = decodeitem(b, + offset + readcount) + + if not complete: + return False, None, readcount2, SPECIAL_NONE + + if special != SPECIAL_START_ARRAY: + raise CBORDecodeError('expected array after finite set ' + 'semantic tag') + + return True, size, readcount + readcount2 + 1, SPECIAL_START_SET + + else: + raise CBORDecodeError('semantic tag %d not allowed' % tagvalue) + + elif majortype == MAJOR_TYPE_SPECIAL: + # Only specific values for the information field are allowed. + if subtype == SUBTYPE_FALSE: + return True, False, 1, SPECIAL_NONE + elif subtype == SUBTYPE_TRUE: + return True, True, 1, SPECIAL_NONE + elif subtype == SUBTYPE_NULL: + return True, None, 1, SPECIAL_NONE + elif subtype == SUBTYPE_INDEFINITE: + return True, None, 1, SPECIAL_INDEFINITE_BREAK + # If value is 24, subtype is in next byte. + else: + raise CBORDecodeError('special type %d not allowed' % subtype) + else: + assert False + +def decodeuint(subtype, b, offset=0, allowindefinite=False): + """Decode an unsigned integer. + + ``subtype`` is the lower 5 bits from the initial byte CBOR item + "header." ``b`` is a buffer containing bytes. ``offset`` points to + the index of the first byte after the byte that ``subtype`` was + derived from. + + ``allowindefinite`` allows the special indefinite length value + indicator. + + Returns a 3-tuple of (successful, value, count). + + The first element is a bool indicating if decoding completed. The 2nd + is the decoded integer value or None if not fully decoded or the subtype + is 31 and ``allowindefinite`` is True. The 3rd value is the count of bytes. + If positive, it is the number of additional bytes decoded. If negative, + it is the number of additional bytes needed to decode this value. + """ + + # Small values are inline. + if subtype < 24: + return True, subtype, 0 + # Indefinite length specifier. + elif subtype == 31: + if allowindefinite: + return True, None, 0 + else: + raise CBORDecodeError('indefinite length uint not allowed here') + elif subtype >= 28: + raise CBORDecodeError('unsupported subtype on integer type: %d' % + subtype) + + if subtype == 24: + s = STRUCT_BIG_UBYTE + elif subtype == 25: + s = STRUCT_BIG_USHORT + elif subtype == 26: + s = STRUCT_BIG_ULONG + elif subtype == 27: + s = STRUCT_BIG_ULONGLONG + else: + raise CBORDecodeError('bounds condition checking violation') + + if len(b) - offset >= s.size: + return True, s.unpack_from(b, offset)[0], s.size + else: + return False, None, len(b) - offset - s.size + +class bytestringchunk(bytes): + """Represents a chunk/segment in an indefinite length bytestring. + + This behaves like a ``bytes`` but in addition has the ``isfirst`` + and ``islast`` attributes indicating whether this chunk is the first + or last in an indefinite length bytestring. + """ + + def __new__(cls, v, first=False, last=False): + self = bytes.__new__(cls, v) + self.isfirst = first + self.islast = last + + return self + +class sansiodecoder(object): + """A CBOR decoder that doesn't perform its own I/O. + + To use, construct an instance and feed it segments containing + CBOR-encoded bytes via ``decode()``. The return value from ``decode()`` + indicates whether a fully-decoded value is available, how many bytes + were consumed, and offers a hint as to how many bytes should be fed + in next time to decode the next value. + + The decoder assumes it will decode N discrete CBOR values, not just + a single value. i.e. if the bytestream contains uints packed one after + the other, the decoder will decode them all, rather than just the initial + one. + + When ``decode()`` indicates a value is available, call ``getavailable()`` + to return all fully decoded values. + + ``decode()`` can partially decode input. It is up to the caller to keep + track of what data was consumed and to pass unconsumed data in on the + next invocation. + + The decoder decodes atomically at the *item* level. See ``decodeitem()``. + If an *item* cannot be fully decoded, the decoder won't record it as + partially consumed. Instead, the caller will be instructed to pass in + the initial bytes of this item on the next invocation. This does result + in some redundant parsing. But the overhead should be minimal. + + This decoder only supports a subset of CBOR as required by Mercurial. + It lacks support for: + + * Indefinite length arrays + * Indefinite length maps + * Use of indefinite length bytestrings as keys or values within + arrays, maps, or sets. + * Nested arrays, maps, or sets within sets + * Any semantic tag that isn't a mathematical finite set + * Floating point numbers + * Undefined special value + + CBOR types are decoded to Python types as follows: + + uint -> int + negint -> int + bytestring -> bytes + map -> dict + array -> list + True -> bool + False -> bool + null -> None + indefinite length bytestring chunk -> [bytestringchunk] + + The only non-obvious mapping here is an indefinite length bytestring + to the ``bytestringchunk`` type. This is to facilitate streaming + indefinite length bytestrings out of the decoder and to differentiate + a regular bytestring from an indefinite length bytestring. + """ + + _STATE_NONE = 0 + _STATE_WANT_MAP_KEY = 1 + _STATE_WANT_MAP_VALUE = 2 + _STATE_WANT_ARRAY_VALUE = 3 + _STATE_WANT_SET_VALUE = 4 + _STATE_WANT_BYTESTRING_CHUNK_FIRST = 5 + _STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT = 6 + + def __init__(self): + # TODO add support for limiting size of bytestrings + # TODO add support for limiting number of keys / values in collections + # TODO add support for limiting size of buffered partial values + + self.decodedbytecount = 0 + + self._state = self._STATE_NONE + + # Stack of active nested collections. Each entry is a dict describing + # the collection. + self._collectionstack = [] + + # Fully decoded key to use for the current map. + self._currentmapkey = None + + # Fully decoded values available for retrieval. + self._decodedvalues = [] + + @property + def inprogress(self): + """Whether the decoder has partially decoded a value.""" + return self._state != self._STATE_NONE + + def decode(self, b, offset=0): + """Attempt to decode bytes from an input buffer. + + ``b`` is a collection of bytes and ``offset`` is the byte + offset within that buffer from which to begin reading data. + + ``b`` must support ``len()`` and accessing bytes slices via + ``__slice__``. Typically ``bytes`` instances are used. + + Returns a tuple with the following fields: + + * Bool indicating whether values are available for retrieval. + * Integer indicating the number of bytes that were fully consumed, + starting from ``offset``. + * Integer indicating the number of bytes that are desired for the + next call in order to decode an item. + """ + if not b: + return bool(self._decodedvalues), 0, 0 + + initialoffset = offset + + # We could easily split the body of this loop into a function. But + # Python performance is sensitive to function calls and collections + # are composed of many items. So leaving as a while loop could help + # with performance. One thing that may not help is the use of + # if..elif versus a lookup/dispatch table. There may be value + # in switching that. + while offset < len(b): + # Attempt to decode an item. This could be a whole value or a + # special value indicating an event, such as start or end of a + # collection or indefinite length type. + complete, value, readcount, special = decodeitem(b, offset) + + if readcount > 0: + self.decodedbytecount += readcount + + if not complete: + assert readcount < 0 + return ( + bool(self._decodedvalues), + offset - initialoffset, + -readcount, + ) + + offset += readcount + + # No nested state. We either have a full value or beginning of a + # complex value to deal with. + if self._state == self._STATE_NONE: + # A normal value. + if special == SPECIAL_NONE: + self._decodedvalues.append(value) + + elif special == SPECIAL_START_ARRAY: + self._collectionstack.append({ + 'remaining': value, + 'v': [], + }) + self._state = self._STATE_WANT_ARRAY_VALUE + + elif special == SPECIAL_START_MAP: + self._collectionstack.append({ + 'remaining': value, + 'v': {}, + }) + self._state = self._STATE_WANT_MAP_KEY + + elif special == SPECIAL_START_SET: + self._collectionstack.append({ + 'remaining': value, + 'v': set(), + }) + self._state = self._STATE_WANT_SET_VALUE + + elif special == SPECIAL_START_INDEFINITE_BYTESTRING: + self._state = self._STATE_WANT_BYTESTRING_CHUNK_FIRST + + else: + raise CBORDecodeError('unhandled special state: %d' % + special) + + # This value becomes an element of the current array. + elif self._state == self._STATE_WANT_ARRAY_VALUE: + # Simple values get appended. + if special == SPECIAL_NONE: + c = self._collectionstack[-1] + c['v'].append(value) + c['remaining'] -= 1 + + # self._state doesn't need changed. + + # An array nested within an array. + elif special == SPECIAL_START_ARRAY: + lastc = self._collectionstack[-1] + newvalue = [] + + lastc['v'].append(newvalue) + lastc['remaining'] -= 1 + + self._collectionstack.append({ + 'remaining': value, + 'v': newvalue, + }) + + # self._state doesn't need changed. + + # A map nested within an array. + elif special == SPECIAL_START_MAP: + lastc = self._collectionstack[-1] + newvalue = {} + + lastc['v'].append(newvalue) + lastc['remaining'] -= 1 + + self._collectionstack.append({ + 'remaining': value, + 'v': newvalue + }) + + self._state = self._STATE_WANT_MAP_KEY + + elif special == SPECIAL_START_SET: + lastc = self._collectionstack[-1] + newvalue = set() + + lastc['v'].append(newvalue) + lastc['remaining'] -= 1 + + self._collectionstack.append({ + 'remaining': value, + 'v': newvalue, + }) + + self._state = self._STATE_WANT_SET_VALUE + + elif special == SPECIAL_START_INDEFINITE_BYTESTRING: + raise CBORDecodeError('indefinite length bytestrings ' + 'not allowed as array values') + + else: + raise CBORDecodeError('unhandled special item when ' + 'expecting array value: %d' % special) + + # This value becomes the key of the current map instance. + elif self._state == self._STATE_WANT_MAP_KEY: + if special == SPECIAL_NONE: + self._currentmapkey = value + self._state = self._STATE_WANT_MAP_VALUE + + elif special == SPECIAL_START_INDEFINITE_BYTESTRING: + raise CBORDecodeError('indefinite length bytestrings ' + 'not allowed as map keys') + + elif special in (SPECIAL_START_ARRAY, SPECIAL_START_MAP, + SPECIAL_START_SET): + raise CBORDecodeError('collections not supported as map ' + 'keys') + + # We do not allow special values to be used as map keys. + else: + raise CBORDecodeError('unhandled special item when ' + 'expecting map key: %d' % special) + + # This value becomes the value of the current map key. + elif self._state == self._STATE_WANT_MAP_VALUE: + # Simple values simply get inserted into the map. + if special == SPECIAL_NONE: + lastc = self._collectionstack[-1] + lastc['v'][self._currentmapkey] = value + lastc['remaining'] -= 1 + + self._state = self._STATE_WANT_MAP_KEY + + # A new array is used as the map value. + elif special == SPECIAL_START_ARRAY: + lastc = self._collectionstack[-1] + newvalue = [] + + lastc['v'][self._currentmapkey] = newvalue + lastc['remaining'] -= 1 + + self._collectionstack.append({ + 'remaining': value, + 'v': newvalue, + }) + + self._state = self._STATE_WANT_ARRAY_VALUE + + # A new map is used as the map value. + elif special == SPECIAL_START_MAP: + lastc = self._collectionstack[-1] + newvalue = {} + + lastc['v'][self._currentmapkey] = newvalue + lastc['remaining'] -= 1 + + self._collectionstack.append({ + 'remaining': value, + 'v': newvalue, + }) + + self._state = self._STATE_WANT_MAP_KEY + + # A new set is used as the map value. + elif special == SPECIAL_START_SET: + lastc = self._collectionstack[-1] + newvalue = set() + + lastc['v'][self._currentmapkey] = newvalue + lastc['remaining'] -= 1 + + self._collectionstack.append({ + 'remaining': value, + 'v': newvalue, + }) + + self._state = self._STATE_WANT_SET_VALUE + + elif special == SPECIAL_START_INDEFINITE_BYTESTRING: + raise CBORDecodeError('indefinite length bytestrings not ' + 'allowed as map values') + + else: + raise CBORDecodeError('unhandled special item when ' + 'expecting map value: %d' % special) + + self._currentmapkey = None + + # This value is added to the current set. + elif self._state == self._STATE_WANT_SET_VALUE: + if special == SPECIAL_NONE: + lastc = self._collectionstack[-1] + lastc['v'].add(value) + lastc['remaining'] -= 1 + + elif special == SPECIAL_START_INDEFINITE_BYTESTRING: + raise CBORDecodeError('indefinite length bytestrings not ' + 'allowed as set values') + + elif special in (SPECIAL_START_ARRAY, + SPECIAL_START_MAP, + SPECIAL_START_SET): + raise CBORDecodeError('collections not allowed as set ' + 'values') + + # We don't allow non-trivial types to exist as set values. + else: + raise CBORDecodeError('unhandled special item when ' + 'expecting set value: %d' % special) + + # This value represents the first chunk in an indefinite length + # bytestring. + elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_FIRST: + # We received a full chunk. + if special == SPECIAL_NONE: + self._decodedvalues.append(bytestringchunk(value, + first=True)) + + self._state = self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT + + # The end of stream marker. This means it is an empty + # indefinite length bytestring. + elif special == SPECIAL_INDEFINITE_BREAK: + # We /could/ convert this to a b''. But we want to preserve + # the nature of the underlying data so consumers expecting + # an indefinite length bytestring get one. + self._decodedvalues.append(bytestringchunk(b'', + first=True, + last=True)) + + # Since indefinite length bytestrings can't be used in + # collections, we must be at the root level. + assert not self._collectionstack + self._state = self._STATE_NONE + + else: + raise CBORDecodeError('unexpected special value when ' + 'expecting bytestring chunk: %d' % + special) + + # This value represents the non-initial chunk in an indefinite + # length bytestring. + elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT: + # We received a full chunk. + if special == SPECIAL_NONE: + self._decodedvalues.append(bytestringchunk(value)) + + # The end of stream marker. + elif special == SPECIAL_INDEFINITE_BREAK: + self._decodedvalues.append(bytestringchunk(b'', last=True)) + + # Since indefinite length bytestrings can't be used in + # collections, we must be at the root level. + assert not self._collectionstack + self._state = self._STATE_NONE + + else: + raise CBORDecodeError('unexpected special value when ' + 'expecting bytestring chunk: %d' % + special) + + else: + raise CBORDecodeError('unhandled decoder state: %d' % + self._state) + + # We could have just added the final value in a collection. End + # all complete collections at the top of the stack. + while True: + # Bail if we're not waiting on a new collection item. + if self._state not in (self._STATE_WANT_ARRAY_VALUE, + self._STATE_WANT_MAP_KEY, + self._STATE_WANT_SET_VALUE): + break + + # Or we are expecting more items for this collection. + lastc = self._collectionstack[-1] + + if lastc['remaining']: + break + + # The collection at the top of the stack is complete. + + # Discard it, as it isn't needed for future items. + self._collectionstack.pop() + + # If this is a nested collection, we don't emit it, since it + # will be emitted by its parent collection. But we do need to + # update state to reflect what the new top-most collection + # on the stack is. + if self._collectionstack: + self._state = { + list: self._STATE_WANT_ARRAY_VALUE, + dict: self._STATE_WANT_MAP_KEY, + set: self._STATE_WANT_SET_VALUE, + }[type(self._collectionstack[-1]['v'])] + + # If this is the root collection, emit it. + else: + self._decodedvalues.append(lastc['v']) + self._state = self._STATE_NONE + + return ( + bool(self._decodedvalues), + offset - initialoffset, + 0, + ) + + def getavailable(self): + """Returns an iterator over fully decoded values. + + Once values are retrieved, they won't be available on the next call. + """ + + l = list(self._decodedvalues) + self._decodedvalues = [] + return l + +def decodeall(b): + """Decode all CBOR items present in an iterable of bytes. + + In addition to regular decode errors, raises CBORDecodeError if the + entirety of the passed buffer does not fully decode to complete CBOR + values. This includes failure to decode any value, incomplete collection + types, incomplete indefinite length items, and extra data at the end of + the buffer. + """ + if not b: + return [] + + decoder = sansiodecoder() + + havevalues, readcount, wantbytes = decoder.decode(b) + + if readcount != len(b): + raise CBORDecodeError('input data not fully consumed') + + if decoder.inprogress: + raise CBORDecodeError('input data not complete') + + return decoder.getavailable() diff --git a/tests/test-cbor.py b/tests/test-cbor.py --- a/tests/test-cbor.py +++ b/tests/test-cbor.py @@ -10,10 +10,17 @@ from mercurial.utils import ( cborutil, ) +class TestCase(unittest.TestCase): + if not getattr(unittest.TestCase, 'assertRaisesRegex', False): + # Python 3.7 deprecates the regex*p* version, but 2.7 lacks + # the regex version. + assertRaisesRegex = (# camelcase-required + unittest.TestCase.assertRaisesRegexp) + def loadit(it): return cbor.loads(b''.join(it)) -class BytestringTests(unittest.TestCase): +class BytestringTests(TestCase): def testsimple(self): self.assertEqual( list(cborutil.streamencode(b'foobar')), @@ -23,11 +30,20 @@ class BytestringTests(unittest.TestCase) loadit(cborutil.streamencode(b'foobar')), b'foobar') + self.assertEqual(cborutil.decodeall(b'\x46foobar'), + [b'foobar']) + + self.assertEqual(cborutil.decodeall(b'\x46foobar\x45fizbi'), + [b'foobar', b'fizbi']) + def testlong(self): source = b'x' * 1048576 self.assertEqual(loadit(cborutil.streamencode(source)), source) + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeall(encoded), [source]) + def testfromiter(self): # This is the example from RFC 7049 Section 2.2.2. source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99'] @@ -47,6 +63,25 @@ class BytestringTests(unittest.TestCase) loadit(cborutil.streamencodebytestringfromiter(source)), b''.join(source)) + self.assertEqual(cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd' + b'\x43\xee\xff\x99\xff'), + [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99', b'']) + + for i, chunk in enumerate( + cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd' + b'\x43\xee\xff\x99\xff')): + self.assertIsInstance(chunk, cborutil.bytestringchunk) + + if i == 0: + self.assertTrue(chunk.isfirst) + else: + self.assertFalse(chunk.isfirst) + + if i == 2: + self.assertTrue(chunk.islast) + else: + self.assertFalse(chunk.islast) + def testfromiterlarge(self): source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576] @@ -71,6 +106,18 @@ class BytestringTests(unittest.TestCase) source, chunksize=42)) self.assertEqual(cbor.loads(dest), source) + self.assertEqual(b''.join(cborutil.decodeall(dest)), source) + + for chunk in cborutil.decodeall(dest): + self.assertIsInstance(chunk, cborutil.bytestringchunk) + self.assertIn(len(chunk), (0, 8, 42)) + + encoded = b'\x5f\xff' + b = cborutil.decodeall(encoded) + self.assertEqual(b, [b'']) + self.assertTrue(b[0].isfirst) + self.assertTrue(b[0].islast) + def testreadtoiter(self): source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff') @@ -81,42 +128,405 @@ class BytestringTests(unittest.TestCase) with self.assertRaises(StopIteration): next(it) -class IntTests(unittest.TestCase): + def testdecodevariouslengths(self): + for i in (0, 1, 22, 23, 24, 25, 254, 255, 256, 65534, 65535, 65536): + source = b'x' * i + encoded = b''.join(cborutil.streamencode(source)) + + if len(source) < 24: + hlen = 1 + elif len(source) < 256: + hlen = 2 + elif len(source) < 65536: + hlen = 3 + elif len(source) < 1048576: + hlen = 5 + + self.assertEqual(cborutil.decodeitem(encoded), + (True, source, hlen + len(source), + cborutil.SPECIAL_NONE)) + + def testpartialdecode(self): + encoded = b''.join(cborutil.streamencode(b'foobar')) + + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -6, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -5, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (False, None, -4, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (False, None, -3, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:5]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:6]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:7]), + (True, b'foobar', 7, cborutil.SPECIAL_NONE)) + + def testpartialdecodevariouslengths(self): + lens = [ + 2, + 3, + 10, + 23, + 24, + 25, + 31, + 100, + 254, + 255, + 256, + 257, + 16384, + 65534, + 65535, + 65536, + 65537, + 131071, + 131072, + 131073, + 1048575, + 1048576, + 1048577, + ] + + for size in lens: + if size < 24: + hlen = 1 + elif size < 2**8: + hlen = 2 + elif size < 2**16: + hlen = 3 + elif size < 2**32: + hlen = 5 + else: + assert False + + source = b'x' * size + encoded = b''.join(cborutil.streamencode(source)) + + res = cborutil.decodeitem(encoded[0:1]) + + if hlen > 1: + self.assertEqual(res, (False, None, -(hlen - 1), + cborutil.SPECIAL_NONE)) + else: + self.assertEqual(res, (False, None, -(size + hlen - 1), + cborutil.SPECIAL_NONE)) + + # Decoding partial header reports remaining header size. + for i in range(hlen - 1): + self.assertEqual(cborutil.decodeitem(encoded[0:i + 1]), + (False, None, -(hlen - i - 1), + cborutil.SPECIAL_NONE)) + + # Decoding complete header reports item size. + self.assertEqual(cborutil.decodeitem(encoded[0:hlen]), + (False, None, -size, cborutil.SPECIAL_NONE)) + + # Decoding single byte after header reports item size - 1 + self.assertEqual(cborutil.decodeitem(encoded[0:hlen + 1]), + (False, None, -(size - 1), cborutil.SPECIAL_NONE)) + + # Decoding all but the last byte reports -1 needed. + self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size - 1]), + (False, None, -1, cborutil.SPECIAL_NONE)) + + # Decoding last byte retrieves value. + self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size]), + (True, source, hlen + size, cborutil.SPECIAL_NONE)) + + def testindefinitepartialdecode(self): + encoded = b''.join(cborutil.streamencodebytestringfromiter( + [b'foobar', b'biz'])) + + # First item should be begin of bytestring special. + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (True, None, 1, + cborutil.SPECIAL_START_INDEFINITE_BYTESTRING)) + + # Second item should be the first chunk. But only available when + # we give it 7 bytes (1 byte header + 6 byte chunk). + self.assertEqual(cborutil.decodeitem(encoded[1:2]), + (False, None, -6, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[1:3]), + (False, None, -5, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[1:4]), + (False, None, -4, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[1:5]), + (False, None, -3, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[1:6]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[1:7]), + (False, None, -1, cborutil.SPECIAL_NONE)) + + self.assertEqual(cborutil.decodeitem(encoded[1:8]), + (True, b'foobar', 7, cborutil.SPECIAL_NONE)) + + # Third item should be second chunk. But only available when + # we give it 4 bytes (1 byte header + 3 byte chunk). + self.assertEqual(cborutil.decodeitem(encoded[8:9]), + (False, None, -3, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[8:10]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[8:11]), + (False, None, -1, cborutil.SPECIAL_NONE)) + + self.assertEqual(cborutil.decodeitem(encoded[8:12]), + (True, b'biz', 4, cborutil.SPECIAL_NONE)) + + # Fourth item should be end of indefinite stream marker. + self.assertEqual(cborutil.decodeitem(encoded[12:13]), + (True, None, 1, cborutil.SPECIAL_INDEFINITE_BREAK)) + + # Now test the behavior when going through the decoder. + + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:1]), + (False, 1, 0)) + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:2]), + (False, 1, 6)) + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:3]), + (False, 1, 5)) + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:4]), + (False, 1, 4)) + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:5]), + (False, 1, 3)) + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:6]), + (False, 1, 2)) + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:7]), + (False, 1, 1)) + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:8]), + (True, 8, 0)) + + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:9]), + (True, 8, 3)) + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:10]), + (True, 8, 2)) + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:11]), + (True, 8, 1)) + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:12]), + (True, 12, 0)) + + self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:13]), + (True, 13, 0)) + + decoder = cborutil.sansiodecoder() + decoder.decode(encoded[0:8]) + values = decoder.getavailable() + self.assertEqual(values, [b'foobar']) + self.assertTrue(values[0].isfirst) + self.assertFalse(values[0].islast) + + self.assertEqual(decoder.decode(encoded[8:12]), + (True, 4, 0)) + values = decoder.getavailable() + self.assertEqual(values, [b'biz']) + self.assertFalse(values[0].isfirst) + self.assertFalse(values[0].islast) + + self.assertEqual(decoder.decode(encoded[12:]), + (True, 1, 0)) + values = decoder.getavailable() + self.assertEqual(values, [b'']) + self.assertFalse(values[0].isfirst) + self.assertTrue(values[0].islast) + +class StringTests(TestCase): + def testdecodeforbidden(self): + encoded = b'\x63foo' + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'string major type not supported'): + cborutil.decodeall(encoded) + +class IntTests(TestCase): def testsmall(self): self.assertEqual(list(cborutil.streamencode(0)), [b'\x00']) + self.assertEqual(cborutil.decodeall(b'\x00'), [0]) + self.assertEqual(list(cborutil.streamencode(1)), [b'\x01']) + self.assertEqual(cborutil.decodeall(b'\x01'), [1]) + self.assertEqual(list(cborutil.streamencode(2)), [b'\x02']) + self.assertEqual(cborutil.decodeall(b'\x02'), [2]) + self.assertEqual(list(cborutil.streamencode(3)), [b'\x03']) + self.assertEqual(cborutil.decodeall(b'\x03'), [3]) + self.assertEqual(list(cborutil.streamencode(4)), [b'\x04']) + self.assertEqual(cborutil.decodeall(b'\x04'), [4]) + + # Multiple value decode works. + self.assertEqual(cborutil.decodeall(b'\x00\x01\x02\x03\x04'), + [0, 1, 2, 3, 4]) def testnegativesmall(self): self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20']) + self.assertEqual(cborutil.decodeall(b'\x20'), [-1]) + self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21']) + self.assertEqual(cborutil.decodeall(b'\x21'), [-2]) + self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22']) + self.assertEqual(cborutil.decodeall(b'\x22'), [-3]) + self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23']) + self.assertEqual(cborutil.decodeall(b'\x23'), [-4]) + self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24']) + self.assertEqual(cborutil.decodeall(b'\x24'), [-5]) + + # Multiple value decode works. + self.assertEqual(cborutil.decodeall(b'\x20\x21\x22\x23\x24'), + [-1, -2, -3, -4, -5]) def testrange(self): for i in range(-70000, 70000, 10): - self.assertEqual( - b''.join(cborutil.streamencode(i)), - cbor.dumps(i)) + encoded = b''.join(cborutil.streamencode(i)) + + self.assertEqual(encoded, cbor.dumps(i)) + self.assertEqual(cborutil.decodeall(encoded), [i]) + + def testdecodepartialubyte(self): + encoded = b''.join(cborutil.streamencode(250)) + + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (True, 250, 2, cborutil.SPECIAL_NONE)) + + def testdecodepartialbyte(self): + encoded = b''.join(cborutil.streamencode(-42)) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (True, -42, 2, cborutil.SPECIAL_NONE)) + + def testdecodepartialushort(self): + encoded = b''.join(cborutil.streamencode(2**15)) + + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:5]), + (True, 2**15, 3, cborutil.SPECIAL_NONE)) + + def testdecodepartialshort(self): + encoded = b''.join(cborutil.streamencode(-1024)) + + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (True, -1024, 3, cborutil.SPECIAL_NONE)) + + def testdecodepartialulong(self): + encoded = b''.join(cborutil.streamencode(2**28)) + + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -4, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -3, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:5]), + (True, 2**28, 5, cborutil.SPECIAL_NONE)) + + def testdecodepartiallong(self): + encoded = b''.join(cborutil.streamencode(-1048580)) -class ArrayTests(unittest.TestCase): + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -4, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -3, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:5]), + (True, -1048580, 5, cborutil.SPECIAL_NONE)) + + def testdecodepartialulonglong(self): + encoded = b''.join(cborutil.streamencode(2**32)) + + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -8, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -7, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (False, None, -6, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (False, None, -5, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:5]), + (False, None, -4, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:6]), + (False, None, -3, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:7]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:8]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:9]), + (True, 2**32, 9, cborutil.SPECIAL_NONE)) + + with self.assertRaisesRegex( + cborutil.CBORDecodeError, 'input data not fully consumed'): + cborutil.decodeall(encoded[0:1]) + + with self.assertRaisesRegex( + cborutil.CBORDecodeError, 'input data not fully consumed'): + cborutil.decodeall(encoded[0:2]) + + def testdecodepartiallonglong(self): + encoded = b''.join(cborutil.streamencode(-7000000000)) + + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -8, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -7, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (False, None, -6, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (False, None, -5, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:5]), + (False, None, -4, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:6]), + (False, None, -3, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:7]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:8]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:9]), + (True, -7000000000, 9, cborutil.SPECIAL_NONE)) + +class ArrayTests(TestCase): def testempty(self): self.assertEqual(list(cborutil.streamencode([])), [b'\x80']) self.assertEqual(loadit(cborutil.streamencode([])), []) + self.assertEqual(cborutil.decodeall(b'\x80'), [[]]) + def testbasic(self): source = [b'foo', b'bar', 1, -10] - self.assertEqual(list(cborutil.streamencode(source)), [ - b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29']) + chunks = [ + b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'] + + self.assertEqual(list(cborutil.streamencode(source)), chunks) + + self.assertEqual(cborutil.decodeall(b''.join(chunks)), [source]) def testemptyfromiter(self): self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])), b'\x9f\xff') + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'indefinite length uint not allowed'): + cborutil.decodeall(b'\x9f\xff') + def testfromiter1(self): source = [b'foo'] @@ -129,26 +539,193 @@ class ArrayTests(unittest.TestCase): dest = b''.join(cborutil.streamencodearrayfromiter(source)) self.assertEqual(cbor.loads(dest), source) + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'indefinite length uint not allowed'): + cborutil.decodeall(dest) + def testtuple(self): source = (b'foo', None, 42) + encoded = b''.join(cborutil.streamencode(source)) - self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))), - list(source)) + self.assertEqual(cbor.loads(encoded), list(source)) + + self.assertEqual(cborutil.decodeall(encoded), [list(source)]) + + def testpartialdecode(self): + source = list(range(4)) + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (True, 4, 1, cborutil.SPECIAL_START_ARRAY)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (True, 4, 1, cborutil.SPECIAL_START_ARRAY)) + + source = list(range(23)) + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (True, 23, 1, cborutil.SPECIAL_START_ARRAY)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (True, 23, 1, cborutil.SPECIAL_START_ARRAY)) + + source = list(range(24)) + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (True, 24, 2, cborutil.SPECIAL_START_ARRAY)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (True, 24, 2, cborutil.SPECIAL_START_ARRAY)) -class SetTests(unittest.TestCase): + source = list(range(256)) + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (True, 256, 3, cborutil.SPECIAL_START_ARRAY)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (True, 256, 3, cborutil.SPECIAL_START_ARRAY)) + + def testnested(self): + source = [[], [], [[], [], []]] + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeall(encoded), [source]) + + source = [True, None, [True, 0, 2], [None], [], [[[]], -87]] + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeall(encoded), [source]) + + # A set within an array. + source = [None, {b'foo', b'bar', None, False}, set()] + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeall(encoded), [source]) + + # A map within an array. + source = [None, {}, {b'foo': b'bar', True: False}, [{}]] + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeall(encoded), [source]) + + def testindefinitebytestringvalues(self): + # Single value array whose value is an empty indefinite bytestring. + encoded = b'\x81\x5f\x40\xff' + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'indefinite length bytestrings not ' + 'allowed as array values'): + cborutil.decodeall(encoded) + +class SetTests(TestCase): def testempty(self): self.assertEqual(list(cborutil.streamencode(set())), [ b'\xd9\x01\x02', b'\x80', ]) + self.assertEqual(cborutil.decodeall(b'\xd9\x01\x02\x80'), [set()]) + def testset(self): source = {b'foo', None, 42} + encoded = b''.join(cborutil.streamencode(source)) - self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))), - source) + self.assertEqual(cbor.loads(encoded), source) + + self.assertEqual(cborutil.decodeall(encoded), [source]) + + def testinvalidtag(self): + # Must use array to encode sets. + encoded = b'\xd9\x01\x02\xa0' + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'expected array after finite set ' + 'semantic tag'): + cborutil.decodeall(encoded) + + def testpartialdecode(self): + # Semantic tag item will be 3 bytes. Set header will be variable + # depending on length. + encoded = b''.join(cborutil.streamencode({i for i in range(23)})) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (True, 23, 4, cborutil.SPECIAL_START_SET)) + self.assertEqual(cborutil.decodeitem(encoded[0:5]), + (True, 23, 4, cborutil.SPECIAL_START_SET)) + + encoded = b''.join(cborutil.streamencode({i for i in range(24)})) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:5]), + (True, 24, 5, cborutil.SPECIAL_START_SET)) + self.assertEqual(cborutil.decodeitem(encoded[0:6]), + (True, 24, 5, cborutil.SPECIAL_START_SET)) -class BoolTests(unittest.TestCase): + encoded = b''.join(cborutil.streamencode({i for i in range(256)})) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:5]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:6]), + (True, 256, 6, cborutil.SPECIAL_START_SET)) + + def testinvalidvalue(self): + encoded = b''.join([ + b'\xd9\x01\x02', # semantic tag + b'\x81', # array of size 1 + b'\x5f\x43foo\xff', # indefinite length bytestring "foo" + ]) + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'indefinite length bytestrings not ' + 'allowed as set values'): + cborutil.decodeall(encoded) + + encoded = b''.join([ + b'\xd9\x01\x02', + b'\x81', + b'\x80', # empty array + ]) + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'collections not allowed as set values'): + cborutil.decodeall(encoded) + + encoded = b''.join([ + b'\xd9\x01\x02', + b'\x81', + b'\xa0', # empty map + ]) + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'collections not allowed as set values'): + cborutil.decodeall(encoded) + + encoded = b''.join([ + b'\xd9\x01\x02', + b'\x81', + b'\xd9\x01\x02\x81\x01', # set with integer 1 + ]) + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'collections not allowed as set values'): + cborutil.decodeall(encoded) + +class BoolTests(TestCase): def testbasic(self): self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5']) self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4']) @@ -156,23 +733,38 @@ class BoolTests(unittest.TestCase): self.assertIs(loadit(cborutil.streamencode(True)), True) self.assertIs(loadit(cborutil.streamencode(False)), False) -class NoneTests(unittest.TestCase): + self.assertEqual(cborutil.decodeall(b'\xf4'), [False]) + self.assertEqual(cborutil.decodeall(b'\xf5'), [True]) + + self.assertEqual(cborutil.decodeall(b'\xf4\xf5\xf5\xf4'), + [False, True, True, False]) + +class NoneTests(TestCase): def testbasic(self): self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6']) self.assertIs(loadit(cborutil.streamencode(None)), None) -class MapTests(unittest.TestCase): + self.assertEqual(cborutil.decodeall(b'\xf6'), [None]) + self.assertEqual(cborutil.decodeall(b'\xf6\xf6'), [None, None]) + +class MapTests(TestCase): def testempty(self): self.assertEqual(list(cborutil.streamencode({})), [b'\xa0']) self.assertEqual(loadit(cborutil.streamencode({})), {}) + self.assertEqual(cborutil.decodeall(b'\xa0'), [{}]) + def testemptyindefinite(self): self.assertEqual(list(cborutil.streamencodemapfromiter([])), [ b'\xbf', b'\xff']) self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {}) + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'indefinite length uint not allowed'): + cborutil.decodeall(b'\xbf\xff') + def testone(self): source = {b'foo': b'bar'} self.assertEqual(list(cborutil.streamencode(source)), [ @@ -180,6 +772,8 @@ class MapTests(unittest.TestCase): self.assertEqual(loadit(cborutil.streamencode(source)), source) + self.assertEqual(cborutil.decodeall(b'\xa1\x43foo\x43bar'), [source]) + def testmultiple(self): source = { b'foo': b'bar', @@ -192,6 +786,9 @@ class MapTests(unittest.TestCase): loadit(cborutil.streamencodemapfromiter(source.items())), source) + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeall(encoded), [source]) + def testcomplex(self): source = { b'key': 1, @@ -205,6 +802,170 @@ class MapTests(unittest.TestCase): loadit(cborutil.streamencodemapfromiter(source.items())), source) + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeall(encoded), [source]) + + def testnested(self): + source = {b'key1': None, b'key2': {b'sub1': b'sub2'}, b'sub2': {}} + encoded = b''.join(cborutil.streamencode(source)) + + self.assertEqual(cborutil.decodeall(encoded), [source]) + + source = { + b'key1': [], + b'key2': [None, False], + b'key3': {b'foo', b'bar'}, + b'key4': {}, + } + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeall(encoded), [source]) + + def testillegalkey(self): + encoded = b''.join([ + # map header + len 1 + b'\xa1', + # indefinite length bytestring "foo" in key position + b'\x5f\x03foo\xff' + ]) + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'indefinite length bytestrings not ' + 'allowed as map keys'): + cborutil.decodeall(encoded) + + encoded = b''.join([ + b'\xa1', + b'\x80', # empty array + b'\x43foo', + ]) + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'collections not supported as map keys'): + cborutil.decodeall(encoded) + + def testillegalvalue(self): + encoded = b''.join([ + b'\xa1', # map headers + b'\x43foo', # key + b'\x5f\x03bar\xff', # indefinite length value + ]) + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'indefinite length bytestrings not ' + 'allowed as map values'): + cborutil.decodeall(encoded) + + def testpartialdecode(self): + source = {b'key1': b'value1'} + encoded = b''.join(cborutil.streamencode(source)) + + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (True, 1, 1, cborutil.SPECIAL_START_MAP)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (True, 1, 1, cborutil.SPECIAL_START_MAP)) + + source = {b'key%d' % i: None for i in range(23)} + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (True, 23, 1, cborutil.SPECIAL_START_MAP)) + + source = {b'key%d' % i: None for i in range(24)} + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (True, 24, 2, cborutil.SPECIAL_START_MAP)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (True, 24, 2, cborutil.SPECIAL_START_MAP)) + + source = {b'key%d' % i: None for i in range(256)} + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (True, 256, 3, cborutil.SPECIAL_START_MAP)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (True, 256, 3, cborutil.SPECIAL_START_MAP)) + + source = {b'key%d' % i: None for i in range(65536)} + encoded = b''.join(cborutil.streamencode(source)) + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -4, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -3, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:3]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:4]), + (False, None, -1, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:5]), + (True, 65536, 5, cborutil.SPECIAL_START_MAP)) + self.assertEqual(cborutil.decodeitem(encoded[0:6]), + (True, 65536, 5, cborutil.SPECIAL_START_MAP)) + +class SemanticTagTests(TestCase): + def testdecodeforbidden(self): + for i in range(500): + if i == cborutil.SEMANTIC_TAG_FINITE_SET: + continue + + tag = cborutil.encodelength(cborutil.MAJOR_TYPE_SEMANTIC, + i) + + encoded = tag + cborutil.encodelength(cborutil.MAJOR_TYPE_UINT, 42) + + # Partial decode is incomplete. + if i < 24: + pass + elif i < 256: + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -1, cborutil.SPECIAL_NONE)) + elif i < 65536: + self.assertEqual(cborutil.decodeitem(encoded[0:1]), + (False, None, -2, cborutil.SPECIAL_NONE)) + self.assertEqual(cborutil.decodeitem(encoded[0:2]), + (False, None, -1, cborutil.SPECIAL_NONE)) + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'semantic tag \d+ not allowed'): + cborutil.decodeitem(encoded) + +class SpecialTypesTests(TestCase): + def testforbiddentypes(self): + for i in range(256): + if i == cborutil.SUBTYPE_FALSE: + continue + elif i == cborutil.SUBTYPE_TRUE: + continue + elif i == cborutil.SUBTYPE_NULL: + continue + + encoded = cborutil.encodelength(cborutil.MAJOR_TYPE_SPECIAL, i) + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'special type \d+ not allowed'): + cborutil.decodeitem(encoded) + +class SansIODecoderTests(TestCase): + def testemptyinput(self): + decoder = cborutil.sansiodecoder() + self.assertEqual(decoder.decode(b''), (False, 0, 0)) + +class DecodeallTests(TestCase): + def testemptyinput(self): + self.assertEqual(cborutil.decodeall(b''), []) + + def testpartialinput(self): + encoded = b''.join([ + b'\x82', # array of 2 elements + b'\x01', # integer 1 + ]) + + with self.assertRaisesRegex(cborutil.CBORDecodeError, + 'input data not complete'): + cborutil.decodeall(encoded) + if __name__ == '__main__': import silenttestrunner silenttestrunner.main(__name__)