Factor out a function to convert a line to a pkt-line.
[jelmer/dulwich-libgit2.git] / dulwich / protocol.py
1 # protocol.py -- Shared parts of the git protocols
2 # Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # or (at your option) any later version of the License.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 """Generic functions for talking the git smart server protocol."""
21
22 from cStringIO import StringIO
23 import os
24 import socket
25
26 from dulwich.errors import (
27     HangupException,
28     GitProtocolError,
29     )
30 from dulwich.misc import (
31     SEEK_END,
32     )
33
34 TCP_GIT_PORT = 9418
35
36 ZERO_SHA = "0" * 40
37
38 SINGLE_ACK = 0
39 MULTI_ACK = 1
40 MULTI_ACK_DETAILED = 2
41
42
43 class ProtocolFile(object):
44     """
45     Some network ops are like file ops. The file ops expect to operate on
46     file objects, so provide them with a dummy file.
47     """
48
49     def __init__(self, read, write):
50         self.read = read
51         self.write = write
52
53     def tell(self):
54         pass
55
56     def close(self):
57         pass
58
59
60 def pkt_line(data):
61     """Wrap data in a pkt-line.
62
63     :param data: The data to wrap, as a str or None.
64     :return: The data prefixed with its length in pkt-line format; if data was
65         None, returns the flush-pkt ('0000')
66     """
67     if data is None:
68         return '0000'
69     return '%04x%s' % (len(data) + 4, data)
70
71
72 class Protocol(object):
73
74     def __init__(self, read, write, report_activity=None):
75         self.read = read
76         self.write = write
77         self.report_activity = report_activity
78
79     def read_pkt_line(self):
80         """
81         Reads a 'pkt line' from the remote git process
82
83         :return: The next string from the stream
84         """
85         try:
86             sizestr = self.read(4)
87             if not sizestr:
88                 raise HangupException()
89             size = int(sizestr, 16)
90             if size == 0:
91                 if self.report_activity:
92                     self.report_activity(4, 'read')
93                 return None
94             if self.report_activity:
95                 self.report_activity(size, 'read')
96             return self.read(size-4)
97         except socket.error, e:
98             raise GitProtocolError(e)
99
100     def read_pkt_seq(self):
101         pkt = self.read_pkt_line()
102         while pkt:
103             yield pkt
104             pkt = self.read_pkt_line()
105
106     def write_pkt_line(self, line):
107         """
108         Sends a 'pkt line' to the remote git process
109
110         :param line: A string containing the data to send
111         """
112         try:
113             line = pkt_line(line)
114             self.write(line)
115             if self.report_activity:
116                 self.report_activity(len(line), 'write')
117         except socket.error, e:
118             raise GitProtocolError(e)
119
120     def write_file(self):
121         class ProtocolFile(object):
122
123             def __init__(self, proto):
124                 self._proto = proto
125                 self._offset = 0
126
127             def write(self, data):
128                 self._proto.write(data)
129                 self._offset += len(data)
130
131             def tell(self):
132                 return self._offset
133
134             def close(self):
135                 pass
136
137         return ProtocolFile(self)
138
139     def write_sideband(self, channel, blob):
140         """
141         Write data to the sideband (a git multiplexing method)
142
143         :param channel: int specifying which channel to write to
144         :param blob: a blob of data (as a string) to send on this channel
145         """
146         # a pktline can be a max of 65520. a sideband line can therefore be
147         # 65520-5 = 65515
148         # WTF: Why have the len in ASCII, but the channel in binary.
149         while blob:
150             self.write_pkt_line("%s%s" % (chr(channel), blob[:65515]))
151             blob = blob[65515:]
152
153     def send_cmd(self, cmd, *args):
154         """
155         Send a command and some arguments to a git server
156
157         Only used for git://
158
159         :param cmd: The remote service to access
160         :param args: List of arguments to send to remove service
161         """
162         self.write_pkt_line("%s %s" % (cmd, "".join(["%s\0" % a for a in args])))
163
164     def read_cmd(self):
165         """
166         Read a command and some arguments from the git client
167
168         Only used for git://
169
170         :return: A tuple of (command, [list of arguments])
171         """
172         line = self.read_pkt_line()
173         splice_at = line.find(" ")
174         cmd, args = line[:splice_at], line[splice_at+1:]
175         assert args[-1] == "\x00"
176         return cmd, args[:-1].split(chr(0))
177
178
179 _RBUFSIZE = 8192  # Default read buffer size.
180
181
182 class ReceivableProtocol(Protocol):
183     """Variant of Protocol that allows reading up to a size without blocking.
184
185     This class has a recv() method that behaves like socket.recv() in addition
186     to a read() method.
187
188     If you want to read n bytes from the wire and block until exactly n bytes
189     (or EOF) are read, use read(n). If you want to read at most n bytes from the
190     wire but don't care if you get less, use recv(n). Note that recv(n) will
191     still block until at least one byte is read.
192     """
193
194     def __init__(self, recv, write, report_activity=None, rbufsize=_RBUFSIZE):
195         super(ReceivableProtocol, self).__init__(self.read, write,
196                                                  report_activity)
197         self._recv = recv
198         self._rbuf = StringIO()
199         self._rbufsize = rbufsize
200
201     def read(self, size):
202         # From _fileobj.read in socket.py in the Python 2.6.5 standard library,
203         # with the following modifications:
204         #  - omit the size <= 0 branch
205         #  - seek back to start rather than 0 in case some buffer has been
206         #    consumed.
207         #  - use SEEK_END instead of the magic number.
208         # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
209         # Licensed under the Python Software Foundation License.
210         # TODO: see if buffer is more efficient than cStringIO.
211         assert size > 0
212
213         # Our use of StringIO rather than lists of string objects returned by
214         # recv() minimizes memory usage and fragmentation that occurs when
215         # rbufsize is large compared to the typical return value of recv().
216         buf = self._rbuf
217         start = buf.tell()
218         buf.seek(0, SEEK_END)
219         # buffer may have been partially consumed by recv()
220         buf_len = buf.tell() - start
221         if buf_len >= size:
222             # Already have size bytes in our buffer?  Extract and return.
223             buf.seek(start)
224             rv = buf.read(size)
225             self._rbuf = StringIO()
226             self._rbuf.write(buf.read())
227             self._rbuf.seek(0)
228             return rv
229
230         self._rbuf = StringIO()  # reset _rbuf.  we consume it via buf.
231         while True:
232             left = size - buf_len
233             # recv() will malloc the amount of memory given as its
234             # parameter even though it often returns much less data
235             # than that.  The returned data string is short lived
236             # as we copy it into a StringIO and free it.  This avoids
237             # fragmentation issues on many platforms.
238             data = self._recv(left)
239             if not data:
240                 break
241             n = len(data)
242             if n == size and not buf_len:
243                 # Shortcut.  Avoid buffer data copies when:
244                 # - We have no data in our buffer.
245                 # AND
246                 # - Our call to recv returned exactly the
247                 #   number of bytes we were asked to read.
248                 return data
249             if n == left:
250                 buf.write(data)
251                 del data  # explicit free
252                 break
253             assert n <= left, "_recv(%d) returned %d bytes" % (left, n)
254             buf.write(data)
255             buf_len += n
256             del data  # explicit free
257             #assert buf_len == buf.tell()
258         buf.seek(start)
259         return buf.read()
260
261     def recv(self, size):
262         assert size > 0
263
264         buf = self._rbuf
265         start = buf.tell()
266         buf.seek(0, SEEK_END)
267         buf_len = buf.tell()
268         buf.seek(start)
269
270         left = buf_len - start
271         if not left:
272             # only read from the wire if our read buffer is exhausted
273             data = self._recv(self._rbufsize)
274             if len(data) == size:
275                 # shortcut: skip the buffer if we read exactly size bytes
276                 return data
277             buf = StringIO()
278             buf.write(data)
279             buf.seek(0)
280             del data  # explicit free
281             self._rbuf = buf
282         return buf.read(size)
283
284
285 def extract_capabilities(text):
286     """Extract a capabilities list from a string, if present.
287
288     :param text: String to extract from
289     :return: Tuple with text with capabilities removed and list of capabilities
290     """
291     if not "\0" in text:
292         return text, []
293     text, capabilities = text.rstrip().split("\0")
294     return (text, capabilities.strip().split(" "))
295
296
297 def extract_want_line_capabilities(text):
298     """Extract a capabilities list from a want line, if present.
299
300     Note that want lines have capabilities separated from the rest of the line
301     by a space instead of a null byte. Thus want lines have the form:
302
303         want obj-id cap1 cap2 ...
304
305     :param text: Want line to extract from
306     :return: Tuple with text with capabilities removed and list of capabilities
307     """
308     split_text = text.rstrip().split(" ")
309     if len(split_text) < 3:
310         return text, []
311     return (" ".join(split_text[:2]), split_text[2:])
312
313
314 def ack_type(capabilities):
315     """Extract the ack type from a capabilities list."""
316     if 'multi_ack_detailed' in capabilities:
317         return MULTI_ACK_DETAILED
318     elif 'multi_ack' in capabilities:
319         return MULTI_ACK
320     return SINGLE_ACK