Make the server decode a pack as it streams.
authorDave Borowitz <dborowitz@google.com>
Thu, 11 Mar 2010 22:12:18 +0000 (14:12 -0800)
committerDave Borowitz <dborowitz@google.com>
Fri, 16 Apr 2010 19:04:12 +0000 (12:04 -0700)
This, in combination with using recv() instead of read(), makes it so we
never do blocking reads past the end of the pack stream, even when the
client doesn't close the connection.

This is done via a PackStreamVerifier class that reads from a Protocol,
unpacks and counts objects, writes through to a file, and computes the
SHA-1 checksum on the fly. It is necessary because the only way we know
when the pack is supposed to end is by parsing the header and reading
the correct number of objects; otherwise, any further reads from the
client would hang.

Changed the Handler constructors to take a Protocol instead of taking
read and write callbacks separately. Modified some interfaces to utility
functions in pack.py so they can be used by the server-side code.

Change-Id: Id4d11e34658d1f00ad06e45330d0d128b367d8e5

dulwich/pack.py
dulwich/server.py
dulwich/tests/compat/test_server.py
dulwich/tests/test_server.py
dulwich/tests/test_web.py
dulwich/web.py

index 41fa873..60b0d72 100644 (file)
@@ -417,12 +417,12 @@ class PackIndex2(PackIndex):
   
 
 
-def read_pack_header(f):
+def read_pack_header(read):
     """Read the header of a pack file.
 
-    :param f: File-like object to read from
+    :param read: Read function
     """
-    header = f.read(12)
+    header = read(12)
     assert header[:4] == "PACK"
     (version,) = unpack_from(">L", header, 4)
     assert version in (2, 3), "Version was %d" % version
@@ -434,20 +434,25 @@ def chunks_length(chunks):
     return sum(imap(len, chunks))
 
 
-def unpack_object(read):
+def unpack_object(read_all, read_some=None):
     """Unpack a Git object.
 
-    :return: tuple with type, uncompressed data as chunks, compressed size and 
-        tail data
+    :param read_all: Read function that blocks until the number of requested
+        bytes are read.
+    :param read_some: Read function that returns at least one byte, but may not
+        return the number of bytes requested.
+    :return: tuple with type, uncompressed data, compressed size and tail data.
     """
-    bytes = take_msb_bytes(read)
+    if read_some is None:
+        read_some = read_all
+    bytes = take_msb_bytes(read_all)
     type = (bytes[0] >> 4) & 0x07
     size = bytes[0] & 0x0f
     for i, byte in enumerate(bytes[1:]):
         size += (byte & 0x7f) << ((i * 7) + 4)
     raw_base = len(bytes)
     if type == 6: # offset delta
-        bytes = take_msb_bytes(read)
+        bytes = take_msb_bytes(read_all)
         raw_base += len(bytes)
         assert not (bytes[-1] & 0x80)
         delta_base_offset = bytes[0] & 0x7f
@@ -455,17 +460,17 @@ def unpack_object(read):
             delta_base_offset += 1
             delta_base_offset <<= 7
             delta_base_offset += (byte & 0x7f)
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read_some, size)
         assert size == chunks_length(uncomp)
         return type, (delta_base_offset, uncomp), comp_len+raw_base, unused
     elif type == 7: # ref delta
-        basename = read(20)
+        basename = read_all(20)
         raw_base += 20
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read_some, size)
         assert size == chunks_length(uncomp)
         return type, (basename, uncomp), comp_len+raw_base, unused
     else:
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read_some, size)
         assert chunks_length(uncomp) == size
         return type, uncomp, comp_len+raw_base, unused
 
@@ -522,7 +527,7 @@ class PackData(object):
             self._file = GitFile(self._filename, 'rb')
         else:
             self._file = file
-        (version, self._num_objects) = read_pack_header(self._file)
+        (version, self._num_objects) = read_pack_header(self._file.read)
         self._offset_cache = LRUSizeCache(1024*1024*20, 
             compute_size=_compute_object_size)
 
index 63fb331..c08a2d5 100644 (file)
@@ -27,15 +27,22 @@ Documentation/technical directory in the cgit distribution, and in particular:
 
 
 import collections
+from cStringIO import StringIO
+import socket
 import SocketServer
+import zlib
 
 from dulwich.errors import (
     ApplyDeltaError,
     ChecksumMismatch,
     GitProtocolError,
     )
