Merge branch 'python3'.
[jelmer/dulwich.git] / dulwich / client.py
1 # client.py -- Implementation of the server side git protocols
2 # Copyright (C) 2008-2013 Jelmer Vernooij <jelmer@samba.org>
3 # Copyright (C) 2008 John Carr
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # or (at your option) a later version of the License.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 """Client side support for the Git protocol.
21
22 The Dulwich client supports the following capabilities:
23
24  * thin-pack
25  * multi_ack_detailed
26  * multi_ack
27  * side-band-64k
28  * ofs-delta
29  * report-status
30  * delete-refs
31
32 Known capabilities that are not supported:
33
34  * shallow
35  * no-progress
36  * include-tag
37 """
38
39 __docformat__ = 'restructuredText'
40
41 from io import BytesIO
42 import dulwich
43 import select
44 import socket
45 import subprocess
46 import sys
47
48 try:
49     import urllib2
50     import urlparse
51 except ImportError:
52     import urllib.request as urllib2
53     import urllib.parse as urlparse
54
55 from dulwich.errors import (
56     GitProtocolError,
57     NotGitRepository,
58     SendPackError,
59     UpdateRefsError,
60     )
61 from dulwich.protocol import (
62     _RBUFSIZE,
63     PktLineParser,
64     Protocol,
65     ProtocolFile,
66     TCP_GIT_PORT,
67     ZERO_SHA,
68     extract_capabilities,
69     )
70 from dulwich.pack import (
71     write_pack_objects,
72     )
73 from dulwich.refs import (
74     read_info_refs,
75     )
76
77
78 def _fileno_can_read(fileno):
79     """Check if a file descriptor is readable."""
80     return len(select.select([fileno], [], [], 0)[0]) > 0
81
82 COMMON_CAPABILITIES = ['ofs-delta', 'side-band-64k']
83 FETCH_CAPABILITIES = (['thin-pack', 'multi_ack', 'multi_ack_detailed'] +
84                       COMMON_CAPABILITIES)
85 SEND_CAPABILITIES = ['report-status'] + COMMON_CAPABILITIES
86
87
88 class ReportStatusParser(object):
89     """Handle status as reported by servers with 'report-status' capability.
90     """
91
92     def __init__(self):
93         self._done = False
94         self._pack_status = None
95         self._ref_status_ok = True
96         self._ref_statuses = []
97
98     def check(self):
99         """Check if there were any errors and, if so, raise exceptions.
100
101         :raise SendPackError: Raised when the server could not unpack
102         :raise UpdateRefsError: Raised when refs could not be updated
103         """
104         if self._pack_status not in ('unpack ok', None):
105             raise SendPackError(self._pack_status)
106         if not self._ref_status_ok:
107             ref_status = {}
108             ok = set()
109             for status in self._ref_statuses:
110                 if ' ' not in status:
111                     # malformed response, move on to the next one
112                     continue
113                 status, ref = status.split(' ', 1)
114
115                 if status == 'ng':
116                     if ' ' in ref:
117                         ref, status = ref.split(' ', 1)
118                 else:
119                     ok.add(ref)
120                 ref_status[ref] = status
121             raise UpdateRefsError('%s failed to update' %
122                                   ', '.join([ref for ref in ref_status
123                                              if ref not in ok]),
124                                   ref_status=ref_status)
125
126     def handle_packet(self, pkt):
127         """Handle a packet.
128
129         :raise GitProtocolError: Raised when packets are received after a
130             flush packet.
131         """
132         if self._done:
133             raise GitProtocolError("received more data after status report")
134         if pkt is None:
135             self._done = True
136             return
137         if self._pack_status is None:
138             self._pack_status = pkt.strip()
139         else:
140             ref_status = pkt.strip()
141             self._ref_statuses.append(ref_status)
142             if not ref_status.startswith('ok '):
143                 self._ref_status_ok = False
144
145
146 def read_pkt_refs(proto):
147     server_capabilities = None
148     refs = {}
149     # Receive refs from server
150     for pkt in proto.read_pkt_seq():
151         (sha, ref) = pkt.rstrip('\n').split(None, 1)
152         if sha == 'ERR':
153             raise GitProtocolError(ref)
154         if server_capabilities is None:
155             (ref, server_capabilities) = extract_capabilities(ref)
156         refs[ref] = sha
157
158     if len(refs) == 0:
159         return None, set([])
160     return refs, set(server_capabilities)
161
162
163 # TODO(durin42): this doesn't correctly degrade if the server doesn't
164 # support some capabilities. This should work properly with servers
165 # that don't support multi_ack.
166 class GitClient(object):
167     """Git smart server client.
168
169     """
170
171     def __init__(self, thin_packs=True, report_activity=None):
172         """Create a new GitClient instance.
173
174         :param thin_packs: Whether or not thin packs should be retrieved
175         :param report_activity: Optional callback for reporting transport
176             activity.
177         """
178         self._report_activity = report_activity
179         self._report_status_parser = None
180         self._fetch_capabilities = set(FETCH_CAPABILITIES)
181         self._send_capabilities = set(SEND_CAPABILITIES)
182         if not thin_packs:
183             self._fetch_capabilities.remove('thin-pack')
184
185     def send_pack(self, path, determine_wants, generate_pack_contents,
186                   progress=None, write_pack=write_pack_objects):
187         """Upload a pack to a remote repository.
188
189         :param path: Repository path
190         :param generate_pack_contents: Function that can return a sequence of
191             the shas of the objects to upload.
192         :param progress: Optional progress function
193         :param write_pack: Function called with (file, iterable of objects) to
194             write the objects returned by generate_pack_contents to the server.
195
196         :raises SendPackError: if server rejects the pack data
197         :raises UpdateRefsError: if the server supports report-status
198                                  and rejects ref updates
199         """
200         raise NotImplementedError(self.send_pack)
201
202     def fetch(self, path, target, determine_wants=None, progress=None):
203         """Fetch into a target repository.
204
205         :param path: Path to fetch from
206         :param target: Target repository to fetch into
207         :param determine_wants: Optional function to determine what refs
208             to fetch
209         :param progress: Optional progress function
210         :return: remote refs as dictionary
211         """
212         if determine_wants is None:
213             determine_wants = target.object_store.determine_wants_all
214         f, commit, abort = target.object_store.add_pack()
215         try:
216             result = self.fetch_pack(
217                 path, determine_wants, target.get_graph_walker(), f.write,
218                 progress)
219         except:
220             abort()
221             raise
222         else:
223             commit()
224         return result
225
226     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
227                    progress=None):
228         """Retrieve a pack from a git smart server.
229
230         :param determine_wants: Callback that returns list of commits to fetch
231         :param graph_walker: Object with next() and ack().
232         :param pack_data: Callback called for each bit of data in the pack
233         :param progress: Callback for progress reports (strings)
234         """
235         raise NotImplementedError(self.fetch_pack)
236
237     def _parse_status_report(self, proto):
238         unpack = proto.read_pkt_line().strip()
239         if unpack != 'unpack ok':
240             st = True
241             # flush remaining error data
242             while st is not None:
243                 st = proto.read_pkt_line()
244             raise SendPackError(unpack)
245         statuses = []
246         errs = False
247         ref_status = proto.read_pkt_line()
248         while ref_status:
249             ref_status = ref_status.strip()
250             statuses.append(ref_status)
251             if not ref_status.startswith('ok '):
252                 errs = True
253             ref_status = proto.read_pkt_line()
254
255         if errs:
256             ref_status = {}
257             ok = set()
258             for status in statuses:
259                 if ' ' not in status:
260                     # malformed response, move on to the next one
261                     continue
262                 status, ref = status.split(' ', 1)
263
264                 if status == 'ng':
265                     if ' ' in ref:
266                         ref, status = ref.split(' ', 1)
267                 else:
268                     ok.add(ref)
269                 ref_status[ref] = status
270             raise UpdateRefsError('%s failed to update' %
271                                   ', '.join([ref for ref in ref_status
272                                              if ref not in ok]),
273                                   ref_status=ref_status)
274
275     def _read_side_band64k_data(self, proto, channel_callbacks):
276         """Read per-channel data.
277
278         This requires the side-band-64k capability.
279
280         :param proto: Protocol object to read from
281         :param channel_callbacks: Dictionary mapping channels to packet
282             handlers to use. None for a callback discards channel data.
283         """
284         for pkt in proto.read_pkt_seq():
285             channel = ord(pkt[0])
286             pkt = pkt[1:]
287             try:
288                 cb = channel_callbacks[channel]
289             except KeyError:
290                 raise AssertionError('Invalid sideband channel %d' % channel)
291             else:
292                 if cb is not None:
293                     cb(pkt)
294
295     def _handle_receive_pack_head(self, proto, capabilities, old_refs,
296                                   new_refs):
297         """Handle the head of a 'git-receive-pack' request.
298
299         :param proto: Protocol object to read from
300         :param capabilities: List of negotiated capabilities
301         :param old_refs: Old refs, as received from the server
302         :param new_refs: New refs
303         :return: (have, want) tuple
304         """
305         want = []
306         have = [x for x in old_refs.values() if not x == ZERO_SHA]
307         sent_capabilities = False
308
309         for refname in set(new_refs.keys() + old_refs.keys()):
310             old_sha1 = old_refs.get(refname, ZERO_SHA)
311             new_sha1 = new_refs.get(refname, ZERO_SHA)
312
313             if old_sha1 != new_sha1:
314                 if sent_capabilities:
315                     proto.write_pkt_line('%s %s %s' % (
316                         old_sha1, new_sha1, refname))
317                 else:
318                     proto.write_pkt_line(
319                         '%s %s %s\0%s' % (old_sha1, new_sha1, refname,
320                                           ' '.join(capabilities)))
321                     sent_capabilities = True
322             if new_sha1 not in have and new_sha1 != ZERO_SHA:
323                 want.append(new_sha1)
324         proto.write_pkt_line(None)
325         return (have, want)
326
327     def _handle_receive_pack_tail(self, proto, capabilities, progress=None):
328         """Handle the tail of a 'git-receive-pack' request.
329
330         :param proto: Protocol object to read from
331         :param capabilities: List of negotiated capabilities
332         :param progress: Optional progress reporting function
333         """
334         if "side-band-64k" in capabilities:
335             if progress is None:
336                 progress = lambda x: None
337             channel_callbacks = {2: progress}
338             if 'report-status' in capabilities:
339                 channel_callbacks[1] = PktLineParser(
340                     self._report_status_parser.handle_packet).parse
341             self._read_side_band64k_data(proto, channel_callbacks)
342         else:
343             if 'report-status' in capabilities:
344                 for pkt in proto.read_pkt_seq():
345                     self._report_status_parser.handle_packet(pkt)
346         if self._report_status_parser is not None:
347             self._report_status_parser.check()
348
349     def _handle_upload_pack_head(self, proto, capabilities, graph_walker,
350                                  wants, can_read):
351         """Handle the head of a 'git-upload-pack' request.
352
353         :param proto: Protocol object to read from
354         :param capabilities: List of negotiated capabilities
355         :param graph_walker: GraphWalker instance to call .ack() on
356         :param wants: List of commits to fetch
357         :param can_read: function that returns a boolean that indicates
358             whether there is extra graph data to read on proto
359         """
360         assert isinstance(wants, list) and isinstance(wants[0], str)
361         proto.write_pkt_line('want %s %s\n' % (
362             wants[0], ' '.join(capabilities)))
363         for want in wants[1:]:
364             proto.write_pkt_line('want %s\n' % want)
365         proto.write_pkt_line(None)
366         have = next(graph_walker)
367         while have:
368             proto.write_pkt_line('have %s\n' % have)
369             if can_read():
370                 pkt = proto.read_pkt_line()
371                 parts = pkt.rstrip('\n').split(' ')
372                 if parts[0] == 'ACK':
373                     graph_walker.ack(parts[1])
374                     if parts[2] in ('continue', 'common'):
375                         pass
376                     elif parts[2] == 'ready':
377                         break
378                     else:
379                         raise AssertionError(
380                             "%s not in ('continue', 'ready', 'common)" %
381                             parts[2])
382             have = next(graph_walker)
383         proto.write_pkt_line('done\n')
384
385     def _handle_upload_pack_tail(self, proto, capabilities, graph_walker,
386                                  pack_data, progress=None, rbufsize=_RBUFSIZE):
387         """Handle the tail of a 'git-upload-pack' request.
388
389         :param proto: Protocol object to read from
390         :param capabilities: List of negotiated capabilities
391         :param graph_walker: GraphWalker instance to call .ack() on
392         :param pack_data: Function to call with pack data
393         :param progress: Optional progress reporting function
394         :param rbufsize: Read buffer size
395         """
396         pkt = proto.read_pkt_line()
397         while pkt:
398             parts = pkt.rstrip('\n').split(' ')
399             if parts[0] == 'ACK':
400                 graph_walker.ack(parts[1])
401             if len(parts) < 3 or parts[2] not in (
402                     'ready', 'continue', 'common'):
403                 break
404             pkt = proto.read_pkt_line()
405         if "side-band-64k" in capabilities:
406             if progress is None:
407                 # Just ignore progress data
408                 progress = lambda x: None
409             self._read_side_band64k_data(proto, {1: pack_data, 2: progress})
410         else:
411             while True:
412                 data = proto.read(rbufsize)
413                 if data == "":
414                     break
415                 pack_data(data)
416
417
418 class TraditionalGitClient(GitClient):
419     """Traditional Git client."""
420
421     def _connect(self, cmd, path):
422         """Create a connection to the server.
423
424         This method is abstract - concrete implementations should
425         implement their own variant which connects to the server and
426         returns an initialized Protocol object with the service ready
427         for use and a can_read function which may be used to see if
428         reads would block.
429
430         :param cmd: The git service name to which we should connect.
431         :param path: The path we should pass to the service.
432         """
433         raise NotImplementedError()
434
435     def send_pack(self, path, determine_wants, generate_pack_contents,
436                   progress=None, write_pack=write_pack_objects):
437         """Upload a pack to a remote repository.
438
439         :param path: Repository path
440         :param generate_pack_contents: Function that can return a sequence of
441             the shas of the objects to upload.
442         :param progress: Optional callback called with progress updates
443         :param write_pack: Function called with (file, iterable of objects) to
444             write the objects returned by generate_pack_contents to the server.
445
446         :raises SendPackError: if server rejects the pack data
447         :raises UpdateRefsError: if the server supports report-status
448                                  and rejects ref updates
449         """
450         proto, unused_can_read = self._connect('receive-pack', path)
451         with proto:
452             old_refs, server_capabilities = read_pkt_refs(proto)
453             negotiated_capabilities = self._send_capabilities & server_capabilities
454
455             if 'report-status' in negotiated_capabilities:
456                 self._report_status_parser = ReportStatusParser()
457             report_status_parser = self._report_status_parser
458
459             try:
460                 new_refs = orig_new_refs = determine_wants(dict(old_refs))
461             except:
462                 proto.write_pkt_line(None)
463                 raise
464
465             if not 'delete-refs' in server_capabilities:
466                 # Server does not support deletions. Fail later.
467                 new_refs = dict(orig_new_refs)
468                 for ref, sha in orig_new_refs.iteritems():
469                     if sha == ZERO_SHA:
470                         if 'report-status' in negotiated_capabilities:
471                             report_status_parser._ref_statuses.append(
472                                 'ng %s remote does not support deleting refs'
473                                 % sha)
474                             report_status_parser._ref_status_ok = False
475                         del new_refs[ref]
476
477             if new_refs is None:
478                 proto.write_pkt_line(None)
479                 return old_refs
480
481             if len(new_refs) == 0 and len(orig_new_refs):
482                 # NOOP - Original new refs filtered out by policy
483                 proto.write_pkt_line(None)
484                 if report_status_parser is not None:
485                     report_status_parser.check()
486                 return old_refs
487
488             (have, want) = self._handle_receive_pack_head(
489                 proto, negotiated_capabilities, old_refs, new_refs)
490             if not want and old_refs == new_refs:
491                 return new_refs
492             objects = generate_pack_contents(have, want)
493
494             dowrite = len(objects) > 0
495             dowrite = dowrite or any(old_refs.get(ref) != sha
496                                      for (ref, sha) in new_refs.iteritems()
497                                      if sha != ZERO_SHA)
498             if dowrite:
499                 write_pack(proto.write_file(), objects)
500
501             self._handle_receive_pack_tail(
502                 proto, negotiated_capabilities, progress)
503             return new_refs
504
505     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
506                    progress=None):
507         """Retrieve a pack from a git smart server.
508
509         :param determine_wants: Callback that returns list of commits to fetch
510         :param graph_walker: Object with next() and ack().
511         :param pack_data: Callback called for each bit of data in the pack
512         :param progress: Callback for progress reports (strings)
513         """
514         proto, can_read = self._connect('upload-pack', path)
515         with proto:
516             refs, server_capabilities = read_pkt_refs(proto)
517             negotiated_capabilities = (
518                 self._fetch_capabilities & server_capabilities)
519
520             if refs is None:
521                 proto.write_pkt_line(None)
522                 return refs
523
524             try:
525                 wants = determine_wants(refs)
526             except:
527                 proto.write_pkt_line(None)
528                 raise
529             if wants is not None:
530                 wants = [cid for cid in wants if cid != ZERO_SHA]
531             if not wants:
532                 proto.write_pkt_line(None)
533                 return refs
534             self._handle_upload_pack_head(
535                 proto, negotiated_capabilities, graph_walker, wants, can_read)
536             self._handle_upload_pack_tail(
537                 proto, negotiated_capabilities, graph_walker, pack_data, progress)
538             return refs
539
540     def archive(self, path, committish, write_data, progress=None,
541                 write_error=None):
542         proto, can_read = self._connect(b'upload-archive', path)
543         with proto:
544             proto.write_pkt_line("argument %s" % committish)
545             proto.write_pkt_line(None)
546             pkt = proto.read_pkt_line()
547             if pkt == "NACK\n":
548                 return
549             elif pkt == "ACK\n":
550                 pass
551             elif pkt.startswith("ERR "):
552                 raise GitProtocolError(pkt[4:].rstrip("\n"))
553             else:
554                 raise AssertionError("invalid response %r" % pkt)
555             ret = proto.read_pkt_line()
556             if ret is not None:
557                 raise AssertionError("expected pkt tail")
558             self._read_side_band64k_data(proto, {
559                 1: write_data, 2: progress, 3: write_error})
560
561
562 class TCPGitClient(TraditionalGitClient):
563     """A Git Client that works over TCP directly (i.e. git://)."""
564
565     def __init__(self, host, port=None, *args, **kwargs):
566         if port is None:
567             port = TCP_GIT_PORT
568         self._host = host
569         self._port = port
570         TraditionalGitClient.__init__(self, *args, **kwargs)
571
572     def _connect(self, cmd, path):
573         sockaddrs = socket.getaddrinfo(
574             self._host, self._port, socket.AF_UNSPEC, socket.SOCK_STREAM)
575         s = None
576         err = socket.error("no address found for %s" % self._host)
577         for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
578             s = socket.socket(family, socktype, proto)
579             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
580             try:
581                 s.connect(sockaddr)
582                 break
583             except socket.error as err:
584                 if s is not None:
585                     s.close()
586                 s = None
587         if s is None:
588             raise err
589         # -1 means system default buffering
590         rfile = s.makefile('rb', -1)
591         # 0 means unbuffered
592         wfile = s.makefile('wb', 0)
593         def close():
594             rfile.close()
595             wfile.close()
596             s.close()
597
598         proto = Protocol(rfile.read, wfile.write, close,
599                          report_activity=self._report_activity)
600         if path.startswith("/~"):
601             path = path[1:]
602         proto.send_cmd('git-%s' % cmd, path, 'host=%s' % self._host)
603         return proto, lambda: _fileno_can_read(s)
604
605
606 class SubprocessWrapper(object):
607     """A socket-like object that talks to a subprocess via pipes."""
608
609     def __init__(self, proc):
610         self.proc = proc
611         self.read = proc.stdout.read
612         self.write = proc.stdin.write
613
614     def can_read(self):
615         if subprocess.mswindows:
616             from msvcrt import get_osfhandle
617             from win32pipe import PeekNamedPipe
618             handle = get_osfhandle(self.proc.stdout.fileno())
619             return PeekNamedPipe(handle, 0)[2] != 0
620         else:
621             return _fileno_can_read(self.proc.stdout.fileno())
622
623     def close(self):
624         self.proc.stdin.close()
625         self.proc.stdout.close()
626         if self.proc.stderr:
627             self.proc.stderr.close()
628         self.proc.wait()
629
630
631 class SubprocessGitClient(TraditionalGitClient):
632     """Git client that talks to a server using a subprocess."""
633
634     def __init__(self, *args, **kwargs):
635         self._connection = None
636         self._stderr = None
637         self._stderr = kwargs.get('stderr')
638         if 'stderr' in kwargs:
639             del kwargs['stderr']
640         TraditionalGitClient.__init__(self, *args, **kwargs)
641
642     def _connect(self, service, path):
643         import subprocess
644         argv = ['git', service, path]
645         p = SubprocessWrapper(
646             subprocess.Popen(argv, bufsize=0, stdin=subprocess.PIPE,
647                              stdout=subprocess.PIPE,
648                              stderr=self._stderr))
649         return Protocol(p.read, p.write, p.close,
650                         report_activity=self._report_activity), p.can_read
651
652
653 class LocalGitClient(GitClient):
654     """Git Client that just uses a local Repo."""
655
656     def __init__(self, thin_packs=True, report_activity=None):
657         """Create a new LocalGitClient instance.
658
659         :param path: Path to the local repository
660         :param thin_packs: Whether or not thin packs should be retrieved
661         :param report_activity: Optional callback for reporting transport
662             activity.
663         """
664         self._report_activity = report_activity
665         # Ignore the thin_packs argument
666
667     def send_pack(self, path, determine_wants, generate_pack_contents,
668                   progress=None, write_pack=write_pack_objects):
669         """Upload a pack to a remote repository.
670
671         :param path: Repository path
672         :param generate_pack_contents: Function that can return a sequence of
673             the shas of the objects to upload.
674         :param progress: Optional progress function
675         :param write_pack: Function called with (file, iterable of objects) to
676             write the objects returned by generate_pack_contents to the server.
677
678         :raises SendPackError: if server rejects the pack data
679         :raises UpdateRefsError: if the server supports report-status
680                                  and rejects ref updates
681         """
682         raise NotImplementedError(self.send_pack)
683
684     def fetch(self, path, target, determine_wants=None, progress=None):
685         """Fetch into a target repository.
686
687         :param path: Path to fetch from
688         :param target: Target repository to fetch into
689         :param determine_wants: Optional function to determine what refs
690             to fetch
691         :param progress: Optional progress function
692         :return: remote refs as dictionary
693         """
694         from dulwich.repo import Repo
695         r = Repo(path)
696         return r.fetch(target, determine_wants=determine_wants,
697                        progress=progress)
698
699     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
700                    progress=None):
701         """Retrieve a pack from a git smart server.
702
703         :param determine_wants: Callback that returns list of commits to fetch
704         :param graph_walker: Object with next() and ack().
705         :param pack_data: Callback called for each bit of data in the pack
706         :param progress: Callback for progress reports (strings)
707         """
708         from dulwich.repo import Repo
709         r = Repo(path)
710         objects_iter = r.fetch_objects(determine_wants, graph_walker, progress)
711
712         # Did the process short-circuit (e.g. in a stateless RPC call)? Note
713         # that the client still expects a 0-object pack in most cases.
714         if objects_iter is None:
715             return
716         write_pack_objects(ProtocolFile(None, pack_data), objects_iter)
717
718
719 # What Git client to use for local access
720 default_local_git_client_cls = SubprocessGitClient
721
722
723 class SSHVendor(object):
724     """A client side SSH implementation."""
725
726     def connect_ssh(self, host, command, username=None, port=None):
727         import warnings
728         warnings.warn(
729             "SSHVendor.connect_ssh has been renamed to SSHVendor.run_command",
730             DeprecationWarning)
731         return self.run_command(host, command, username=username, port=port)
732
733     def run_command(self, host, command, username=None, port=None):
734         """Connect to an SSH server.
735
736         Run a command remotely and return a file-like object for interaction
737         with the remote command.
738
739         :param host: Host name
740         :param command: Command to run
741         :param username: Optional ame of user to log in as
742         :param port: Optional SSH port to use
743         """
744         raise NotImplementedError(self.run_command)
745
746
747 class SubprocessSSHVendor(SSHVendor):
748     """SSH vendor that shells out to the local 'ssh' command."""
749
750     def run_command(self, host, command, username=None, port=None):
751         import subprocess
752         #FIXME: This has no way to deal with passwords..
753         args = ['ssh', '-x']
754         if port is not None:
755             args.extend(['-p', str(port)])
756         if username is not None:
757             host = '%s@%s' % (username, host)
758         args.append(host)
759         proc = subprocess.Popen(args + command,
760                                 stdin=subprocess.PIPE,
761                                 stdout=subprocess.PIPE)
762         return SubprocessWrapper(proc)
763
764
765 try:
766     import paramiko
767 except ImportError:
768     pass
769 else:
770     import threading
771
772     class ParamikoWrapper(object):
773         STDERR_READ_N = 2048  # 2k
774
775         def __init__(self, client, channel, progress_stderr=None):
776             self.client = client
777             self.channel = channel
778             self.progress_stderr = progress_stderr
779             self.should_monitor = bool(progress_stderr) or True
780             self.monitor_thread = None
781             self.stderr = ''
782
783             # Channel must block
784             self.channel.setblocking(True)
785
786             # Start
787             if self.should_monitor:
788                 self.monitor_thread = threading.Thread(
789                     target=self.monitor_stderr)
790                 self.monitor_thread.start()
791
792         def monitor_stderr(self):
793             while self.should_monitor:
794                 # Block and read
795                 data = self.read_stderr(self.STDERR_READ_N)
796
797                 # Socket closed
798                 if not data:
799                     self.should_monitor = False
800                     break
801
802                 # Emit data
803                 if self.progress_stderr:
804                     self.progress_stderr(data)
805
806                 # Append to buffer
807                 self.stderr += data
808
809         def stop_monitoring(self):
810             # Stop StdErr thread
811             if self.should_monitor:
812                 self.should_monitor = False
813                 self.monitor_thread.join()
814
815                 # Get left over data
816                 data = self.channel.in_stderr_buffer.empty()
817                 self.stderr += data
818
819         def can_read(self):
820             return self.channel.recv_ready()
821
822         def write(self, data):
823             return self.channel.sendall(data)
824
825         def read_stderr(self, n):
826             return self.channel.recv_stderr(n)
827
828         def read(self, n=None):
829             data = self.channel.recv(n)
830             data_len = len(data)
831
832             # Closed socket
833             if not data:
834                 return
835
836             # Read more if needed
837             if n and data_len < n:
838                 diff_len = n - data_len
839                 return data + self.read(diff_len)
840             return data
841
842         def close(self):
843             self.channel.close()
844             self.stop_monitoring()
845
846     class ParamikoSSHVendor(object):
847
848         def __init__(self):
849             self.ssh_kwargs = {}
850
851         def run_command(self, host, command, username=None, port=None,
852                         progress_stderr=None):
853
854             # Paramiko needs an explicit port. None is not valid
855             if port is None:
856                 port = 22
857
858             client = paramiko.SSHClient()
859
860             policy = paramiko.client.MissingHostKeyPolicy()
861             client.set_missing_host_key_policy(policy)
862             client.connect(host, username=username, port=port,
863                            **self.ssh_kwargs)
864
865             # Open SSH session
866             channel = client.get_transport().open_session()
867
868             # Run commands
869             channel.exec_command(*command)
870
871             return ParamikoWrapper(
872                 client, channel, progress_stderr=progress_stderr)
873
874
875 # Can be overridden by users
876 get_ssh_vendor = SubprocessSSHVendor
877
878
879 class SSHGitClient(TraditionalGitClient):
880
881     def __init__(self, host, port=None, username=None, *args, **kwargs):
882         self.host = host
883         self.port = port
884         self.username = username
885         TraditionalGitClient.__init__(self, *args, **kwargs)
886         self.alternative_paths = {}
887
888     def _get_cmd_path(self, cmd):
889         return self.alternative_paths.get(cmd, 'git-%s' % cmd)
890
891     def _connect(self, cmd, path):
892         if path.startswith("/~"):
893             path = path[1:]
894         con = get_ssh_vendor().run_command(
895             self.host, ["%s '%s'" % (self._get_cmd_path(cmd), path)],
896             port=self.port, username=self.username)
897         return (Protocol(con.read, con.write, con.close, 
898                          report_activity=self._report_activity), 
899                 con.can_read)
900
901
902 def default_user_agent_string():
903     return "dulwich/%s" % ".".join([str(x) for x in dulwich.__version__])
904
905
906 def default_urllib2_opener(config):
907     if config is not None:
908         proxy_server = config.get("http", "proxy")
909     else:
910         proxy_server = None
911     handlers = []
912     if proxy_server is not None:
913         handlers.append(urllib2.ProxyHandler({"http": proxy_server}))
914     opener = urllib2.build_opener(*handlers)
915     if config is not None:
916         user_agent = config.get("http", "useragent")
917     else:
918         user_agent = None
919     if user_agent is None:
920         user_agent = default_user_agent_string()
921     opener.addheaders = [('User-agent', user_agent)]
922     return opener
923
924
925 class HttpGitClient(GitClient):
926
927     def __init__(self, base_url, dumb=None, opener=None, config=None, *args,
928                  **kwargs):
929         self.base_url = base_url.rstrip("/") + "/"
930         self.dumb = dumb
931         if opener is None:
932             self.opener = default_urllib2_opener(config)
933         else:
934             self.opener = opener
935         GitClient.__init__(self, *args, **kwargs)
936
937     def _get_url(self, path):
938         return urlparse.urljoin(self.base_url, path).rstrip("/") + "/"
939
940     def _http_request(self, url, headers={}, data=None):
941         req = urllib2.Request(url, headers=headers, data=data)
942         try:
943             resp = self.opener.open(req)
944         except urllib2.HTTPError as e:
945             if e.code == 404:
946                 raise NotGitRepository()
947             if e.code != 200:
948                 raise GitProtocolError("unexpected http response %d" % e.code)
949         return resp
950
951     def _discover_references(self, service, url):
952         assert url[-1] == "/"
953         url = urlparse.urljoin(url, "info/refs")
954         headers = {}
955         if self.dumb is not False:
956             url += "?service=%s" % service
957             headers["Content-Type"] = "application/x-%s-request" % service
958         resp = self._http_request(url, headers)
959         try:
960             self.dumb = (not resp.info().gettype().startswith("application/x-git-"))
961             if not self.dumb:
962                 proto = Protocol(resp.read, None)
963                 # The first line should mention the service
964                 pkts = list(proto.read_pkt_seq())
965                 if pkts != [('# service=%s\n' % service)]:
966                     raise GitProtocolError(
967                         "unexpected first line %r from smart server" % pkts)
968                 return read_pkt_refs(proto)
969             else:
970                 return read_info_refs(resp), set()
971         finally:
972             resp.close()
973
974     def _smart_request(self, service, url, data):
975         assert url[-1] == "/"
976         url = urlparse.urljoin(url, service)
977         headers = {"Content-Type": "application/x-%s-request" % service}
978         resp = self._http_request(url, headers, data)
979         if resp.info().gettype() != ("application/x-%s-result" % service):
980             raise GitProtocolError("Invalid content-type from server: %s"
981                 % resp.info().gettype())
982         return resp
983
984     def send_pack(self, path, determine_wants, generate_pack_contents,
985                   progress=None, write_pack=write_pack_objects):
986         """Upload a pack to a remote repository.
987
988         :param path: Repository path
989         :param generate_pack_contents: Function that can return a sequence of
990             the shas of the objects to upload.
991         :param progress: Optional progress function
992         :param write_pack: Function called with (file, iterable of objects) to
993             write the objects returned by generate_pack_contents to the server.
994
995         :raises SendPackError: if server rejects the pack data
996         :raises UpdateRefsError: if the server supports report-status
997                                  and rejects ref updates
998         """
999         url = self._get_url(path)
1000         old_refs, server_capabilities = self._discover_references(
1001             "git-receive-pack", url)
1002         negotiated_capabilities = self._send_capabilities & server_capabilities
1003
1004         if 'report-status' in negotiated_capabilities:
1005             self._report_status_parser = ReportStatusParser()
1006
1007         new_refs = determine_wants(dict(old_refs))
1008         if new_refs is None:
1009             return old_refs
1010         if self.dumb:
1011             raise NotImplementedError(self.fetch_pack)
1012         req_data = BytesIO()
1013         req_proto = Protocol(None, req_data.write)
1014         (have, want) = self._handle_receive_pack_head(
1015             req_proto, negotiated_capabilities, old_refs, new_refs)
1016         if not want and old_refs == new_refs:
1017             return new_refs
1018         objects = generate_pack_contents(have, want)
1019         if len(objects) > 0:
1020             write_pack(req_proto.write_file(), objects)
1021         resp = self._smart_request("git-receive-pack", url,
1022                                    data=req_data.getvalue())
1023         try:
1024             resp_proto = Protocol(resp.read, None)
1025             self._handle_receive_pack_tail(resp_proto, negotiated_capabilities,
1026                 progress)
1027             return new_refs
1028         finally:
1029             resp.close()
1030
1031
1032     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
1033                    progress=None):
1034         """Retrieve a pack from a git smart server.
1035
1036         :param determine_wants: Callback that returns list of commits to fetch
1037         :param graph_walker: Object with next() and ack().
1038         :param pack_data: Callback called for each bit of data in the pack
1039         :param progress: Callback for progress reports (strings)
1040         :return: Dictionary with the refs of the remote repository
1041         """
1042         url = self._get_url(path)
1043         refs, server_capabilities = self._discover_references(
1044             "git-upload-pack", url)
1045         negotiated_capabilities = self._fetch_capabilities & server_capabilities
1046         wants = determine_wants(refs)
1047         if wants is not None:
1048             wants = [cid for cid in wants if cid != ZERO_SHA]
1049         if not wants:
1050             return refs
1051         if self.dumb:
1052             raise NotImplementedError(self.send_pack)
1053         req_data = BytesIO()
1054         req_proto = Protocol(None, req_data.write)
1055         self._handle_upload_pack_head(
1056             req_proto, negotiated_capabilities, graph_walker, wants,
1057             lambda: False)
1058         resp = self._smart_request(
1059             "git-upload-pack", url, data=req_data.getvalue())
1060         try:
1061             resp_proto = Protocol(resp.read, None)
1062             self._handle_upload_pack_tail(resp_proto, negotiated_capabilities,
1063                 graph_walker, pack_data, progress)
1064             return refs
1065         finally:
1066             resp.close()
1067
1068
1069 def get_transport_and_path_from_url(url, config=None, **kwargs):
1070     """Obtain a git client from a URL.
1071
1072     :param url: URL to open
1073     :param config: Optional config object
1074     :param thin_packs: Whether or not thin packs should be retrieved
1075     :param report_activity: Optional callback for reporting transport
1076         activity.
1077     :return: Tuple with client instance and relative path.
1078     """
1079     parsed = urlparse.urlparse(url)
1080     if parsed.scheme == 'git':
1081         return (TCPGitClient(parsed.hostname, port=parsed.port, **kwargs),
1082                 parsed.path)
1083     elif parsed.scheme == 'git+ssh':
1084         path = parsed.path
1085         if path.startswith('/'):
1086             path = parsed.path[1:]
1087         return SSHGitClient(parsed.hostname, port=parsed.port,
1088                             username=parsed.username, **kwargs), path
1089     elif parsed.scheme in ('http', 'https'):
1090         return HttpGitClient(urlparse.urlunparse(parsed), config=config,
1091                 **kwargs), parsed.path
1092     elif parsed.scheme == 'file':
1093         return default_local_git_client_cls(**kwargs), parsed.path
1094
1095     raise ValueError("unknown scheme '%s'" % parsed.scheme)
1096
1097
1098 def get_transport_and_path(location, **kwargs):
1099     """Obtain a git client from a URL.
1100
1101     :param location: URL or path
1102     :param config: Optional config object
1103     :param thin_packs: Whether or not thin packs should be retrieved
1104     :param report_activity: Optional callback for reporting transport
1105         activity.
1106     :return: Tuple with client instance and relative path.
1107     """
1108     # First, try to parse it as a URL
1109     try:
1110         return get_transport_and_path_from_url(location, **kwargs)
1111     except ValueError:
1112         pass
1113
1114     if (sys.platform == 'win32' and
1115             location[0].isalpha() and location[1:2] == ':\\'):
1116         # Windows local path
1117         return default_local_git_client_cls(**kwargs), location
1118
1119     if ':' in location and not '@' in location:
1120         # SSH with no user@, zero or one leading slash.
1121         (hostname, path) = location.split(':')
1122         return SSHGitClient(hostname, **kwargs), path
1123     elif '@' in location and ':' in location:
1124         # SSH with user@host:foo.
1125         user_host, path = location.split(':')
1126         user, host = user_host.rsplit('@')
1127         return SSHGitClient(host, username=user, **kwargs), path
1128
1129     # Otherwise, assume it's a local path.
1130     return default_local_git_client_cls(**kwargs), location