client: add get_refs method for just ref discovery
[jelmer/dulwich.git] / dulwich / client.py
1 # client.py -- Implementation of the server side git protocols
2 # Copyright (C) 2008-2013 Jelmer Vernooij <jelmer@samba.org>
3 # Copyright (C) 2008 John Carr
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # or (at your option) a later version of the License.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 """Client side support for the Git protocol.
21
22 The Dulwich client supports the following capabilities:
23
24  * thin-pack
25  * multi_ack_detailed
26  * multi_ack
27  * side-band-64k
28  * ofs-delta
29  * report-status
30  * delete-refs
31
32 Known capabilities that are not supported:
33
34  * shallow
35  * no-progress
36  * include-tag
37 """
38
39 __docformat__ = 'restructuredText'
40
41 from contextlib import closing
42 from io import BytesIO, BufferedReader
43 import dulwich
44 import select
45 import socket
46 import subprocess
47 import sys
48
49 try:
50     import urllib2
51     import urlparse
52 except ImportError:
53     import urllib.request as urllib2
54     import urllib.parse as urlparse
55
56 from dulwich.errors import (
57     GitProtocolError,
58     NotGitRepository,
59     SendPackError,
60     UpdateRefsError,
61     )
62 from dulwich.protocol import (
63     _RBUFSIZE,
64     CAPABILITY_DELETE_REFS,
65     CAPABILITY_MULTI_ACK,
66     CAPABILITY_MULTI_ACK_DETAILED,
67     CAPABILITY_OFS_DELTA,
68     CAPABILITY_REPORT_STATUS,
69     CAPABILITY_SIDE_BAND_64K,
70     CAPABILITY_THIN_PACK,
71     COMMAND_DONE,
72     COMMAND_HAVE,
73     COMMAND_WANT,
74     SIDE_BAND_CHANNEL_DATA,
75     SIDE_BAND_CHANNEL_PROGRESS,
76     SIDE_BAND_CHANNEL_FATAL,
77     PktLineParser,
78     Protocol,
79     ProtocolFile,
80     TCP_GIT_PORT,
81     ZERO_SHA,
82     extract_capabilities,
83     )
84 from dulwich.pack import (
85     write_pack_objects,
86     )
87 from dulwich.refs import (
88     read_info_refs,
89     )
90
91
92 def _fileno_can_read(fileno):
93     """Check if a file descriptor is readable."""
94     return len(select.select([fileno], [], [], 0)[0]) > 0
95
96 COMMON_CAPABILITIES = [CAPABILITY_OFS_DELTA, CAPABILITY_SIDE_BAND_64K]
97 FETCH_CAPABILITIES = ([CAPABILITY_THIN_PACK, CAPABILITY_MULTI_ACK,
98                        CAPABILITY_MULTI_ACK_DETAILED] +
99                       COMMON_CAPABILITIES)
100 SEND_CAPABILITIES = [CAPABILITY_REPORT_STATUS] + COMMON_CAPABILITIES
101
102
103 class ReportStatusParser(object):
104     """Handle status as reported by servers with 'report-status' capability.
105     """
106
107     def __init__(self):
108         self._done = False
109         self._pack_status = None
110         self._ref_status_ok = True
111         self._ref_statuses = []
112
113     def check(self):
114         """Check if there were any errors and, if so, raise exceptions.
115
116         :raise SendPackError: Raised when the server could not unpack
117         :raise UpdateRefsError: Raised when refs could not be updated
118         """
119         if self._pack_status not in (b'unpack ok', None):
120             raise SendPackError(self._pack_status)
121         if not self._ref_status_ok:
122             ref_status = {}
123             ok = set()
124             for status in self._ref_statuses:
125                 if b' ' not in status:
126                     # malformed response, move on to the next one
127                     continue
128                 status, ref = status.split(b' ', 1)
129
130                 if status == b'ng':
131                     if b' ' in ref:
132                         ref, status = ref.split(b' ', 1)
133                 else:
134                     ok.add(ref)
135                 ref_status[ref] = status
136             # TODO(jelmer): don't assume encoding of refs is ascii.
137             raise UpdateRefsError(', '.join([
138                 ref.decode('ascii') for ref in ref_status if ref not in ok]) +
139                 ' failed to update', ref_status=ref_status)
140
141     def handle_packet(self, pkt):
142         """Handle a packet.
143
144         :raise GitProtocolError: Raised when packets are received after a
145             flush packet.
146         """
147         if self._done:
148             raise GitProtocolError("received more data after status report")
149         if pkt is None:
150             self._done = True
151             return
152         if self._pack_status is None:
153             self._pack_status = pkt.strip()
154         else:
155             ref_status = pkt.strip()
156             self._ref_statuses.append(ref_status)
157             if not ref_status.startswith(b'ok '):
158                 self._ref_status_ok = False
159
160
161 def read_pkt_refs(proto):
162     server_capabilities = None
163     refs = {}
164     # Receive refs from server
165     for pkt in proto.read_pkt_seq():
166         (sha, ref) = pkt.rstrip(b'\n').split(None, 1)
167         if sha == b'ERR':
168             raise GitProtocolError(ref)
169         if server_capabilities is None:
170             (ref, server_capabilities) = extract_capabilities(ref)
171         refs[ref] = sha
172
173     if len(refs) == 0:
174         return None, set([])
175     return refs, set(server_capabilities)
176
177
178 # TODO(durin42): this doesn't correctly degrade if the server doesn't
179 # support some capabilities. This should work properly with servers
180 # that don't support multi_ack.
181 class GitClient(object):
182     """Git smart server client.
183
184     """
185
186     def __init__(self, thin_packs=True, report_activity=None):
187         """Create a new GitClient instance.
188
189         :param thin_packs: Whether or not thin packs should be retrieved
190         :param report_activity: Optional callback for reporting transport
191             activity.
192         """
193         self._report_activity = report_activity
194         self._report_status_parser = None
195         self._fetch_capabilities = set(FETCH_CAPABILITIES)
196         self._send_capabilities = set(SEND_CAPABILITIES)
197         if not thin_packs:
198             self._fetch_capabilities.remove(CAPABILITY_THIN_PACK)
199
200     def send_pack(self, path, determine_wants, generate_pack_contents,
201                   progress=None, write_pack=write_pack_objects):
202         """Upload a pack to a remote repository.
203
204         :param path: Repository path
205         :param generate_pack_contents: Function that can return a sequence of
206             the shas of the objects to upload.
207         :param progress: Optional progress function
208         :param write_pack: Function called with (file, iterable of objects) to
209             write the objects returned by generate_pack_contents to the server.
210
211         :raises SendPackError: if server rejects the pack data
212         :raises UpdateRefsError: if the server supports report-status
213                                  and rejects ref updates
214         """
215         raise NotImplementedError(self.send_pack)
216
217     def fetch(self, path, target, determine_wants=None, progress=None):
218         """Fetch into a target repository.
219
220         :param path: Path to fetch from
221         :param target: Target repository to fetch into
222         :param determine_wants: Optional function to determine what refs
223             to fetch
224         :param progress: Optional progress function
225         :return: remote refs as dictionary
226         """
227         if determine_wants is None:
228             determine_wants = target.object_store.determine_wants_all
229         f, commit, abort = target.object_store.add_pack()
230         try:
231             result = self.fetch_pack(
232                 path, determine_wants, target.get_graph_walker(), f.write,
233                 progress)
234         except:
235             abort()
236             raise
237         else:
238             commit()
239         return result
240
241     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
242                    progress=None):
243         """Retrieve a pack from a git smart server.
244
245         :param determine_wants: Callback that returns list of commits to fetch
246         :param graph_walker: Object with next() and ack().
247         :param pack_data: Callback called for each bit of data in the pack
248         :param progress: Callback for progress reports (strings)
249         """
250         raise NotImplementedError(self.fetch_pack)
251
252     def get_refs(self, path):
253         """Retrieve the current refs from a git smart server.
254
255         :param path: Path to the repo to fetch from.
256         """
257         raise NotImplementedError(self.get_refs)
258
259     def _parse_status_report(self, proto):
260         unpack = proto.read_pkt_line().strip()
261         if unpack != b'unpack ok':
262             st = True
263             # flush remaining error data
264             while st is not None:
265                 st = proto.read_pkt_line()
266             raise SendPackError(unpack)
267         statuses = []
268         errs = False
269         ref_status = proto.read_pkt_line()
270         while ref_status:
271             ref_status = ref_status.strip()
272             statuses.append(ref_status)
273             if not ref_status.startswith(b'ok '):
274                 errs = True
275             ref_status = proto.read_pkt_line()
276
277         if errs:
278             ref_status = {}
279             ok = set()
280             for status in statuses:
281                 if b' ' not in status:
282                     # malformed response, move on to the next one
283                     continue
284                 status, ref = status.split(b' ', 1)
285
286                 if status == b'ng':
287                     if b' ' in ref:
288                         ref, status = ref.split(b' ', 1)
289                 else:
290                     ok.add(ref)
291                 ref_status[ref] = status
292             raise UpdateRefsError(', '.join([ref for ref in ref_status
293                                              if ref not in ok]) +
294                                              b' failed to update',
295                                   ref_status=ref_status)
296
297     def _read_side_band64k_data(self, proto, channel_callbacks):
298         """Read per-channel data.
299
300         This requires the side-band-64k capability.
301
302         :param proto: Protocol object to read from
303         :param channel_callbacks: Dictionary mapping channels to packet
304             handlers to use. None for a callback discards channel data.
305         """
306         for pkt in proto.read_pkt_seq():
307             channel = ord(pkt[:1])
308             pkt = pkt[1:]
309             try:
310                 cb = channel_callbacks[channel]
311             except KeyError:
312                 raise AssertionError('Invalid sideband channel %d' % channel)
313             else:
314                 if cb is not None:
315                     cb(pkt)
316
317     def _handle_receive_pack_head(self, proto, capabilities, old_refs,
318                                   new_refs):
319         """Handle the head of a 'git-receive-pack' request.
320
321         :param proto: Protocol object to read from
322         :param capabilities: List of negotiated capabilities
323         :param old_refs: Old refs, as received from the server
324         :param new_refs: New refs
325         :return: (have, want) tuple
326         """
327         want = []
328         have = [x for x in old_refs.values() if not x == ZERO_SHA]
329         sent_capabilities = False
330
331         all_refs = set(new_refs.keys()).union(set(old_refs.keys()))
332         for refname in all_refs:
333             old_sha1 = old_refs.get(refname, ZERO_SHA)
334             new_sha1 = new_refs.get(refname, ZERO_SHA)
335
336             if old_sha1 != new_sha1:
337                 if sent_capabilities:
338                     proto.write_pkt_line(old_sha1 + b' ' + new_sha1 + b' ' + refname)
339                 else:
340                     proto.write_pkt_line(
341                         old_sha1 + b' ' + new_sha1 + b' ' + refname + b'\0' +
342                         b' '.join(capabilities))
343                     sent_capabilities = True
344             if new_sha1 not in have and new_sha1 != ZERO_SHA:
345                 want.append(new_sha1)
346         proto.write_pkt_line(None)
347         return (have, want)
348
349     def _handle_receive_pack_tail(self, proto, capabilities, progress=None):
350         """Handle the tail of a 'git-receive-pack' request.
351
352         :param proto: Protocol object to read from
353         :param capabilities: List of negotiated capabilities
354         :param progress: Optional progress reporting function
355         """
356         if b"side-band-64k" in capabilities:
357             if progress is None:
358                 progress = lambda x: None
359             channel_callbacks = {2: progress}
360             if CAPABILITY_REPORT_STATUS in capabilities:
361                 channel_callbacks[1] = PktLineParser(
362                     self._report_status_parser.handle_packet).parse
363             self._read_side_band64k_data(proto, channel_callbacks)
364         else:
365             if CAPABILITY_REPORT_STATUS in capabilities:
366                 for pkt in proto.read_pkt_seq():
367                     self._report_status_parser.handle_packet(pkt)
368         if self._report_status_parser is not None:
369             self._report_status_parser.check()
370
371     def _handle_upload_pack_head(self, proto, capabilities, graph_walker,
372                                  wants, can_read):
373         """Handle the head of a 'git-upload-pack' request.
374
375         :param proto: Protocol object to read from
376         :param capabilities: List of negotiated capabilities
377         :param graph_walker: GraphWalker instance to call .ack() on
378         :param wants: List of commits to fetch
379         :param can_read: function that returns a boolean that indicates
380             whether there is extra graph data to read on proto
381         """
382         assert isinstance(wants, list) and isinstance(wants[0], bytes)
383         proto.write_pkt_line(COMMAND_WANT + b' ' + wants[0] + b' ' + b' '.join(capabilities) + b'\n')
384         for want in wants[1:]:
385             proto.write_pkt_line(COMMAND_WANT + b' ' + want + b'\n')
386         proto.write_pkt_line(None)
387         have = next(graph_walker)
388         while have:
389             proto.write_pkt_line(COMMAND_HAVE + b' ' + have + b'\n')
390             if can_read():
391                 pkt = proto.read_pkt_line()
392                 parts = pkt.rstrip(b'\n').split(b' ')
393                 if parts[0] == b'ACK':
394                     graph_walker.ack(parts[1])
395                     if parts[2] in (b'continue', b'common'):
396                         pass
397                     elif parts[2] == b'ready':
398                         break
399                     else:
400                         raise AssertionError(
401                             "%s not in ('continue', 'ready', 'common)" %
402                             parts[2])
403             have = next(graph_walker)
404         proto.write_pkt_line(COMMAND_DONE + b'\n')
405
406     def _handle_upload_pack_tail(self, proto, capabilities, graph_walker,
407                                  pack_data, progress=None, rbufsize=_RBUFSIZE):
408         """Handle the tail of a 'git-upload-pack' request.
409
410         :param proto: Protocol object to read from
411         :param capabilities: List of negotiated capabilities
412         :param graph_walker: GraphWalker instance to call .ack() on
413         :param pack_data: Function to call with pack data
414         :param progress: Optional progress reporting function
415         :param rbufsize: Read buffer size
416         """
417         pkt = proto.read_pkt_line()
418         while pkt:
419             parts = pkt.rstrip(b'\n').split(b' ')
420             if parts[0] == b'ACK':
421                 graph_walker.ack(parts[1])
422             if len(parts) < 3 or parts[2] not in (
423                     b'ready', b'continue', b'common'):
424                 break
425             pkt = proto.read_pkt_line()
426         if CAPABILITY_SIDE_BAND_64K in capabilities:
427             if progress is None:
428                 # Just ignore progress data
429                 progress = lambda x: None
430             self._read_side_band64k_data(proto, {
431                 SIDE_BAND_CHANNEL_DATA: pack_data,
432                 SIDE_BAND_CHANNEL_PROGRESS: progress}
433             )
434         else:
435             while True:
436                 data = proto.read(rbufsize)
437                 if data == b"":
438                     break
439                 pack_data(data)
440
441
442 class TraditionalGitClient(GitClient):
443     """Traditional Git client."""
444
445     def _connect(self, cmd, path):
446         """Create a connection to the server.
447
448         This method is abstract - concrete implementations should
449         implement their own variant which connects to the server and
450         returns an initialized Protocol object with the service ready
451         for use and a can_read function which may be used to see if
452         reads would block.
453
454         :param cmd: The git service name to which we should connect.
455         :param path: The path we should pass to the service.
456         """
457         raise NotImplementedError()
458
459     def send_pack(self, path, determine_wants, generate_pack_contents,
460                   progress=None, write_pack=write_pack_objects):
461         """Upload a pack to a remote repository.
462
463         :param path: Repository path
464         :param generate_pack_contents: Function that can return a sequence of
465             the shas of the objects to upload.
466         :param progress: Optional callback called with progress updates
467         :param write_pack: Function called with (file, iterable of objects) to
468             write the objects returned by generate_pack_contents to the server.
469
470         :raises SendPackError: if server rejects the pack data
471         :raises UpdateRefsError: if the server supports report-status
472                                  and rejects ref updates
473         """
474         proto, unused_can_read = self._connect(b'receive-pack', path)
475         with proto:
476             old_refs, server_capabilities = read_pkt_refs(proto)
477             negotiated_capabilities = self._send_capabilities & server_capabilities
478
479             if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
480                 self._report_status_parser = ReportStatusParser()
481             report_status_parser = self._report_status_parser
482
483             try:
484                 new_refs = orig_new_refs = determine_wants(dict(old_refs))
485             except:
486                 proto.write_pkt_line(None)
487                 raise
488
489             if not CAPABILITY_DELETE_REFS in server_capabilities:
490                 # Server does not support deletions. Fail later.
491                 new_refs = dict(orig_new_refs)
492                 for ref, sha in orig_new_refs.items():
493                     if sha == ZERO_SHA:
494                         if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
495                             report_status_parser._ref_statuses.append(
496                                 b'ng ' + sha + b' remote does not support deleting refs')
497                             report_status_parser._ref_status_ok = False
498                         del new_refs[ref]
499
500             if new_refs is None:
501                 proto.write_pkt_line(None)
502                 return old_refs
503
504             if len(new_refs) == 0 and len(orig_new_refs):
505                 # NOOP - Original new refs filtered out by policy
506                 proto.write_pkt_line(None)
507                 if report_status_parser is not None:
508                     report_status_parser.check()
509                 return old_refs
510
511             (have, want) = self._handle_receive_pack_head(
512                 proto, negotiated_capabilities, old_refs, new_refs)
513             if not want and old_refs == new_refs:
514                 return new_refs
515             objects = generate_pack_contents(have, want)
516
517             dowrite = len(objects) > 0
518             dowrite = dowrite or any(old_refs.get(ref) != sha
519                                      for (ref, sha) in new_refs.items()
520                                      if sha != ZERO_SHA)
521             if dowrite:
522                 write_pack(proto.write_file(), objects)
523
524             self._handle_receive_pack_tail(
525                 proto, negotiated_capabilities, progress)
526             return new_refs
527
528     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
529                    progress=None):
530         """Retrieve a pack from a git smart server.
531
532         :param determine_wants: Callback that returns list of commits to fetch
533         :param graph_walker: Object with next() and ack().
534         :param pack_data: Callback called for each bit of data in the pack
535         :param progress: Callback for progress reports (strings)
536         """
537         proto, can_read = self._connect(b'upload-pack', path)
538         with proto:
539             refs, server_capabilities = read_pkt_refs(proto)
540             negotiated_capabilities = (
541                 self._fetch_capabilities & server_capabilities)
542
543             if refs is None:
544                 proto.write_pkt_line(None)
545                 return refs
546
547             try:
548                 wants = determine_wants(refs)
549             except:
550                 proto.write_pkt_line(None)
551                 raise
552             if wants is not None:
553                 wants = [cid for cid in wants if cid != ZERO_SHA]
554             if not wants:
555                 proto.write_pkt_line(None)
556                 return refs
557             self._handle_upload_pack_head(
558                 proto, negotiated_capabilities, graph_walker, wants, can_read)
559             self._handle_upload_pack_tail(
560                 proto, negotiated_capabilities, graph_walker, pack_data, progress)
561             return refs
562
563     def get_refs(self, path):
564         """Retrieve the current refs from a git smart server."""
565         # stock `git ls-remote` uses upload-pack
566         proto, _ = self._connect(b'upload-pack', path)
567         with proto:
568             refs, _ = read_pkt_refs(proto)
569             return refs
570
571     def archive(self, path, committish, write_data, progress=None,
572                 write_error=None):
573         proto, can_read = self._connect(b'upload-archive', path)
574         with proto:
575             proto.write_pkt_line(b"argument " + committish)
576             proto.write_pkt_line(None)
577             pkt = proto.read_pkt_line()
578             if pkt == b"NACK\n":
579                 return
580             elif pkt == b"ACK\n":
581                 pass
582             elif pkt.startswith(b"ERR "):
583                 raise GitProtocolError(pkt[4:].rstrip(b"\n"))
584             else:
585                 raise AssertionError("invalid response %r" % pkt)
586             ret = proto.read_pkt_line()
587             if ret is not None:
588                 raise AssertionError("expected pkt tail")
589             self._read_side_band64k_data(proto, {
590                 SIDE_BAND_CHANNEL_DATA: write_data,
591                 SIDE_BAND_CHANNEL_PROGRESS: progress,
592                 SIDE_BAND_CHANNEL_FATAL: write_error})
593
594
595 class TCPGitClient(TraditionalGitClient):
596     """A Git Client that works over TCP directly (i.e. git://)."""
597
598     def __init__(self, host, port=None, *args, **kwargs):
599         if port is None:
600             port = TCP_GIT_PORT
601         self._host = host
602         self._port = port
603         TraditionalGitClient.__init__(self, *args, **kwargs)
604
605     def _connect(self, cmd, path):
606         sockaddrs = socket.getaddrinfo(
607             self._host, self._port, socket.AF_UNSPEC, socket.SOCK_STREAM)
608         s = None
609         err = socket.error("no address found for %s" % self._host)
610         for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
611             s = socket.socket(family, socktype, proto)
612             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
613             try:
614                 s.connect(sockaddr)
615                 break
616             except socket.error as err:
617                 if s is not None:
618                     s.close()
619                 s = None
620         if s is None:
621             raise err
622         # -1 means system default buffering
623         rfile = s.makefile('rb', -1)
624         # 0 means unbuffered
625         wfile = s.makefile('wb', 0)
626         def close():
627             rfile.close()
628             wfile.close()
629             s.close()
630
631         proto = Protocol(rfile.read, wfile.write, close,
632                          report_activity=self._report_activity)
633         if path.startswith(b"/~"):
634             path = path[1:]
635         proto.send_cmd(b'git-' + cmd, path, b'host=' + self._host)
636         return proto, lambda: _fileno_can_read(s)
637
638
639 class SubprocessWrapper(object):
640     """A socket-like object that talks to a subprocess via pipes."""
641
642     def __init__(self, proc):
643         self.proc = proc
644         if sys.version_info[0] == 2:
645             self.read = proc.stdout.read
646         else:
647             self.read = BufferedReader(proc.stdout).read
648         self.write = proc.stdin.write
649
650     def can_read(self):
651         if subprocess.mswindows:
652             from msvcrt import get_osfhandle
653             from win32pipe import PeekNamedPipe
654             handle = get_osfhandle(self.proc.stdout.fileno())
655             data, total_bytes_avail, msg_bytes_left = PeekNamedPipe(handle, 0)
656             return total_bytes_avail != 0
657         else:
658             return _fileno_can_read(self.proc.stdout.fileno())
659
660     def close(self):
661         self.proc.stdin.close()
662         self.proc.stdout.close()
663         if self.proc.stderr:
664             self.proc.stderr.close()
665         self.proc.wait()
666
667
668 def find_git_command():
669     """Find command to run for system Git (usually C Git).
670     """
671     if sys.platform == 'win32': # support .exe, .bat and .cmd
672         try: # to avoid overhead
673             import win32api
674         except ImportError: # run through cmd.exe with some overhead
675             return ['cmd', '/c', 'git']
676         else:
677             status, git = win32api.FindExecutable('git')
678             return [git]
679     else:
680         return ['git']
681
682
683 class SubprocessGitClient(TraditionalGitClient):
684     """Git client that talks to a server using a subprocess."""
685
686     def __init__(self, *args, **kwargs):
687         self._connection = None
688         self._stderr = None
689         self._stderr = kwargs.get('stderr')
690         if 'stderr' in kwargs:
691             del kwargs['stderr']
692         TraditionalGitClient.__init__(self, *args, **kwargs)
693
694     git_command = None
695
696     def _connect(self, service, path):
697         import subprocess
698         if self.git_command is None:
699             git_command = find_git_command()
700         argv = git_command + [service, path]
701         argv = ['git', service.decode('ascii'), path]
702         p = SubprocessWrapper(
703             subprocess.Popen(argv, bufsize=0, stdin=subprocess.PIPE,
704                              stdout=subprocess.PIPE,
705                              stderr=self._stderr))
706         return Protocol(p.read, p.write, p.close,
707                         report_activity=self._report_activity), p.can_read
708
709
710 class LocalGitClient(GitClient):
711     """Git Client that just uses a local Repo."""
712
713     def __init__(self, thin_packs=True, report_activity=None):
714         """Create a new LocalGitClient instance.
715
716         :param path: Path to the local repository
717         :param thin_packs: Whether or not thin packs should be retrieved
718         :param report_activity: Optional callback for reporting transport
719             activity.
720         """
721         self._report_activity = report_activity
722         # Ignore the thin_packs argument
723
724     def send_pack(self, path, determine_wants, generate_pack_contents,
725                   progress=None, write_pack=write_pack_objects):
726         """Upload a pack to a remote repository.
727
728         :param path: Repository path
729         :param generate_pack_contents: Function that can return a sequence of
730             the shas of the objects to upload.
731         :param progress: Optional progress function
732         :param write_pack: Function called with (file, iterable of objects) to
733             write the objects returned by generate_pack_contents to the server.
734
735         :raises SendPackError: if server rejects the pack data
736         :raises UpdateRefsError: if the server supports report-status
737                                  and rejects ref updates
738         """
739         from dulwich.repo import Repo
740
741         with closing(Repo(path)) as target:
742             old_refs = target.get_refs()
743             new_refs = determine_wants(dict(old_refs))
744
745             have = [sha1 for sha1 in old_refs.values() if sha1 != ZERO_SHA]
746             want = []
747             all_refs = set(new_refs.keys()).union(set(old_refs.keys()))
748             for refname in all_refs:
749                 old_sha1 = old_refs.get(refname, ZERO_SHA)
750                 new_sha1 = new_refs.get(refname, ZERO_SHA)
751                 if new_sha1 not in have and new_sha1 != ZERO_SHA:
752                     want.append(new_sha1)
753
754             if not want and old_refs == new_refs:
755                 return new_refs
756
757             target.object_store.add_objects(generate_pack_contents(have, want))
758
759             for name, sha in new_refs.items():
760                 target.refs[name] = sha
761
762         return new_refs
763
764     def fetch(self, path, target, determine_wants=None, progress=None):
765         """Fetch into a target repository.
766
767         :param path: Path to fetch from
768         :param target: Target repository to fetch into
769         :param determine_wants: Optional function to determine what refs
770             to fetch
771         :param progress: Optional progress function
772         :return: remote refs as dictionary
773         """
774         from dulwich.repo import Repo
775         with closing(Repo(path)) as r:
776             return r.fetch(target, determine_wants=determine_wants,
777                            progress=progress)
778
779     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
780                    progress=None):
781         """Retrieve a pack from a git smart server.
782
783         :param determine_wants: Callback that returns list of commits to fetch
784         :param graph_walker: Object with next() and ack().
785         :param pack_data: Callback called for each bit of data in the pack
786         :param progress: Callback for progress reports (strings)
787         """
788         from dulwich.repo import Repo
789         with closing(Repo(path)) as r:
790             objects_iter = r.fetch_objects(determine_wants, graph_walker, progress)
791
792             # Did the process short-circuit (e.g. in a stateless RPC call)? Note
793             # that the client still expects a 0-object pack in most cases.
794             if objects_iter is None:
795                 return
796             write_pack_objects(ProtocolFile(None, pack_data), objects_iter)
797
798     def get_refs(self, path):
799         """Retrieve the current refs from a git smart server."""
800         from dulwich.repo import Repo
801
802         with closing(Repo(path)) as target:
803             return target.get_refs()
804
805
806 # What Git client to use for local access
807 default_local_git_client_cls = LocalGitClient
808
809
810 class SSHVendor(object):
811     """A client side SSH implementation."""
812
813     def connect_ssh(self, host, command, username=None, port=None):
814         import warnings
815         warnings.warn(
816             "SSHVendor.connect_ssh has been renamed to SSHVendor.run_command",
817             DeprecationWarning)
818         return self.run_command(host, command, username=username, port=port)
819
820     def run_command(self, host, command, username=None, port=None):
821         """Connect to an SSH server.
822
823         Run a command remotely and return a file-like object for interaction
824         with the remote command.
825
826         :param host: Host name
827         :param command: Command to run
828         :param username: Optional ame of user to log in as
829         :param port: Optional SSH port to use
830         """
831         raise NotImplementedError(self.run_command)
832
833
834 class SubprocessSSHVendor(SSHVendor):
835     """SSH vendor that shells out to the local 'ssh' command."""
836
837     def run_command(self, host, command, username=None, port=None):
838         import subprocess
839         #FIXME: This has no way to deal with passwords..
840         args = ['ssh', '-x']
841         if port is not None:
842             args.extend(['-p', str(port)])
843         if username is not None:
844             host = '%s@%s' % (username, host)
845         args.append(host)
846         proc = subprocess.Popen(args + command,
847                                 stdin=subprocess.PIPE,
848                                 stdout=subprocess.PIPE)
849         return SubprocessWrapper(proc)
850
851
852 try:
853     import paramiko
854 except ImportError:
855     pass
856 else:
857     import threading
858
859     class ParamikoWrapper(object):
860         STDERR_READ_N = 2048  # 2k
861
862         def __init__(self, client, channel, progress_stderr=None):
863             self.client = client
864             self.channel = channel
865             self.progress_stderr = progress_stderr
866             self.should_monitor = bool(progress_stderr) or True
867             self.monitor_thread = None
868             self.stderr = ''
869
870             # Channel must block
871             self.channel.setblocking(True)
872
873             # Start
874             if self.should_monitor:
875                 self.monitor_thread = threading.Thread(
876                     target=self.monitor_stderr)
877                 self.monitor_thread.start()
878
879         def monitor_stderr(self):
880             while self.should_monitor:
881                 # Block and read
882                 data = self.read_stderr(self.STDERR_READ_N)
883
884                 # Socket closed
885                 if not data:
886                     self.should_monitor = False
887                     break
888
889                 # Emit data
890                 if self.progress_stderr:
891                     self.progress_stderr(data)
892
893                 # Append to buffer
894                 self.stderr += data
895
896         def stop_monitoring(self):
897             # Stop StdErr thread
898             if self.should_monitor:
899                 self.should_monitor = False
900                 self.monitor_thread.join()
901
902                 # Get left over data
903                 data = self.channel.in_stderr_buffer.empty()
904                 self.stderr += data
905
906         def can_read(self):
907             return self.channel.recv_ready()
908
909         def write(self, data):
910             return self.channel.sendall(data)
911
912         def read_stderr(self, n):
913             return self.channel.recv_stderr(n)
914
915         def read(self, n=None):
916             data = self.channel.recv(n)
917             data_len = len(data)
918
919             # Closed socket
920             if not data:
921                 return
922
923             # Read more if needed
924             if n and data_len < n:
925                 diff_len = n - data_len
926                 return data + self.read(diff_len)
927             return data
928
929         def close(self):
930             self.channel.close()
931             self.stop_monitoring()
932
933     class ParamikoSSHVendor(object):
934
935         def __init__(self):
936             self.ssh_kwargs = {}
937
938         def run_command(self, host, command, username=None, port=None,
939                         progress_stderr=None):
940
941             # Paramiko needs an explicit port. None is not valid
942             if port is None:
943                 port = 22
944
945             client = paramiko.SSHClient()
946
947             policy = paramiko.client.MissingHostKeyPolicy()
948             client.set_missing_host_key_policy(policy)
949             client.connect(host, username=username, port=port,
950                            **self.ssh_kwargs)
951
952             # Open SSH session
953             channel = client.get_transport().open_session()
954
955             # Run commands
956             channel.exec_command(*command)
957
958             return ParamikoWrapper(
959                 client, channel, progress_stderr=progress_stderr)
960
961
962 # Can be overridden by users
963 get_ssh_vendor = SubprocessSSHVendor
964
965
966 class SSHGitClient(TraditionalGitClient):
967
968     def __init__(self, host, port=None, username=None, *args, **kwargs):
969         self.host = host
970         self.port = port
971         self.username = username
972         TraditionalGitClient.__init__(self, *args, **kwargs)
973         self.alternative_paths = {}
974
975     def _get_cmd_path(self, cmd):
976         cmd = cmd.decode('ascii')
977         return self.alternative_paths.get(cmd, 'git-' + cmd)
978
979     def _connect(self, cmd, path):
980         if path.startswith("/~"):
981             path = path[1:]
982         con = get_ssh_vendor().run_command(
983             self.host, [self._get_cmd_path(cmd), path],
984             port=self.port, username=self.username)
985         return (Protocol(con.read, con.write, con.close,
986                          report_activity=self._report_activity),
987                 con.can_read)
988
989
990 def default_user_agent_string():
991     return "dulwich/%s" % ".".join([str(x) for x in dulwich.__version__])
992
993
994 def default_urllib2_opener(config):
995     if config is not None:
996         proxy_server = config.get("http", "proxy")
997     else:
998         proxy_server = None
999     handlers = []
1000     if proxy_server is not None:
1001         handlers.append(urllib2.ProxyHandler({"http": proxy_server}))
1002     opener = urllib2.build_opener(*handlers)
1003     if config is not None:
1004         user_agent = config.get("http", "useragent")
1005     else:
1006         user_agent = None
1007     if user_agent is None:
1008         user_agent = default_user_agent_string()
1009     opener.addheaders = [('User-agent', user_agent)]
1010     return opener
1011
1012
1013 class HttpGitClient(GitClient):
1014
1015     def __init__(self, base_url, dumb=None, opener=None, config=None, *args,
1016                  **kwargs):
1017         self.base_url = base_url.rstrip("/") + "/"
1018         self.dumb = dumb
1019         if opener is None:
1020             self.opener = default_urllib2_opener(config)
1021         else:
1022             self.opener = opener
1023         GitClient.__init__(self, *args, **kwargs)
1024
1025     def __repr__(self):
1026         return "%s(%r, dumb=%r)" % (type(self).__name__, self.base_url, self.dumb)
1027
1028     def _get_url(self, path):
1029         return urlparse.urljoin(self.base_url, path).rstrip("/") + "/"
1030
1031     def _http_request(self, url, headers={}, data=None):
1032         req = urllib2.Request(url, headers=headers, data=data)
1033         try:
1034             resp = self.opener.open(req)
1035         except urllib2.HTTPError as e:
1036             if e.code == 404:
1037                 raise NotGitRepository()
1038             if e.code != 200:
1039                 raise GitProtocolError("unexpected http response %d" % e.code)
1040         return resp
1041
1042     def _discover_references(self, service, url):
1043         assert url[-1] == "/"
1044         url = urlparse.urljoin(url, "info/refs")
1045         headers = {}
1046         if self.dumb is not False:
1047             url += "?service=%s" % service
1048             headers["Content-Type"] = "application/x-%s-request" % service
1049         resp = self._http_request(url, headers)
1050         try:
1051             self.dumb = (not resp.info().gettype().startswith("application/x-git-"))
1052             if not self.dumb:
1053                 proto = Protocol(resp.read, None)
1054                 # The first line should mention the service
1055                 pkts = list(proto.read_pkt_seq())
1056                 if pkts != [('# service=%s\n' % service)]:
1057                     raise GitProtocolError(
1058                         "unexpected first line %r from smart server" % pkts)
1059                 return read_pkt_refs(proto)
1060             else:
1061                 return read_info_refs(resp), set()
1062         finally:
1063             resp.close()
1064
1065     def _smart_request(self, service, url, data):
1066         assert url[-1] == "/"
1067         url = urlparse.urljoin(url, service)
1068         headers = {"Content-Type": "application/x-%s-request" % service}
1069         resp = self._http_request(url, headers, data)
1070         if resp.info().gettype() != ("application/x-%s-result" % service):
1071             raise GitProtocolError("Invalid content-type from server: %s"
1072                 % resp.info().gettype())
1073         return resp
1074
1075     def send_pack(self, path, determine_wants, generate_pack_contents,
1076                   progress=None, write_pack=write_pack_objects):
1077         """Upload a pack to a remote repository.
1078
1079         :param path: Repository path
1080         :param generate_pack_contents: Function that can return a sequence of
1081             the shas of the objects to upload.
1082         :param progress: Optional progress function
1083         :param write_pack: Function called with (file, iterable of objects) to
1084             write the objects returned by generate_pack_contents to the server.
1085
1086         :raises SendPackError: if server rejects the pack data
1087         :raises UpdateRefsError: if the server supports report-status
1088                                  and rejects ref updates
1089         """
1090         url = self._get_url(path)
1091         old_refs, server_capabilities = self._discover_references(
1092             b"git-receive-pack", url)
1093         negotiated_capabilities = self._send_capabilities & server_capabilities
1094
1095         if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
1096             self._report_status_parser = ReportStatusParser()
1097
1098         new_refs = determine_wants(dict(old_refs))
1099         if new_refs is None:
1100             return old_refs
1101         if self.dumb:
1102             raise NotImplementedError(self.fetch_pack)
1103         req_data = BytesIO()
1104         req_proto = Protocol(None, req_data.write)
1105         (have, want) = self._handle_receive_pack_head(
1106             req_proto, negotiated_capabilities, old_refs, new_refs)
1107         if not want and old_refs == new_refs:
1108             return new_refs
1109         objects = generate_pack_contents(have, want)
1110         if len(objects) > 0:
1111             write_pack(req_proto.write_file(), objects)
1112         resp = self._smart_request(b"git-receive-pack", url,
1113                                    data=req_data.getvalue())
1114         try:
1115             resp_proto = Protocol(resp.read, None)
1116             self._handle_receive_pack_tail(resp_proto, negotiated_capabilities,
1117                 progress)
1118             return new_refs
1119         finally:
1120             resp.close()
1121
1122
1123     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
1124                    progress=None):
1125         """Retrieve a pack from a git smart server.
1126
1127         :param determine_wants: Callback that returns list of commits to fetch
1128         :param graph_walker: Object with next() and ack().
1129         :param pack_data: Callback called for each bit of data in the pack
1130         :param progress: Callback for progress reports (strings)
1131         :return: Dictionary with the refs of the remote repository
1132         """
1133         url = self._get_url(path)
1134         refs, server_capabilities = self._discover_references(
1135             b"git-upload-pack", url)
1136         negotiated_capabilities = self._fetch_capabilities & server_capabilities
1137         wants = determine_wants(refs)
1138         if wants is not None:
1139             wants = [cid for cid in wants if cid != ZERO_SHA]
1140         if not wants:
1141             return refs
1142         if self.dumb:
1143             raise NotImplementedError(self.send_pack)
1144         req_data = BytesIO()
1145         req_proto = Protocol(None, req_data.write)
1146         self._handle_upload_pack_head(
1147             req_proto, negotiated_capabilities, graph_walker, wants,
1148             lambda: False)
1149         resp = self._smart_request(
1150             b"git-upload-pack", url, data=req_data.getvalue())
1151         try:
1152             resp_proto = Protocol(resp.read, None)
1153             self._handle_upload_pack_tail(resp_proto, negotiated_capabilities,
1154                 graph_walker, pack_data, progress)
1155             return refs
1156         finally:
1157             resp.close()
1158
1159     def get_refs(self, path):
1160         """Retrieve the current refs from a git smart server."""
1161         url = self._get_url(path)
1162         refs, _ = self._discover_references(
1163             b"git-upload-pack", url)
1164         return refs
1165
1166
1167 def get_transport_and_path_from_url(url, config=None, **kwargs):
1168     """Obtain a git client from a URL.
1169
1170     :param url: URL to open
1171     :param config: Optional config object
1172     :param thin_packs: Whether or not thin packs should be retrieved
1173     :param report_activity: Optional callback for reporting transport
1174         activity.
1175     :return: Tuple with client instance and relative path.
1176     """
1177     parsed = urlparse.urlparse(url)
1178     if parsed.scheme == 'git':
1179         return (TCPGitClient(parsed.hostname, port=parsed.port, **kwargs),
1180                 parsed.path)
1181     elif parsed.scheme == 'git+ssh':
1182         path = parsed.path
1183         if path.startswith('/'):
1184             path = parsed.path[1:]
1185         return SSHGitClient(parsed.hostname, port=parsed.port,
1186                             username=parsed.username, **kwargs), path
1187     elif parsed.scheme in ('http', 'https'):
1188         return HttpGitClient(urlparse.urlunparse(parsed), config=config,
1189                 **kwargs), parsed.path
1190     elif parsed.scheme == 'file':
1191         return default_local_git_client_cls(**kwargs), parsed.path
1192
1193     raise ValueError("unknown scheme '%s'" % parsed.scheme)
1194
1195
1196 def get_transport_and_path(location, **kwargs):
1197     """Obtain a git client from a URL.
1198
1199     :param location: URL or path
1200     :param config: Optional config object
1201     :param thin_packs: Whether or not thin packs should be retrieved
1202     :param report_activity: Optional callback for reporting transport
1203         activity.
1204     :return: Tuple with client instance and relative path.
1205     """
1206     # First, try to parse it as a URL
1207     try:
1208         return get_transport_and_path_from_url(location, **kwargs)
1209     except ValueError:
1210         pass
1211
1212     if (sys.platform == 'win32' and
1213             location[0].isalpha() and location[1:3] == ':\\'):
1214         # Windows local path
1215         return default_local_git_client_cls(**kwargs), location
1216
1217     if ':' in location and not '@' in location:
1218         # SSH with no user@, zero or one leading slash.
1219         (hostname, path) = location.split(':')
1220         return SSHGitClient(hostname, **kwargs), path
1221     elif '@' in location and ':' in location:
1222         # SSH with user@host:foo.
1223         user_host, path = location.split(':')
1224         user, host = user_host.rsplit('@')
1225         return SSHGitClient(host, username=user, **kwargs), path
1226
1227     # Otherwise, assume it's a local path.
1228     return default_local_git_client_cls(**kwargs), location