Merge use of constants for OFS/REF delta's.
[jelmer/dulwich-libgit2.git] / dulwich / pack.py
index 5db1fe4a26ba6d9461a1bd34c35514dde0dcb0dc..4344e5bdedb904fb0d36237ef31bd2caae1a34da 100644 (file)
@@ -35,6 +35,12 @@ try:
 except ImportError:
     from misc import defaultdict
 
+from cStringIO import (
+    StringIO,
+    )
+from collections import (
+    deque,
+    )
 import difflib
 from itertools import (
     chain,
@@ -429,6 +435,7 @@ def read_pack_header(read):
     """Read the header of a pack file.
 
     :param read: Read function
+    :return: Tuple with pack version and number of objects
     """
     header = read(12)
     assert header[:4] == "PACK"
@@ -490,6 +497,146 @@ def _compute_object_size((num, obj)):
     return chunks_length(obj)
 
 
+class PackStreamReader(object):
+    """Class to read a pack stream.
+
+    The pack is read from a ReceivableProtocol using read() or recv() as
+    appropriate.
+    """
+
+    def __init__(self, read_all, read_some=None):
+        self.read_all = read_all
+        if read_some is None:
+            self.read_some = read_all
+        else:
+            self.read_some = read_some
+        self.sha = make_sha()
+        self._offset = 0
+        self._rbuf = StringIO()
+        # trailer is a deque to avoid memory allocation on small reads
+        self._trailer = deque()
+
+    def _read(self, read, size):
+        """Read up to size bytes using the given callback.
+
+        As a side effect, update the verifier's hash (excluding the last 20
+        bytes read) and write through to the output file.
+
+        :param read: The read callback to read from.
+        :param size: The maximum number of bytes to read; the particular
+            behavior is callback-specific.
+        """
+        data = read(size)
+
+        # maintain a trailer of the last 20 bytes we've read
+        n = len(data)
+        self._offset += n
+        tn = len(self._trailer)
+        if n >= 20:
+            to_pop = tn
+            to_add = 20
+        else:
+            to_pop = max(n + tn - 20, 0)
+            to_add = n
+        for _ in xrange(to_pop):
+            self.sha.update(self._trailer.popleft())
+        self._trailer.extend(data[-to_add:])
+
+        # hash everything but the trailer
+        self.sha.update(data[:-to_add])
+        return data
+
+    def _buf_len(self):
+        buf = self._rbuf
+        start = buf.tell()
+        buf.seek(0, os.SEEK_END)
+        end = buf.tell()
+        buf.seek(start)
+        return end - start
+
+    @property
+    def offset(self):
+        return self._offset - self._buf_len()
+
+    def read(self, size):
+        """Read, blocking until size bytes are read."""
+        buf_len = self._buf_len()
+        if buf_len >= size:
+            return self._rbuf.read(size)
+        buf_data = self._rbuf.read()
+        self._rbuf = StringIO()
+        return buf_data + self._read(self.read_all, size - buf_len)
+
+    def recv(self, size):
+        """Read up to size bytes, blocking until one byte is read."""
+        buf_len = self._buf_len()
+        if buf_len:
+            data = self._rbuf.read(size)
+            if size >= buf_len:
+                self._rbuf = StringIO()
+            return data
+        return self._read(self.read_some, size)
+
+    def __len__(self):
+        return self._num_objects
+
+    def read_objects(self):
+        """Read the objects in this pack file.
+
+        :raise AssertionError: if there is an error in the pack format.
+        :raise ChecksumMismatch: if the checksum of the pack contents does not
+            match the checksum in the pack trailer.
+        :raise zlib.error: if an error occurred during zlib decompression.
+        :raise IOError: if an error occurred writing to the output file.
+        """
+        pack_version, self._num_objects = read_pack_header(self.read)
+        for i in xrange(self._num_objects):
+            type, uncomp, comp_len, unused = unpack_object(self.read, self.recv)
+            yield type, uncomp, comp_len
+
+            # prepend any unused data to current read buffer
+            buf = StringIO()
+            buf.write(unused)
+            buf.write(self._rbuf.read())
+            buf.seek(0)
+            self._rbuf = buf
+
+        pack_sha = sha_to_hex(''.join([c for c in self._trailer]))
+        calculated_sha = self.sha.hexdigest()
+        if pack_sha != calculated_sha:
+            raise ChecksumMismatch(pack_sha, calculated_sha)
+
+
+class PackObjectIterator(object):
+
+    def __init__(self, pack, progress=None):
+        self.i = 0
+        self.offset = pack._header_size
+        self.num = len(pack)
+        self.map = pack._file
+        self._progress = progress
+
+    def __iter__(self):
+        return self
+
+    def __len__(self):
+        return self.num
+
+    def next(self):
+        if self.i == self.num:
+            raise StopIteration
+        self.map.seek(self.offset)
+        (type, obj, total_size, unused) = unpack_object(self.map.read)
+        self.map.seek(self.offset)
+        crc32 = zlib.crc32(self.map.read(total_size)) & 0xffffffff
+        ret = (self.offset, type, obj, crc32)
+        self.offset += total_size
+        if self._progress is not None:
+            self._progress(self.i, self.num)
+        self.i+=1
+        return ret
+
+
 class PackData(object):
     """The data contained in a packfile.
 
@@ -611,35 +758,7 @@ class PackData(object):
         return (type, apply_delta(base_chunks, delta))
 
     def iterobjects(self, progress=None):
-
-        class ObjectIterator(object):
-
-            def __init__(self, pack):
-                self.i = 0
-                self.offset = pack._header_size
-                self.num = len(pack)
-                self.map = pack._file
-
-            def __iter__(self):
-                return self
-
-            def __len__(self):
-                return self.num
-
-            def next(self):
-                if self.i == self.num:
-                    raise StopIteration
-                self.map.seek(self.offset)
-                (type, obj, total_size, unused) = unpack_object(self.map.read)
-                self.map.seek(self.offset)
-                crc32 = zlib.crc32(self.map.read(total_size)) & 0xffffffff
-                ret = (self.offset, type, obj, crc32)
-                self.offset += total_size
-                if progress:
-                    progress(self.i, self.num)
-                self.i+=1
-                return ret
-        return ObjectIterator(self)
+        return PackObjectIterator(self, progress)
 
     def iterentries(self, ext_resolve_ref=None, progress=None):
         """Yield entries summarizing the contents of this pack.