91d6e8fff51f559a9e7c5414294cd21782240a78
[jelmer/dulwich-libgit2.git] / dulwich / client.py
1 # client.py -- Implementation of the server side git protocols
2 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
3 # Copyright (C) 2008 John Carr
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # or (at your option) a later version of the License.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 """Client side support for the Git protocol."""
21
22 __docformat__ = 'restructuredText'
23
24 import os
25 import select
26 import socket
27 import subprocess
28
29 from dulwich.protocol import (
30     Protocol,
31     TCP_GIT_PORT,
32     extract_capabilities,
33     )
34 from dulwich.pack import (
35     write_pack_data,
36     )
37
38
39 def _fileno_can_read(fileno):
40     return len(select.select([fileno], [], [], 0)[0]) > 0
41
42
43 class SimpleFetchGraphWalker(object):
44
45     def __init__(self, local_heads, get_parents):
46         self.heads = set(local_heads)
47         self.get_parents = get_parents
48         self.parents = {}
49
50     def ack(self, ref):
51         if ref in self.heads:
52             self.heads.remove(ref)
53         if ref in self.parents:
54             for p in self.parents[ref]:
55                 self.ack(p)
56
57     def next(self):
58         if self.heads:
59             ret = self.heads.pop()
60             ps = self.get_parents(ret)
61             self.parents[ret] = ps
62             self.heads.update(ps)
63             return ret
64         return None
65
66
67 CAPABILITIES = ["multi_ack", "side-band-64k", "ofs-delta"]
68
69
70 class GitClient(object):
71     """Git smart server client.
72
73     """
74
75     def __init__(self, can_read, read, write, thin_packs=True, 
76         report_activity=None):
77         """Create a new GitClient instance.
78
79         :param can_read: Function that returns True if there is data available
80             to be read.
81         :param read: Callback for reading data, takes number of bytes to read
82         :param write: Callback for writing data
83         :param thin_packs: Whether or not thin packs should be retrieved
84         :param report_activity: Optional callback for reporting transport
85             activity.
86         """
87         self.proto = Protocol(read, write, report_activity)
88         self._can_read = can_read
89         self._capabilities = list(CAPABILITIES)
90         if thin_packs:
91             self._capabilities.append("thin-pack")
92
93     def capabilities(self):
94         return " ".join(self._capabilities)
95
96     def read_refs(self):
97         server_capabilities = None
98         refs = {}
99         # Receive refs from server
100         for pkt in self.proto.read_pkt_seq():
101             (sha, ref) = pkt.rstrip("\n").split(" ", 1)
102             if server_capabilities is None:
103                 (ref, server_capabilities) = extract_capabilities(ref)
104             refs[ref] = sha
105         return refs, server_capabilities
106
107     def send_pack(self, path, generate_pack_contents):
108         """Upload a pack to a remote repository.
109
110         :param path: Repository path
111         :param generate_pack_contents: Function that can return the shas of the 
112             objects to upload.
113         """
114         refs, server_capabilities = self.read_refs()
115         changed_refs = [] # FIXME
116         if not changed_refs:
117             self.proto.write_pkt_line(None)
118             return
119         self.proto.write_pkt_line("%s %s %s\0%s" % (changed_refs[0][0], changed_refs[0][1], changed_refs[0][2], self.capabilities()))
120         want = []
121         have = []
122         for changed_ref in changed_refs[:]:
123             self.proto.write_pkt_line("%s %s %s" % changed_refs)
124             want.append(changed_refs[1])
125             if changed_refs[0] != "0"*40:
126                 have.append(changed_refs[0])
127         self.proto.write_pkt_line(None)
128         shas = generate_pack_contents(want, have, None)
129         write_pack_data(self.write, shas, len(shas))
130
131     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
132         """Retrieve a pack from a git smart server.
133
134         :param determine_wants: Callback that returns list of commits to fetch
135         :param graph_walker: Object with next() and ack().
136         :param pack_data: Callback called for each bit of data in the pack
137         :param progress: Callback for progress reports (strings)
138         """
139         (refs, server_capabilities) = self.read_refs()
140         wants = determine_wants(refs)
141         if not wants:
142             self.proto.write_pkt_line(None)
143             return
144         self.proto.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities()))
145         for want in wants[1:]:
146             self.proto.write_pkt_line("want %s\n" % want)
147         self.proto.write_pkt_line(None)
148         have = graph_walker.next()
149         while have:
150             self.proto.write_pkt_line("have %s\n" % have)
151             if self._can_read():
152                 pkt = self.proto.read_pkt_line()
153                 parts = pkt.rstrip("\n").split(" ")
154                 if parts[0] == "ACK":
155                     graph_walker.ack(parts[1])
156                     assert parts[2] == "continue"
157             have = graph_walker.next()
158         self.proto.write_pkt_line("done\n")
159         pkt = self.proto.read_pkt_line()
160         while pkt:
161             parts = pkt.rstrip("\n").split(" ")
162             if parts[0] == "ACK":
163                 graph_walker.ack(pkt.split(" ")[1])
164             if len(parts) < 3 or parts[2] != "continue":
165                 break
166             pkt = self.proto.read_pkt_line()
167         for pkt in self.proto.read_pkt_seq():
168             channel = ord(pkt[0])
169             pkt = pkt[1:]
170             if channel == 1:
171                 pack_data(pkt)
172             elif channel == 2:
173                 progress(pkt)
174             else:
175                 raise AssertionError("Invalid sideband channel %d" % channel)
176
177
178 class TCPGitClient(GitClient):
179     """A Git Client that works over TCP directly (i.e. git://)."""
180
181     def __init__(self, host, port=None, *args, **kwargs):
182         self._socket = socket.socket(type=socket.SOCK_STREAM)
183         if port is None:
184             port = TCP_GIT_PORT
185         self._socket.connect((host, port))
186         self.rfile = self._socket.makefile('rb', -1)
187         self.wfile = self._socket.makefile('wb', 0)
188         self.host = host
189         super(TCPGitClient, self).__init__(lambda: _fileno_can_read(self._socket.fileno()), self.rfile.read, self.wfile.write, *args, **kwargs)
190
191     def send_pack(self, path):
192         """Send a pack to a remote host.
193
194         :param path: Path of the repository on the remote host
195         """
196         self.proto.send_cmd("git-receive-pack", path, "host=%s" % self.host)
197         super(TCPGitClient, self).send_pack(path)
198
199     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
200         """Fetch a pack from the remote host.
201         
202         :param path: Path of the reposiutory on the remote host
203         :param determine_wants: Callback that receives available refs dict and 
204             should return list of sha's to fetch.
205         :param graph_walker: GraphWalker instance used to find missing shas
206         :param pack_data: Callback for writing pack data
207         :param progress: Callback for writing progress
208         """
209         self.proto.send_cmd("git-upload-pack", path, "host=%s" % self.host)
210         super(TCPGitClient, self).fetch_pack(path, determine_wants, graph_walker, pack_data, progress)
211
212
213 class SubprocessGitClient(GitClient):
214
215     def __init__(self, *args, **kwargs):
216         self.proc = None
217         self._args = args
218         self._kwargs = kwargs
219
220     def _connect(self, service, *args):
221         argv = [service] + list(args)
222         self.proc = subprocess.Popen(argv, bufsize=0,
223                                 stdin=subprocess.PIPE,
224                                 stdout=subprocess.PIPE)
225         def read_fn(size):
226             return self.proc.stdout.read(size)
227         def write_fn(data):
228             self.proc.stdin.write(data)
229             self.proc.stdin.flush()
230         return GitClient(lambda: _fileno_can_read(self.proc.stdout.fileno()), read_fn, write_fn, *args, **kwargs)
231
232     def send_pack(self, path):
233         client = self._connect("git-receive-pack", path)
234         client.send_pack(path)
235
236     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, 
237         progress):
238         client = self._connect("git-upload-pack", path)
239         client.fetch_pack(path, determine_wants, graph_walker, pack_data, progress)
240
241
242 class SSHSubprocess(object):
243     """A socket-like object that talks to an ssh subprocess via pipes."""
244
245     def __init__(self, proc):
246         self.proc = proc
247
248     def send(self, data):
249         return os.write(self.proc.stdin.fileno(), data)
250
251     def recv(self, count):
252         return self.proc.stdout.read(count)
253
254     def close(self):
255         self.proc.stdin.close()
256         self.proc.stdout.close()
257         self.proc.wait()
258
259
260 class SSHVendor(object):
261
262     def connect_ssh(self, host, command, username=None, port=None):
263         #FIXME: This has no way to deal with passwords..
264         args = ['ssh', '-x']
265         if port is not None:
266             args.extend(['-p', str(port)])
267         if username is not None:
268             host = "%s@%s" % (username, host)
269         args.append(host)
270         proc = subprocess.Popen(args + command,
271                                 stdin=subprocess.PIPE,
272                                 stdout=subprocess.PIPE)
273         return SSHSubprocess(proc)
274
275 # Can be overridden by users
276 get_ssh_vendor = SSHVendor
277
278
279 class SSHGitClient(GitClient):
280
281     def __init__(self, host, port=None, *args, **kwargs):
282         self.host = host
283         self.port = port
284         self._args = args
285         self._kwargs = kwargs
286
287     def send_pack(self, path):
288         remote = get_ssh_vendor().connect_ssh(self.host, ["git-receive-pack %s" % path], port=self.port)
289         client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs)
290         client.send_pack(path)
291
292     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
293         remote = get_ssh_vendor().connect_ssh(self.host, ["git-upload-pack %s" % path], port=self.port)
294         client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs)
295         client.fetch_pack(path, determine_wants, graph_walker, pack_data, progress)
296