##// END OF EJS Templates
cborutil: cast bytearray to bytes...
Gregory Szorc -
r40160:b638219a default
parent child Browse files
Show More
@@ -1,990 +1,995 b''
1 1 # cborutil.py - CBOR extensions
2 2 #
3 3 # Copyright 2018 Gregory Szorc <gregory.szorc@gmail.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import struct
11 11 import sys
12 12
13 13 from .. import pycompat
14 14
15 15 # Very short very of RFC 7049...
16 16 #
17 17 # Each item begins with a byte. The 3 high bits of that byte denote the
18 18 # "major type." The lower 5 bits denote the "subtype." Each major type
19 19 # has its own encoding mechanism.
20 20 #
21 21 # Most types have lengths. However, bytestring, string, array, and map
22 22 # can be indefinite length. These are denotes by a subtype with value 31.
23 23 # Sub-components of those types then come afterwards and are terminated
24 24 # by a "break" byte.
25 25
26 26 MAJOR_TYPE_UINT = 0
27 27 MAJOR_TYPE_NEGINT = 1
28 28 MAJOR_TYPE_BYTESTRING = 2
29 29 MAJOR_TYPE_STRING = 3
30 30 MAJOR_TYPE_ARRAY = 4
31 31 MAJOR_TYPE_MAP = 5
32 32 MAJOR_TYPE_SEMANTIC = 6
33 33 MAJOR_TYPE_SPECIAL = 7
34 34
35 35 SUBTYPE_MASK = 0b00011111
36 36
37 37 SUBTYPE_FALSE = 20
38 38 SUBTYPE_TRUE = 21
39 39 SUBTYPE_NULL = 22
40 40 SUBTYPE_HALF_FLOAT = 25
41 41 SUBTYPE_SINGLE_FLOAT = 26
42 42 SUBTYPE_DOUBLE_FLOAT = 27
43 43 SUBTYPE_INDEFINITE = 31
44 44
45 45 SEMANTIC_TAG_FINITE_SET = 258
46 46
47 47 # Indefinite types begin with their major type ORd with information value 31.
48 48 BEGIN_INDEFINITE_BYTESTRING = struct.pack(
49 49 r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE)
50 50 BEGIN_INDEFINITE_ARRAY = struct.pack(
51 51 r'>B', MAJOR_TYPE_ARRAY << 5 | SUBTYPE_INDEFINITE)
52 52 BEGIN_INDEFINITE_MAP = struct.pack(
53 53 r'>B', MAJOR_TYPE_MAP << 5 | SUBTYPE_INDEFINITE)
54 54
55 55 ENCODED_LENGTH_1 = struct.Struct(r'>B')
56 56 ENCODED_LENGTH_2 = struct.Struct(r'>BB')
57 57 ENCODED_LENGTH_3 = struct.Struct(r'>BH')
58 58 ENCODED_LENGTH_4 = struct.Struct(r'>BL')
59 59 ENCODED_LENGTH_5 = struct.Struct(r'>BQ')
60 60
61 61 # The break ends an indefinite length item.
62 62 BREAK = b'\xff'
63 63 BREAK_INT = 255
64 64
65 65 def encodelength(majortype, length):
66 66 """Obtain a value encoding the major type and its length."""
67 67 if length < 24:
68 68 return ENCODED_LENGTH_1.pack(majortype << 5 | length)
69 69 elif length < 256:
70 70 return ENCODED_LENGTH_2.pack(majortype << 5 | 24, length)
71 71 elif length < 65536:
72 72 return ENCODED_LENGTH_3.pack(majortype << 5 | 25, length)
73 73 elif length < 4294967296:
74 74 return ENCODED_LENGTH_4.pack(majortype << 5 | 26, length)
75 75 else:
76 76 return ENCODED_LENGTH_5.pack(majortype << 5 | 27, length)
77 77
78 78 def streamencodebytestring(v):
79 79 yield encodelength(MAJOR_TYPE_BYTESTRING, len(v))
80 80 yield v
81 81
82 82 def streamencodebytestringfromiter(it):
83 83 """Convert an iterator of chunks to an indefinite bytestring.
84 84
85 85 Given an input that is iterable and each element in the iterator is
86 86 representable as bytes, emit an indefinite length bytestring.
87 87 """
88 88 yield BEGIN_INDEFINITE_BYTESTRING
89 89
90 90 for chunk in it:
91 91 yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
92 92 yield chunk
93 93
94 94 yield BREAK
95 95
96 96 def streamencodeindefinitebytestring(source, chunksize=65536):
97 97 """Given a large source buffer, emit as an indefinite length bytestring.
98 98
99 99 This is a generator of chunks constituting the encoded CBOR data.
100 100 """
101 101 yield BEGIN_INDEFINITE_BYTESTRING
102 102
103 103 i = 0
104 104 l = len(source)
105 105
106 106 while True:
107 107 chunk = source[i:i + chunksize]
108 108 i += len(chunk)
109 109
110 110 yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
111 111 yield chunk
112 112
113 113 if i >= l:
114 114 break
115 115
116 116 yield BREAK
117 117
118 118 def streamencodeint(v):
119 119 if v >= 18446744073709551616 or v < -18446744073709551616:
120 120 raise ValueError('big integers not supported')
121 121
122 122 if v >= 0:
123 123 yield encodelength(MAJOR_TYPE_UINT, v)
124 124 else:
125 125 yield encodelength(MAJOR_TYPE_NEGINT, abs(v) - 1)
126 126
127 127 def streamencodearray(l):
128 128 """Encode a known size iterable to an array."""
129 129
130 130 yield encodelength(MAJOR_TYPE_ARRAY, len(l))
131 131
132 132 for i in l:
133 133 for chunk in streamencode(i):
134 134 yield chunk
135 135
136 136 def streamencodearrayfromiter(it):
137 137 """Encode an iterator of items to an indefinite length array."""
138 138
139 139 yield BEGIN_INDEFINITE_ARRAY
140 140
141 141 for i in it:
142 142 for chunk in streamencode(i):
143 143 yield chunk
144 144
145 145 yield BREAK
146 146
147 147 def _mixedtypesortkey(v):
148 148 return type(v).__name__, v
149 149
150 150 def streamencodeset(s):
151 151 # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
152 152 # semantic tag 258 for finite sets.
153 153 yield encodelength(MAJOR_TYPE_SEMANTIC, SEMANTIC_TAG_FINITE_SET)
154 154
155 155 for chunk in streamencodearray(sorted(s, key=_mixedtypesortkey)):
156 156 yield chunk
157 157
158 158 def streamencodemap(d):
159 159 """Encode dictionary to a generator.
160 160
161 161 Does not supporting indefinite length dictionaries.
162 162 """
163 163 yield encodelength(MAJOR_TYPE_MAP, len(d))
164 164
165 165 for key, value in sorted(d.iteritems(),
166 166 key=lambda x: _mixedtypesortkey(x[0])):
167 167 for chunk in streamencode(key):
168 168 yield chunk
169 169 for chunk in streamencode(value):
170 170 yield chunk
171 171
172 172 def streamencodemapfromiter(it):
173 173 """Given an iterable of (key, value), encode to an indefinite length map."""
174 174 yield BEGIN_INDEFINITE_MAP
175 175
176 176 for key, value in it:
177 177 for chunk in streamencode(key):
178 178 yield chunk
179 179 for chunk in streamencode(value):
180 180 yield chunk
181 181
182 182 yield BREAK
183 183
184 184 def streamencodebool(b):
185 185 # major type 7, simple value 20 and 21.
186 186 yield b'\xf5' if b else b'\xf4'
187 187
188 188 def streamencodenone(v):
189 189 # major type 7, simple value 22.
190 190 yield b'\xf6'
191 191
192 192 STREAM_ENCODERS = {
193 193 bytes: streamencodebytestring,
194 194 int: streamencodeint,
195 195 pycompat.long: streamencodeint,
196 196 list: streamencodearray,
197 197 tuple: streamencodearray,
198 198 dict: streamencodemap,
199 199 set: streamencodeset,
200 200 bool: streamencodebool,
201 201 type(None): streamencodenone,
202 202 }
203 203
204 204 def streamencode(v):
205 205 """Encode a value in a streaming manner.
206 206
207 207 Given an input object, encode it to CBOR recursively.
208 208
209 209 Returns a generator of CBOR encoded bytes. There is no guarantee
210 210 that each emitted chunk fully decodes to a value or sub-value.
211 211
212 212 Encoding is deterministic - unordered collections are sorted.
213 213 """
214 214 fn = STREAM_ENCODERS.get(v.__class__)
215 215
216 216 if not fn:
217 217 raise ValueError('do not know how to encode %s' % type(v))
218 218
219 219 return fn(v)
220 220
221 221 class CBORDecodeError(Exception):
222 222 """Represents an error decoding CBOR."""
223 223
224 224 if sys.version_info.major >= 3:
225 225 def _elementtointeger(b, i):
226 226 return b[i]
227 227 else:
228 228 def _elementtointeger(b, i):
229 229 return ord(b[i])
230 230
231 231 STRUCT_BIG_UBYTE = struct.Struct(r'>B')
232 232 STRUCT_BIG_USHORT = struct.Struct('>H')
233 233 STRUCT_BIG_ULONG = struct.Struct('>L')
234 234 STRUCT_BIG_ULONGLONG = struct.Struct('>Q')
235 235
236 236 SPECIAL_NONE = 0
237 237 SPECIAL_START_INDEFINITE_BYTESTRING = 1
238 238 SPECIAL_START_ARRAY = 2
239 239 SPECIAL_START_MAP = 3
240 240 SPECIAL_START_SET = 4
241 241 SPECIAL_INDEFINITE_BREAK = 5
242 242
243 243 def decodeitem(b, offset=0):
244 244 """Decode a new CBOR value from a buffer at offset.
245 245
246 246 This function attempts to decode up to one complete CBOR value
247 247 from ``b`` starting at offset ``offset``.
248 248
249 249 The beginning of a collection (such as an array, map, set, or
250 250 indefinite length bytestring) counts as a single value. For these
251 251 special cases, a state flag will indicate that a special value was seen.
252 252
253 253 When called, the function either returns a decoded value or gives
254 254 a hint as to how many more bytes are needed to do so. By calling
255 255 the function repeatedly given a stream of bytes, the caller can
256 256 build up the original values.
257 257
258 258 Returns a tuple with the following elements:
259 259
260 260 * Bool indicating whether a complete value was decoded.
261 261 * A decoded value if first value is True otherwise None
262 262 * Integer number of bytes. If positive, the number of bytes
263 263 read. If negative, the number of bytes we need to read to
264 264 decode this value or the next chunk in this value.
265 265 * One of the ``SPECIAL_*`` constants indicating special treatment
266 266 for this value. ``SPECIAL_NONE`` means this is a fully decoded
267 267 simple value (such as an integer or bool).
268 268 """
269 269
270 270 initial = _elementtointeger(b, offset)
271 271 offset += 1
272 272
273 273 majortype = initial >> 5
274 274 subtype = initial & SUBTYPE_MASK
275 275
276 276 if majortype == MAJOR_TYPE_UINT:
277 277 complete, value, readcount = decodeuint(subtype, b, offset)
278 278
279 279 if complete:
280 280 return True, value, readcount + 1, SPECIAL_NONE
281 281 else:
282 282 return False, None, readcount, SPECIAL_NONE
283 283
284 284 elif majortype == MAJOR_TYPE_NEGINT:
285 285 # Negative integers are the same as UINT except inverted minus 1.
286 286 complete, value, readcount = decodeuint(subtype, b, offset)
287 287
288 288 if complete:
289 289 return True, -value - 1, readcount + 1, SPECIAL_NONE
290 290 else:
291 291 return False, None, readcount, SPECIAL_NONE
292 292
293 293 elif majortype == MAJOR_TYPE_BYTESTRING:
294 294 # Beginning of bytestrings are treated as uints in order to
295 295 # decode their length, which may be indefinite.
296 296 complete, size, readcount = decodeuint(subtype, b, offset,
297 297 allowindefinite=True)
298 298
299 299 # We don't know the size of the bytestring. It must be a definitive
300 300 # length since the indefinite subtype would be encoded in the initial
301 301 # byte.
302 302 if not complete:
303 303 return False, None, readcount, SPECIAL_NONE
304 304
305 305 # We know the length of the bytestring.
306 306 if size is not None:
307 307 # And the data is available in the buffer.
308 308 if offset + readcount + size <= len(b):
309 309 value = b[offset + readcount:offset + readcount + size]
310 310 return True, value, readcount + size + 1, SPECIAL_NONE
311 311
312 312 # And we need more data in order to return the bytestring.
313 313 else:
314 314 wanted = len(b) - offset - readcount - size
315 315 return False, None, wanted, SPECIAL_NONE
316 316
317 317 # It is an indefinite length bytestring.
318 318 else:
319 319 return True, None, 1, SPECIAL_START_INDEFINITE_BYTESTRING
320 320
321 321 elif majortype == MAJOR_TYPE_STRING:
322 322 raise CBORDecodeError('string major type not supported')
323 323
324 324 elif majortype == MAJOR_TYPE_ARRAY:
325 325 # Beginning of arrays are treated as uints in order to decode their
326 326 # length. We don't allow indefinite length arrays.
327 327 complete, size, readcount = decodeuint(subtype, b, offset)
328 328
329 329 if complete:
330 330 return True, size, readcount + 1, SPECIAL_START_ARRAY
331 331 else:
332 332 return False, None, readcount, SPECIAL_NONE
333 333
334 334 elif majortype == MAJOR_TYPE_MAP:
335 335 # Beginning of maps are treated as uints in order to decode their
336 336 # number of elements. We don't allow indefinite length arrays.
337 337 complete, size, readcount = decodeuint(subtype, b, offset)
338 338
339 339 if complete:
340 340 return True, size, readcount + 1, SPECIAL_START_MAP
341 341 else:
342 342 return False, None, readcount, SPECIAL_NONE
343 343
344 344 elif majortype == MAJOR_TYPE_SEMANTIC:
345 345 # Semantic tag value is read the same as a uint.
346 346 complete, tagvalue, readcount = decodeuint(subtype, b, offset)
347 347
348 348 if not complete:
349 349 return False, None, readcount, SPECIAL_NONE
350 350
351 351 # This behavior here is a little wonky. The main type being "decorated"
352 352 # by this semantic tag follows. A more robust parser would probably emit
353 353 # a special flag indicating this as a semantic tag and let the caller
354 354 # deal with the types that follow. But since we don't support many
355 355 # semantic tags, it is easier to deal with the special cases here and
356 356 # hide complexity from the caller. If we add support for more semantic
357 357 # tags, we should probably move semantic tag handling into the caller.
358 358 if tagvalue == SEMANTIC_TAG_FINITE_SET:
359 359 if offset + readcount >= len(b):
360 360 return False, None, -1, SPECIAL_NONE
361 361
362 362 complete, size, readcount2, special = decodeitem(b,
363 363 offset + readcount)
364 364
365 365 if not complete:
366 366 return False, None, readcount2, SPECIAL_NONE
367 367
368 368 if special != SPECIAL_START_ARRAY:
369 369 raise CBORDecodeError('expected array after finite set '
370 370 'semantic tag')
371 371
372 372 return True, size, readcount + readcount2 + 1, SPECIAL_START_SET
373 373
374 374 else:
375 375 raise CBORDecodeError('semantic tag %d not allowed' % tagvalue)
376 376
377 377 elif majortype == MAJOR_TYPE_SPECIAL:
378 378 # Only specific values for the information field are allowed.
379 379 if subtype == SUBTYPE_FALSE:
380 380 return True, False, 1, SPECIAL_NONE
381 381 elif subtype == SUBTYPE_TRUE:
382 382 return True, True, 1, SPECIAL_NONE
383 383 elif subtype == SUBTYPE_NULL:
384 384 return True, None, 1, SPECIAL_NONE
385 385 elif subtype == SUBTYPE_INDEFINITE:
386 386 return True, None, 1, SPECIAL_INDEFINITE_BREAK
387 387 # If value is 24, subtype is in next byte.
388 388 else:
389 389 raise CBORDecodeError('special type %d not allowed' % subtype)
390 390 else:
391 391 assert False
392 392
393 393 def decodeuint(subtype, b, offset=0, allowindefinite=False):
394 394 """Decode an unsigned integer.
395 395
396 396 ``subtype`` is the lower 5 bits from the initial byte CBOR item
397 397 "header." ``b`` is a buffer containing bytes. ``offset`` points to
398 398 the index of the first byte after the byte that ``subtype`` was
399 399 derived from.
400 400
401 401 ``allowindefinite`` allows the special indefinite length value
402 402 indicator.
403 403
404 404 Returns a 3-tuple of (successful, value, count).
405 405
406 406 The first element is a bool indicating if decoding completed. The 2nd
407 407 is the decoded integer value or None if not fully decoded or the subtype
408 408 is 31 and ``allowindefinite`` is True. The 3rd value is the count of bytes.
409 409 If positive, it is the number of additional bytes decoded. If negative,
410 410 it is the number of additional bytes needed to decode this value.
411 411 """
412 412
413 413 # Small values are inline.
414 414 if subtype < 24:
415 415 return True, subtype, 0
416 416 # Indefinite length specifier.
417 417 elif subtype == 31:
418 418 if allowindefinite:
419 419 return True, None, 0
420 420 else:
421 421 raise CBORDecodeError('indefinite length uint not allowed here')
422 422 elif subtype >= 28:
423 423 raise CBORDecodeError('unsupported subtype on integer type: %d' %
424 424 subtype)
425 425
426 426 if subtype == 24:
427 427 s = STRUCT_BIG_UBYTE
428 428 elif subtype == 25:
429 429 s = STRUCT_BIG_USHORT
430 430 elif subtype == 26:
431 431 s = STRUCT_BIG_ULONG
432 432 elif subtype == 27:
433 433 s = STRUCT_BIG_ULONGLONG
434 434 else:
435 435 raise CBORDecodeError('bounds condition checking violation')
436 436
437 437 if len(b) - offset >= s.size:
438 438 return True, s.unpack_from(b, offset)[0], s.size
439 439 else:
440 440 return False, None, len(b) - offset - s.size
441 441
442 442 class bytestringchunk(bytes):
443 443 """Represents a chunk/segment in an indefinite length bytestring.
444 444
445 445 This behaves like a ``bytes`` but in addition has the ``isfirst``
446 446 and ``islast`` attributes indicating whether this chunk is the first
447 447 or last in an indefinite length bytestring.
448 448 """
449 449
450 450 def __new__(cls, v, first=False, last=False):
451 451 self = bytes.__new__(cls, v)
452 452 self.isfirst = first
453 453 self.islast = last
454 454
455 455 return self
456 456
457 457 class sansiodecoder(object):
458 458 """A CBOR decoder that doesn't perform its own I/O.
459 459
460 460 To use, construct an instance and feed it segments containing
461 461 CBOR-encoded bytes via ``decode()``. The return value from ``decode()``
462 462 indicates whether a fully-decoded value is available, how many bytes
463 463 were consumed, and offers a hint as to how many bytes should be fed
464 464 in next time to decode the next value.
465 465
466 466 The decoder assumes it will decode N discrete CBOR values, not just
467 467 a single value. i.e. if the bytestream contains uints packed one after
468 468 the other, the decoder will decode them all, rather than just the initial
469 469 one.
470 470
471 471 When ``decode()`` indicates a value is available, call ``getavailable()``
472 472 to return all fully decoded values.
473 473
474 474 ``decode()`` can partially decode input. It is up to the caller to keep
475 475 track of what data was consumed and to pass unconsumed data in on the
476 476 next invocation.
477 477
478 478 The decoder decodes atomically at the *item* level. See ``decodeitem()``.
479 479 If an *item* cannot be fully decoded, the decoder won't record it as
480 480 partially consumed. Instead, the caller will be instructed to pass in
481 481 the initial bytes of this item on the next invocation. This does result
482 482 in some redundant parsing. But the overhead should be minimal.
483 483
484 484 This decoder only supports a subset of CBOR as required by Mercurial.
485 485 It lacks support for:
486 486
487 487 * Indefinite length arrays
488 488 * Indefinite length maps
489 489 * Use of indefinite length bytestrings as keys or values within
490 490 arrays, maps, or sets.
491 491 * Nested arrays, maps, or sets within sets
492 492 * Any semantic tag that isn't a mathematical finite set
493 493 * Floating point numbers
494 494 * Undefined special value
495 495
496 496 CBOR types are decoded to Python types as follows:
497 497
498 498 uint -> int
499 499 negint -> int
500 500 bytestring -> bytes
501 501 map -> dict
502 502 array -> list
503 503 True -> bool
504 504 False -> bool
505 505 null -> None
506 506 indefinite length bytestring chunk -> [bytestringchunk]
507 507
508 508 The only non-obvious mapping here is an indefinite length bytestring
509 509 to the ``bytestringchunk`` type. This is to facilitate streaming
510 510 indefinite length bytestrings out of the decoder and to differentiate
511 511 a regular bytestring from an indefinite length bytestring.
512 512 """
513 513
514 514 _STATE_NONE = 0
515 515 _STATE_WANT_MAP_KEY = 1
516 516 _STATE_WANT_MAP_VALUE = 2
517 517 _STATE_WANT_ARRAY_VALUE = 3
518 518 _STATE_WANT_SET_VALUE = 4
519 519 _STATE_WANT_BYTESTRING_CHUNK_FIRST = 5
520 520 _STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT = 6
521 521
522 522 def __init__(self):
523 523 # TODO add support for limiting size of bytestrings
524 524 # TODO add support for limiting number of keys / values in collections
525 525 # TODO add support for limiting size of buffered partial values
526 526
527 527 self.decodedbytecount = 0
528 528
529 529 self._state = self._STATE_NONE
530 530
531 531 # Stack of active nested collections. Each entry is a dict describing
532 532 # the collection.
533 533 self._collectionstack = []
534 534
535 535 # Fully decoded key to use for the current map.
536 536 self._currentmapkey = None
537 537
538 538 # Fully decoded values available for retrieval.
539 539 self._decodedvalues = []
540 540
541 541 @property
542 542 def inprogress(self):
543 543 """Whether the decoder has partially decoded a value."""
544 544 return self._state != self._STATE_NONE
545 545
546 546 def decode(self, b, offset=0):
547 547 """Attempt to decode bytes from an input buffer.
548 548
549 549 ``b`` is a collection of bytes and ``offset`` is the byte
550 550 offset within that buffer from which to begin reading data.
551 551
552 552 ``b`` must support ``len()`` and accessing bytes slices via
553 553 ``__slice__``. Typically ``bytes`` instances are used.
554 554
555 555 Returns a tuple with the following fields:
556 556
557 557 * Bool indicating whether values are available for retrieval.
558 558 * Integer indicating the number of bytes that were fully consumed,
559 559 starting from ``offset``.
560 560 * Integer indicating the number of bytes that are desired for the
561 561 next call in order to decode an item.
562 562 """
563 563 if not b:
564 564 return bool(self._decodedvalues), 0, 0
565 565
566 566 initialoffset = offset
567 567
568 568 # We could easily split the body of this loop into a function. But
569 569 # Python performance is sensitive to function calls and collections
570 570 # are composed of many items. So leaving as a while loop could help
571 571 # with performance. One thing that may not help is the use of
572 572 # if..elif versus a lookup/dispatch table. There may be value
573 573 # in switching that.
574 574 while offset < len(b):
575 575 # Attempt to decode an item. This could be a whole value or a
576 576 # special value indicating an event, such as start or end of a
577 577 # collection or indefinite length type.
578 578 complete, value, readcount, special = decodeitem(b, offset)
579 579
580 580 if readcount > 0:
581 581 self.decodedbytecount += readcount
582 582
583 583 if not complete:
584 584 assert readcount < 0
585 585 return (
586 586 bool(self._decodedvalues),
587 587 offset - initialoffset,
588 588 -readcount,
589 589 )
590 590
591 591 offset += readcount
592 592
593 593 # No nested state. We either have a full value or beginning of a
594 594 # complex value to deal with.
595 595 if self._state == self._STATE_NONE:
596 596 # A normal value.
597 597 if special == SPECIAL_NONE:
598 598 self._decodedvalues.append(value)
599 599
600 600 elif special == SPECIAL_START_ARRAY:
601 601 self._collectionstack.append({
602 602 'remaining': value,
603 603 'v': [],
604 604 })
605 605 self._state = self._STATE_WANT_ARRAY_VALUE
606 606
607 607 elif special == SPECIAL_START_MAP:
608 608 self._collectionstack.append({
609 609 'remaining': value,
610 610 'v': {},
611 611 })
612 612 self._state = self._STATE_WANT_MAP_KEY
613 613
614 614 elif special == SPECIAL_START_SET:
615 615 self._collectionstack.append({
616 616 'remaining': value,
617 617 'v': set(),
618 618 })
619 619 self._state = self._STATE_WANT_SET_VALUE
620 620
621 621 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
622 622 self._state = self._STATE_WANT_BYTESTRING_CHUNK_FIRST
623 623
624 624 else:
625 625 raise CBORDecodeError('unhandled special state: %d' %
626 626 special)
627 627
628 628 # This value becomes an element of the current array.
629 629 elif self._state == self._STATE_WANT_ARRAY_VALUE:
630 630 # Simple values get appended.
631 631 if special == SPECIAL_NONE:
632 632 c = self._collectionstack[-1]
633 633 c['v'].append(value)
634 634 c['remaining'] -= 1
635 635
636 636 # self._state doesn't need changed.
637 637
638 638 # An array nested within an array.
639 639 elif special == SPECIAL_START_ARRAY:
640 640 lastc = self._collectionstack[-1]
641 641 newvalue = []
642 642
643 643 lastc['v'].append(newvalue)
644 644 lastc['remaining'] -= 1
645 645
646 646 self._collectionstack.append({
647 647 'remaining': value,
648 648 'v': newvalue,
649 649 })
650 650
651 651 # self._state doesn't need changed.
652 652
653 653 # A map nested within an array.
654 654 elif special == SPECIAL_START_MAP:
655 655 lastc = self._collectionstack[-1]
656 656 newvalue = {}
657 657
658 658 lastc['v'].append(newvalue)
659 659 lastc['remaining'] -= 1
660 660
661 661 self._collectionstack.append({
662 662 'remaining': value,
663 663 'v': newvalue
664 664 })
665 665
666 666 self._state = self._STATE_WANT_MAP_KEY
667 667
668 668 elif special == SPECIAL_START_SET:
669 669 lastc = self._collectionstack[-1]
670 670 newvalue = set()
671 671
672 672 lastc['v'].append(newvalue)
673 673 lastc['remaining'] -= 1
674 674
675 675 self._collectionstack.append({
676 676 'remaining': value,
677 677 'v': newvalue,
678 678 })
679 679
680 680 self._state = self._STATE_WANT_SET_VALUE
681 681
682 682 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
683 683 raise CBORDecodeError('indefinite length bytestrings '
684 684 'not allowed as array values')
685 685
686 686 else:
687 687 raise CBORDecodeError('unhandled special item when '
688 688 'expecting array value: %d' % special)
689 689
690 690 # This value becomes the key of the current map instance.
691 691 elif self._state == self._STATE_WANT_MAP_KEY:
692 692 if special == SPECIAL_NONE:
693 693 self._currentmapkey = value
694 694 self._state = self._STATE_WANT_MAP_VALUE
695 695
696 696 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
697 697 raise CBORDecodeError('indefinite length bytestrings '
698 698 'not allowed as map keys')
699 699
700 700 elif special in (SPECIAL_START_ARRAY, SPECIAL_START_MAP,
701 701 SPECIAL_START_SET):
702 702 raise CBORDecodeError('collections not supported as map '
703 703 'keys')
704 704
705 705 # We do not allow special values to be used as map keys.
706 706 else:
707 707 raise CBORDecodeError('unhandled special item when '
708 708 'expecting map key: %d' % special)
709 709
710 710 # This value becomes the value of the current map key.
711 711 elif self._state == self._STATE_WANT_MAP_VALUE:
712 712 # Simple values simply get inserted into the map.
713 713 if special == SPECIAL_NONE:
714 714 lastc = self._collectionstack[-1]
715 715 lastc['v'][self._currentmapkey] = value
716 716 lastc['remaining'] -= 1
717 717
718 718 self._state = self._STATE_WANT_MAP_KEY
719 719
720 720 # A new array is used as the map value.
721 721 elif special == SPECIAL_START_ARRAY:
722 722 lastc = self._collectionstack[-1]
723 723 newvalue = []
724 724
725 725 lastc['v'][self._currentmapkey] = newvalue
726 726 lastc['remaining'] -= 1
727 727
728 728 self._collectionstack.append({
729 729 'remaining': value,
730 730 'v': newvalue,
731 731 })
732 732
733 733 self._state = self._STATE_WANT_ARRAY_VALUE
734 734
735 735 # A new map is used as the map value.
736 736 elif special == SPECIAL_START_MAP:
737 737 lastc = self._collectionstack[-1]
738 738 newvalue = {}
739 739
740 740 lastc['v'][self._currentmapkey] = newvalue
741 741 lastc['remaining'] -= 1
742 742
743 743 self._collectionstack.append({
744 744 'remaining': value,
745 745 'v': newvalue,
746 746 })
747 747
748 748 self._state = self._STATE_WANT_MAP_KEY
749 749
750 750 # A new set is used as the map value.
751 751 elif special == SPECIAL_START_SET:
752 752 lastc = self._collectionstack[-1]
753 753 newvalue = set()
754 754
755 755 lastc['v'][self._currentmapkey] = newvalue
756 756 lastc['remaining'] -= 1
757 757
758 758 self._collectionstack.append({
759 759 'remaining': value,
760 760 'v': newvalue,
761 761 })
762 762
763 763 self._state = self._STATE_WANT_SET_VALUE
764 764
765 765 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
766 766 raise CBORDecodeError('indefinite length bytestrings not '
767 767 'allowed as map values')
768 768
769 769 else:
770 770 raise CBORDecodeError('unhandled special item when '
771 771 'expecting map value: %d' % special)
772 772
773 773 self._currentmapkey = None
774 774
775 775 # This value is added to the current set.
776 776 elif self._state == self._STATE_WANT_SET_VALUE:
777 777 if special == SPECIAL_NONE:
778 778 lastc = self._collectionstack[-1]
779 779 lastc['v'].add(value)
780 780 lastc['remaining'] -= 1
781 781
782 782 elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
783 783 raise CBORDecodeError('indefinite length bytestrings not '
784 784 'allowed as set values')
785 785
786 786 elif special in (SPECIAL_START_ARRAY,
787 787 SPECIAL_START_MAP,
788 788 SPECIAL_START_SET):
789 789 raise CBORDecodeError('collections not allowed as set '
790 790 'values')
791 791
792 792 # We don't allow non-trivial types to exist as set values.
793 793 else:
794 794 raise CBORDecodeError('unhandled special item when '
795 795 'expecting set value: %d' % special)
796 796
797 797 # This value represents the first chunk in an indefinite length
798 798 # bytestring.
799 799 elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_FIRST:
800 800 # We received a full chunk.
801 801 if special == SPECIAL_NONE:
802 802 self._decodedvalues.append(bytestringchunk(value,
803 803 first=True))
804 804
805 805 self._state = self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT
806 806
807 807 # The end of stream marker. This means it is an empty
808 808 # indefinite length bytestring.
809 809 elif special == SPECIAL_INDEFINITE_BREAK:
810 810 # We /could/ convert this to a b''. But we want to preserve
811 811 # the nature of the underlying data so consumers expecting
812 812 # an indefinite length bytestring get one.
813 813 self._decodedvalues.append(bytestringchunk(b'',
814 814 first=True,
815 815 last=True))
816 816
817 817 # Since indefinite length bytestrings can't be used in
818 818 # collections, we must be at the root level.
819 819 assert not self._collectionstack
820 820 self._state = self._STATE_NONE
821 821
822 822 else:
823 823 raise CBORDecodeError('unexpected special value when '
824 824 'expecting bytestring chunk: %d' %
825 825 special)
826 826
827 827 # This value represents the non-initial chunk in an indefinite
828 828 # length bytestring.
829 829 elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT:
830 830 # We received a full chunk.
831 831 if special == SPECIAL_NONE:
832 832 self._decodedvalues.append(bytestringchunk(value))
833 833
834 834 # The end of stream marker.
835 835 elif special == SPECIAL_INDEFINITE_BREAK:
836 836 self._decodedvalues.append(bytestringchunk(b'', last=True))
837 837
838 838 # Since indefinite length bytestrings can't be used in
839 839 # collections, we must be at the root level.
840 840 assert not self._collectionstack
841 841 self._state = self._STATE_NONE
842 842
843 843 else:
844 844 raise CBORDecodeError('unexpected special value when '
845 845 'expecting bytestring chunk: %d' %
846 846 special)
847 847
848 848 else:
849 849 raise CBORDecodeError('unhandled decoder state: %d' %
850 850 self._state)
851 851
852 852 # We could have just added the final value in a collection. End
853 853 # all complete collections at the top of the stack.
854 854 while True:
855 855 # Bail if we're not waiting on a new collection item.
856 856 if self._state not in (self._STATE_WANT_ARRAY_VALUE,
857 857 self._STATE_WANT_MAP_KEY,
858 858 self._STATE_WANT_SET_VALUE):
859 859 break
860 860
861 861 # Or we are expecting more items for this collection.
862 862 lastc = self._collectionstack[-1]
863 863
864 864 if lastc['remaining']:
865 865 break
866 866
867 867 # The collection at the top of the stack is complete.
868 868
869 869 # Discard it, as it isn't needed for future items.
870 870 self._collectionstack.pop()
871 871
872 872 # If this is a nested collection, we don't emit it, since it
873 873 # will be emitted by its parent collection. But we do need to
874 874 # update state to reflect what the new top-most collection
875 875 # on the stack is.
876 876 if self._collectionstack:
877 877 self._state = {
878 878 list: self._STATE_WANT_ARRAY_VALUE,
879 879 dict: self._STATE_WANT_MAP_KEY,
880 880 set: self._STATE_WANT_SET_VALUE,
881 881 }[type(self._collectionstack[-1]['v'])]
882 882
883 883 # If this is the root collection, emit it.
884 884 else:
885 885 self._decodedvalues.append(lastc['v'])
886 886 self._state = self._STATE_NONE
887 887
888 888 return (
889 889 bool(self._decodedvalues),
890 890 offset - initialoffset,
891 891 0,
892 892 )
893 893
894 894 def getavailable(self):
895 895 """Returns an iterator over fully decoded values.
896 896
897 897 Once values are retrieved, they won't be available on the next call.
898 898 """
899 899
900 900 l = list(self._decodedvalues)
901 901 self._decodedvalues = []
902 902 return l
903 903
904 904 class bufferingdecoder(object):
905 905 """A CBOR decoder that buffers undecoded input.
906 906
907 907 This is a glorified wrapper around ``sansiodecoder`` that adds a buffering
908 908 layer. All input that isn't consumed by ``sansiodecoder`` will be buffered
909 909 and concatenated with any new input that arrives later.
910 910
911 911 TODO consider adding limits as to the maximum amount of data that can
912 912 be buffered.
913 913 """
914 914 def __init__(self):
915 915 self._decoder = sansiodecoder()
916 916 self._chunks = []
917 917 self._wanted = 0
918 918
919 919 def decode(self, b):
920 920 """Attempt to decode bytes to CBOR values.
921 921
922 922 Returns a tuple with the following fields:
923 923
924 924 * Bool indicating whether new values are available for retrieval.
925 925 * Integer number of bytes decoded from the new input.
926 926 * Integer number of bytes wanted to decode the next value.
927 927 """
928 # We /might/ be able to support passing a bytearray all the
929 # way through. For now, let's cheat.
930 if isinstance(b, bytearray):
931 b = bytes(b)
932
928 933 # Our strategy for buffering is to aggregate the incoming chunks in a
929 934 # list until we've received enough data to decode the next item.
930 935 # This is slightly more complicated than using an ``io.BytesIO``
931 936 # or continuously concatenating incoming data. However, because it
932 937 # isn't constantly reallocating backing memory for a growing buffer,
933 938 # it prevents excessive memory thrashing and is significantly faster,
934 939 # especially in cases where the percentage of input chunks that don't
935 940 # decode into a full item is high.
936 941
937 942 if self._chunks:
938 943 # A previous call said we needed N bytes to decode the next item.
939 944 # But this call doesn't provide enough data. We buffer the incoming
940 945 # chunk without attempting to decode.
941 946 if len(b) < self._wanted:
942 947 self._chunks.append(b)
943 948 self._wanted -= len(b)
944 949 return False, 0, self._wanted
945 950
946 951 # Else we may have enough data to decode the next item. Aggregate
947 952 # old data with new and reset the buffer.
948 953 newlen = len(b)
949 954 self._chunks.append(b)
950 955 b = b''.join(self._chunks)
951 956 self._chunks = []
952 957 oldlen = len(b) - newlen
953 958
954 959 else:
955 960 oldlen = 0
956 961
957 962 available, readcount, wanted = self._decoder.decode(b)
958 963 self._wanted = wanted
959 964
960 965 if readcount < len(b):
961 966 self._chunks.append(b[readcount:])
962 967
963 968 return available, readcount - oldlen, wanted
964 969
965 970 def getavailable(self):
966 971 return self._decoder.getavailable()
967 972
968 973 def decodeall(b):
969 974 """Decode all CBOR items present in an iterable of bytes.
970 975
971 976 In addition to regular decode errors, raises CBORDecodeError if the
972 977 entirety of the passed buffer does not fully decode to complete CBOR
973 978 values. This includes failure to decode any value, incomplete collection
974 979 types, incomplete indefinite length items, and extra data at the end of
975 980 the buffer.
976 981 """
977 982 if not b:
978 983 return []
979 984
980 985 decoder = sansiodecoder()
981 986
982 987 havevalues, readcount, wantbytes = decoder.decode(b)
983 988
984 989 if readcount != len(b):
985 990 raise CBORDecodeError('input data not fully consumed')
986 991
987 992 if decoder.inprogress:
988 993 raise CBORDecodeError('input data not complete')
989 994
990 995 return decoder.getavailable()
@@ -1,984 +1,992 b''
1 1 from __future__ import absolute_import
2 2
3 3 import unittest
4 4
5 5 from mercurial.thirdparty import (
6 6 cbor,
7 7 )
8 8 from mercurial.utils import (
9 9 cborutil,
10 10 )
11 11
12 12 class TestCase(unittest.TestCase):
13 13 if not getattr(unittest.TestCase, 'assertRaisesRegex', False):
14 14 # Python 3.7 deprecates the regex*p* version, but 2.7 lacks
15 15 # the regex version.
16 16 assertRaisesRegex = (# camelcase-required
17 17 unittest.TestCase.assertRaisesRegexp)
18 18
19 19 def loadit(it):
20 20 return cbor.loads(b''.join(it))
21 21
22 22 class BytestringTests(TestCase):
23 23 def testsimple(self):
24 24 self.assertEqual(
25 25 list(cborutil.streamencode(b'foobar')),
26 26 [b'\x46', b'foobar'])
27 27
28 28 self.assertEqual(
29 29 loadit(cborutil.streamencode(b'foobar')),
30 30 b'foobar')
31 31
32 32 self.assertEqual(cborutil.decodeall(b'\x46foobar'),
33 33 [b'foobar'])
34 34
35 35 self.assertEqual(cborutil.decodeall(b'\x46foobar\x45fizbi'),
36 36 [b'foobar', b'fizbi'])
37 37
38 38 def testlong(self):
39 39 source = b'x' * 1048576
40 40
41 41 self.assertEqual(loadit(cborutil.streamencode(source)), source)
42 42
43 43 encoded = b''.join(cborutil.streamencode(source))
44 44 self.assertEqual(cborutil.decodeall(encoded), [source])
45 45
46 46 def testfromiter(self):
47 47 # This is the example from RFC 7049 Section 2.2.2.
48 48 source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
49 49
50 50 self.assertEqual(
51 51 list(cborutil.streamencodebytestringfromiter(source)),
52 52 [
53 53 b'\x5f',
54 54 b'\x44',
55 55 b'\xaa\xbb\xcc\xdd',
56 56 b'\x43',
57 57 b'\xee\xff\x99',
58 58 b'\xff',
59 59 ])
60 60
61 61 self.assertEqual(
62 62 loadit(cborutil.streamencodebytestringfromiter(source)),
63 63 b''.join(source))
64 64
65 65 self.assertEqual(cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
66 66 b'\x43\xee\xff\x99\xff'),
67 67 [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99', b''])
68 68
69 69 for i, chunk in enumerate(
70 70 cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
71 71 b'\x43\xee\xff\x99\xff')):
72 72 self.assertIsInstance(chunk, cborutil.bytestringchunk)
73 73
74 74 if i == 0:
75 75 self.assertTrue(chunk.isfirst)
76 76 else:
77 77 self.assertFalse(chunk.isfirst)
78 78
79 79 if i == 2:
80 80 self.assertTrue(chunk.islast)
81 81 else:
82 82 self.assertFalse(chunk.islast)
83 83
84 84 def testfromiterlarge(self):
85 85 source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
86 86
87 87 self.assertEqual(
88 88 loadit(cborutil.streamencodebytestringfromiter(source)),
89 89 b''.join(source))
90 90
91 91 def testindefinite(self):
92 92 source = b'\x00\x01\x02\x03' + b'\xff' * 16384
93 93
94 94 it = cborutil.streamencodeindefinitebytestring(source, chunksize=2)
95 95
96 96 self.assertEqual(next(it), b'\x5f')
97 97 self.assertEqual(next(it), b'\x42')
98 98 self.assertEqual(next(it), b'\x00\x01')
99 99 self.assertEqual(next(it), b'\x42')
100 100 self.assertEqual(next(it), b'\x02\x03')
101 101 self.assertEqual(next(it), b'\x42')
102 102 self.assertEqual(next(it), b'\xff\xff')
103 103
104 104 dest = b''.join(cborutil.streamencodeindefinitebytestring(
105 105 source, chunksize=42))
106 106 self.assertEqual(cbor.loads(dest), source)
107 107
108 108 self.assertEqual(b''.join(cborutil.decodeall(dest)), source)
109 109
110 110 for chunk in cborutil.decodeall(dest):
111 111 self.assertIsInstance(chunk, cborutil.bytestringchunk)
112 112 self.assertIn(len(chunk), (0, 8, 42))
113 113
114 114 encoded = b'\x5f\xff'
115 115 b = cborutil.decodeall(encoded)
116 116 self.assertEqual(b, [b''])
117 117 self.assertTrue(b[0].isfirst)
118 118 self.assertTrue(b[0].islast)
119 119
120 120 def testdecodevariouslengths(self):
121 121 for i in (0, 1, 22, 23, 24, 25, 254, 255, 256, 65534, 65535, 65536):
122 122 source = b'x' * i
123 123 encoded = b''.join(cborutil.streamencode(source))
124 124
125 125 if len(source) < 24:
126 126 hlen = 1
127 127 elif len(source) < 256:
128 128 hlen = 2
129 129 elif len(source) < 65536:
130 130 hlen = 3
131 131 elif len(source) < 1048576:
132 132 hlen = 5
133 133
134 134 self.assertEqual(cborutil.decodeitem(encoded),
135 135 (True, source, hlen + len(source),
136 136 cborutil.SPECIAL_NONE))
137 137
138 138 def testpartialdecode(self):
139 139 encoded = b''.join(cborutil.streamencode(b'foobar'))
140 140
141 141 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
142 142 (False, None, -6, cborutil.SPECIAL_NONE))
143 143 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
144 144 (False, None, -5, cborutil.SPECIAL_NONE))
145 145 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
146 146 (False, None, -4, cborutil.SPECIAL_NONE))
147 147 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
148 148 (False, None, -3, cborutil.SPECIAL_NONE))
149 149 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
150 150 (False, None, -2, cborutil.SPECIAL_NONE))
151 151 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
152 152 (False, None, -1, cborutil.SPECIAL_NONE))
153 153 self.assertEqual(cborutil.decodeitem(encoded[0:7]),
154 154 (True, b'foobar', 7, cborutil.SPECIAL_NONE))
155 155
156 156 def testpartialdecodevariouslengths(self):
157 157 lens = [
158 158 2,
159 159 3,
160 160 10,
161 161 23,
162 162 24,
163 163 25,
164 164 31,
165 165 100,
166 166 254,
167 167 255,
168 168 256,
169 169 257,
170 170 16384,
171 171 65534,
172 172 65535,
173 173 65536,
174 174 65537,
175 175 131071,
176 176 131072,
177 177 131073,
178 178 1048575,
179 179 1048576,
180 180 1048577,
181 181 ]
182 182
183 183 for size in lens:
184 184 if size < 24:
185 185 hlen = 1
186 186 elif size < 2**8:
187 187 hlen = 2
188 188 elif size < 2**16:
189 189 hlen = 3
190 190 elif size < 2**32:
191 191 hlen = 5
192 192 else:
193 193 assert False
194 194
195 195 source = b'x' * size
196 196 encoded = b''.join(cborutil.streamencode(source))
197 197
198 198 res = cborutil.decodeitem(encoded[0:1])
199 199
200 200 if hlen > 1:
201 201 self.assertEqual(res, (False, None, -(hlen - 1),
202 202 cborutil.SPECIAL_NONE))
203 203 else:
204 204 self.assertEqual(res, (False, None, -(size + hlen - 1),
205 205 cborutil.SPECIAL_NONE))
206 206
207 207 # Decoding partial header reports remaining header size.
208 208 for i in range(hlen - 1):
209 209 self.assertEqual(cborutil.decodeitem(encoded[0:i + 1]),
210 210 (False, None, -(hlen - i - 1),
211 211 cborutil.SPECIAL_NONE))
212 212
213 213 # Decoding complete header reports item size.
214 214 self.assertEqual(cborutil.decodeitem(encoded[0:hlen]),
215 215 (False, None, -size, cborutil.SPECIAL_NONE))
216 216
217 217 # Decoding single byte after header reports item size - 1
218 218 self.assertEqual(cborutil.decodeitem(encoded[0:hlen + 1]),
219 219 (False, None, -(size - 1), cborutil.SPECIAL_NONE))
220 220
221 221 # Decoding all but the last byte reports -1 needed.
222 222 self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size - 1]),
223 223 (False, None, -1, cborutil.SPECIAL_NONE))
224 224
225 225 # Decoding last byte retrieves value.
226 226 self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size]),
227 227 (True, source, hlen + size, cborutil.SPECIAL_NONE))
228 228
229 229 def testindefinitepartialdecode(self):
230 230 encoded = b''.join(cborutil.streamencodebytestringfromiter(
231 231 [b'foobar', b'biz']))
232 232
233 233 # First item should be begin of bytestring special.
234 234 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
235 235 (True, None, 1,
236 236 cborutil.SPECIAL_START_INDEFINITE_BYTESTRING))
237 237
238 238 # Second item should be the first chunk. But only available when
239 239 # we give it 7 bytes (1 byte header + 6 byte chunk).
240 240 self.assertEqual(cborutil.decodeitem(encoded[1:2]),
241 241 (False, None, -6, cborutil.SPECIAL_NONE))
242 242 self.assertEqual(cborutil.decodeitem(encoded[1:3]),
243 243 (False, None, -5, cborutil.SPECIAL_NONE))
244 244 self.assertEqual(cborutil.decodeitem(encoded[1:4]),
245 245 (False, None, -4, cborutil.SPECIAL_NONE))
246 246 self.assertEqual(cborutil.decodeitem(encoded[1:5]),
247 247 (False, None, -3, cborutil.SPECIAL_NONE))
248 248 self.assertEqual(cborutil.decodeitem(encoded[1:6]),
249 249 (False, None, -2, cborutil.SPECIAL_NONE))
250 250 self.assertEqual(cborutil.decodeitem(encoded[1:7]),
251 251 (False, None, -1, cborutil.SPECIAL_NONE))
252 252
253 253 self.assertEqual(cborutil.decodeitem(encoded[1:8]),
254 254 (True, b'foobar', 7, cborutil.SPECIAL_NONE))
255 255
256 256 # Third item should be second chunk. But only available when
257 257 # we give it 4 bytes (1 byte header + 3 byte chunk).
258 258 self.assertEqual(cborutil.decodeitem(encoded[8:9]),
259 259 (False, None, -3, cborutil.SPECIAL_NONE))
260 260 self.assertEqual(cborutil.decodeitem(encoded[8:10]),
261 261 (False, None, -2, cborutil.SPECIAL_NONE))
262 262 self.assertEqual(cborutil.decodeitem(encoded[8:11]),
263 263 (False, None, -1, cborutil.SPECIAL_NONE))
264 264
265 265 self.assertEqual(cborutil.decodeitem(encoded[8:12]),
266 266 (True, b'biz', 4, cborutil.SPECIAL_NONE))
267 267
268 268 # Fourth item should be end of indefinite stream marker.
269 269 self.assertEqual(cborutil.decodeitem(encoded[12:13]),
270 270 (True, None, 1, cborutil.SPECIAL_INDEFINITE_BREAK))
271 271
272 272 # Now test the behavior when going through the decoder.
273 273
274 274 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:1]),
275 275 (False, 1, 0))
276 276 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:2]),
277 277 (False, 1, 6))
278 278 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:3]),
279 279 (False, 1, 5))
280 280 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:4]),
281 281 (False, 1, 4))
282 282 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:5]),
283 283 (False, 1, 3))
284 284 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:6]),
285 285 (False, 1, 2))
286 286 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:7]),
287 287 (False, 1, 1))
288 288 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:8]),
289 289 (True, 8, 0))
290 290
291 291 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:9]),
292 292 (True, 8, 3))
293 293 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:10]),
294 294 (True, 8, 2))
295 295 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:11]),
296 296 (True, 8, 1))
297 297 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:12]),
298 298 (True, 12, 0))
299 299
300 300 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:13]),
301 301 (True, 13, 0))
302 302
303 303 decoder = cborutil.sansiodecoder()
304 304 decoder.decode(encoded[0:8])
305 305 values = decoder.getavailable()
306 306 self.assertEqual(values, [b'foobar'])
307 307 self.assertTrue(values[0].isfirst)
308 308 self.assertFalse(values[0].islast)
309 309
310 310 self.assertEqual(decoder.decode(encoded[8:12]),
311 311 (True, 4, 0))
312 312 values = decoder.getavailable()
313 313 self.assertEqual(values, [b'biz'])
314 314 self.assertFalse(values[0].isfirst)
315 315 self.assertFalse(values[0].islast)
316 316
317 317 self.assertEqual(decoder.decode(encoded[12:]),
318 318 (True, 1, 0))
319 319 values = decoder.getavailable()
320 320 self.assertEqual(values, [b''])
321 321 self.assertFalse(values[0].isfirst)
322 322 self.assertTrue(values[0].islast)
323 323
324 324 class StringTests(TestCase):
325 325 def testdecodeforbidden(self):
326 326 encoded = b'\x63foo'
327 327 with self.assertRaisesRegex(cborutil.CBORDecodeError,
328 328 'string major type not supported'):
329 329 cborutil.decodeall(encoded)
330 330
331 331 class IntTests(TestCase):
332 332 def testsmall(self):
333 333 self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
334 334 self.assertEqual(cborutil.decodeall(b'\x00'), [0])
335 335
336 336 self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
337 337 self.assertEqual(cborutil.decodeall(b'\x01'), [1])
338 338
339 339 self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
340 340 self.assertEqual(cborutil.decodeall(b'\x02'), [2])
341 341
342 342 self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
343 343 self.assertEqual(cborutil.decodeall(b'\x03'), [3])
344 344
345 345 self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
346 346 self.assertEqual(cborutil.decodeall(b'\x04'), [4])
347 347
348 348 # Multiple value decode works.
349 349 self.assertEqual(cborutil.decodeall(b'\x00\x01\x02\x03\x04'),
350 350 [0, 1, 2, 3, 4])
351 351
352 352 def testnegativesmall(self):
353 353 self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
354 354 self.assertEqual(cborutil.decodeall(b'\x20'), [-1])
355 355
356 356 self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
357 357 self.assertEqual(cborutil.decodeall(b'\x21'), [-2])
358 358
359 359 self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
360 360 self.assertEqual(cborutil.decodeall(b'\x22'), [-3])
361 361
362 362 self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
363 363 self.assertEqual(cborutil.decodeall(b'\x23'), [-4])
364 364
365 365 self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
366 366 self.assertEqual(cborutil.decodeall(b'\x24'), [-5])
367 367
368 368 # Multiple value decode works.
369 369 self.assertEqual(cborutil.decodeall(b'\x20\x21\x22\x23\x24'),
370 370 [-1, -2, -3, -4, -5])
371 371
372 372 def testrange(self):
373 373 for i in range(-70000, 70000, 10):
374 374 encoded = b''.join(cborutil.streamencode(i))
375 375
376 376 self.assertEqual(encoded, cbor.dumps(i))
377 377 self.assertEqual(cborutil.decodeall(encoded), [i])
378 378
379 379 def testdecodepartialubyte(self):
380 380 encoded = b''.join(cborutil.streamencode(250))
381 381
382 382 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
383 383 (False, None, -1, cborutil.SPECIAL_NONE))
384 384 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
385 385 (True, 250, 2, cborutil.SPECIAL_NONE))
386 386
387 387 def testdecodepartialbyte(self):
388 388 encoded = b''.join(cborutil.streamencode(-42))
389 389 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
390 390 (False, None, -1, cborutil.SPECIAL_NONE))
391 391 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
392 392 (True, -42, 2, cborutil.SPECIAL_NONE))
393 393
394 394 def testdecodepartialushort(self):
395 395 encoded = b''.join(cborutil.streamencode(2**15))
396 396
397 397 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
398 398 (False, None, -2, cborutil.SPECIAL_NONE))
399 399 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
400 400 (False, None, -1, cborutil.SPECIAL_NONE))
401 401 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
402 402 (True, 2**15, 3, cborutil.SPECIAL_NONE))
403 403
404 404 def testdecodepartialshort(self):
405 405 encoded = b''.join(cborutil.streamencode(-1024))
406 406
407 407 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
408 408 (False, None, -2, cborutil.SPECIAL_NONE))
409 409 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
410 410 (False, None, -1, cborutil.SPECIAL_NONE))
411 411 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
412 412 (True, -1024, 3, cborutil.SPECIAL_NONE))
413 413
414 414 def testdecodepartialulong(self):
415 415 encoded = b''.join(cborutil.streamencode(2**28))
416 416
417 417 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
418 418 (False, None, -4, cborutil.SPECIAL_NONE))
419 419 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
420 420 (False, None, -3, cborutil.SPECIAL_NONE))
421 421 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
422 422 (False, None, -2, cborutil.SPECIAL_NONE))
423 423 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
424 424 (False, None, -1, cborutil.SPECIAL_NONE))
425 425 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
426 426 (True, 2**28, 5, cborutil.SPECIAL_NONE))
427 427
428 428 def testdecodepartiallong(self):
429 429 encoded = b''.join(cborutil.streamencode(-1048580))
430 430
431 431 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
432 432 (False, None, -4, cborutil.SPECIAL_NONE))
433 433 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
434 434 (False, None, -3, cborutil.SPECIAL_NONE))
435 435 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
436 436 (False, None, -2, cborutil.SPECIAL_NONE))
437 437 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
438 438 (False, None, -1, cborutil.SPECIAL_NONE))
439 439 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
440 440 (True, -1048580, 5, cborutil.SPECIAL_NONE))
441 441
442 442 def testdecodepartialulonglong(self):
443 443 encoded = b''.join(cborutil.streamencode(2**32))
444 444
445 445 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
446 446 (False, None, -8, cborutil.SPECIAL_NONE))
447 447 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
448 448 (False, None, -7, cborutil.SPECIAL_NONE))
449 449 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
450 450 (False, None, -6, cborutil.SPECIAL_NONE))
451 451 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
452 452 (False, None, -5, cborutil.SPECIAL_NONE))
453 453 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
454 454 (False, None, -4, cborutil.SPECIAL_NONE))
455 455 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
456 456 (False, None, -3, cborutil.SPECIAL_NONE))
457 457 self.assertEqual(cborutil.decodeitem(encoded[0:7]),
458 458 (False, None, -2, cborutil.SPECIAL_NONE))
459 459 self.assertEqual(cborutil.decodeitem(encoded[0:8]),
460 460 (False, None, -1, cborutil.SPECIAL_NONE))
461 461 self.assertEqual(cborutil.decodeitem(encoded[0:9]),
462 462 (True, 2**32, 9, cborutil.SPECIAL_NONE))
463 463
464 464 with self.assertRaisesRegex(
465 465 cborutil.CBORDecodeError, 'input data not fully consumed'):
466 466 cborutil.decodeall(encoded[0:1])
467 467
468 468 with self.assertRaisesRegex(
469 469 cborutil.CBORDecodeError, 'input data not fully consumed'):
470 470 cborutil.decodeall(encoded[0:2])
471 471
472 472 def testdecodepartiallonglong(self):
473 473 encoded = b''.join(cborutil.streamencode(-7000000000))
474 474
475 475 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
476 476 (False, None, -8, cborutil.SPECIAL_NONE))
477 477 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
478 478 (False, None, -7, cborutil.SPECIAL_NONE))
479 479 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
480 480 (False, None, -6, cborutil.SPECIAL_NONE))
481 481 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
482 482 (False, None, -5, cborutil.SPECIAL_NONE))
483 483 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
484 484 (False, None, -4, cborutil.SPECIAL_NONE))
485 485 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
486 486 (False, None, -3, cborutil.SPECIAL_NONE))
487 487 self.assertEqual(cborutil.decodeitem(encoded[0:7]),
488 488 (False, None, -2, cborutil.SPECIAL_NONE))
489 489 self.assertEqual(cborutil.decodeitem(encoded[0:8]),
490 490 (False, None, -1, cborutil.SPECIAL_NONE))
491 491 self.assertEqual(cborutil.decodeitem(encoded[0:9]),
492 492 (True, -7000000000, 9, cborutil.SPECIAL_NONE))
493 493
494 494 class ArrayTests(TestCase):
495 495 def testempty(self):
496 496 self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
497 497 self.assertEqual(loadit(cborutil.streamencode([])), [])
498 498
499 499 self.assertEqual(cborutil.decodeall(b'\x80'), [[]])
500 500
501 501 def testbasic(self):
502 502 source = [b'foo', b'bar', 1, -10]
503 503
504 504 chunks = [
505 505 b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29']
506 506
507 507 self.assertEqual(list(cborutil.streamencode(source)), chunks)
508 508
509 509 self.assertEqual(cborutil.decodeall(b''.join(chunks)), [source])
510 510
511 511 def testemptyfromiter(self):
512 512 self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
513 513 b'\x9f\xff')
514 514
515 515 with self.assertRaisesRegex(cborutil.CBORDecodeError,
516 516 'indefinite length uint not allowed'):
517 517 cborutil.decodeall(b'\x9f\xff')
518 518
519 519 def testfromiter1(self):
520 520 source = [b'foo']
521 521
522 522 self.assertEqual(list(cborutil.streamencodearrayfromiter(source)), [
523 523 b'\x9f',
524 524 b'\x43', b'foo',
525 525 b'\xff',
526 526 ])
527 527
528 528 dest = b''.join(cborutil.streamencodearrayfromiter(source))
529 529 self.assertEqual(cbor.loads(dest), source)
530 530
531 531 with self.assertRaisesRegex(cborutil.CBORDecodeError,
532 532 'indefinite length uint not allowed'):
533 533 cborutil.decodeall(dest)
534 534
535 535 def testtuple(self):
536 536 source = (b'foo', None, 42)
537 537 encoded = b''.join(cborutil.streamencode(source))
538 538
539 539 self.assertEqual(cbor.loads(encoded), list(source))
540 540
541 541 self.assertEqual(cborutil.decodeall(encoded), [list(source)])
542 542
543 543 def testpartialdecode(self):
544 544 source = list(range(4))
545 545 encoded = b''.join(cborutil.streamencode(source))
546 546 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
547 547 (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
548 548 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
549 549 (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
550 550
551 551 source = list(range(23))
552 552 encoded = b''.join(cborutil.streamencode(source))
553 553 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
554 554 (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
555 555 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
556 556 (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
557 557
558 558 source = list(range(24))
559 559 encoded = b''.join(cborutil.streamencode(source))
560 560 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
561 561 (False, None, -1, cborutil.SPECIAL_NONE))
562 562 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
563 563 (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
564 564 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
565 565 (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
566 566
567 567 source = list(range(256))
568 568 encoded = b''.join(cborutil.streamencode(source))
569 569 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
570 570 (False, None, -2, cborutil.SPECIAL_NONE))
571 571 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
572 572 (False, None, -1, cborutil.SPECIAL_NONE))
573 573 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
574 574 (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
575 575 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
576 576 (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
577 577
578 578 def testnested(self):
579 579 source = [[], [], [[], [], []]]
580 580 encoded = b''.join(cborutil.streamencode(source))
581 581 self.assertEqual(cborutil.decodeall(encoded), [source])
582 582
583 583 source = [True, None, [True, 0, 2], [None], [], [[[]], -87]]
584 584 encoded = b''.join(cborutil.streamencode(source))
585 585 self.assertEqual(cborutil.decodeall(encoded), [source])
586 586
587 587 # A set within an array.
588 588 source = [None, {b'foo', b'bar', None, False}, set()]
589 589 encoded = b''.join(cborutil.streamencode(source))
590 590 self.assertEqual(cborutil.decodeall(encoded), [source])
591 591
592 592 # A map within an array.
593 593 source = [None, {}, {b'foo': b'bar', True: False}, [{}]]
594 594 encoded = b''.join(cborutil.streamencode(source))
595 595 self.assertEqual(cborutil.decodeall(encoded), [source])
596 596
597 597 def testindefinitebytestringvalues(self):
598 598 # Single value array whose value is an empty indefinite bytestring.
599 599 encoded = b'\x81\x5f\x40\xff'
600 600
601 601 with self.assertRaisesRegex(cborutil.CBORDecodeError,
602 602 'indefinite length bytestrings not '
603 603 'allowed as array values'):
604 604 cborutil.decodeall(encoded)
605 605
606 606 class SetTests(TestCase):
607 607 def testempty(self):
608 608 self.assertEqual(list(cborutil.streamencode(set())), [
609 609 b'\xd9\x01\x02',
610 610 b'\x80',
611 611 ])
612 612
613 613 self.assertEqual(cborutil.decodeall(b'\xd9\x01\x02\x80'), [set()])
614 614
615 615 def testset(self):
616 616 source = {b'foo', None, 42}
617 617 encoded = b''.join(cborutil.streamencode(source))
618 618
619 619 self.assertEqual(cbor.loads(encoded), source)
620 620
621 621 self.assertEqual(cborutil.decodeall(encoded), [source])
622 622
623 623 def testinvalidtag(self):
624 624 # Must use array to encode sets.
625 625 encoded = b'\xd9\x01\x02\xa0'
626 626
627 627 with self.assertRaisesRegex(cborutil.CBORDecodeError,
628 628 'expected array after finite set '
629 629 'semantic tag'):
630 630 cborutil.decodeall(encoded)
631 631
632 632 def testpartialdecode(self):
633 633 # Semantic tag item will be 3 bytes. Set header will be variable
634 634 # depending on length.
635 635 encoded = b''.join(cborutil.streamencode({i for i in range(23)}))
636 636 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
637 637 (False, None, -2, cborutil.SPECIAL_NONE))
638 638 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
639 639 (False, None, -1, cborutil.SPECIAL_NONE))
640 640 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
641 641 (False, None, -1, cborutil.SPECIAL_NONE))
642 642 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
643 643 (True, 23, 4, cborutil.SPECIAL_START_SET))
644 644 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
645 645 (True, 23, 4, cborutil.SPECIAL_START_SET))
646 646
647 647 encoded = b''.join(cborutil.streamencode({i for i in range(24)}))
648 648 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
649 649 (False, None, -2, cborutil.SPECIAL_NONE))
650 650 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
651 651 (False, None, -1, cborutil.SPECIAL_NONE))
652 652 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
653 653 (False, None, -1, cborutil.SPECIAL_NONE))
654 654 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
655 655 (False, None, -1, cborutil.SPECIAL_NONE))
656 656 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
657 657 (True, 24, 5, cborutil.SPECIAL_START_SET))
658 658 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
659 659 (True, 24, 5, cborutil.SPECIAL_START_SET))
660 660
661 661 encoded = b''.join(cborutil.streamencode({i for i in range(256)}))
662 662 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
663 663 (False, None, -2, cborutil.SPECIAL_NONE))
664 664 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
665 665 (False, None, -1, cborutil.SPECIAL_NONE))
666 666 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
667 667 (False, None, -1, cborutil.SPECIAL_NONE))
668 668 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
669 669 (False, None, -2, cborutil.SPECIAL_NONE))
670 670 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
671 671 (False, None, -1, cborutil.SPECIAL_NONE))
672 672 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
673 673 (True, 256, 6, cborutil.SPECIAL_START_SET))
674 674
675 675 def testinvalidvalue(self):
676 676 encoded = b''.join([
677 677 b'\xd9\x01\x02', # semantic tag
678 678 b'\x81', # array of size 1
679 679 b'\x5f\x43foo\xff', # indefinite length bytestring "foo"
680 680 ])
681 681
682 682 with self.assertRaisesRegex(cborutil.CBORDecodeError,
683 683 'indefinite length bytestrings not '
684 684 'allowed as set values'):
685 685 cborutil.decodeall(encoded)
686 686
687 687 encoded = b''.join([
688 688 b'\xd9\x01\x02',
689 689 b'\x81',
690 690 b'\x80', # empty array
691 691 ])
692 692
693 693 with self.assertRaisesRegex(cborutil.CBORDecodeError,
694 694 'collections not allowed as set values'):
695 695 cborutil.decodeall(encoded)
696 696
697 697 encoded = b''.join([
698 698 b'\xd9\x01\x02',
699 699 b'\x81',
700 700 b'\xa0', # empty map
701 701 ])
702 702
703 703 with self.assertRaisesRegex(cborutil.CBORDecodeError,
704 704 'collections not allowed as set values'):
705 705 cborutil.decodeall(encoded)
706 706
707 707 encoded = b''.join([
708 708 b'\xd9\x01\x02',
709 709 b'\x81',
710 710 b'\xd9\x01\x02\x81\x01', # set with integer 1
711 711 ])
712 712
713 713 with self.assertRaisesRegex(cborutil.CBORDecodeError,
714 714 'collections not allowed as set values'):
715 715 cborutil.decodeall(encoded)
716 716
717 717 class BoolTests(TestCase):
718 718 def testbasic(self):
719 719 self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5'])
720 720 self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
721 721
722 722 self.assertIs(loadit(cborutil.streamencode(True)), True)
723 723 self.assertIs(loadit(cborutil.streamencode(False)), False)
724 724
725 725 self.assertEqual(cborutil.decodeall(b'\xf4'), [False])
726 726 self.assertEqual(cborutil.decodeall(b'\xf5'), [True])
727 727
728 728 self.assertEqual(cborutil.decodeall(b'\xf4\xf5\xf5\xf4'),
729 729 [False, True, True, False])
730 730
731 731 class NoneTests(TestCase):
732 732 def testbasic(self):
733 733 self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
734 734
735 735 self.assertIs(loadit(cborutil.streamencode(None)), None)
736 736
737 737 self.assertEqual(cborutil.decodeall(b'\xf6'), [None])
738 738 self.assertEqual(cborutil.decodeall(b'\xf6\xf6'), [None, None])
739 739
740 740 class MapTests(TestCase):
741 741 def testempty(self):
742 742 self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
743 743 self.assertEqual(loadit(cborutil.streamencode({})), {})
744 744
745 745 self.assertEqual(cborutil.decodeall(b'\xa0'), [{}])
746 746
747 747 def testemptyindefinite(self):
748 748 self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
749 749 b'\xbf', b'\xff'])
750 750
751 751 self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
752 752
753 753 with self.assertRaisesRegex(cborutil.CBORDecodeError,
754 754 'indefinite length uint not allowed'):
755 755 cborutil.decodeall(b'\xbf\xff')
756 756
757 757 def testone(self):
758 758 source = {b'foo': b'bar'}
759 759 self.assertEqual(list(cborutil.streamencode(source)), [
760 760 b'\xa1', b'\x43', b'foo', b'\x43', b'bar'])
761 761
762 762 self.assertEqual(loadit(cborutil.streamencode(source)), source)
763 763
764 764 self.assertEqual(cborutil.decodeall(b'\xa1\x43foo\x43bar'), [source])
765 765
766 766 def testmultiple(self):
767 767 source = {
768 768 b'foo': b'bar',
769 769 b'baz': b'value1',
770 770 }
771 771
772 772 self.assertEqual(loadit(cborutil.streamencode(source)), source)
773 773
774 774 self.assertEqual(
775 775 loadit(cborutil.streamencodemapfromiter(source.items())),
776 776 source)
777 777
778 778 encoded = b''.join(cborutil.streamencode(source))
779 779 self.assertEqual(cborutil.decodeall(encoded), [source])
780 780
781 781 def testcomplex(self):
782 782 source = {
783 783 b'key': 1,
784 784 2: -10,
785 785 }
786 786
787 787 self.assertEqual(loadit(cborutil.streamencode(source)),
788 788 source)
789 789
790 790 self.assertEqual(
791 791 loadit(cborutil.streamencodemapfromiter(source.items())),
792 792 source)
793 793
794 794 encoded = b''.join(cborutil.streamencode(source))
795 795 self.assertEqual(cborutil.decodeall(encoded), [source])
796 796
797 797 def testnested(self):
798 798 source = {b'key1': None, b'key2': {b'sub1': b'sub2'}, b'sub2': {}}
799 799 encoded = b''.join(cborutil.streamencode(source))
800 800
801 801 self.assertEqual(cborutil.decodeall(encoded), [source])
802 802
803 803 source = {
804 804 b'key1': [],
805 805 b'key2': [None, False],
806 806 b'key3': {b'foo', b'bar'},
807 807 b'key4': {},
808 808 }
809 809 encoded = b''.join(cborutil.streamencode(source))
810 810 self.assertEqual(cborutil.decodeall(encoded), [source])
811 811
812 812 def testillegalkey(self):
813 813 encoded = b''.join([
814 814 # map header + len 1
815 815 b'\xa1',
816 816 # indefinite length bytestring "foo" in key position
817 817 b'\x5f\x03foo\xff'
818 818 ])
819 819
820 820 with self.assertRaisesRegex(cborutil.CBORDecodeError,
821 821 'indefinite length bytestrings not '
822 822 'allowed as map keys'):
823 823 cborutil.decodeall(encoded)
824 824
825 825 encoded = b''.join([
826 826 b'\xa1',
827 827 b'\x80', # empty array
828 828 b'\x43foo',
829 829 ])
830 830
831 831 with self.assertRaisesRegex(cborutil.CBORDecodeError,
832 832 'collections not supported as map keys'):
833 833 cborutil.decodeall(encoded)
834 834
835 835 def testillegalvalue(self):
836 836 encoded = b''.join([
837 837 b'\xa1', # map headers
838 838 b'\x43foo', # key
839 839 b'\x5f\x03bar\xff', # indefinite length value
840 840 ])
841 841
842 842 with self.assertRaisesRegex(cborutil.CBORDecodeError,
843 843 'indefinite length bytestrings not '
844 844 'allowed as map values'):
845 845 cborutil.decodeall(encoded)
846 846
847 847 def testpartialdecode(self):
848 848 source = {b'key1': b'value1'}
849 849 encoded = b''.join(cborutil.streamencode(source))
850 850
851 851 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
852 852 (True, 1, 1, cborutil.SPECIAL_START_MAP))
853 853 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
854 854 (True, 1, 1, cborutil.SPECIAL_START_MAP))
855 855
856 856 source = {b'key%d' % i: None for i in range(23)}
857 857 encoded = b''.join(cborutil.streamencode(source))
858 858 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
859 859 (True, 23, 1, cborutil.SPECIAL_START_MAP))
860 860
861 861 source = {b'key%d' % i: None for i in range(24)}
862 862 encoded = b''.join(cborutil.streamencode(source))
863 863 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
864 864 (False, None, -1, cborutil.SPECIAL_NONE))
865 865 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
866 866 (True, 24, 2, cborutil.SPECIAL_START_MAP))
867 867 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
868 868 (True, 24, 2, cborutil.SPECIAL_START_MAP))
869 869
870 870 source = {b'key%d' % i: None for i in range(256)}
871 871 encoded = b''.join(cborutil.streamencode(source))
872 872 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
873 873 (False, None, -2, cborutil.SPECIAL_NONE))
874 874 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
875 875 (False, None, -1, cborutil.SPECIAL_NONE))
876 876 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
877 877 (True, 256, 3, cborutil.SPECIAL_START_MAP))
878 878 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
879 879 (True, 256, 3, cborutil.SPECIAL_START_MAP))
880 880
881 881 source = {b'key%d' % i: None for i in range(65536)}
882 882 encoded = b''.join(cborutil.streamencode(source))
883 883 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
884 884 (False, None, -4, cborutil.SPECIAL_NONE))
885 885 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
886 886 (False, None, -3, cborutil.SPECIAL_NONE))
887 887 self.assertEqual(cborutil.decodeitem(encoded[0:3]),
888 888 (False, None, -2, cborutil.SPECIAL_NONE))
889 889 self.assertEqual(cborutil.decodeitem(encoded[0:4]),
890 890 (False, None, -1, cborutil.SPECIAL_NONE))
891 891 self.assertEqual(cborutil.decodeitem(encoded[0:5]),
892 892 (True, 65536, 5, cborutil.SPECIAL_START_MAP))
893 893 self.assertEqual(cborutil.decodeitem(encoded[0:6]),
894 894 (True, 65536, 5, cborutil.SPECIAL_START_MAP))
895 895
896 896 class SemanticTagTests(TestCase):
897 897 def testdecodeforbidden(self):
898 898 for i in range(500):
899 899 if i == cborutil.SEMANTIC_TAG_FINITE_SET:
900 900 continue
901 901
902 902 tag = cborutil.encodelength(cborutil.MAJOR_TYPE_SEMANTIC,
903 903 i)
904 904
905 905 encoded = tag + cborutil.encodelength(cborutil.MAJOR_TYPE_UINT, 42)
906 906
907 907 # Partial decode is incomplete.
908 908 if i < 24:
909 909 pass
910 910 elif i < 256:
911 911 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
912 912 (False, None, -1, cborutil.SPECIAL_NONE))
913 913 elif i < 65536:
914 914 self.assertEqual(cborutil.decodeitem(encoded[0:1]),
915 915 (False, None, -2, cborutil.SPECIAL_NONE))
916 916 self.assertEqual(cborutil.decodeitem(encoded[0:2]),
917 917 (False, None, -1, cborutil.SPECIAL_NONE))
918 918
919 919 with self.assertRaisesRegex(cborutil.CBORDecodeError,
920 920 'semantic tag \d+ not allowed'):
921 921 cborutil.decodeitem(encoded)
922 922
923 923 class SpecialTypesTests(TestCase):
924 924 def testforbiddentypes(self):
925 925 for i in range(256):
926 926 if i == cborutil.SUBTYPE_FALSE:
927 927 continue
928 928 elif i == cborutil.SUBTYPE_TRUE:
929 929 continue
930 930 elif i == cborutil.SUBTYPE_NULL:
931 931 continue
932 932
933 933 encoded = cborutil.encodelength(cborutil.MAJOR_TYPE_SPECIAL, i)
934 934
935 935 with self.assertRaisesRegex(cborutil.CBORDecodeError,
936 936 'special type \d+ not allowed'):
937 937 cborutil.decodeitem(encoded)
938 938
939 939 class SansIODecoderTests(TestCase):
940 940 def testemptyinput(self):
941 941 decoder = cborutil.sansiodecoder()
942 942 self.assertEqual(decoder.decode(b''), (False, 0, 0))
943 943
944 944 class BufferingDecoderTests(TestCase):
945 945 def testsimple(self):
946 946 source = [
947 947 b'foobar',
948 948 b'x' * 128,
949 949 {b'foo': b'bar'},
950 950 True,
951 951 False,
952 952 None,
953 953 [None for i in range(128)],
954 954 ]
955 955
956 956 encoded = b''.join(cborutil.streamencode(source))
957 957
958 958 for step in range(1, 32):
959 959 decoder = cborutil.bufferingdecoder()
960 960 start = 0
961 961
962 962 while start < len(encoded):
963 963 decoder.decode(encoded[start:start + step])
964 964 start += step
965 965
966 966 self.assertEqual(decoder.getavailable(), [source])
967 967
968 def testbytearray(self):
969 source = b''.join(cborutil.streamencode(b'foobar'))
970
971 decoder = cborutil.bufferingdecoder()
972 decoder.decode(bytearray(source))
973
974 self.assertEqual(decoder.getavailable(), [b'foobar'])
975
968 976 class DecodeallTests(TestCase):
969 977 def testemptyinput(self):
970 978 self.assertEqual(cborutil.decodeall(b''), [])
971 979
972 980 def testpartialinput(self):
973 981 encoded = b''.join([
974 982 b'\x82', # array of 2 elements
975 983 b'\x01', # integer 1
976 984 ])
977 985
978 986 with self.assertRaisesRegex(cborutil.CBORDecodeError,
979 987 'input data not complete'):
980 988 cborutil.decodeall(encoded)
981 989
982 990 if __name__ == '__main__':
983 991 import silenttestrunner
984 992 silenttestrunner.main(__name__)
General Comments 0
You need to be logged in to leave comments. Login now