diff --git a/mercurial/util.py b/mercurial/util.py --- a/mercurial/util.py +++ b/mercurial/util.py @@ -1980,6 +1980,35 @@ def filechunkiter(f, size=131072, limit= limit -= len(s) yield s +class cappedreader(object): + """A file object proxy that allows reading up to N bytes. + + Given a source file object, instances of this type allow reading up to + N bytes from that source file object. Attempts to read past the allowed + limit are treated as EOF. + + It is assumed that I/O is not performed on the original file object + in addition to I/O that is performed by this instance. If there is, + state tracking will get out of sync and unexpected results will ensue. + """ + def __init__(self, fh, limit): + """Allow reading up to bytes from .""" + self._fh = fh + self._left = limit + + def read(self, n=-1): + if not self._left: + return b'' + + if n < 0: + n = self._left + + data = self._fh.read(min(n, self._left)) + self._left -= len(data) + assert self._left >= 0 + + return data + def makedate(timestamp=None): '''Return a unix timestamp (or the current time) as a (unixtime, offset) tuple based off the local timezone.''' diff --git a/tests/test-cappedreader.py b/tests/test-cappedreader.py new file mode 100644 --- /dev/null +++ b/tests/test-cappedreader.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import, print_function + +import io +import unittest + +from mercurial import ( + util, +) + +class CappedReaderTests(unittest.TestCase): + def testreadfull(self): + source = io.BytesIO(b'x' * 100) + + reader = util.cappedreader(source, 10) + res = reader.read(10) + self.assertEqual(res, b'x' * 10) + self.assertEqual(source.tell(), 10) + source.seek(0) + + reader = util.cappedreader(source, 15) + res = reader.read(16) + self.assertEqual(res, b'x' * 15) + self.assertEqual(source.tell(), 15) + source.seek(0) + + reader = util.cappedreader(source, 100) + res = reader.read(100) + self.assertEqual(res, b'x' * 100) + self.assertEqual(source.tell(), 100) + source.seek(0) + + reader = util.cappedreader(source, 50) + res = reader.read() + self.assertEqual(res, b'x' * 50) + self.assertEqual(source.tell(), 50) + source.seek(0) + + def testreadnegative(self): + source = io.BytesIO(b'x' * 100) + + reader = util.cappedreader(source, 20) + res = reader.read(-1) + self.assertEqual(res, b'x' * 20) + self.assertEqual(source.tell(), 20) + source.seek(0) + + reader = util.cappedreader(source, 100) + res = reader.read(-1) + self.assertEqual(res, b'x' * 100) + self.assertEqual(source.tell(), 100) + source.seek(0) + + def testreadmultiple(self): + source = io.BytesIO(b'x' * 100) + + reader = util.cappedreader(source, 10) + for i in range(10): + res = reader.read(1) + self.assertEqual(res, b'x') + self.assertEqual(source.tell(), i + 1) + + self.assertEqual(source.tell(), 10) + res = reader.read(1) + self.assertEqual(res, b'') + self.assertEqual(source.tell(), 10) + source.seek(0) + + reader = util.cappedreader(source, 45) + for i in range(4): + res = reader.read(10) + self.assertEqual(res, b'x' * 10) + self.assertEqual(source.tell(), (i + 1) * 10) + + res = reader.read(10) + self.assertEqual(res, b'x' * 5) + self.assertEqual(source.tell(), 45) + + def readlimitpasteof(self): + source = io.BytesIO(b'x' * 100) + + reader = util.cappedreader(source, 1024) + res = reader.read(1000) + self.assertEqual(res, b'x' * 100) + self.assertEqual(source.tell(), 100) + res = reader.read(1000) + self.assertEqual(res, b'') + self.assertEqual(source.tell(), 100) + +if __name__ == '__main__': + import silenttestrunner + silenttestrunner.main(__name__)