Move PackStreamReader to dulwich.pack.
[jelmer/dulwich-libgit2.git] / dulwich / pack.py
index bc982ab9a0679632321d684f8f721720b5e78fd7..fcad7e3f2977ea121585026a2708c6007d0696a4 100644 (file)
@@ -35,6 +35,12 @@ try:
 except ImportError:
     from misc import defaultdict
 
+from cStringIO import (
+    StringIO,
+    )
+from collections import (
+    deque,
+    )
 import difflib
 from itertools import (
     chain,
@@ -485,6 +491,107 @@ def _compute_object_size((num, obj)):
     return chunks_length(obj)
 
 
+class PackStreamReader(object):
+    """Class to read a pack stream.
+
+    The pack is read from a ReceivableProtocol using read() or recv() as
+    appropriate.
+    """
+
+    def __init__(self, read_all, read_some=None):
+        self.read_all = read_all
+        if read_some is None:
+            self.read_some = read_all
+        else:
+            self.read_some = read_some
+        self.sha = make_sha()
+        self._rbuf = StringIO()
+        # trailer is a deque to avoid memory allocation on small reads
+        self._trailer = deque()
+
+    def _read(self, read, size):
+        """Read up to size bytes using the given callback.
+
+        As a side effect, update the verifier's hash (excluding the last 20
+        bytes read) and write through to the output file.
+
+        :param read: The read callback to read from.
+        :param size: The maximum number of bytes to read; the particular
+            behavior is callback-specific.
+        """
+        data = read(size)
+
+        # maintain a trailer of the last 20 bytes we've read
+        n = len(data)
+        tn = len(self._trailer)
+        if n >= 20:
+            to_pop = tn
+            to_add = 20
+        else:
+            to_pop = max(n + tn - 20, 0)
+            to_add = n
+        for _ in xrange(to_pop):
+            self.sha.update(self._trailer.popleft())
+        self._trailer.extend(data[-to_add:])
+
+        # hash everything but the trailer
+        self.sha.update(data[:-to_add])
+        return data
+
+    def _buf_len(self):
+        buf = self._rbuf
+        start = buf.tell()
+        buf.seek(0, os.SEEK_END)
+        end = buf.tell()
+        buf.seek(start)
+        return end - start
+
+    def read(self, size):
+        """Read, blocking until size bytes are read."""
+        buf_len = self._buf_len()
+        if buf_len >= size:
+            return self._rbuf.read(size)
+        buf_data = self._rbuf.read()
+        self._rbuf = StringIO()
+        return buf_data + self._read(self.read_all, size - buf_len)
+
+    def recv(self, size):
+        """Read up to size bytes, blocking until one byte is read."""
+        buf_len = self._buf_len()
+        if buf_len:
+            data = self._rbuf.read(size)
+            if size >= buf_len:
+                self._rbuf = StringIO()
+            return data
+        return self._read(self.read_some, size)
+
+    def read_objects(self):
+        """Read the objects in this pack file.
+
+        :raise AssertionError: if there is an error in the pack format.
+        :raise ChecksumMismatch: if the checksum of the pack contents does not
+            match the checksum in the pack trailer.
+        :raise zlib.error: if an error occurred during zlib decompression.
+        :raise IOError: if an error occurred writing to the output file.
+        """
+        pack_version, num_objects = read_pack_header(self.read)
+        for i in xrange(num_objects):
+            type, uncomp, comp_len, unused = unpack_object(self.read, self.recv)
+            yield type, uncomp, comp_len
+
+            # prepend any unused data to current read buffer
+            buf = StringIO()
+            buf.write(unused)
+            buf.write(self._rbuf.read())
+            buf.seek(0)
+            self._rbuf = buf
+
+        pack_sha = sha_to_hex(''.join([c for c in self._trailer]))
+        calculated_sha = self.sha.hexdigest()
+        if pack_sha != calculated_sha:
+            raise ChecksumMismatch(pack_sha, calculated_sha)
+
+
 class PackData(object):
     """The data contained in a packfile.
 
@@ -634,6 +741,7 @@ class PackData(object):
                     progress(self.i, self.num)
                 self.i+=1
                 return ret
+
         return ObjectIterator(self)
 
     def iterentries(self, ext_resolve_ref=None, progress=None):