Allow opening pack objects from memory.
[jelmer/dulwich-libgit2.git] / dulwich / protocol.py
index 36f2d2dd0db31406ab368d644e0b805a665b8647..eabe486d30d4dedaba8bb468756c05ddaae54c72 100644 (file)
@@ -1,10 +1,11 @@
 # protocol.py -- Shared parts of the git protocols
 # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
+# Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
 #
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
-# of the License.
+# or (at your option) any later version of the License.
 #
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
 
+"""Generic functions for talking the git smart server protocol."""
 
-class Protocol(object):
+import socket
+
+from dulwich.errors import (
+    HangupException,
+    GitProtocolError,
+    )
+
+TCP_GIT_PORT = 9418
+
+SINGLE_ACK = 0
+MULTI_ACK = 1
+MULTI_ACK_DETAILED = 2
+
+class ProtocolFile(object):
+    """
+    Some network ops are like file ops. The file ops expect to operate on
+    file objects, so provide them with a dummy file.
+    """
 
     def __init__(self, read, write):
         self.read = read
         self.write = write
 
+    def tell(self):
+        pass
+
+    def close(self):
+        pass
+
+
+class Protocol(object):
+
+    def __init__(self, read, write, report_activity=None):
+        self.read = read
+        self.write = write
+        self.report_activity = report_activity
+
     def read_pkt_line(self):
         """
         Reads a 'pkt line' from the remote git process
 
         :return: The next string from the stream
         """
-        sizestr = self.read(4)
-        if not sizestr:
-            return None
-        size = int(sizestr, 16)
-        if size == 0:
-            return None
-        return self.read(size-4)
+        try:
+            sizestr = self.read(4)
+            if not sizestr:
+                raise HangupException()
+            size = int(sizestr, 16)
+            if size == 0:
+                if self.report_activity:
+                    self.report_activity(4, 'read')
+                return None
+            if self.report_activity:
+                self.report_activity(size, 'read')
+            return self.read(size-4)
+        except socket.error, e:
+            raise GitProtocolError(e)
 
     def read_pkt_seq(self):
         pkt = self.read_pkt_line()
@@ -49,10 +89,36 @@ class Protocol(object):
 
         :param line: A string containing the data to send
         """
-        if line is None:
-            self.write("0000")
-        else:
-            self.write("%04x%s" % (len(line)+4, line))
+        try:
+            if line is None:
+                self.write("0000")
+                if self.report_activity:
+                    self.report_activity(4, 'write')
+            else:
+                self.write("%04x%s" % (len(line)+4, line))
+                if self.report_activity:
+                    self.report_activity(4+len(line), 'write')
+        except socket.error, e:
+            raise GitProtocolError(e)
+
+    def write_file(self):
+        class ProtocolFile(object):
+
+            def __init__(self, proto):
+                self._proto = proto
+                self._offset = 0
+
+            def write(self, data):
+                self._proto.write(data)
+                self._offset += len(data)
+
+            def tell(self):
+                return self._offset
+
+            def close(self):
+                pass
+
+        return ProtocolFile(self)
 
     def write_sideband(self, channel, blob):
         """
@@ -61,11 +127,72 @@ class Protocol(object):
         :param channel: int specifying which channel to write to
         :param blob: a blob of data (as a string) to send on this channel
         """
-        # a pktline can be a max of 65535. a sideband line can therefore be
-        # 65535-5 = 65530
+        # a pktline can be a max of 65520. a sideband line can therefore be
+        # 65520-5 = 65515
         # WTF: Why have the len in ASCII, but the channel in binary.
         while blob:
-            self.write_pkt_line("%s%s" % (chr(channel), blob[:65530]))
-            blob = blob[65530:]
+            self.write_pkt_line("%s%s" % (chr(channel), blob[:65515]))
+            blob = blob[65515:]
+
+    def send_cmd(self, cmd, *args):
+        """
+        Send a command and some arguments to a git server
+
+        Only used for git://
+
+        :param cmd: The remote service to access
+        :param args: List of arguments to send to remove service
+        """
+        self.write_pkt_line("%s %s" % (cmd, "".join(["%s\0" % a for a in args])))
+
+    def read_cmd(self):
+        """
+        Read a command and some arguments from the git client
+
+        Only used for git://
+
+        :return: A tuple of (command, [list of arguments])
+        """
+        line = self.read_pkt_line()
+        splice_at = line.find(" ")
+        cmd, args = line[:splice_at], line[splice_at+1:]
+        assert args[-1] == "\x00"
+        return cmd, args[:-1].split(chr(0))
+
+
+def extract_capabilities(text):
+    """Extract a capabilities list from a string, if present.
+
+    :param text: String to extract from
+    :return: Tuple with text with capabilities removed and list of capabilities
+    """
+    if not "\0" in text:
+        return text, []
+    text, capabilities = text.rstrip().split("\0")
+    return (text, capabilities.split(" "))
+
+
+def extract_want_line_capabilities(text):
+    """Extract a capabilities list from a want line, if present.
+
+    Note that want lines have capabilities separated from the rest of the line
+    by a space instead of a null byte. Thus want lines have the form:
+
+        want obj-id cap1 cap2 ...
+
+    :param text: Want line to extract from
+    :return: Tuple with text with capabilities removed and list of capabilities
+    """
+    split_text = text.rstrip().split(" ")
+    if len(split_text) < 3:
+        return text, []
+    return (" ".join(split_text[:2]), split_text[2:])
 
 
+def ack_type(capabilities):
+    """Extract the ack type from a capabilities list."""
+    if 'multi_ack_detailed' in capabilities:
+      return MULTI_ACK_DETAILED
+    elif 'multi_ack' in capabilities:
+        return MULTI_ACK
+    return SINGLE_ACK