diff --git a/mercurial/utils/cborutil.py b/mercurial/utils/cborutil.py --- a/mercurial/utils/cborutil.py +++ b/mercurial/utils/cborutil.py @@ -898,6 +898,48 @@ class sansiodecoder(object): self._decodedvalues = [] return l +class bufferingdecoder(object): + """A CBOR decoder that buffers undecoded input. + + This is a glorified wrapper around ``sansiodecoder`` that adds a buffering + layer. All input that isn't consumed by ``sansiodecoder`` will be buffered + and concatenated with any new input that arrives later. + + TODO consider adding limits as to the maximum amount of data that can + be buffered. + """ + def __init__(self): + self._decoder = sansiodecoder() + self._leftover = None + + def decode(self, b): + """Attempt to decode bytes to CBOR values. + + Returns a tuple with the following fields: + + * Bool indicating whether new values are available for retrieval. + * Integer number of bytes decoded from the new input. + * Integer number of bytes wanted to decode the next value. + """ + + if self._leftover: + oldlen = len(self._leftover) + b = self._leftover + b + self._leftover = None + else: + b = b + oldlen = 0 + + available, readcount, wanted = self._decoder.decode(b) + + if readcount < len(b): + self._leftover = b[readcount:] + + return available, readcount - oldlen, wanted + + def getavailable(self): + return self._decoder.getavailable() + def decodeall(b): """Decode all CBOR items present in an iterable of bytes. diff --git a/tests/test-cbor.py b/tests/test-cbor.py --- a/tests/test-cbor.py +++ b/tests/test-cbor.py @@ -941,6 +941,30 @@ class SansIODecoderTests(TestCase): decoder = cborutil.sansiodecoder() self.assertEqual(decoder.decode(b''), (False, 0, 0)) +class BufferingDecoderTests(TestCase): + def testsimple(self): + source = [ + b'foobar', + b'x' * 128, + {b'foo': b'bar'}, + True, + False, + None, + [None for i in range(128)], + ] + + encoded = b''.join(cborutil.streamencode(source)) + + for step in range(1, 32): + decoder = cborutil.bufferingdecoder() + start = 0 + + while start < len(encoded): + decoder.decode(encoded[start:start + step]) + start += step + + self.assertEqual(decoder.getavailable(), [source]) + class DecodeallTests(TestCase): def testemptyinput(self): self.assertEqual(cborutil.decodeall(b''), [])