Raise sha error if necessary, always return refs, fix docs.
[jelmer/dulwich-libgit2.git] / dulwich / client.py
1 # client.py -- Implementation of the server side git protocols
2 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
3 # Copyright (C) 2008 John Carr
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # or (at your option) a later version of the License.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 """Client side support for the Git protocol."""
21
22 __docformat__ = 'restructuredText'
23
24 import os
25 import select
26 import socket
27 import subprocess
28
29 from dulwich.errors import (
30     ChecksumMismatch,
31     )
32 from dulwich.protocol import (
33     Protocol,
34     TCP_GIT_PORT,
35     extract_capabilities,
36     )
37 from dulwich.pack import (
38     write_pack_data,
39     )
40
41
42 def _fileno_can_read(fileno):
43     """Check if a file descriptor is readable."""
44     return len(select.select([fileno], [], [], 0)[0]) > 0
45
46
47 class SimpleFetchGraphWalker(object):
48     """Graph walker that finds out what commits are missing."""
49
50     def __init__(self, local_heads, get_parents):
51         """Create a new SimpleFetchGraphWalker instance.
52
53         :param local_heads: SHA1s that should be retrieved
54         :param get_parents: Function for finding the parents of a SHA1.
55         """
56         self.heads = set(local_heads)
57         self.get_parents = get_parents
58         self.parents = {}
59
60     def ack(self, sha):
61         """Ack that a particular revision and its ancestors are present in the target."""
62         if sha in self.heads:
63             self.heads.remove(sha)
64         if sha in self.parents:
65             for p in self.parents[sha]:
66                 self.ack(p)
67
68     def next(self):
69         """Iterate over revisions that might be missing in the target."""
70         if self.heads:
71             ret = self.heads.pop()
72             ps = self.get_parents(ret)
73             self.parents[ret] = ps
74             self.heads.update(ps)
75             return ret
76         return None
77
78
79 CAPABILITIES = ["multi_ack", "side-band-64k", "ofs-delta"]
80
81
82 class GitClient(object):
83     """Git smart server client.
84
85     """
86
87     def __init__(self, can_read, read, write, thin_packs=True, 
88         report_activity=None):
89         """Create a new GitClient instance.
90
91         :param can_read: Function that returns True if there is data available
92             to be read.
93         :param read: Callback for reading data, takes number of bytes to read
94         :param write: Callback for writing data
95         :param thin_packs: Whether or not thin packs should be retrieved
96         :param report_activity: Optional callback for reporting transport
97             activity.
98         """
99         self.proto = Protocol(read, write, report_activity)
100         self._can_read = can_read
101         self._capabilities = list(CAPABILITIES)
102         if thin_packs:
103             self._capabilities.append("thin-pack")
104
105     def capabilities(self):
106         return " ".join(self._capabilities)
107
108     def read_refs(self):
109         server_capabilities = None
110         refs = {}
111         # Receive refs from server
112         for pkt in self.proto.read_pkt_seq():
113             (sha, ref) = pkt.rstrip("\n").split(" ", 1)
114             if server_capabilities is None:
115                 (ref, server_capabilities) = extract_capabilities(ref)
116             refs[ref] = sha
117         return refs, server_capabilities
118
119     def send_pack(self, path, determine_wants, generate_pack_contents):
120         """Upload a pack to a remote repository.
121
122         :param path: Repository path
123         :param generate_pack_contents: Function that can return the shas of the 
124             objects to upload.
125         """
126         refs, server_capabilities = self.read_refs()
127         changed_refs = determine_wants(refs)
128         if not changed_refs:
129             self.proto.write_pkt_line(None)
130             return {}
131         want = []
132         have = []
133         sent_capabilities = False
134         for changed_ref, new_sha1 in changed_refs.iteritems():
135             old_sha1 = refs.get(changed_ref, "0" * 40)
136             if sent_capabilities:
137                 self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1, changed_ref))
138             else:
139                 self.proto.write_pkt_line("%s %s %s\0%s" % (old_sha1, new_sha1, changed_ref, self.capabilities()))
140                 sent_capabilities = True
141             want.append(new_sha1)
142             if old_sha1 != "0"*40:
143                 have.append(old_sha1)
144         self.proto.write_pkt_line(None)
145         objects = generate_pack_contents(want, have)
146         (entries, sha) = write_pack_data(self.proto.write_file(), objects, len(objects))
147         self.proto.write(sha)
148         
149         # read the final confirmation sha
150         client_sha = self.proto.read(20)
151         if not client_sha in (None, sha)
152             raise ChecksumMismatch(sha, client_sha)
153             
154         return changed_refs
155
156     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
157                    progress):
158         """Retrieve a pack from a git smart server.
159
160         :param determine_wants: Callback that returns list of commits to fetch
161         :param graph_walker: Object with next() and ack().
162         :param pack_data: Callback called for each bit of data in the pack
163         :param progress: Callback for progress reports (strings)
164         """
165         (refs, server_capabilities) = self.read_refs()
166         wants = determine_wants(refs)
167         if not wants:
168             self.proto.write_pkt_line(None)
169             return refs
170         self.proto.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities()))
171         for want in wants[1:]:
172             self.proto.write_pkt_line("want %s\n" % want)
173         self.proto.write_pkt_line(None)
174         have = graph_walker.next()
175         while have:
176             self.proto.write_pkt_line("have %s\n" % have)
177             if self._can_read():
178                 pkt = self.proto.read_pkt_line()
179                 parts = pkt.rstrip("\n").split(" ")
180                 if parts[0] == "ACK":
181                     graph_walker.ack(parts[1])
182                     assert parts[2] == "continue"
183             have = graph_walker.next()
184         self.proto.write_pkt_line("done\n")
185         pkt = self.proto.read_pkt_line()
186         while pkt:
187             parts = pkt.rstrip("\n").split(" ")
188             if parts[0] == "ACK":
189                 graph_walker.ack(pkt.split(" ")[1])
190             if len(parts) < 3 or parts[2] != "continue":
191                 break
192             pkt = self.proto.read_pkt_line()
193         for pkt in self.proto.read_pkt_seq():
194             channel = ord(pkt[0])
195             pkt = pkt[1:]
196             if channel == 1:
197                 pack_data(pkt)
198             elif channel == 2:
199                 progress(pkt)
200             else:
201                 raise AssertionError("Invalid sideband channel %d" % channel)
202         return refs
203
204
205 class TCPGitClient(GitClient):
206     """A Git Client that works over TCP directly (i.e. git://)."""
207
208     def __init__(self, host, port=None, *args, **kwargs):
209         self._socket = socket.socket(type=socket.SOCK_STREAM)
210         if port is None:
211             port = TCP_GIT_PORT
212         self._socket.connect((host, port))
213         self.rfile = self._socket.makefile('rb', -1)
214         self.wfile = self._socket.makefile('wb', 0)
215         self.host = host
216         super(TCPGitClient, self).__init__(lambda: _fileno_can_read(self._socket.fileno()), self.rfile.read, self.wfile.write, *args, **kwargs)
217
218     def send_pack(self, path, changed_refs, generate_pack_contents):
219         """Send a pack to a remote host.
220
221         :param path: Path of the repository on the remote host
222         """
223         self.proto.send_cmd("git-receive-pack", path, "host=%s" % self.host)
224         return super(TCPGitClient, self).send_pack(path, changed_refs, generate_pack_contents)
225
226     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
227         """Fetch a pack from the remote host.
228         
229         :param path: Path of the reposiutory on the remote host
230         :param determine_wants: Callback that receives available refs dict and 
231             should return list of sha's to fetch.
232         :param graph_walker: GraphWalker instance used to find missing shas
233         :param pack_data: Callback for writing pack data
234         :param progress: Callback for writing progress
235         """
236         self.proto.send_cmd("git-upload-pack", path, "host=%s" % self.host)
237         return super(TCPGitClient, self).fetch_pack(path, determine_wants,
238             graph_walker, pack_data, progress)
239
240
241 class SubprocessGitClient(GitClient):
242     """Git client that talks to a server using a subprocess."""
243
244     def __init__(self, *args, **kwargs):
245         self.proc = None
246         self._args = args
247         self._kwargs = kwargs
248
249     def _connect(self, service, *args, **kwargs):
250         argv = [service] + list(args)
251         self.proc = subprocess.Popen(argv, bufsize=0,
252                                 stdin=subprocess.PIPE,
253                                 stdout=subprocess.PIPE)
254         def read_fn(size):
255             return self.proc.stdout.read(size)
256         def write_fn(data):
257             self.proc.stdin.write(data)
258             self.proc.stdin.flush()
259         return GitClient(lambda: _fileno_can_read(self.proc.stdout.fileno()), read_fn, write_fn, *args, **kwargs)
260
261     def send_pack(self, path, changed_refs, generate_pack_contents):
262         """Upload a pack to the server.
263
264         :param path: Path to the git repository on the server
265         :param changed_refs: Dictionary with new values for the refs
266         :param generate_pack_contents: Function that returns an iterator over 
267             objects to send
268         """
269         client = self._connect("git-receive-pack", path)
270         return client.send_pack(path, changed_refs, generate_pack_contents)
271
272     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, 
273         progress):
274         """Retrieve a pack from the server
275
276         :param path: Path to the git repository on the server
277         :param determine_wants: Function that receives existing refs 
278             on the server and returns a list of desired shas
279         :param graph_walker: GraphWalker instance
280         :param pack_data: Function that can write pack data
281         :param progress: Function that can write progress texts
282         """
283         client = self._connect("git-upload-pack", path)
284         return client.fetch_pack(path, determine_wants, graph_walker, pack_data,
285                                  progress)
286
287
288 class SSHSubprocess(object):
289     """A socket-like object that talks to an ssh subprocess via pipes."""
290
291     def __init__(self, proc):
292         self.proc = proc
293
294     def send(self, data):
295         return os.write(self.proc.stdin.fileno(), data)
296
297     def recv(self, count):
298         return self.proc.stdout.read(count)
299
300     def close(self):
301         self.proc.stdin.close()
302         self.proc.stdout.close()
303         self.proc.wait()
304
305
306 class SSHVendor(object):
307
308     def connect_ssh(self, host, command, username=None, port=None):
309         #FIXME: This has no way to deal with passwords..
310         args = ['ssh', '-x']
311         if port is not None:
312             args.extend(['-p', str(port)])
313         if username is not None:
314             host = "%s@%s" % (username, host)
315         args.append(host)
316         proc = subprocess.Popen(args + command,
317                                 stdin=subprocess.PIPE,
318                                 stdout=subprocess.PIPE)
319         return SSHSubprocess(proc)
320
321 # Can be overridden by users
322 get_ssh_vendor = SSHVendor
323
324
325 class SSHGitClient(GitClient):
326
327     def __init__(self, host, port=None, *args, **kwargs):
328         self.host = host
329         self.port = port
330         self._args = args
331         self._kwargs = kwargs
332
333     def send_pack(self, path, determine_wants, generate_pack_contents):
334         remote = get_ssh_vendor().connect_ssh(self.host, ["git-receive-pack %s" % path], port=self.port)
335         client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs)
336         return client.send_pack(path, determine_wants, generate_pack_contents)
337
338     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
339         progress):
340         remote = get_ssh_vendor().connect_ssh(self.host, ["git-upload-pack %s" % path], port=self.port)
341         client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs)
342         return client.fetch_pack(path, determine_wants, graph_walker, pack_data,
343                                  progress)
344