From a22eb8f17ac34d59bd05a947792098119e654787 Mon Sep 17 00:00:00 2001 From: Dave Borowitz Date: Thu, 11 Mar 2010 14:12:18 -0800 Subject: [PATCH] Make the server decode a pack as it streams. This, in combination with using recv() instead of read(), makes it so we never do blocking reads past the end of the pack stream, even when the client doesn't close the connection. This is done via a PackStreamVerifier class that reads from a Protocol, unpacks and counts objects, writes through to a file, and computes the SHA-1 checksum on the fly. It is necessary because the only way we know when the pack is supposed to end is by parsing the header and reading the correct number of objects; otherwise, any further reads from the client would hang. Changed the Handler constructors to take a Protocol instead of taking read and write callbacks separately. Modified some interfaces to utility functions in pack.py so they can be used by the server-side code. Change-Id: Id4d11e34658d1f00ad06e45330d0d128b367d8e5 --- dulwich/pack.py | 31 ++++--- dulwich/server.py | 134 ++++++++++++++++++++++++---- dulwich/tests/compat/test_server.py | 1 + dulwich/tests/test_server.py | 2 +- dulwich/tests/test_web.py | 22 ++--- dulwich/web.py | 12 +-- 6 files changed, 150 insertions(+), 52 deletions(-) diff --git a/dulwich/pack.py b/dulwich/pack.py index 41fa873..60b0d72 100644 --- a/dulwich/pack.py +++ b/dulwich/pack.py @@ -417,12 +417,12 @@ class PackIndex2(PackIndex): -def read_pack_header(f): +def read_pack_header(read): """Read the header of a pack file. - :param f: File-like object to read from + :param read: Read function """ - header = f.read(12) + header = read(12) assert header[:4] == "PACK" (version,) = unpack_from(">L", header, 4) assert version in (2, 3), "Version was %d" % version @@ -434,20 +434,25 @@ def chunks_length(chunks): return sum(imap(len, chunks)) -def unpack_object(read): +def unpack_object(read_all, read_some=None): """Unpack a Git object. - :return: tuple with type, uncompressed data as chunks, compressed size and - tail data + :param read_all: Read function that blocks until the number of requested + bytes are read. + :param read_some: Read function that returns at least one byte, but may not + return the number of bytes requested. + :return: tuple with type, uncompressed data, compressed size and tail data. """ - bytes = take_msb_bytes(read) + if read_some is None: + read_some = read_all + bytes = take_msb_bytes(read_all) type = (bytes[0] >> 4) & 0x07 size = bytes[0] & 0x0f for i, byte in enumerate(bytes[1:]): size += (byte & 0x7f) << ((i * 7) + 4) raw_base = len(bytes) if type == 6: # offset delta - bytes = take_msb_bytes(read) + bytes = take_msb_bytes(read_all) raw_base += len(bytes) assert not (bytes[-1] & 0x80) delta_base_offset = bytes[0] & 0x7f @@ -455,17 +460,17 @@ def unpack_object(read): delta_base_offset += 1 delta_base_offset <<= 7 delta_base_offset += (byte & 0x7f) - uncomp, comp_len, unused = read_zlib_chunks(read, size) + uncomp, comp_len, unused = read_zlib_chunks(read_some, size) assert size == chunks_length(uncomp) return type, (delta_base_offset, uncomp), comp_len+raw_base, unused elif type == 7: # ref delta - basename = read(20) + basename = read_all(20) raw_base += 20 - uncomp, comp_len, unused = read_zlib_chunks(read, size) + uncomp, comp_len, unused = read_zlib_chunks(read_some, size) assert size == chunks_length(uncomp) return type, (basename, uncomp), comp_len+raw_base, unused else: - uncomp, comp_len, unused = read_zlib_chunks(read, size) + uncomp, comp_len, unused = read_zlib_chunks(read_some, size) assert chunks_length(uncomp) == size return type, uncomp, comp_len+raw_base, unused @@ -522,7 +527,7 @@ class PackData(object): self._file = GitFile(self._filename, 'rb') else: self._file = file - (version, self._num_objects) = read_pack_header(self._file) + (version, self._num_objects) = read_pack_header(self._file.read) self._offset_cache = LRUSizeCache(1024*1024*20, compute_size=_compute_object_size) diff --git a/dulwich/server.py b/dulwich/server.py index 63fb331..c08a2d5 100644 --- a/dulwich/server.py +++ b/dulwich/server.py @@ -27,15 +27,22 @@ Documentation/technical directory in the cgit distribution, and in particular: import collections +from cStringIO import StringIO +import socket import SocketServer +import zlib from dulwich.errors import ( ApplyDeltaError, ChecksumMismatch, GitProtocolError, ) +from dulwich.misc import ( + make_sha, + ) from dulwich.objects import ( hex_to_sha, + sha_to_hex, ) from dulwich.protocol import ( ProtocolFile, @@ -51,6 +58,8 @@ from dulwich.protocol import ( ack_type, ) from dulwich.pack import ( + read_pack_header, + unpack_object, write_pack_data, ) @@ -103,6 +112,105 @@ class BackendRepo(object): raise NotImplementedError +class PackStreamVerifier(object): + """Class to verify a pack stream as it is being read. + + The pack is read from a ReceivableProtocol using read() or recv() as + appropriate and written out to the given file-like object. + """ + + def __init__(self, proto, outfile): + self.proto = proto + self.outfile = outfile + self.sha = make_sha() + self._rbuf = StringIO() + # trailer is a deque to avoid memory allocation on small reads + self._trailer = collections.deque() + + def _read(self, read, size): + """Read up to size bytes using the given callback. + + As a side effect, update the verifier's hash (excluding the last 20 + bytes read) and write through to the output file. + + :param read: The read callback to read from. + :param size: The maximum number of bytes to read; the particular + behavior is callback-specific. + """ + data = read(size) + + # maintain a trailer of the last 20 bytes we've read + n = len(data) + tn = len(self._trailer) + if n >= 20: + to_pop = tn + to_add = 20 + else: + to_pop = max(n + tn - 20, 0) + to_add = n + for _ in xrange(to_pop): + self.sha.update(self._trailer.popleft()) + self._trailer.extend(data[-to_add:]) + + # hash everything but the trailer + self.sha.update(data[:-to_add]) + self.outfile.write(data) + return data + + def _buf_len(self): + buf = self._rbuf + start = buf.tell() + buf.seek(0, 2) + end = buf.tell() + buf.seek(start) + return end - start + + def read(self, size): + """Read, blocking until size bytes are read.""" + buf_len = self._buf_len() + if buf_len >= size: + return self._rbuf.read(size) + buf_data = self._rbuf.read() + self._rbuf = StringIO() + return buf_data + self._read(self.proto.read, size - buf_len) + + def recv(self, size): + """Read up to size bytes, blocking until one byte is read.""" + buf_len = self._buf_len() + if buf_len: + data = self._rbuf.read(size) + if size >= buf_len: + self._rbuf = StringIO() + return data + return self._read(self.proto.recv, size) + + def verify(self): + """Verify a pack stream and write it to the output file. + + :raise AssertionError: if there is an error in the pack format. + :raise ChecksumMismatch: if the checksum of the pack contents does not + match the checksum in the pack trailer. + :raise socket.error: if an error occurred reading from the socket. + :raise zlib.error: if an error occurred during zlib decompression. + :raise IOError: if an error occurred writing to the output file. + """ + _, num_objects = read_pack_header(self.read) + for i in xrange(num_objects): + type, _, _, unused = unpack_object(self.read, self.recv) + + # prepend any unused data to current read buffer + buf = StringIO() + buf.write(unused) + buf.write(self._rbuf.read()) + buf.seek(0) + self._rbuf = buf + + pack_sha = sha_to_hex(''.join([c for c in self._trailer])) + calculated_sha = self.sha.hexdigest() + if pack_sha != calculated_sha: + raise ChecksumMismatch(pack_sha, calculated_sha) + + class DictBackend(Backend): """Trivial backend that looks up Git repositories in a dictionary.""" @@ -117,9 +225,9 @@ class DictBackend(Backend): class Handler(object): """Smart protocol command handler base class.""" - def __init__(self, backend, read, write): + def __init__(self, backend, proto): self.backend = backend - self.proto = Protocol(read, write) + self.proto = proto self._client_capabilities = None def capability_line(self): @@ -158,9 +266,9 @@ class Handler(object): class UploadPackHandler(Handler): """Protocol handler for uploading a pack to the server.""" - def __init__(self, backend, args, read, write, + def __init__(self, backend, args, proto, stateless_rpc=False, advertise_refs=False): - Handler.__init__(self, backend, read, write) + Handler.__init__(self, backend, proto) self.repo = backend.open_repository(args[0]) self._graph_walker = None self.stateless_rpc = stateless_rpc @@ -522,9 +630,9 @@ class MultiAckDetailedGraphWalkerImpl(object): class ReceivePackHandler(Handler): """Protocol handler for downloading a pack from the client.""" - def __init__(self, backend, args, read, write, + def __init__(self, backend, args, proto, stateless_rpc=False, advertise_refs=False): - Handler.__init__(self, backend, read, write) + Handler.__init__(self, backend, proto) self.repo = backend.open_repository(args[0]) self.stateless_rpc = stateless_rpc self.advertise_refs = advertise_refs @@ -532,20 +640,14 @@ class ReceivePackHandler(Handler): def capabilities(self): return ("report-status", "delete-refs") - def _apply_pack(self, refs, read): + def _apply_pack(self, refs): f, commit = self.repo.object_store.add_thin_pack() all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError) status = [] unpack_error = None # TODO: more informative error messages than just the exception string try: - # TODO: decode the pack as we stream to avoid blocking reads beyond - # the end of data (when using HTTP/1.1 chunked encoding) - while True: - data = read(10240) - if not data: - break - f.write(data) + PackStreamVerifier(self.proto, f).verify() except all_exceptions, e: unpack_error = str(e).replace('\n', '') try: @@ -620,7 +722,7 @@ class ReceivePackHandler(Handler): ref = self.proto.read_pkt_line() # backend can now deal with this refs and read a pack using self.read - status = self.repo._apply_pack(client_refs, self.proto.read) + status = self._apply_pack(client_refs) # when we have read all the pack from the client, send a status report # if the client asked for it @@ -649,7 +751,7 @@ class TCPGitRequestHandler(SocketServer.StreamRequestHandler): else: return - h = cls(self.server.backend, args, self.rfile.read, self.wfile.write) + h = cls(self.server.backend, args, proto) h.handle() diff --git a/dulwich/tests/compat/test_server.py b/dulwich/tests/compat/test_server.py index 2f58a9c..22d329d 100644 --- a/dulwich/tests/compat/test_server.py +++ b/dulwich/tests/compat/test_server.py @@ -78,4 +78,5 @@ class GitServerTestCase(ServerTests, CompatTestCase): return port def test_push_to_dulwich(self): + # TODO(dborowitz): enable after merging thin pack fixes. raise TestSkipped('Skipping push test due to known deadlock bug.') diff --git a/dulwich/tests/test_server.py b/dulwich/tests/test_server.py index e13ed22..e551f5b 100644 --- a/dulwich/tests/test_server.py +++ b/dulwich/tests/test_server.py @@ -79,7 +79,7 @@ class TestProto(object): class HandlerTestCase(TestCase): def setUp(self): - self._handler = Handler(Backend(), None, None) + self._handler = Handler(Backend(), None) self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3') self._handler.required_capabilities = lambda: ('cap2',) diff --git a/dulwich/tests/test_web.py b/dulwich/tests/test_web.py index 17bf499..c9d7d70 100644 --- a/dulwich/tests/test_web.py +++ b/dulwich/tests/test_web.py @@ -153,28 +153,16 @@ class DumbHandlersTestCase(WebTestCase): class SmartHandlersTestCase(WebTestCase): - class TestProtocol(object): - def __init__(self, handler): - self._handler = handler - - def write_pkt_line(self, line): - if line is None: - self._handler.write('flush-pkt\n') - else: - self._handler.write('pkt-line: %s' % line) - class _TestUploadPackHandler(object): - def __init__(self, backend, args, read, write, stateless_rpc=False, + def __init__(self, backend, args, proto, stateless_rpc=False, advertise_refs=False): self.args = args - self.read = read - self.write = write - self.proto = SmartHandlersTestCase.TestProtocol(self) + self.proto = proto self.stateless_rpc = stateless_rpc self.advertise_refs = advertise_refs def handle(self): - self.write('handled input: %s' % self.read()) + self.proto.write('handled input: %s' % self.proto.recv(1024)) def _MakeHandler(self, *args, **kwargs): self._handler = self._TestUploadPackHandler(*args, **kwargs) @@ -222,8 +210,8 @@ class SmartHandlersTestCase(WebTestCase): mat = re.search('.*', '/git-upload-pack') output = ''.join(get_info_refs(self._req, 'backend', mat, services=self.services())) - self.assertEquals(('pkt-line: # service=git-upload-pack\n' - 'flush-pkt\n' + self.assertEquals(('001e# service=git-upload-pack\n' + '0000' # input is ignored by the handler 'handled input: '), output) self.assertTrue(self._handler.advertise_refs) diff --git a/dulwich/web.py b/dulwich/web.py index 28af518..c075a99 100644 --- a/dulwich/web.py +++ b/dulwich/web.py @@ -26,6 +26,9 @@ try: from urlparse import parse_qs except ImportError: from dulwich.misc import parse_qs +from dulwich.protocol import ( + ReceivableProtocol, + ) from dulwich.server import ( ReceivePackHandler, UploadPackHandler, @@ -138,9 +141,8 @@ def get_info_refs(req, backend, mat, services=None): req.nocache() req.respond(HTTP_OK, 'application/x-%s-advertisement' % service) output = StringIO() - dummy_input = StringIO() # GET request, handler doesn't need to read - handler = handler_cls(backend, [url_prefix(mat)], - dummy_input.read, output.write, + proto = ReceivableProtocol(StringIO().read, output.write) + handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True, advertise_refs=True) handler.proto.write_pkt_line('# service=%s\n' % service) handler.proto.write_pkt_line(None) @@ -216,8 +218,8 @@ def handle_service_request(req, backend, mat, services=None): # content-length if 'CONTENT_LENGTH' in req.environ: input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH'])) - handler = handler_cls(backend, [url_prefix(mat)], input.read, output.write, - stateless_rpc=True) + proto = ReceivableProtocol(input.read, output.write) + handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True) handler.handle() yield output.getvalue() -- 2.34.1