##// END OF EJS Templates
cborutil: implement sans I/O decoder...
Gregory Szorc -
r39448:aeb551a3 default
parent child Browse files
Show More
This diff has been collapsed as it changes many lines, (715 lines changed) Show them Hide them
@@ -8,6 +8,7 b''
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import struct
10 import struct
11 import sys
11
12
12 from ..thirdparty.cbor.cbor2 import (
13 from ..thirdparty.cbor.cbor2 import (
13 decoder as decodermod,
14 decoder as decodermod,
@@ -35,11 +36,16 b' MAJOR_TYPE_SPECIAL = 7'
35
36
36 SUBTYPE_MASK = 0b00011111
37 SUBTYPE_MASK = 0b00011111
37
38
39 SUBTYPE_FALSE = 20
40 SUBTYPE_TRUE = 21
41 SUBTYPE_NULL = 22
38 SUBTYPE_HALF_FLOAT = 25
42 SUBTYPE_HALF_FLOAT = 25
39 SUBTYPE_SINGLE_FLOAT = 26
43 SUBTYPE_SINGLE_FLOAT = 26
40 SUBTYPE_DOUBLE_FLOAT = 27
44 SUBTYPE_DOUBLE_FLOAT = 27
41 SUBTYPE_INDEFINITE = 31
45 SUBTYPE_INDEFINITE = 31
42
46
47 SEMANTIC_TAG_FINITE_SET = 258
48
43 # Indefinite types begin with their major type ORd with information value 31.
49 # Indefinite types begin with their major type ORd with information value 31.
44 BEGIN_INDEFINITE_BYTESTRING = struct.pack(
50 BEGIN_INDEFINITE_BYTESTRING = struct.pack(
45 r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE)
51 r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE)
@@ -146,7 +152,7 b' def _mixedtypesortkey(v):'
146 def streamencodeset(s):
152 def streamencodeset(s):
147 # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
153 # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
148 # semantic tag 258 for finite sets.
154 # semantic tag 258 for finite sets.
149 yield encodelength(MAJOR_TYPE_SEMANTIC, 258)
155 yield encodelength(MAJOR_TYPE_SEMANTIC, SEMANTIC_TAG_FINITE_SET)
150
156
151 for chunk in streamencodearray(sorted(s, key=_mixedtypesortkey)):
157 for chunk in streamencodearray(sorted(s, key=_mixedtypesortkey)):
152 yield chunk
158 yield chunk
@@ -260,3 +266,710 b' def readindefinitebytestringtoiter(fh, e'
260 len(chunk), length))
266 len(chunk), length))
261
267
262 yield chunk
268 yield chunk
269
270 class CBORDecodeError(Exception):
271 """Represents an error decoding CBOR."""
272
273 if sys.version_info.major >= 3:
274 def _elementtointeger(b, i):
275 return b[i]
276 else:
277 def _elementtointeger(b, i):
278 return ord(b[i])
279
280 STRUCT_BIG_UBYTE = struct.Struct(r'>B')
281 STRUCT_BIG_USHORT = struct.Struct('>H')
282 STRUCT_BIG_ULONG = struct.Struct('>L')
283 STRUCT_BIG_ULONGLONG = struct.Struct('>Q')
284
285 SPECIAL_NONE = 0
286 SPECIAL_START_INDEFINITE_BYTESTRING = 1
287 SPECIAL_START_ARRAY = 2
288 SPECIAL_START_MAP = 3
289 SPECIAL_START_SET = 4
290 SPECIAL_INDEFINITE_BREAK = 5
291
292 def decodeitem(b, offset=0):
293 """Decode a new CBOR value from a buffer at offset.
294
295 This function attempts to decode up to one complete CBOR value
296 from ``b`` starting at offset ``offset``.
297
298 The beginning of a collection (such as an array, map, set, or
299 indefinite length bytestring) counts as a single value. For these
300 special cases, a state flag will indicate that a special value was seen.
301
302 When called, the function either returns a decoded value or gives
303 a hint as to how many more bytes are needed to do so. By calling
304 the function repeatedly given a stream of bytes, the caller can
305 build up the original values.
306
307 Returns a tuple with the following elements:
308
309 * Bool indicating whether a complete value was decoded.
310 * A decoded value if first value is True otherwise None
311 * Integer number of bytes. If positive, the number of bytes
312 read. If negative, the number of bytes we need to read to
313 decode this value or the next chunk in this value.
314 * One of the ``SPECIAL_*`` constants indicating special treatment
315 for this value. ``SPECIAL_NONE`` means this is a fully decoded
316 simple value (such as an integer or bool).
317 """
318
319 initial = _elementtointeger(b, offset)
320 offset += 1
321
322 majortype = initial >> 5
323 subtype = initial & SUBTYPE_MASK
324
325 if majortype == MAJOR_TYPE_UINT:
326 complete, value, readcount = decodeuint(subtype, b, offset)
327
328 if complete:
329 return True, value, readcount + 1, SPECIAL_NONE
330 else:
331 return False, None, readcount, SPECIAL_NONE
332
333 elif majortype == MAJOR_TYPE_NEGINT:
334 # Negative integers are the same as UINT except inverted minus 1.
335 complete, value, readcount = decodeuint(subtype, b, offset)
336
337 if complete:
338 return True, -value - 1, readcount + 1, SPECIAL_NONE
339 else:
340 return False, None, readcount, SPECIAL_NONE
341
342 elif majortype == MAJOR_TYPE_BYTESTRING:
343 # Beginning of bytestrings are treated as uints in order to
344 # decode their length, which may be indefinite.
345 complete, size, readcount = decodeuint(subtype, b, offset,
346 allowindefinite=True)
347
348 # We don't know the size of the bytestring. It must be a definitive
349 # length since the indefinite subtype would be encoded in the initial
350 # byte.
351 if not complete:
352 return False, None, readcount, SPECIAL_NONE
353
354 # We know the length of the bytestring.
355 if size is not None:
356 # And the data is available in the buffer.
357 if offset + readcount + size <= len(b):
358 value = b[offset + readcount:offset + readcount + size]
359 return True, value, readcount + size + 1, SPECIAL_NONE
360
361 # And we need more data in order to return the bytestring.
362 else:
363 wanted = len(b) - offset - readcount - size
364 return False, None, wanted, SPECIAL_NONE
365
366 # It is an indefinite length bytestring.
367 else:
368 return True, None, 1, SPECIAL_START_INDEFINITE_BYTESTRING
369
370 elif majortype == MAJOR_TYPE_STRING:
371 raise CBORDecodeError('string major type not supported')
372
373 elif majortype == MAJOR_TYPE_ARRAY:
374 # Beginning of arrays are treated as uints in order to decode their
375 # length. We don't allow indefinite length arrays.
376 complete, size, readcount = decodeuint(subtype, b, offset)
377
378 if complete:
379 return True, size, readcount + 1, SPECIAL_START_ARRAY
380 else:
381 return False, None, readcount, SPECIAL_NONE
382
383 elif majortype == MAJOR_TYPE_MAP:
384 # Beginning of maps are treated as uints in order to decode their
385 # number of elements. We don't allow indefinite length arrays.
386 complete, size, readcount = decodeuint(subtype, b, offset)
387
388 if complete:
389 return True, size, readcount + 1, SPECIAL_START_MAP
390 else:
391 return False, None, readcount, SPECIAL_NONE
392
393 elif majortype == MAJOR_TYPE_SEMANTIC:
394 # Semantic tag value is read the same as a uint.
395 complete, tagvalue, readcount = decodeuint(subtype, b, offset)
396
397 if not complete:
398 return False, None, readcount, SPECIAL_NONE
399
400 # This behavior here is a little wonky. The main type being "decorated"
401 # by this semantic tag follows. A more robust parser would probably emit
402 # a special flag indicating this as a semantic tag and let the caller
403 # deal with the types that follow. But since we don't support many
404 # semantic tags, it is easier to deal with the special cases here and
405 # hide complexity from the caller. If we add support for more semantic
406 # tags, we should probably move semantic tag handling into the caller.
407 if tagvalue == SEMANTIC_TAG_FINITE_SET:
408 if offset + readcount >= len(b):
409 return False, None, -1, SPECIAL_NONE
410
411 complete, size, readcount2, special = decodeitem(b,
412 offset + readcount)
413
414 if not complete:
415 return False, None, readcount2, SPECIAL_NONE
416
417 if special != SPECIAL_START_ARRAY:
418 raise CBORDecodeError('expected array after finite set '
419 'semantic tag')
420
421 return True, size, readcount + readcount2 + 1, SPECIAL_START_SET
422
423 else:
424 raise CBORDecodeError('semantic tag %d not allowed' % tagvalue)
425
426 elif majortype == MAJOR_TYPE_SPECIAL:
427 # Only specific values for the information field are allowed.
428 if subtype == SUBTYPE_FALSE:
429 return True, False, 1, SPECIAL_NONE
430 elif subtype == SUBTYPE_TRUE:
431 return True, True, 1, SPECIAL_NONE
432 elif subtype == SUBTYPE_NULL:
433 return True, None, 1, SPECIAL_NONE
434 elif subtype == SUBTYPE_INDEFINITE:
435 return True, None, 1, SPECIAL_INDEFINITE_BREAK
436 # If value is 24, subtype is in next byte.
437 else:
438 raise CBORDecodeError('special type %d not allowed' % subtype)
439 else:
440 assert False
441
442 def decodeuint(subtype, b, offset=0, allowindefinite=False):
443 """Decode an unsigned integer.
444
445 ``subtype`` is the lower 5 bits from the initial byte CBOR item
446 "header." ``b`` is a buffer containing bytes. ``offset`` points to
447 the index of the first byte after the byte that ``subtype`` was
448 derived from.
449
450 ``allowindefinite`` allows the special indefinite length value
451 indicator.
452
453 Returns a 3-tuple of (successful, value, count).
454
455 The first element is a bool indicating if decoding completed. The 2nd
456 is the decoded integer value or None if not fully decoded or the subtype
457 is 31 and ``allowindefinite`` is True. The 3rd value is the count of bytes.
458 If positive, it is the number of additional bytes decoded. If negative,
459 it is the number of additional bytes needed to decode this value.
460 """
461
462 # Small values are inline.
463 if subtype < 24:
464 return True, subtype, 0
465 # Indefinite length specifier.
466 elif subtype == 31:
467 if allowindefinite:
468 return True, None, 0
469 else:
470 raise CBORDecodeError('indefinite length uint not allowed here')
471 elif subtype >= 28:
472 raise CBORDecodeError('unsupported subtype on integer type: %d' %
473 subtype)
474
475 if subtype == 24:
476 s = STRUCT_BIG_UBYTE
477 elif subtype == 25:
478 s = STRUCT_BIG_USHORT
479 elif subtype == 26:
480 s = STRUCT_BIG_ULONG
481 elif subtype == 27:
482 s = STRUCT_BIG_ULONGLONG
483 else:
484 raise CBORDecodeError('bounds condition checking violation')
485
486 if len(b) - offset >= s.size:
487 return True, s.unpack_from(b, offset)[0], s.size
488 else:
489 return False, None, len(b) - offset - s.size
490
491 class bytestringchunk(bytes):
492 """Represents a chunk/segment in an indefinite length bytestring.
493
494 This behaves like a ``bytes`` but in addition has the ``isfirst``
495 and ``islast`` attributes indicating whether this chunk is the first
496 or last in an indefinite length bytestring.
497 """
498
499 def __new__(cls, v, first=False, last=False):
500 self = bytes.__new__(cls, v)
501 self.isfirst = first
502 self.islast = last
503
504 return self
505
506 class sansiodecoder(object):
507 """A CBOR decoder that doesn't perform its own I/O.
508
509 To use, construct an instance and feed it segments containing
510 CBOR-encoded bytes via ``decode()``. The return value from ``decode()``
511 indicates whether a fully-decoded value is available, how many bytes
512 were consumed, and offers a hint as to how many bytes should be fed
513 in next time to decode the next value.
514
515 The decoder assumes it will decode N discrete CBOR values, not just
516 a single value. i.e. if the bytestream contains uints packed one after
517 the other, the decoder will decode them all, rather than just the initial
518 one.
519
520 When ``decode()`` indicates a value is available, call ``getavailable()``
521 to return all fully decoded values.
522
523 ``decode()`` can partially decode input. It is up to the caller to keep
524 track of what data was consumed and to pass unconsumed data in on the
525 next invocation.
526
527 The decoder decodes atomically at the *item* level. See ``decodeitem()``.
528 If an *item* cannot be fully decoded, the decoder won't record it as
529 partially consumed. Instead, the caller will be instructed to pass in
530 the initial bytes of this item on the next invocation. This does result
531 in some redundant parsing. But the overhead should be minimal.
532
533 This decoder only supports a subset of CBOR as required by Mercurial.
534 It lacks support for:
535
536 * Indefinite length arrays
537 * Indefinite length maps
538 * Use of indefinite length bytestrings as keys or values within
539 arrays, maps, or sets.
540 * Nested arrays, maps, or sets within sets
541 * Any semantic tag that isn't a mathematical finite set
542 * Floating point numbers
543 * Undefined special value
544
545 CBOR types are decoded to Python types as follows:
546
547 uint -> int
548 negint -> int
549 bytestring -> bytes
550 map -> dict
551 array -> list
552 True -> bool
553 False -> bool
554 null -> None
555 indefinite length bytestring chunk -> [bytestringchunk]
556
557 The only non-obvious mapping here is an indefinite length bytestring
558 to the ``bytestringchunk`` type. This is to facilitate streaming
559 indefinite length bytestrings out of the decoder and to differentiate
560 a regular bytestring from an indefinite length bytestring.
561 """
562
563 _STATE_NONE = 0
564 _STATE_WANT_MAP_KEY = 1
565 _STATE_WANT_MAP_VALUE = 2
566 _STATE_WANT_ARRAY_VALUE = 3
567 _STATE_WANT_SET_VALUE = 4
568 _STATE_WANT_BYTESTRING_CHUNK_FIRST = 5
569 _STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT = 6
570
571 def __init__(self):
572 # TODO add support for limiting size of bytestrings
573 # TODO add support for limiting number of keys / values in collections
574 # TODO add support for limiting size of buffered partial values
575
576 self.decodedbytecount = 0
577
578 self._state = self._STATE_NONE
579
580 # Stack of active nested collections. Each entry is a dict describing
581 # the collection.
582 self._collectionstack = []
583
584 # Fully decoded key to use for the current map.
585 self._currentmapkey = None
586
587 # Fully decoded values available for retrieval.
588 self._decodedvalues = []
589
590 @property
591 def inprogress(self):
592 """Whether the decoder has partially decoded a value."""
593 return self._state != self._STATE_NONE
594
595 def decode(self, b, offset=0):
596 """Attempt to decode bytes from an input buffer.
597
598 ``b`` is a collection of bytes and ``offset`` is the byte
599 offset within that buffer from which to begin reading data.
600
601 ``b`` must support ``len()`` and accessing bytes slices via
602 ``__slice__``. Typically ``bytes`` instances are used.
603
604 Returns a tuple with the following fields:
605
606 * Bool indicating whether values are available for retrieval.
607 * Integer indicating the number of bytes that were fully consumed,
608 starting from ``offset``.
609 * Integer indicating the number of bytes that are desired for the
610 next call in order to decode an item.
611 """
612 if not b:
613 return bool(self._decodedvalues), 0, 0
614
615 initialoffset = offset
616
617 # We could easily split the body of this loop into a function. But
618 # Python performance is sensitive to function calls and collections
619 # are composed of many items. So leaving as a while loop could help
620 # with performance. One thing that may not help is the use of
621 # if..elif versus a lookup/dispatch table. There may be value
622 # in switching that.
623 while offset < len(b):
624 # Attempt to decode an item. This could be a whole value or a
625 # special value indicating an event, such as start or end of a
626 # collection or indefinite length type.
627 complete, value, readcount, special = decodeitem(b, offset)
628
629 if readcount > 0:
630 self.decodedbytecount += readcount
631
632 if not complete:
633 assert readcount < 0
634 return (
635 bool(self._decodedvalues),
636 offset - initialoffset,
637 -readcount,
638 )
639
640 offset += readcount
641
642 # No nested state. We either have a full value or beginning of a
643 # complex value to deal with.
644 if self._state == self._STATE_NONE:
645 # A normal value.
646 if special == SPECIAL_NONE:
647 self._decodedvalues.append(value)
648
649 elif special == SPECIAL_START_ARRAY:
650 self._collectionstack.append({
651 'remaining': value,
652 'v': [],
653 })
654 self._state = self._STATE_WANT_ARRAY_VALUE
655
656 elif special == SPECIAL_START_MAP:
657 self._collectionstack.append({
658 'remaining': value,
659 'v': {},
660 })
661 self._state = self._STATE_WANT_MAP_KEY
662
663 elif special == SPECIAL_START_SET:
664 self._collectionstack.append({
665 'remaining': value,
666 'v': set(),
667 })
668 self._state = self._STATE_WANT_SET_VALUE
669
670 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
671 self._state = self._STATE_WANT_BYTESTRING_CHUNK_FIRST
672
673 else:
674 raise CBORDecodeError('unhandled special state: %d' %
675 special)
676
677 # This value becomes an element of the current array.
678 elif self._state == self._STATE_WANT_ARRAY_VALUE:
679 # Simple values get appended.
680 if special == SPECIAL_NONE:
681 c = self._collectionstack[-1]
682 c['v'].append(value)
683 c['remaining'] -= 1
684
685 # self._state doesn't need changed.
686
687 # An array nested within an array.
688 elif special == SPECIAL_START_ARRAY:
689 lastc = self._collectionstack[-1]
690 newvalue = []
691
692 lastc['v'].append(newvalue)
693 lastc['remaining'] -= 1
694
695 self._collectionstack.append({
696 'remaining': value,
697 'v': newvalue,
698 })
699
700 # self._state doesn't need changed.
701
702 # A map nested within an array.
703 elif special == SPECIAL_START_MAP:
704 lastc = self._collectionstack[-1]
705 newvalue = {}
706
707 lastc['v'].append(newvalue)
708 lastc['remaining'] -= 1
709
710 self._collectionstack.append({
711 'remaining': value,
712 'v': newvalue
713 })
714
715 self._state = self._STATE_WANT_MAP_KEY
716
717 elif special == SPECIAL_START_SET:
718 lastc = self._collectionstack[-1]
719 newvalue = set()
720
721 lastc['v'].append(newvalue)
722 lastc['remaining'] -= 1
723
724 self._collectionstack.append({
725 'remaining': value,
726 'v': newvalue,
727 })
728
729 self._state = self._STATE_WANT_SET_VALUE
730
731 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
732 raise CBORDecodeError('indefinite length bytestrings '
733 'not allowed as array values')
734
735 else:
736 raise CBORDecodeError('unhandled special item when '
737 'expecting array value: %d' % special)
738
739 # This value becomes the key of the current map instance.
740 elif self._state == self._STATE_WANT_MAP_KEY:
741 if special == SPECIAL_NONE:
742 self._currentmapkey = value
743 self._state = self._STATE_WANT_MAP_VALUE
744
745 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
746 raise CBORDecodeError('indefinite length bytestrings '
747 'not allowed as map keys')
748
749 elif special in (SPECIAL_START_ARRAY, SPECIAL_START_MAP,
750 SPECIAL_START_SET):
751 raise CBORDecodeError('collections not supported as map '
752 'keys')
753
754 # We do not allow special values to be used as map keys.
755 else:
756 raise CBORDecodeError('unhandled special item when '
757 'expecting map key: %d' % special)
758
759 # This value becomes the value of the current map key.
760 elif self._state == self._STATE_WANT_MAP_VALUE:
761 # Simple values simply get inserted into the map.
762 if special == SPECIAL_NONE:
763 lastc = self._collectionstack[-1]
764 lastc['v'][self._currentmapkey] = value
765 lastc['remaining'] -= 1
766
767 self._state = self._STATE_WANT_MAP_KEY
768
769 # A new array is used as the map value.
770 elif special == SPECIAL_START_ARRAY:
771 lastc = self._collectionstack[-1]
772 newvalue = []
773
774 lastc['v'][self._currentmapkey] = newvalue
775 lastc['remaining'] -= 1
776
777 self._collectionstack.append({
778 'remaining': value,
779 'v': newvalue,
780 })
781
782 self._state = self._STATE_WANT_ARRAY_VALUE
783
784 # A new map is used as the map value.
785 elif special == SPECIAL_START_MAP:
786 lastc = self._collectionstack[-1]
787 newvalue = {}
788
789 lastc['v'][self._currentmapkey] = newvalue
790 lastc['remaining'] -= 1
791
792 self._collectionstack.append({
793 'remaining': value,
794 'v': newvalue,
795 })
796
797 self._state = self._STATE_WANT_MAP_KEY
798
799 # A new set is used as the map value.
800 elif special == SPECIAL_START_SET:
801 lastc = self._collectionstack[-1]
802 newvalue = set()
803
804 lastc['v'][self._currentmapkey] = newvalue
805 lastc['remaining'] -= 1
806
807 self._collectionstack.append({
808 'remaining': value,
809 'v': newvalue,
810 })
811
812 self._state = self._STATE_WANT_SET_VALUE
813
814 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
815 raise CBORDecodeError('indefinite length bytestrings not '
816 'allowed as map values')
817
818 else:
819 raise CBORDecodeError('unhandled special item when '
820 'expecting map value: %d' % special)
821
822 self._currentmapkey = None
823
824 # This value is added to the current set.
825 elif self._state == self._STATE_WANT_SET_VALUE:
826 if special == SPECIAL_NONE:
827 lastc = self._collectionstack[-1]
828 lastc['v'].add(value)
829 lastc['remaining'] -= 1
830
831 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
832 raise CBORDecodeError('indefinite length bytestrings not '
833 'allowed as set values')
834
835 elif special in (SPECIAL_START_ARRAY,
836 SPECIAL_START_MAP,
837 SPECIAL_START_SET):
838 raise CBORDecodeError('collections not allowed as set '
839 'values')
840
841 # We don't allow non-trivial types to exist as set values.
842 else:
843 raise CBORDecodeError('unhandled special item when '
844 'expecting set value: %d' % special)
845
846 # This value represents the first chunk in an indefinite length
847 # bytestring.
848 elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_FIRST:
849 # We received a full chunk.
850 if special == SPECIAL_NONE:
851 self._decodedvalues.append(bytestringchunk(value,
852 first=True))
853
854 self._state = self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT
855
856 # The end of stream marker. This means it is an empty
857 # indefinite length bytestring.
858 elif special == SPECIAL_INDEFINITE_BREAK:
859 # We /could/ convert this to a b''. But we want to preserve
860 # the nature of the underlying data so consumers expecting
861 # an indefinite length bytestring get one.
862 self._decodedvalues.append(bytestringchunk(b'',
863 first=True,
864 last=True))
865
866 # Since indefinite length bytestrings can't be used in
867 # collections, we must be at the root level.
868 assert not self._collectionstack
869 self._state = self._STATE_NONE
870
871 else:
872 raise CBORDecodeError('unexpected special value when '
873 'expecting bytestring chunk: %d' %
874 special)
875
876 # This value represents the non-initial chunk in an indefinite
877 # length bytestring.
878 elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT:
879 # We received a full chunk.
880 if special == SPECIAL_NONE:
881 self._decodedvalues.append(bytestringchunk(value))
882
883 # The end of stream marker.
884 elif special == SPECIAL_INDEFINITE_BREAK:
885 self._decodedvalues.append(bytestringchunk(b'', last=True))
886
887 # Since indefinite length bytestrings can't be used in
888 # collections, we must be at the root level.
889 assert not self._collectionstack
890 self._state = self._STATE_NONE
891
892 else:
893 raise CBORDecodeError('unexpected special value when '
894 'expecting bytestring chunk: %d' %
895 special)
896
897 else:
898 raise CBORDecodeError('unhandled decoder state: %d' %
899 self._state)
900
901 # We could have just added the final value in a collection. End
902 # all complete collections at the top of the stack.
903 while True:
904 # Bail if we're not waiting on a new collection item.
905 if self._state not in (self._STATE_WANT_ARRAY_VALUE,
906 self._STATE_WANT_MAP_KEY,
907 self._STATE_WANT_SET_VALUE):
908 break
909
910 # Or we are expecting more items for this collection.
911 lastc = self._collectionstack[-1]
912
913 if lastc['remaining']:
914 break
915
916 # The collection at the top of the stack is complete.
917
918 # Discard it, as it isn't needed for future items.
919 self._collectionstack.pop()
920
921 # If this is a nested collection, we don't emit it, since it
922 # will be emitted by its parent collection. But we do need to
923 # update state to reflect what the new top-most collection
924 # on the stack is.
925 if self._collectionstack:
926 self._state = {
927 list: self._STATE_WANT_ARRAY_VALUE,
928 dict: self._STATE_WANT_MAP_KEY,
929 set: self._STATE_WANT_SET_VALUE,
930 }[type(self._collectionstack[-1]['v'])]
931
932 # If this is the root collection, emit it.
933 else:
934 self._decodedvalues.append(lastc['v'])
935 self._state = self._STATE_NONE
936
937 return (
938 bool(self._decodedvalues),
939 offset - initialoffset,
940 0,
941 )
942
943 def getavailable(self):
944 """Returns an iterator over fully decoded values.
945
946 Once values are retrieved, they won't be available on the next call.
947 """
948
949 l = list(self._decodedvalues)
950 self._decodedvalues = []
951 return l
952
953 def decodeall(b):
954 """Decode all CBOR items present in an iterable of bytes.
955
956 In addition to regular decode errors, raises CBORDecodeError if the
957 entirety of the passed buffer does not fully decode to complete CBOR
958 values. This includes failure to decode any value, incomplete collection
959 types, incomplete indefinite length items, and extra data at the end of
960 the buffer.
961 """
962 if not b:
963 return []
964
965 decoder = sansiodecoder()
966
967 havevalues, readcount, wantbytes = decoder.decode(b)
968
969 if readcount != len(b):
970 raise CBORDecodeError('input data not fully consumed')
971
972 if decoder.inprogress:
973 raise CBORDecodeError('input data not complete')
974
975 return decoder.getavailable()
This diff has been collapsed as it changes many lines, (793 lines changed) Show them Hide them
@@ -10,10 +10,17 b' from mercurial.utils import ('
10 cborutil,
10 cborutil,
11 )
11 )
12
12
13 class TestCase(unittest.TestCase):
14 if not getattr(unittest.TestCase, 'assertRaisesRegex', False):
15 # Python 3.7 deprecates the regex*p* version, but 2.7 lacks
16 # the regex version.
17 assertRaisesRegex = (# camelcase-required
18 unittest.TestCase.assertRaisesRegexp)
19
13 def loadit(it):
20 def loadit(it):
14 return cbor.loads(b''.join(it))
21 return cbor.loads(b''.join(it))
15
22
16 class BytestringTests(unittest.TestCase):
23 class BytestringTests(TestCase):
17 def testsimple(self):
24 def testsimple(self):
18 self.assertEqual(
25 self.assertEqual(
19 list(cborutil.streamencode(b'foobar')),
26 list(cborutil.streamencode(b'foobar')),
@@ -23,11 +30,20 b' class BytestringTests(unittest.TestCase)'
23 loadit(cborutil.streamencode(b'foobar')),
30 loadit(cborutil.streamencode(b'foobar')),
24 b'foobar')
31 b'foobar')
25
32
33 self.assertEqual(cborutil.decodeall(b'\x46foobar'),
34 [b'foobar'])
35
36 self.assertEqual(cborutil.decodeall(b'\x46foobar\x45fizbi'),
37 [b'foobar', b'fizbi'])
38
26 def testlong(self):
39 def testlong(self):
27 source = b'x' * 1048576
40 source = b'x' * 1048576
28
41
29 self.assertEqual(loadit(cborutil.streamencode(source)), source)
42 self.assertEqual(loadit(cborutil.streamencode(source)), source)
30
43
44 encoded = b''.join(cborutil.streamencode(source))
45 self.assertEqual(cborutil.decodeall(encoded), [source])
46
31 def testfromiter(self):
47 def testfromiter(self):
32 # This is the example from RFC 7049 Section 2.2.2.
48 # This is the example from RFC 7049 Section 2.2.2.
33 source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
49 source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
@@ -47,6 +63,25 b' class BytestringTests(unittest.TestCase)'
47 loadit(cborutil.streamencodebytestringfromiter(source)),
63 loadit(cborutil.streamencodebytestringfromiter(source)),
48 b''.join(source))
64 b''.join(source))
49
65
66 self.assertEqual(cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
67 b'\x43\xee\xff\x99\xff'),
68 [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99', b''])
69
70 for i, chunk in enumerate(
71 cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
72 b'\x43\xee\xff\x99\xff')):
73 self.assertIsInstance(chunk, cborutil.bytestringchunk)
74
75 if i == 0:
76 self.assertTrue(chunk.isfirst)
77 else:
78 self.assertFalse(chunk.isfirst)
79
80 if i == 2:
81 self.assertTrue(chunk.islast)
82 else:
83 self.assertFalse(chunk.islast)
84
50 def testfromiterlarge(self):
85 def testfromiterlarge(self):
51 source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
86 source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
52
87
@@ -71,6 +106,18 b' class BytestringTests(unittest.TestCase)'
71 source, chunksize=42))
106 source, chunksize=42))
72 self.assertEqual(cbor.loads(dest), source)
107 self.assertEqual(cbor.loads(dest), source)
73
108
109 self.assertEqual(b''.join(cborutil.decodeall(dest)), source)
110
111 for chunk in cborutil.decodeall(dest):
112 self.assertIsInstance(chunk, cborutil.bytestringchunk)
113 self.assertIn(len(chunk), (0, 8, 42))
114
115 encoded = b'\x5f\xff'
116 b = cborutil.decodeall(encoded)
117 self.assertEqual(b, [b''])
118 self.assertTrue(b[0].isfirst)
119 self.assertTrue(b[0].islast)
120
74 def testreadtoiter(self):
121 def testreadtoiter(self):
75 source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff')
122 source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff')
76
123
@@ -81,42 +128,405 b' class BytestringTests(unittest.TestCase)'
81 with self.assertRaises(StopIteration):
128 with self.assertRaises(StopIteration):
82 next(it)
129 next(it)
83
130
84 class IntTests(unittest.TestCase):
131 def testdecodevariouslengths(self):
132 for i in (0, 1, 22, 23, 24, 25, 254, 255, 256, 65534, 65535, 65536):
133 source = b'x' * i
134 encoded = b''.join(cborutil.streamencode(source))
135
136 if len(source) < 24:
137 hlen = 1
138 elif len(source) < 256:
139 hlen = 2
140 elif len(source) < 65536:
141 hlen = 3
142 elif len(source) < 1048576:
143 hlen = 5
144
145 self.assertEqual(cborutil.decodeitem(encoded),
146 (True, source, hlen + len(source),
147 cborutil.SPECIAL_NONE))
148
149 def testpartialdecode(self):
150 encoded = b''.join(cborutil.streamencode(b'foobar'))
151
152 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
153 (False, None, -6, cborutil.SPECIAL_NONE))
154 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
155 (False, None, -5, cborutil.SPECIAL_NONE))
156 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
157 (False, None, -4, cborutil.SPECIAL_NONE))
158 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
159 (False, None, -3, cborutil.SPECIAL_NONE))
160 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
161 (False, None, -2, cborutil.SPECIAL_NONE))
162 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
163 (False, None, -1, cborutil.SPECIAL_NONE))
164 self.assertEqual(cborutil.decodeitem(encoded[0:7]),
165 (True, b'foobar', 7, cborutil.SPECIAL_NONE))
166
167 def testpartialdecodevariouslengths(self):
168 lens = [
169 2,
170 3,
171 10,
172 23,
173 24,
174 25,
175 31,
176 100,
177 254,
178 255,
179 256,
180 257,
181 16384,
182 65534,
183 65535,
184 65536,
185 65537,
186 131071,
187 131072,
188 131073,
189 1048575,
190 1048576,
191 1048577,
192 ]
193
194 for size in lens:
195 if size < 24:
196 hlen = 1
197 elif size < 2**8:
198 hlen = 2
199 elif size < 2**16:
200 hlen = 3
201 elif size < 2**32:
202 hlen = 5
203 else:
204 assert False
205
206 source = b'x' * size
207 encoded = b''.join(cborutil.streamencode(source))
208
209 res = cborutil.decodeitem(encoded[0:1])
210
211 if hlen > 1:
212 self.assertEqual(res, (False, None, -(hlen - 1),
213 cborutil.SPECIAL_NONE))
214 else:
215 self.assertEqual(res, (False, None, -(size + hlen - 1),
216 cborutil.SPECIAL_NONE))
217
218 # Decoding partial header reports remaining header size.
219 for i in range(hlen - 1):
220 self.assertEqual(cborutil.decodeitem(encoded[0:i + 1]),
221 (False, None, -(hlen - i - 1),
222 cborutil.SPECIAL_NONE))
223
224 # Decoding complete header reports item size.
225 self.assertEqual(cborutil.decodeitem(encoded[0:hlen]),
226 (False, None, -size, cborutil.SPECIAL_NONE))
227
228 # Decoding single byte after header reports item size - 1
229 self.assertEqual(cborutil.decodeitem(encoded[0:hlen + 1]),
230 (False, None, -(size - 1), cborutil.SPECIAL_NONE))
231
232 # Decoding all but the last byte reports -1 needed.
233 self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size - 1]),
234 (False, None, -1, cborutil.SPECIAL_NONE))
235
236 # Decoding last byte retrieves value.
237 self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size]),
238 (True, source, hlen + size, cborutil.SPECIAL_NONE))
239
240 def testindefinitepartialdecode(self):
241 encoded = b''.join(cborutil.streamencodebytestringfromiter(
242 [b'foobar', b'biz']))
243
244 # First item should be begin of bytestring special.
245 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
246 (True, None, 1,
247 cborutil.SPECIAL_START_INDEFINITE_BYTESTRING))
248
249 # Second item should be the first chunk. But only available when
250 # we give it 7 bytes (1 byte header + 6 byte chunk).
251 self.assertEqual(cborutil.decodeitem(encoded[1:2]),
252 (False, None, -6, cborutil.SPECIAL_NONE))
253 self.assertEqual(cborutil.decodeitem(encoded[1:3]),
254 (False, None, -5, cborutil.SPECIAL_NONE))
255 self.assertEqual(cborutil.decodeitem(encoded[1:4]),
256 (False, None, -4, cborutil.SPECIAL_NONE))
257 self.assertEqual(cborutil.decodeitem(encoded[1:5]),
258 (False, None, -3, cborutil.SPECIAL_NONE))
259 self.assertEqual(cborutil.decodeitem(encoded[1:6]),
260 (False, None, -2, cborutil.SPECIAL_NONE))
261 self.assertEqual(cborutil.decodeitem(encoded[1:7]),
262 (False, None, -1, cborutil.SPECIAL_NONE))
263
264 self.assertEqual(cborutil.decodeitem(encoded[1:8]),
265 (True, b'foobar', 7, cborutil.SPECIAL_NONE))
266
267 # Third item should be second chunk. But only available when
268 # we give it 4 bytes (1 byte header + 3 byte chunk).
269 self.assertEqual(cborutil.decodeitem(encoded[8:9]),
270 (False, None, -3, cborutil.SPECIAL_NONE))
271 self.assertEqual(cborutil.decodeitem(encoded[8:10]),
272 (False, None, -2, cborutil.SPECIAL_NONE))
273 self.assertEqual(cborutil.decodeitem(encoded[8:11]),
274 (False, None, -1, cborutil.SPECIAL_NONE))
275
276 self.assertEqual(cborutil.decodeitem(encoded[8:12]),
277 (True, b'biz', 4, cborutil.SPECIAL_NONE))
278
279 # Fourth item should be end of indefinite stream marker.
280 self.assertEqual(cborutil.decodeitem(encoded[12:13]),
281 (True, None, 1, cborutil.SPECIAL_INDEFINITE_BREAK))
282
283 # Now test the behavior when going through the decoder.
284
285 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:1]),
286 (False, 1, 0))
287 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:2]),
288 (False, 1, 6))
289 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:3]),
290 (False, 1, 5))
291 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:4]),
292 (False, 1, 4))
293 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:5]),
294 (False, 1, 3))
295 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:6]),
296 (False, 1, 2))
297 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:7]),
298 (False, 1, 1))
299 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:8]),
300 (True, 8, 0))
301
302 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:9]),
303 (True, 8, 3))
304 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:10]),
305 (True, 8, 2))
306 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:11]),
307 (True, 8, 1))
308 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:12]),
309 (True, 12, 0))
310
311 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:13]),
312 (True, 13, 0))
313
314 decoder = cborutil.sansiodecoder()
315 decoder.decode(encoded[0:8])
316 values = decoder.getavailable()
317 self.assertEqual(values, [b'foobar'])
318 self.assertTrue(values[0].isfirst)
319 self.assertFalse(values[0].islast)
320
321 self.assertEqual(decoder.decode(encoded[8:12]),
322 (True, 4, 0))
323 values = decoder.getavailable()
324 self.assertEqual(values, [b'biz'])
325 self.assertFalse(values[0].isfirst)
326 self.assertFalse(values[0].islast)
327
328 self.assertEqual(decoder.decode(encoded[12:]),
329 (True, 1, 0))
330 values = decoder.getavailable()
331 self.assertEqual(values, [b''])
332 self.assertFalse(values[0].isfirst)
333 self.assertTrue(values[0].islast)
334
335 class StringTests(TestCase):
336 def testdecodeforbidden(self):
337 encoded = b'\x63foo'
338 with self.assertRaisesRegex(cborutil.CBORDecodeError,
339 'string major type not supported'):
340 cborutil.decodeall(encoded)
341
342 class IntTests(TestCase):
85 def testsmall(self):
343 def testsmall(self):
86 self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
344 self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
345 self.assertEqual(cborutil.decodeall(b'\x00'), [0])
346
87 self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
347 self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
348 self.assertEqual(cborutil.decodeall(b'\x01'), [1])
349
88 self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
350 self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
351 self.assertEqual(cborutil.decodeall(b'\x02'), [2])
352
89 self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
353 self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
354 self.assertEqual(cborutil.decodeall(b'\x03'), [3])
355
90 self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
356 self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
357 self.assertEqual(cborutil.decodeall(b'\x04'), [4])
358
359 # Multiple value decode works.
360 self.assertEqual(cborutil.decodeall(b'\x00\x01\x02\x03\x04'),
361 [0, 1, 2, 3, 4])
91
362
92 def testnegativesmall(self):
363 def testnegativesmall(self):
93 self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
364 self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
365 self.assertEqual(cborutil.decodeall(b'\x20'), [-1])
366
94 self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
367 self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
368 self.assertEqual(cborutil.decodeall(b'\x21'), [-2])
369
95 self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
370 self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
371 self.assertEqual(cborutil.decodeall(b'\x22'), [-3])
372
96 self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
373 self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
374 self.assertEqual(cborutil.decodeall(b'\x23'), [-4])
375
97 self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
376 self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
377 self.assertEqual(cborutil.decodeall(b'\x24'), [-5])
378
379 # Multiple value decode works.
380 self.assertEqual(cborutil.decodeall(b'\x20\x21\x22\x23\x24'),
381 [-1, -2, -3, -4, -5])
98
382
99 def testrange(self):
383 def testrange(self):
100 for i in range(-70000, 70000, 10):
384 for i in range(-70000, 70000, 10):
101 self.assertEqual(
385 encoded = b''.join(cborutil.streamencode(i))
102 b''.join(cborutil.streamencode(i)),
386
103 cbor.dumps(i))
387 self.assertEqual(encoded, cbor.dumps(i))
388 self.assertEqual(cborutil.decodeall(encoded), [i])
389
390 def testdecodepartialubyte(self):
391 encoded = b''.join(cborutil.streamencode(250))
392
393 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
394 (False, None, -1, cborutil.SPECIAL_NONE))
395 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
396 (True, 250, 2, cborutil.SPECIAL_NONE))
397
398 def testdecodepartialbyte(self):
399 encoded = b''.join(cborutil.streamencode(-42))
400 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
401 (False, None, -1, cborutil.SPECIAL_NONE))
402 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
403 (True, -42, 2, cborutil.SPECIAL_NONE))
404
405 def testdecodepartialushort(self):
406 encoded = b''.join(cborutil.streamencode(2**15))
407
408 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
409 (False, None, -2, cborutil.SPECIAL_NONE))
410 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
411 (False, None, -1, cborutil.SPECIAL_NONE))
412 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
413 (True, 2**15, 3, cborutil.SPECIAL_NONE))
414
415 def testdecodepartialshort(self):
416 encoded = b''.join(cborutil.streamencode(-1024))
417
418 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
419 (False, None, -2, cborutil.SPECIAL_NONE))
420 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
421 (False, None, -1, cborutil.SPECIAL_NONE))
422 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
423 (True, -1024, 3, cborutil.SPECIAL_NONE))
424
425 def testdecodepartialulong(self):
426 encoded = b''.join(cborutil.streamencode(2**28))
427
428 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
429 (False, None, -4, cborutil.SPECIAL_NONE))
430 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
431 (False, None, -3, cborutil.SPECIAL_NONE))
432 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
433 (False, None, -2, cborutil.SPECIAL_NONE))
434 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
435 (False, None, -1, cborutil.SPECIAL_NONE))
436 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
437 (True, 2**28, 5, cborutil.SPECIAL_NONE))
438
439 def testdecodepartiallong(self):
440 encoded = b''.join(cborutil.streamencode(-1048580))
104
441
105 class ArrayTests(unittest.TestCase):
442 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
443 (False, None, -4, cborutil.SPECIAL_NONE))
444 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
445 (False, None, -3, cborutil.SPECIAL_NONE))
446 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
447 (False, None, -2, cborutil.SPECIAL_NONE))
448 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
449 (False, None, -1, cborutil.SPECIAL_NONE))
450 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
451 (True, -1048580, 5, cborutil.SPECIAL_NONE))
452
453 def testdecodepartialulonglong(self):
454 encoded = b''.join(cborutil.streamencode(2**32))
455
456 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
457 (False, None, -8, cborutil.SPECIAL_NONE))
458 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
459 (False, None, -7, cborutil.SPECIAL_NONE))
460 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
461 (False, None, -6, cborutil.SPECIAL_NONE))
462 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
463 (False, None, -5, cborutil.SPECIAL_NONE))
464 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
465 (False, None, -4, cborutil.SPECIAL_NONE))
466 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
467 (False, None, -3, cborutil.SPECIAL_NONE))
468 self.assertEqual(cborutil.decodeitem(encoded[0:7]),
469 (False, None, -2, cborutil.SPECIAL_NONE))
470 self.assertEqual(cborutil.decodeitem(encoded[0:8]),
471 (False, None, -1, cborutil.SPECIAL_NONE))
472 self.assertEqual(cborutil.decodeitem(encoded[0:9]),
473 (True, 2**32, 9, cborutil.SPECIAL_NONE))
474
475 with self.assertRaisesRegex(
476 cborutil.CBORDecodeError, 'input data not fully consumed'):
477 cborutil.decodeall(encoded[0:1])
478
479 with self.assertRaisesRegex(
480 cborutil.CBORDecodeError, 'input data not fully consumed'):
481 cborutil.decodeall(encoded[0:2])
482
483 def testdecodepartiallonglong(self):
484 encoded = b''.join(cborutil.streamencode(-7000000000))
485
486 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
487 (False, None, -8, cborutil.SPECIAL_NONE))
488 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
489 (False, None, -7, cborutil.SPECIAL_NONE))
490 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
491 (False, None, -6, cborutil.SPECIAL_NONE))
492 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
493 (False, None, -5, cborutil.SPECIAL_NONE))
494 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
495 (False, None, -4, cborutil.SPECIAL_NONE))
496 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
497 (False, None, -3, cborutil.SPECIAL_NONE))
498 self.assertEqual(cborutil.decodeitem(encoded[0:7]),
499 (False, None, -2, cborutil.SPECIAL_NONE))
500 self.assertEqual(cborutil.decodeitem(encoded[0:8]),
501 (False, None, -1, cborutil.SPECIAL_NONE))
502 self.assertEqual(cborutil.decodeitem(encoded[0:9]),
503 (True, -7000000000, 9, cborutil.SPECIAL_NONE))
504
505 class ArrayTests(TestCase):
106 def testempty(self):
506 def testempty(self):
107 self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
507 self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
108 self.assertEqual(loadit(cborutil.streamencode([])), [])
508 self.assertEqual(loadit(cborutil.streamencode([])), [])
109
509
510 self.assertEqual(cborutil.decodeall(b'\x80'), [[]])
511
110 def testbasic(self):
512 def testbasic(self):
111 source = [b'foo', b'bar', 1, -10]
513 source = [b'foo', b'bar', 1, -10]
112
514
113 self.assertEqual(list(cborutil.streamencode(source)), [
515 chunks = [
114 b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'])
516 b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29']
517
518 self.assertEqual(list(cborutil.streamencode(source)), chunks)
519
520 self.assertEqual(cborutil.decodeall(b''.join(chunks)), [source])
115
521
116 def testemptyfromiter(self):
522 def testemptyfromiter(self):
117 self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
523 self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
118 b'\x9f\xff')
524 b'\x9f\xff')
119
525
526 with self.assertRaisesRegex(cborutil.CBORDecodeError,
527 'indefinite length uint not allowed'):
528 cborutil.decodeall(b'\x9f\xff')
529
120 def testfromiter1(self):
530 def testfromiter1(self):
121 source = [b'foo']
531 source = [b'foo']
122
532
@@ -129,26 +539,193 b' class ArrayTests(unittest.TestCase):'
129 dest = b''.join(cborutil.streamencodearrayfromiter(source))
539 dest = b''.join(cborutil.streamencodearrayfromiter(source))
130 self.assertEqual(cbor.loads(dest), source)
540 self.assertEqual(cbor.loads(dest), source)
131
541
542 with self.assertRaisesRegex(cborutil.CBORDecodeError,
543 'indefinite length uint not allowed'):
544 cborutil.decodeall(dest)
545
132 def testtuple(self):
546 def testtuple(self):
133 source = (b'foo', None, 42)
547 source = (b'foo', None, 42)
548 encoded = b''.join(cborutil.streamencode(source))
134
549
135 self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
550 self.assertEqual(cbor.loads(encoded), list(source))
136 list(source))
551
552 self.assertEqual(cborutil.decodeall(encoded), [list(source)])
553
554 def testpartialdecode(self):
555 source = list(range(4))
556 encoded = b''.join(cborutil.streamencode(source))
557 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
558 (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
559 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
560 (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
561
562 source = list(range(23))
563 encoded = b''.join(cborutil.streamencode(source))
564 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
565 (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
566 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
567 (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
568
569 source = list(range(24))
570 encoded = b''.join(cborutil.streamencode(source))
571 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
572 (False, None, -1, cborutil.SPECIAL_NONE))
573 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
574 (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
575 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
576 (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
137
577
138 class SetTests(unittest.TestCase):
578 source = list(range(256))
579 encoded = b''.join(cborutil.streamencode(source))
580 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
581 (False, None, -2, cborutil.SPECIAL_NONE))
582 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
583 (False, None, -1, cborutil.SPECIAL_NONE))
584 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
585 (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
586 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
587 (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
588
589 def testnested(self):
590 source = [[], [], [[], [], []]]
591 encoded = b''.join(cborutil.streamencode(source))
592 self.assertEqual(cborutil.decodeall(encoded), [source])
593
594 source = [True, None, [True, 0, 2], [None], [], [[[]], -87]]
595 encoded = b''.join(cborutil.streamencode(source))
596 self.assertEqual(cborutil.decodeall(encoded), [source])
597
598 # A set within an array.
599 source = [None, {b'foo', b'bar', None, False}, set()]
600 encoded = b''.join(cborutil.streamencode(source))
601 self.assertEqual(cborutil.decodeall(encoded), [source])
602
603 # A map within an array.
604 source = [None, {}, {b'foo': b'bar', True: False}, [{}]]
605 encoded = b''.join(cborutil.streamencode(source))
606 self.assertEqual(cborutil.decodeall(encoded), [source])
607
608 def testindefinitebytestringvalues(self):
609 # Single value array whose value is an empty indefinite bytestring.
610 encoded = b'\x81\x5f\x40\xff'
611
612 with self.assertRaisesRegex(cborutil.CBORDecodeError,
613 'indefinite length bytestrings not '
614 'allowed as array values'):
615 cborutil.decodeall(encoded)
616
617 class SetTests(TestCase):
139 def testempty(self):
618 def testempty(self):
140 self.assertEqual(list(cborutil.streamencode(set())), [
619 self.assertEqual(list(cborutil.streamencode(set())), [
141 b'\xd9\x01\x02',
620 b'\xd9\x01\x02',
142 b'\x80',
621 b'\x80',
143 ])
622 ])
144
623
624 self.assertEqual(cborutil.decodeall(b'\xd9\x01\x02\x80'), [set()])
625
145 def testset(self):
626 def testset(self):
146 source = {b'foo', None, 42}
627 source = {b'foo', None, 42}
628 encoded = b''.join(cborutil.streamencode(source))
147
629
148 self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
630 self.assertEqual(cbor.loads(encoded), source)
149 source)
631
632 self.assertEqual(cborutil.decodeall(encoded), [source])
633
634 def testinvalidtag(self):
635 # Must use array to encode sets.
636 encoded = b'\xd9\x01\x02\xa0'
637
638 with self.assertRaisesRegex(cborutil.CBORDecodeError,
639 'expected array after finite set '
640 'semantic tag'):
641 cborutil.decodeall(encoded)
642
643 def testpartialdecode(self):
644 # Semantic tag item will be 3 bytes. Set header will be variable
645 # depending on length.
646 encoded = b''.join(cborutil.streamencode({i for i in range(23)}))
647 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
648 (False, None, -2, cborutil.SPECIAL_NONE))
649 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
650 (False, None, -1, cborutil.SPECIAL_NONE))
651 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
652 (False, None, -1, cborutil.SPECIAL_NONE))
653 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
654 (True, 23, 4, cborutil.SPECIAL_START_SET))
655 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
656 (True, 23, 4, cborutil.SPECIAL_START_SET))
657
658 encoded = b''.join(cborutil.streamencode({i for i in range(24)}))
659 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
660 (False, None, -2, cborutil.SPECIAL_NONE))
661 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
662 (False, None, -1, cborutil.SPECIAL_NONE))
663 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
664 (False, None, -1, cborutil.SPECIAL_NONE))
665 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
666 (False, None, -1, cborutil.SPECIAL_NONE))
667 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
668 (True, 24, 5, cborutil.SPECIAL_START_SET))
669 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
670 (True, 24, 5, cborutil.SPECIAL_START_SET))
150
671
151 class BoolTests(unittest.TestCase):
672 encoded = b''.join(cborutil.streamencode({i for i in range(256)}))
673 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
674 (False, None, -2, cborutil.SPECIAL_NONE))
675 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
676 (False, None, -1, cborutil.SPECIAL_NONE))
677 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
678 (False, None, -1, cborutil.SPECIAL_NONE))
679 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
680 (False, None, -2, cborutil.SPECIAL_NONE))
681 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
682 (False, None, -1, cborutil.SPECIAL_NONE))
683 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
684 (True, 256, 6, cborutil.SPECIAL_START_SET))
685
686 def testinvalidvalue(self):
687 encoded = b''.join([
688 b'\xd9\x01\x02', # semantic tag
689 b'\x81', # array of size 1
690 b'\x5f\x43foo\xff', # indefinite length bytestring "foo"
691 ])
692
693 with self.assertRaisesRegex(cborutil.CBORDecodeError,
694 'indefinite length bytestrings not '
695 'allowed as set values'):
696 cborutil.decodeall(encoded)
697
698 encoded = b''.join([
699 b'\xd9\x01\x02',
700 b'\x81',
701 b'\x80', # empty array
702 ])
703
704 with self.assertRaisesRegex(cborutil.CBORDecodeError,
705 'collections not allowed as set values'):
706 cborutil.decodeall(encoded)
707
708 encoded = b''.join([
709 b'\xd9\x01\x02',
710 b'\x81',
711 b'\xa0', # empty map
712 ])
713
714 with self.assertRaisesRegex(cborutil.CBORDecodeError,
715 'collections not allowed as set values'):
716 cborutil.decodeall(encoded)
717
718 encoded = b''.join([
719 b'\xd9\x01\x02',
720 b'\x81',
721 b'\xd9\x01\x02\x81\x01', # set with integer 1
722 ])
723
724 with self.assertRaisesRegex(cborutil.CBORDecodeError,
725 'collections not allowed as set values'):
726 cborutil.decodeall(encoded)
727
728 class BoolTests(TestCase):
152 def testbasic(self):
729 def testbasic(self):
153 self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5'])
730 self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5'])
154 self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
731 self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
@@ -156,23 +733,38 b' class BoolTests(unittest.TestCase):'
156 self.assertIs(loadit(cborutil.streamencode(True)), True)
733 self.assertIs(loadit(cborutil.streamencode(True)), True)
157 self.assertIs(loadit(cborutil.streamencode(False)), False)
734 self.assertIs(loadit(cborutil.streamencode(False)), False)
158
735
159 class NoneTests(unittest.TestCase):
736 self.assertEqual(cborutil.decodeall(b'\xf4'), [False])
737 self.assertEqual(cborutil.decodeall(b'\xf5'), [True])
738
739 self.assertEqual(cborutil.decodeall(b'\xf4\xf5\xf5\xf4'),
740 [False, True, True, False])
741
742 class NoneTests(TestCase):
160 def testbasic(self):
743 def testbasic(self):
161 self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
744 self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
162
745
163 self.assertIs(loadit(cborutil.streamencode(None)), None)
746 self.assertIs(loadit(cborutil.streamencode(None)), None)
164
747
165 class MapTests(unittest.TestCase):
748 self.assertEqual(cborutil.decodeall(b'\xf6'), [None])
749 self.assertEqual(cborutil.decodeall(b'\xf6\xf6'), [None, None])
750
751 class MapTests(TestCase):
166 def testempty(self):
752 def testempty(self):
167 self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
753 self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
168 self.assertEqual(loadit(cborutil.streamencode({})), {})
754 self.assertEqual(loadit(cborutil.streamencode({})), {})
169
755
756 self.assertEqual(cborutil.decodeall(b'\xa0'), [{}])
757
170 def testemptyindefinite(self):
758 def testemptyindefinite(self):
171 self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
759 self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
172 b'\xbf', b'\xff'])
760 b'\xbf', b'\xff'])
173
761
174 self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
762 self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
175
763
764 with self.assertRaisesRegex(cborutil.CBORDecodeError,
765 'indefinite length uint not allowed'):
766 cborutil.decodeall(b'\xbf\xff')
767
176 def testone(self):
768 def testone(self):
177 source = {b'foo': b'bar'}
769 source = {b'foo': b'bar'}
178 self.assertEqual(list(cborutil.streamencode(source)), [
770 self.assertEqual(list(cborutil.streamencode(source)), [
@@ -180,6 +772,8 b' class MapTests(unittest.TestCase):'
180
772
181 self.assertEqual(loadit(cborutil.streamencode(source)), source)
773 self.assertEqual(loadit(cborutil.streamencode(source)), source)
182
774
775 self.assertEqual(cborutil.decodeall(b'\xa1\x43foo\x43bar'), [source])
776
183 def testmultiple(self):
777 def testmultiple(self):
184 source = {
778 source = {
185 b'foo': b'bar',
779 b'foo': b'bar',
@@ -192,6 +786,9 b' class MapTests(unittest.TestCase):'
192 loadit(cborutil.streamencodemapfromiter(source.items())),
786 loadit(cborutil.streamencodemapfromiter(source.items())),
193 source)
787 source)
194
788
789 encoded = b''.join(cborutil.streamencode(source))
790 self.assertEqual(cborutil.decodeall(encoded), [source])
791
195 def testcomplex(self):
792 def testcomplex(self):
196 source = {
793 source = {
197 b'key': 1,
794 b'key': 1,
@@ -205,6 +802,170 b' class MapTests(unittest.TestCase):'
205 loadit(cborutil.streamencodemapfromiter(source.items())),
802 loadit(cborutil.streamencodemapfromiter(source.items())),
206 source)
803 source)
207
804
805 encoded = b''.join(cborutil.streamencode(source))
806 self.assertEqual(cborutil.decodeall(encoded), [source])
807
808 def testnested(self):
809 source = {b'key1': None, b'key2': {b'sub1': b'sub2'}, b'sub2': {}}
810 encoded = b''.join(cborutil.streamencode(source))
811
812 self.assertEqual(cborutil.decodeall(encoded), [source])
813
814 source = {
815 b'key1': [],
816 b'key2': [None, False],
817 b'key3': {b'foo', b'bar'},
818 b'key4': {},
819 }
820 encoded = b''.join(cborutil.streamencode(source))
821 self.assertEqual(cborutil.decodeall(encoded), [source])
822
823 def testillegalkey(self):
824 encoded = b''.join([
825 # map header + len 1
826 b'\xa1',
827 # indefinite length bytestring "foo" in key position
828 b'\x5f\x03foo\xff'
829 ])
830
831 with self.assertRaisesRegex(cborutil.CBORDecodeError,
832 'indefinite length bytestrings not '
833 'allowed as map keys'):
834 cborutil.decodeall(encoded)
835
836 encoded = b''.join([
837 b'\xa1',
838 b'\x80', # empty array
839 b'\x43foo',
840 ])
841
842 with self.assertRaisesRegex(cborutil.CBORDecodeError,
843 'collections not supported as map keys'):
844 cborutil.decodeall(encoded)
845
846 def testillegalvalue(self):
847 encoded = b''.join([
848 b'\xa1', # map headers
849 b'\x43foo', # key
850 b'\x5f\x03bar\xff', # indefinite length value
851 ])
852
853 with self.assertRaisesRegex(cborutil.CBORDecodeError,
854 'indefinite length bytestrings not '
855 'allowed as map values'):
856 cborutil.decodeall(encoded)
857
858 def testpartialdecode(self):
859 source = {b'key1': b'value1'}
860 encoded = b''.join(cborutil.streamencode(source))
861
862 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
863 (True, 1, 1, cborutil.SPECIAL_START_MAP))
864 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
865 (True, 1, 1, cborutil.SPECIAL_START_MAP))
866
867 source = {b'key%d' % i: None for i in range(23)}
868 encoded = b''.join(cborutil.streamencode(source))
869 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
870 (True, 23, 1, cborutil.SPECIAL_START_MAP))
871
872 source = {b'key%d' % i: None for i in range(24)}
873 encoded = b''.join(cborutil.streamencode(source))
874 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
875 (False, None, -1, cborutil.SPECIAL_NONE))
876 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
877 (True, 24, 2, cborutil.SPECIAL_START_MAP))
878 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
879 (True, 24, 2, cborutil.SPECIAL_START_MAP))
880
881 source = {b'key%d' % i: None for i in range(256)}
882 encoded = b''.join(cborutil.streamencode(source))
883 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
884 (False, None, -2, cborutil.SPECIAL_NONE))
885 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
886 (False, None, -1, cborutil.SPECIAL_NONE))
887 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
888 (True, 256, 3, cborutil.SPECIAL_START_MAP))
889 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
890 (True, 256, 3, cborutil.SPECIAL_START_MAP))
891
892 source = {b'key%d' % i: None for i in range(65536)}
893 encoded = b''.join(cborutil.streamencode(source))
894 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
895 (False, None, -4, cborutil.SPECIAL_NONE))
896 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
897 (False, None, -3, cborutil.SPECIAL_NONE))
898 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
899 (False, None, -2, cborutil.SPECIAL_NONE))
900 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
901 (False, None, -1, cborutil.SPECIAL_NONE))
902 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
903 (True, 65536, 5, cborutil.SPECIAL_START_MAP))
904 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
905 (True, 65536, 5, cborutil.SPECIAL_START_MAP))
906
907 class SemanticTagTests(TestCase):
908 def testdecodeforbidden(self):
909 for i in range(500):
910 if i == cborutil.SEMANTIC_TAG_FINITE_SET:
911 continue
912
913 tag = cborutil.encodelength(cborutil.MAJOR_TYPE_SEMANTIC,
914 i)
915
916 encoded = tag + cborutil.encodelength(cborutil.MAJOR_TYPE_UINT, 42)
917
918 # Partial decode is incomplete.
919 if i < 24:
920 pass
921 elif i < 256:
922 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
923 (False, None, -1, cborutil.SPECIAL_NONE))
924 elif i < 65536:
925 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
926 (False, None, -2, cborutil.SPECIAL_NONE))
927 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
928 (False, None, -1, cborutil.SPECIAL_NONE))
929
930 with self.assertRaisesRegex(cborutil.CBORDecodeError,
931 'semantic tag \d+ not allowed'):
932 cborutil.decodeitem(encoded)
933
934 class SpecialTypesTests(TestCase):
935 def testforbiddentypes(self):
936 for i in range(256):
937 if i == cborutil.SUBTYPE_FALSE:
938 continue
939 elif i == cborutil.SUBTYPE_TRUE:
940 continue
941 elif i == cborutil.SUBTYPE_NULL:
942 continue
943
944 encoded = cborutil.encodelength(cborutil.MAJOR_TYPE_SPECIAL, i)
945
946 with self.assertRaisesRegex(cborutil.CBORDecodeError,
947 'special type \d+ not allowed'):
948 cborutil.decodeitem(encoded)
949
950 class SansIODecoderTests(TestCase):
951 def testemptyinput(self):
952 decoder = cborutil.sansiodecoder()
953 self.assertEqual(decoder.decode(b''), (False, 0, 0))
954
955 class DecodeallTests(TestCase):
956 def testemptyinput(self):
957 self.assertEqual(cborutil.decodeall(b''), [])
958
959 def testpartialinput(self):
960 encoded = b''.join([
961 b'\x82', # array of 2 elements
962 b'\x01', # integer 1
963 ])
964
965 with self.assertRaisesRegex(cborutil.CBORDecodeError,
966 'input data not complete'):
967 cborutil.decodeall(encoded)
968
208 if __name__ == '__main__':
969 if __name__ == '__main__':
209 import silenttestrunner
970 import silenttestrunner
210 silenttestrunner.main(__name__)
971 silenttestrunner.main(__name__)
General Comments 0
You need to be logged in to leave comments. Login now