Support IPv6 for git:// connections.
[jelmer/dulwich.git] / dulwich / client.py
1 # client.py -- Implementation of the server side git protocols
2 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
3 # Copyright (C) 2008 John Carr
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # or (at your option) a later version of the License.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 """Client side support for the Git protocol."""
21
22 __docformat__ = 'restructuredText'
23
24 import select
25 import socket
26 import urlparse
27
28 from dulwich.errors import (
29     GitProtocolError,
30     SendPackError,
31     UpdateRefsError,
32     )
33 from dulwich.protocol import (
34     Protocol,
35     TCP_GIT_PORT,
36     ZERO_SHA,
37     extract_capabilities,
38     )
39 from dulwich.pack import (
40     write_pack_data,
41     )
42
43
44 # Python 2.6.6 included these in urlparse.uses_netloc upstream. Do
45 # monkeypatching to enable similar behaviour in earlier Pythons:
46 for scheme in ('git', 'git+ssh'):
47     if scheme not in urlparse.uses_netloc:
48         urlparse.uses_netloc.append(scheme)
49
50 def _fileno_can_read(fileno):
51     """Check if a file descriptor is readable."""
52     return len(select.select([fileno], [], [], 0)[0]) > 0
53
54 COMMON_CAPABILITIES = ['ofs-delta']
55 FETCH_CAPABILITIES = ['multi_ack', 'side-band-64k'] + COMMON_CAPABILITIES
56 SEND_CAPABILITIES = ['report-status'] + COMMON_CAPABILITIES
57
58 # TODO(durin42): this doesn't correctly degrade if the server doesn't
59 # support some capabilities. This should work properly with servers
60 # that don't support side-band-64k and multi_ack.
61 class GitClient(object):
62     """Git smart server client.
63
64     """
65
66     def __init__(self, thin_packs=True, report_activity=None):
67         """Create a new GitClient instance.
68
69         :param thin_packs: Whether or not thin packs should be retrieved
70         :param report_activity: Optional callback for reporting transport
71             activity.
72         """
73         self._report_activity = report_activity
74         self._fetch_capabilities = list(FETCH_CAPABILITIES)
75         self._send_capabilities = list(SEND_CAPABILITIES)
76         if thin_packs:
77             self._fetch_capabilities.append('thin-pack')
78
79     def _connect(self, cmd, path):
80         """Create a connection to the server.
81
82         This method is abstract - concrete implementations should
83         implement their own variant which connects to the server and
84         returns an initialized Protocol object with the service ready
85         for use and a can_read function which may be used to see if
86         reads would block.
87
88         :param cmd: The git service name to which we should connect.
89         :param path: The path we should pass to the service.
90         """
91         raise NotImplementedError()
92
93     def _read_refs(self, proto):
94         server_capabilities = None
95         refs = {}
96         # Receive refs from server
97         for pkt in proto.read_pkt_seq():
98             (sha, ref) = pkt.rstrip('\n').split(' ', 1)
99             if sha == 'ERR':
100                 raise GitProtocolError(ref)
101             if server_capabilities is None:
102                 (ref, server_capabilities) = extract_capabilities(ref)
103             refs[ref] = sha
104         return refs, server_capabilities
105
106     def _parse_status_report(self, proto):
107         unpack = proto.read_pkt_line().strip()
108         if unpack != 'unpack ok':
109             st = True
110             # flush remaining error data
111             while st is not None:
112                 st = proto.read_pkt_line()
113             raise SendPackError(unpack)
114         statuses = []
115         errs = False
116         ref_status = proto.read_pkt_line()
117         while ref_status:
118             ref_status = ref_status.strip()
119             statuses.append(ref_status)
120             if not ref_status.startswith('ok '):
121                 errs = True
122             ref_status = proto.read_pkt_line()
123
124         if errs:
125             ref_status = {}
126             ok = set()
127             for status in statuses:
128                 if ' ' not in status:
129                     # malformed response, move on to the next one
130                     continue
131                 status, ref = status.split(' ', 1)
132
133                 if status == 'ng':
134                     if ' ' in ref:
135                         ref, status = ref.split(' ', 1)
136                 else:
137                     ok.add(ref)
138                 ref_status[ref] = status
139             raise UpdateRefsError('%s failed to update' %
140                                   ', '.join([ref for ref in ref_status
141                                              if ref not in ok]),
142                                   ref_status=ref_status)
143
144
145     # TODO(durin42): add side-band-64k capability support here and advertise it
146     def send_pack(self, path, determine_wants, generate_pack_contents):
147         """Upload a pack to a remote repository.
148
149         :param path: Repository path
150         :param generate_pack_contents: Function that can return the shas of the
151             objects to upload.
152
153         :raises SendPackError: if server rejects the pack data
154         :raises UpdateRefsError: if the server supports report-status
155                                  and rejects ref updates
156         """
157         proto, unused_can_read = self._connect('receive-pack', path)
158         old_refs, server_capabilities = self._read_refs(proto)
159         if 'report-status' not in server_capabilities:
160             self._send_capabilities.remove('report-status')
161         new_refs = determine_wants(old_refs)
162         if not new_refs:
163             proto.write_pkt_line(None)
164             return {}
165         want = []
166         have = [x for x in old_refs.values() if not x == ZERO_SHA]
167         sent_capabilities = False
168         for refname in set(new_refs.keys() + old_refs.keys()):
169             old_sha1 = old_refs.get(refname, ZERO_SHA)
170             new_sha1 = new_refs.get(refname, ZERO_SHA)
171             if old_sha1 != new_sha1:
172                 if sent_capabilities:
173                     proto.write_pkt_line('%s %s %s' % (old_sha1, new_sha1,
174                                                             refname))
175                 else:
176                     proto.write_pkt_line(
177                       '%s %s %s\0%s' % (old_sha1, new_sha1, refname,
178                                         ' '.join(self._send_capabilities)))
179                     sent_capabilities = True
180             if new_sha1 not in have and new_sha1 != ZERO_SHA:
181                 want.append(new_sha1)
182         proto.write_pkt_line(None)
183         if not want:
184             return new_refs
185         objects = generate_pack_contents(have, want)
186         entries, sha = write_pack_data(proto.write_file(), objects,
187                                        len(objects))
188
189         if 'report-status' in self._send_capabilities:
190             self._parse_status_report(proto)
191         # wait for EOF before returning
192         data = proto.read()
193         if data:
194             raise SendPackError('Unexpected response %r' % data)
195         return new_refs
196
197     def fetch(self, path, target, determine_wants=None, progress=None):
198         """Fetch into a target repository.
199
200         :param path: Path to fetch from
201         :param target: Target repository to fetch into
202         :param determine_wants: Optional function to determine what refs
203             to fetch
204         :param progress: Optional progress function
205         :return: remote refs
206         """
207         if determine_wants is None:
208             determine_wants = target.object_store.determine_wants_all
209         f, commit = target.object_store.add_pack()
210         try:
211             return self.fetch_pack(path, determine_wants,
212                 target.get_graph_walker(), f.write, progress)
213         finally:
214             commit()
215
216     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
217                    progress):
218         """Retrieve a pack from a git smart server.
219
220         :param determine_wants: Callback that returns list of commits to fetch
221         :param graph_walker: Object with next() and ack().
222         :param pack_data: Callback called for each bit of data in the pack
223         :param progress: Callback for progress reports (strings)
224         """
225         proto, can_read = self._connect('upload-pack', path)
226         (refs, server_capabilities) = self._read_refs(proto)
227         wants = determine_wants(refs)
228         if not wants:
229             proto.write_pkt_line(None)
230             return refs
231         assert isinstance(wants, list) and type(wants[0]) == str
232         proto.write_pkt_line('want %s %s\n' % (
233             wants[0], ' '.join(self._fetch_capabilities)))
234         for want in wants[1:]:
235             proto.write_pkt_line('want %s\n' % want)
236         proto.write_pkt_line(None)
237         have = graph_walker.next()
238         while have:
239             proto.write_pkt_line('have %s\n' % have)
240             if can_read():
241                 pkt = proto.read_pkt_line()
242                 parts = pkt.rstrip('\n').split(' ')
243                 if parts[0] == 'ACK':
244                     graph_walker.ack(parts[1])
245                     assert parts[2] == 'continue'
246             have = graph_walker.next()
247         proto.write_pkt_line('done\n')
248         pkt = proto.read_pkt_line()
249         while pkt:
250             parts = pkt.rstrip('\n').split(' ')
251             if parts[0] == 'ACK':
252                 graph_walker.ack(pkt.split(' ')[1])
253             if len(parts) < 3 or parts[2] != 'continue':
254                 break
255             pkt = proto.read_pkt_line()
256         # TODO(durin42): this is broken if the server didn't support the
257         # side-band-64k capability.
258         for pkt in proto.read_pkt_seq():
259             channel = ord(pkt[0])
260             pkt = pkt[1:]
261             if channel == 1:
262                 pack_data(pkt)
263             elif channel == 2:
264                 if progress is not None:
265                     progress(pkt)
266             else:
267                 raise AssertionError('Invalid sideband channel %d' % channel)
268         return refs
269
270
271 class TCPGitClient(GitClient):
272     """A Git Client that works over TCP directly (i.e. git://)."""
273
274     def __init__(self, host, port=None, *args, **kwargs):
275         if port is None:
276             port = TCP_GIT_PORT
277         self._host = host
278         self._port = port
279         GitClient.__init__(self, *args, **kwargs)
280
281     def _connect(self, cmd, path):
282         sockaddrs = socket.getaddrinfo(self._host, self._port,
283             socket.AF_UNSPEC, socket.SOCK_STREAM, 0, 0)
284         s = None
285         err = socket.error("no address found for %s" % self._host)
286         for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
287             try:
288                 s = socket.socket(family, socktype, proto)
289                 s.setsockopt(socket.IPPROTO_TCP,
290                                         socket.TCP_NODELAY, 1)
291                 s.connect(sockaddr)
292             except socket.error, err:
293                 if s is not None:
294                     s.close()
295                 s = None
296         if s is None:
297             raise err
298         # -1 means system default buffering
299         rfile = s.makefile('rb', -1)
300         # 0 means unbuffered
301         wfile = s.makefile('wb', 0)
302         proto = Protocol(rfile.read, wfile.write,
303                          report_activity=self._report_activity)
304         proto.send_cmd('git-%s' % cmd, path, 'host=%s' % self._host)
305         return proto, lambda: _fileno_can_read(s)
306
307
308 class SubprocessWrapper(object):
309     """A socket-like object that talks to a subprocess via pipes."""
310
311     def __init__(self, proc):
312         self.proc = proc
313         self.read = proc.stdout.read
314         self.write = proc.stdin.write
315
316     def can_read(self):
317         return _fileno_can_read(self.proc.stdout.fileno())
318
319     def close(self):
320         self.proc.stdin.close()
321         self.proc.stdout.close()
322         self.proc.wait()
323
324
325 class SubprocessGitClient(GitClient):
326     """Git client that talks to a server using a subprocess."""
327
328     def __init__(self, *args, **kwargs):
329         self._connection = None
330         GitClient.__init__(self, *args, **kwargs)
331
332     def _connect(self, service, path):
333         import subprocess
334         argv = ['git', service, path]
335         p = SubprocessWrapper(
336             subprocess.Popen(argv, bufsize=0, stdin=subprocess.PIPE,
337                              stdout=subprocess.PIPE))
338         return Protocol(p.read, p.write,
339                         report_activity=self._report_activity), p.can_read
340
341
342 class SSHVendor(object):
343
344     def connect_ssh(self, host, command, username=None, port=None):
345         import subprocess
346         #FIXME: This has no way to deal with passwords..
347         args = ['ssh', '-x']
348         if port is not None:
349             args.extend(['-p', str(port)])
350         if username is not None:
351             host = '%s@%s' % (username, host)
352         args.append(host)
353         proc = subprocess.Popen(args + command,
354                                 stdin=subprocess.PIPE,
355                                 stdout=subprocess.PIPE)
356         return SubprocessWrapper(proc)
357
358 # Can be overridden by users
359 get_ssh_vendor = SSHVendor
360
361
362 class SSHGitClient(GitClient):
363
364     def __init__(self, host, port=None, username=None, *args, **kwargs):
365         self.host = host
366         self.port = port
367         self.username = username
368         GitClient.__init__(self, *args, **kwargs)
369         self.alternative_paths = {}
370
371     def _get_cmd_path(self, cmd):
372         return self.alternative_paths.get(cmd, 'git-%s' % cmd)
373
374     def _connect(self, cmd, path):
375         con = get_ssh_vendor().connect_ssh(
376             self.host, ["%s '%s'" % (self._get_cmd_path(cmd), path)],
377             port=self.port, username=self.username)
378         return (Protocol(con.read, con.write, report_activity=self._report_activity),
379                 con.can_read)
380
381
382 def get_transport_and_path(uri):
383     """Obtain a git client from a URI or path.
384
385     :param uri: URI or path
386     :return: Tuple with client instance and relative path.
387     """
388     parsed = urlparse.urlparse(uri)
389     if parsed.scheme == 'git':
390         return TCPGitClient(parsed.hostname, port=parsed.port), parsed.path
391     elif parsed.scheme == 'git+ssh':
392         return SSHGitClient(parsed.hostname, port=parsed.port,
393                             username=parsed.username), parsed.path
394
395     if parsed.scheme and not parsed.netloc:
396         # SSH with no user@, zero or one leading slash.
397         return SSHGitClient(parsed.scheme), parsed.path
398     elif parsed.scheme:
399         raise ValueError('Unknown git protocol scheme: %s' % parsed.scheme)
400     elif '@' in parsed.path and ':' in parsed.path:
401         # SSH with user@host:foo.
402         user_host, path = parsed.path.split(':')
403         user, host = user_host.rsplit('@')
404         return SSHGitClient(host, username=user), path
405
406     # Otherwise, assume it's a local path.
407     return SubprocessGitClient(), uri