+from dulwich.misc import (
+    make_sha,
+    )
 from dulwich.objects import (
     hex_to_sha,
+    sha_to_hex,
     )
 from dulwich.protocol import (
     ProtocolFile,
@@ -51,6 +58,8 @@ from dulwich.protocol import (
     ack_type,
     )
 from dulwich.pack import (
+    read_pack_header,
+    unpack_object,
     write_pack_data,
     )
 
@@ -103,6 +112,105 @@ class BackendRepo(object):
         raise NotImplementedError
 
 
+class PackStreamVerifier(object):
+    """Class to verify a pack stream as it is being read.
+
+    The pack is read from a ReceivableProtocol using read() or recv() as
+    appropriate and written out to the given file-like object.
+    """
+
+    def __init__(self, proto, outfile):
+        self.proto = proto
+        self.outfile = outfile
+        self.sha = make_sha()
+        self._rbuf = StringIO()
+        # trailer is a deque to avoid memory allocation on small reads
+        self._trailer = collections.deque()
+
+    def _read(self, read, size):
+        """Read up to size bytes using the given callback.
+
+        As a side effect, update the verifier's hash (excluding the last 20
+        bytes read) and write through to the output file.
+
+        :param read: The read callback to read from.
+        :param size: The maximum number of bytes to read; the particular
+            behavior is callback-specific.
+        """
+        data = read(size)
+
+        # maintain a trailer of the last 20 bytes we've read
+        n = len(data)
+        tn = len(self._trailer)
+        if n >= 20:
+            to_pop = tn
+            to_add = 20
+        else:
+            to_pop = max(n + tn - 20, 0)
+            to_add = n
+        for _ in xrange(to_pop):
+            self.sha.update(self._trailer.popleft())
+        self._trailer.extend(data[-to_add:])
+
+        # hash everything but the trailer
+        self.sha.update(data[:-to_add])
+        self.outfile.write(data)
+        return data
+
+    def _buf_len(self):
+        buf = self._rbuf
+        start = buf.tell()
+        buf.seek(0, 2)
+        end = buf.tell()
+        buf.seek(start)
+        return end - start
+
+    def read(self, size):
+        """Read, blocking until size bytes are read."""
+        buf_len = self._buf_len()
+        if buf_len >= size:
+            return self._rbuf.read(size)
+        buf_data = self._rbuf.read()
+        self._rbuf = StringIO()
+        return buf_data + self._read(self.proto.read, size - buf_len)
+
+    def recv(self, size):
+        """Read up to size bytes, blocking until one byte is read."""
+        buf_len = self._buf_len()
+        if buf_len:
+            data = self._rbuf.read(size)
+            if size >= buf_len:
+                self._rbuf = StringIO()
+            return data
+        return self._read(self.proto.recv, size)
+
+    def verify(self):
+        """Verify a pack stream and write it to the output file.
+
+        :raise AssertionError: if there is an error in the pack format.
+        :raise ChecksumMismatch: if the checksum of the pack contents does not
+            match the checksum in the pack trailer.
+        :raise socket.error: if an error occurred reading from the socket.
+        :raise zlib.error: if an error occurred during zlib decompression.
+        :raise IOError: if an error occurred writing to the output file.
+        """
+        _, num_objects = read_pack_header(self.read)
+        for i in xrange(num_objects):
+            type, _, _, unused = unpack_object(self.read, self.recv)
+
+            # prepend any unused data to current read buffer
+            buf = StringIO()
+            buf.write(unused)
+            buf.write(self._rbuf.read())
+            buf.seek(0)
+            self._rbuf = buf
+
+        pack_sha = sha_to_hex(''.join([c for c in self._trailer]))
+        calculated_sha = self.sha.hexdigest()
+        if pack_sha != calculated_sha:
+            raise ChecksumMismatch(pack_sha, calculated_sha)
+
+
 class DictBackend(Backend):
     """Trivial backend that looks up Git repositories in a dictionary."""
 
@@ -117,9 +225,9 @@ class DictBackend(Backend):
 class Handler(object):
     """Smart protocol command handler base class."""
 
-    def __init__(self, backend, read, write):
+    def __init__(self, backend, proto):
         self.backend = backend
-        self.proto = Protocol(read, write)
+        self.proto = proto
         self._client_capabilities = None
 
     def capability_line(self):
@@ -158,9 +266,9 @@ class Handler(object):
 class UploadPackHandler(Handler):
     """Protocol handler for uploading a pack to the server."""
 
-    def __init__(self, backend, args, read, write,
+    def __init__(self, backend, args, proto,
                  stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
+        Handler.__init__(self, backend, proto)
         self.repo = backend.open_repository(args[0])
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
@@ -522,9 +630,9 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(Handler):
     """Protocol handler for downloading a pack from the client."""
 
-    def __init__(self, backend, args, read, write,
+    def __init__(self, backend, args, proto,
                  stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
+        Handler.__init__(self, backend, proto)
         self.repo = backend.open_repository(args[0])
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
@@ -532,20 +640,14 @@ class ReceivePackHandler(Handler):
     def capabilities(self):
         return ("report-status", "delete-refs")
 
-    def _apply_pack(self, refs, read):
+    def _apply_pack(self, refs):
         f, commit = self.repo.object_store.add_thin_pack()
         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
         status = []
         unpack_error = None
         # TODO: more informative error messages than just the exception string
         try:
-            # TODO: decode the pack as we stream to avoid blocking reads beyond
-            # the end of data (when using HTTP/1.1 chunked encoding)
-            while True:
-                data = read(10240)
-                if not data:
-                    break
-                f.write(data)
+            PackStreamVerifier(self.proto, f).verify()
         except all_exceptions, e:
             unpack_error = str(e).replace('\n', '')
         try:
@@ -620,7 +722,7 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
 
         # backend can now deal with this refs and read a pack using self.read
-        status = self.repo._apply_pack(client_refs, self.proto.read)
+        status = self._apply_pack(client_refs)
 
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
@@ -649,7 +751,7 @@ class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
         else:
             return
 
-        h = cls(self.server.backend, args, self.rfile.read, self.wfile.write)
+        h = cls(self.server.backend, args, proto)
         h.handle()
 
 
index 2f58a9c..22d329d 100644 (file)
@@ -78,4 +78,5 @@ class GitServerTestCase(ServerTests, CompatTestCase):
         return port
 
     def test_push_to_dulwich(self):
+        # TODO(dborowitz): enable after merging thin pack fixes.
         raise TestSkipped('Skipping push test due to known deadlock bug.')
index e13ed22..e551f5b 100644 (file)
@@ -79,7 +79,7 @@ class TestProto(object):
 class HandlerTestCase(TestCase):
 
     def setUp(self):
-        self._handler = Handler(Backend(), None, None)
+        self._handler = Handler(Backend(), None)
         self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
         self._handler.required_capabilities = lambda: ('cap2',)
 
index 17bf499..c9d7d70 100644 (file)
@@ -153,28 +153,16 @@ class DumbHandlersTestCase(WebTestCase):
 
 class SmartHandlersTestCase(WebTestCase):
 
-    class TestProtocol(object):
-        def __init__(self, handler):
-            self._handler = handler
-
-        def write_pkt_line(self, line):
-            if line is None:
-                self._handler.write('flush-pkt\n')
-            else:
-                self._handler.write('pkt-line: %s' % line)
-
     class _TestUploadPackHandler(object):
-        def __init__(self, backend, args, read, write, stateless_rpc=False,
+        def __init__(self, backend, args, proto, stateless_rpc=False,
                      advertise_refs=False):
             self.args = args
-            self.read = read
-            self.write = write
-            self.proto = SmartHandlersTestCase.TestProtocol(self)
+            self.proto = proto
             self.stateless_rpc = stateless_rpc
             self.advertise_refs = advertise_refs
 
         def handle(self):
-            self.write('handled input: %s' % self.read())
+            self.proto.write('handled input: %s' % self.proto.recv(1024))
 
     def _MakeHandler(self, *args, **kwargs):
         self._handler = self._TestUploadPackHandler(*args, **kwargs)
@@ -222,8 +210,8 @@ class SmartHandlersTestCase(WebTestCase):
         mat = re.search('.*', '/git-upload-pack')
         output = ''.join(get_info_refs(self._req, 'backend', mat,
                                        services=self.services()))
-        self.assertEquals(('pkt-line: # service=git-upload-pack\n'
-                           'flush-pkt\n'
+        self.assertEquals(('001e# service=git-upload-pack\n'
+                           '0000'
                            # input is ignored by the handler
                            'handled input: '), output)
         self.assertTrue(self._handler.advertise_refs)
index 28af518..c075a99 100644 (file)
@@ -26,6 +26,9 @@ try:
     from urlparse import parse_qs
 except ImportError:
     from dulwich.misc import parse_qs
+from dulwich.protocol import (
+    ReceivableProtocol,
+    )
 from dulwich.server import (
     ReceivePackHandler,
     UploadPackHandler,
@@ -138,9 +141,8 @@ def get_info_refs(req, backend, mat, services=None):
         req.nocache()
         req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
         output = StringIO()
-        dummy_input = StringIO()  # GET request, handler doesn't need to read
-        handler = handler_cls(backend, [url_prefix(mat)],
-                              dummy_input.read, output.write,
+        proto = ReceivableProtocol(StringIO().read, output.write)
+        handler = handler_cls(backend, [url_prefix(mat)], proto,
                               stateless_rpc=True, advertise_refs=True)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line(None)
@@ -216,8 +218,8 @@ def handle_service_request(req, backend, mat, services=None):
     # content-length
     if 'CONTENT_LENGTH' in req.environ:
         input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
-    handler = handler_cls(backend, [url_prefix(mat)], input.read, output.write,
-                          stateless_rpc=True)
+    proto = ReceivableProtocol(input.read, output.write)
+    handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
     handler.handle()
     yield output.getvalue()