cf355cf1619931af1f855803c5b4efe6da5f289b
[jelmer/dulwich.git] / dulwich / client.py
1 # client.py -- Implementation of the client side git protocols
2 # Copyright (C) 2008-2013 Jelmer Vernooij <jelmer@samba.org>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; either version 2
7 # or (at your option) a later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 """Client side support for the Git protocol.
20
21 The Dulwich client supports the following capabilities:
22
23  * thin-pack
24  * multi_ack_detailed
25  * multi_ack
26  * side-band-64k
27  * ofs-delta
28  * quiet
29  * report-status
30  * delete-refs
31
32 Known capabilities that are not supported:
33
34  * shallow
35  * no-progress
36  * include-tag
37 """
38
39 __docformat__ = 'restructuredText'
40
41 from contextlib import closing
42 from io import BytesIO, BufferedReader
43 import dulwich
44 import select
45 import shlex
46 import socket
47 import subprocess
48 import sys
49
50 try:
51     import urllib2
52     import urlparse
53 except ImportError:
54     import urllib.request as urllib2
55     import urllib.parse as urlparse
56
57 from dulwich.errors import (
58     GitProtocolError,
59     NotGitRepository,
60     SendPackError,
61     UpdateRefsError,
62     )
63 from dulwich.protocol import (
64     _RBUFSIZE,
65     capability_agent,
66     CAPABILITY_DELETE_REFS,
67     CAPABILITY_MULTI_ACK,
68     CAPABILITY_MULTI_ACK_DETAILED,
69     CAPABILITY_OFS_DELTA,
70     CAPABILITY_QUIET,
71     CAPABILITY_REPORT_STATUS,
72     CAPABILITY_SIDE_BAND_64K,
73     CAPABILITY_THIN_PACK,
74     COMMAND_DONE,
75     COMMAND_HAVE,
76     COMMAND_WANT,
77     SIDE_BAND_CHANNEL_DATA,
78     SIDE_BAND_CHANNEL_PROGRESS,
79     SIDE_BAND_CHANNEL_FATAL,
80     PktLineParser,
81     Protocol,
82     ProtocolFile,
83     TCP_GIT_PORT,
84     ZERO_SHA,
85     extract_capabilities,
86     )
87 from dulwich.pack import (
88     write_pack_objects,
89     )
90 from dulwich.refs import (
91     read_info_refs,
92     )
93
94
95 def _fileno_can_read(fileno):
96     """Check if a file descriptor is readable."""
97     return len(select.select([fileno], [], [], 0)[0]) > 0
98
99
100 COMMON_CAPABILITIES = [CAPABILITY_OFS_DELTA, CAPABILITY_SIDE_BAND_64K]
101 FETCH_CAPABILITIES = ([CAPABILITY_THIN_PACK, CAPABILITY_MULTI_ACK,
102                        CAPABILITY_MULTI_ACK_DETAILED] +
103                       COMMON_CAPABILITIES)
104 SEND_CAPABILITIES = [CAPABILITY_REPORT_STATUS] + COMMON_CAPABILITIES
105
106
107 class ReportStatusParser(object):
108     """Handle status as reported by servers with 'report-status' capability.
109     """
110
111     def __init__(self):
112         self._done = False
113         self._pack_status = None
114         self._ref_status_ok = True
115         self._ref_statuses = []
116
117     def check(self):
118         """Check if there were any errors and, if so, raise exceptions.
119
120         :raise SendPackError: Raised when the server could not unpack
121         :raise UpdateRefsError: Raised when refs could not be updated
122         """
123         if self._pack_status not in (b'unpack ok', None):
124             raise SendPackError(self._pack_status)
125         if not self._ref_status_ok:
126             ref_status = {}
127             ok = set()
128             for status in self._ref_statuses:
129                 if b' ' not in status:
130                     # malformed response, move on to the next one
131                     continue
132                 status, ref = status.split(b' ', 1)
133
134                 if status == b'ng':
135                     if b' ' in ref:
136                         ref, status = ref.split(b' ', 1)
137                 else:
138                     ok.add(ref)
139                 ref_status[ref] = status
140             # TODO(jelmer): don't assume encoding of refs is ascii.
141             raise UpdateRefsError(', '.join([
142                 ref.decode('ascii') for ref in ref_status if ref not in ok]) +
143                 ' failed to update', ref_status=ref_status)
144
145     def handle_packet(self, pkt):
146         """Handle a packet.
147
148         :raise GitProtocolError: Raised when packets are received after a
149             flush packet.
150         """
151         if self._done:
152             raise GitProtocolError("received more data after status report")
153         if pkt is None:
154             self._done = True
155             return
156         if self._pack_status is None:
157             self._pack_status = pkt.strip()
158         else:
159             ref_status = pkt.strip()
160             self._ref_statuses.append(ref_status)
161             if not ref_status.startswith(b'ok '):
162                 self._ref_status_ok = False
163
164
165 def read_pkt_refs(proto):
166     server_capabilities = None
167     refs = {}
168     # Receive refs from server
169     for pkt in proto.read_pkt_seq():
170         (sha, ref) = pkt.rstrip(b'\n').split(None, 1)
171         if sha == b'ERR':
172             raise GitProtocolError(ref)
173         if server_capabilities is None:
174             (ref, server_capabilities) = extract_capabilities(ref)
175         refs[ref] = sha
176
177     if len(refs) == 0:
178         return None, set([])
179     return refs, set(server_capabilities)
180
181
182 # TODO(durin42): this doesn't correctly degrade if the server doesn't
183 # support some capabilities. This should work properly with servers
184 # that don't support multi_ack.
185 class GitClient(object):
186     """Git smart server client.
187
188     """
189
190     def __init__(self, thin_packs=True, report_activity=None, quiet=False):
191         """Create a new GitClient instance.
192
193         :param thin_packs: Whether or not thin packs should be retrieved
194         :param report_activity: Optional callback for reporting transport
195             activity.
196         """
197         self._report_activity = report_activity
198         self._report_status_parser = None
199         self._fetch_capabilities = set(FETCH_CAPABILITIES)
200         self._fetch_capabilities.add(capability_agent())
201         self._send_capabilities = set(SEND_CAPABILITIES)
202         self._send_capabilities.add(capability_agent())
203         if quiet:
204             self._send_capabilities.add(CAPABILITY_QUIET)
205         if not thin_packs:
206             self._fetch_capabilities.remove(CAPABILITY_THIN_PACK)
207
208     def send_pack(self, path, determine_wants, generate_pack_contents,
209                   progress=None, write_pack=write_pack_objects):
210         """Upload a pack to a remote repository.
211
212         :param path: Repository path (as bytestring)
213         :param generate_pack_contents: Function that can return a sequence of
214             the shas of the objects to upload.
215         :param progress: Optional progress function
216         :param write_pack: Function called with (file, iterable of objects) to
217             write the objects returned by generate_pack_contents to the server.
218
219         :raises SendPackError: if server rejects the pack data
220         :raises UpdateRefsError: if the server supports report-status
221                                  and rejects ref updates
222         """
223         raise NotImplementedError(self.send_pack)
224
225     def fetch(self, path, target, determine_wants=None, progress=None):
226         """Fetch into a target repository.
227
228         :param path: Path to fetch from (as bytestring)
229         :param target: Target repository to fetch into
230         :param determine_wants: Optional function to determine what refs
231             to fetch
232         :param progress: Optional progress function
233         :return: remote refs as dictionary
234         """
235         if determine_wants is None:
236             determine_wants = target.object_store.determine_wants_all
237
238         f, commit, abort = target.object_store.add_pack()
239         try:
240             result = self.fetch_pack(
241                 path, determine_wants, target.get_graph_walker(), f.write,
242                 progress)
243         except:
244             abort()
245             raise
246         else:
247             commit()
248         return result
249
250     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
251                    progress=None):
252         """Retrieve a pack from a git smart server.
253
254         :param determine_wants: Callback that returns list of commits to fetch
255         :param graph_walker: Object with next() and ack().
256         :param pack_data: Callback called for each bit of data in the pack
257         :param progress: Callback for progress reports (strings)
258         """
259         raise NotImplementedError(self.fetch_pack)
260
261     def get_refs(self, path):
262         """Retrieve the current refs from a git smart server.
263
264         :param path: Path to the repo to fetch from. (as bytestring)
265         """
266         raise NotImplementedError(self.get_refs)
267
268     def _parse_status_report(self, proto):
269         unpack = proto.read_pkt_line().strip()
270         if unpack != b'unpack ok':
271             st = True
272             # flush remaining error data
273             while st is not None:
274                 st = proto.read_pkt_line()
275             raise SendPackError(unpack)
276         statuses = []
277         errs = False
278         ref_status = proto.read_pkt_line()
279         while ref_status:
280             ref_status = ref_status.strip()
281             statuses.append(ref_status)
282             if not ref_status.startswith(b'ok '):
283                 errs = True
284             ref_status = proto.read_pkt_line()
285
286         if errs:
287             ref_status = {}
288             ok = set()
289             for status in statuses:
290                 if b' ' not in status:
291                     # malformed response, move on to the next one
292                     continue
293                 status, ref = status.split(b' ', 1)
294
295                 if status == b'ng':
296                     if b' ' in ref:
297                         ref, status = ref.split(b' ', 1)
298                 else:
299                     ok.add(ref)
300                 ref_status[ref] = status
301             raise UpdateRefsError(', '.join([ref for ref in ref_status
302                                              if ref not in ok]) +
303                                              b' failed to update',
304                                   ref_status=ref_status)
305
306     def _read_side_band64k_data(self, proto, channel_callbacks):
307         """Read per-channel data.
308
309         This requires the side-band-64k capability.
310
311         :param proto: Protocol object to read from
312         :param channel_callbacks: Dictionary mapping channels to packet
313             handlers to use. None for a callback discards channel data.
314         """
315         for pkt in proto.read_pkt_seq():
316             channel = ord(pkt[:1])
317             pkt = pkt[1:]
318             try:
319                 cb = channel_callbacks[channel]
320             except KeyError:
321                 raise AssertionError('Invalid sideband channel %d' % channel)
322             else:
323                 if cb is not None:
324                     cb(pkt)
325
326     def _handle_receive_pack_head(self, proto, capabilities, old_refs,
327                                   new_refs):
328         """Handle the head of a 'git-receive-pack' request.
329
330         :param proto: Protocol object to read from
331         :param capabilities: List of negotiated capabilities
332         :param old_refs: Old refs, as received from the server
333         :param new_refs: New refs
334         :return: (have, want) tuple
335         """
336         want = []
337         have = [x for x in old_refs.values() if not x == ZERO_SHA]
338         sent_capabilities = False
339
340         all_refs = set(new_refs.keys()).union(set(old_refs.keys()))
341         for refname in all_refs:
342             old_sha1 = old_refs.get(refname, ZERO_SHA)
343             new_sha1 = new_refs.get(refname, ZERO_SHA)
344
345             if old_sha1 != new_sha1:
346                 if sent_capabilities:
347                     proto.write_pkt_line(old_sha1 + b' ' + new_sha1 + b' ' + refname)
348                 else:
349                     proto.write_pkt_line(
350                         old_sha1 + b' ' + new_sha1 + b' ' + refname + b'\0' +
351                         b' '.join(capabilities))
352                     sent_capabilities = True
353             if new_sha1 not in have and new_sha1 != ZERO_SHA:
354                 want.append(new_sha1)
355         proto.write_pkt_line(None)
356         return (have, want)
357
358     def _handle_receive_pack_tail(self, proto, capabilities, progress=None):
359         """Handle the tail of a 'git-receive-pack' request.
360
361         :param proto: Protocol object to read from
362         :param capabilities: List of negotiated capabilities
363         :param progress: Optional progress reporting function
364         """
365         if b"side-band-64k" in capabilities:
366             if progress is None:
367                 progress = lambda x: None
368             channel_callbacks = {2: progress}
369             if CAPABILITY_REPORT_STATUS in capabilities:
370                 channel_callbacks[1] = PktLineParser(
371                     self._report_status_parser.handle_packet).parse
372             self._read_side_band64k_data(proto, channel_callbacks)
373         else:
374             if CAPABILITY_REPORT_STATUS in capabilities:
375                 for pkt in proto.read_pkt_seq():
376                     self._report_status_parser.handle_packet(pkt)
377         if self._report_status_parser is not None:
378             self._report_status_parser.check()
379
380     def _handle_upload_pack_head(self, proto, capabilities, graph_walker,
381                                  wants, can_read):
382         """Handle the head of a 'git-upload-pack' request.
383
384         :param proto: Protocol object to read from
385         :param capabilities: List of negotiated capabilities
386         :param graph_walker: GraphWalker instance to call .ack() on
387         :param wants: List of commits to fetch
388         :param can_read: function that returns a boolean that indicates
389             whether there is extra graph data to read on proto
390         """
391         assert isinstance(wants, list) and isinstance(wants[0], bytes)
392         proto.write_pkt_line(COMMAND_WANT + b' ' + wants[0] + b' ' + b' '.join(capabilities) + b'\n')
393         for want in wants[1:]:
394             proto.write_pkt_line(COMMAND_WANT + b' ' + want + b'\n')
395         proto.write_pkt_line(None)
396         have = next(graph_walker)
397         while have:
398             proto.write_pkt_line(COMMAND_HAVE + b' ' + have + b'\n')
399             if can_read():
400                 pkt = proto.read_pkt_line()
401                 parts = pkt.rstrip(b'\n').split(b' ')
402                 if parts[0] == b'ACK':
403                     graph_walker.ack(parts[1])
404                     if parts[2] in (b'continue', b'common'):
405                         pass
406                     elif parts[2] == b'ready':
407                         break
408                     else:
409                         raise AssertionError(
410                             "%s not in ('continue', 'ready', 'common)" %
411                             parts[2])
412             have = next(graph_walker)
413         proto.write_pkt_line(COMMAND_DONE + b'\n')
414
415     def _handle_upload_pack_tail(self, proto, capabilities, graph_walker,
416                                  pack_data, progress=None, rbufsize=_RBUFSIZE):
417         """Handle the tail of a 'git-upload-pack' request.
418
419         :param proto: Protocol object to read from
420         :param capabilities: List of negotiated capabilities
421         :param graph_walker: GraphWalker instance to call .ack() on
422         :param pack_data: Function to call with pack data
423         :param progress: Optional progress reporting function
424         :param rbufsize: Read buffer size
425         """
426         pkt = proto.read_pkt_line()
427         while pkt:
428             parts = pkt.rstrip(b'\n').split(b' ')
429             if parts[0] == b'ACK':
430                 graph_walker.ack(parts[1])
431             if len(parts) < 3 or parts[2] not in (
432                     b'ready', b'continue', b'common'):
433                 break
434             pkt = proto.read_pkt_line()
435         if CAPABILITY_SIDE_BAND_64K in capabilities:
436             if progress is None:
437                 # Just ignore progress data
438                 progress = lambda x: None
439             self._read_side_band64k_data(proto, {
440                 SIDE_BAND_CHANNEL_DATA: pack_data,
441                 SIDE_BAND_CHANNEL_PROGRESS: progress}
442             )
443         else:
444             while True:
445                 data = proto.read(rbufsize)
446                 if data == b"":
447                     break
448                 pack_data(data)
449
450
451 class TraditionalGitClient(GitClient):
452     """Traditional Git client."""
453
454     def _connect(self, cmd, path):
455         """Create a connection to the server.
456
457         This method is abstract - concrete implementations should
458         implement their own variant which connects to the server and
459         returns an initialized Protocol object with the service ready
460         for use and a can_read function which may be used to see if
461         reads would block.
462
463         :param cmd: The git service name to which we should connect.
464         :param path: The path we should pass to the service. (as bytestirng)
465         """
466         raise NotImplementedError()
467
468     def send_pack(self, path, determine_wants, generate_pack_contents,
469                   progress=None, write_pack=write_pack_objects):
470         """Upload a pack to a remote repository.
471
472         :param path: Repository path (as bytestring)
473         :param generate_pack_contents: Function that can return a sequence of
474             the shas of the objects to upload.
475         :param progress: Optional callback called with progress updates
476         :param write_pack: Function called with (file, iterable of objects) to
477             write the objects returned by generate_pack_contents to the server.
478
479         :raises SendPackError: if server rejects the pack data
480         :raises UpdateRefsError: if the server supports report-status
481                                  and rejects ref updates
482         """
483         proto, unused_can_read = self._connect(b'receive-pack', path)
484         with proto:
485             old_refs, server_capabilities = read_pkt_refs(proto)
486             negotiated_capabilities = self._send_capabilities & server_capabilities
487
488             if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
489                 self._report_status_parser = ReportStatusParser()
490             report_status_parser = self._report_status_parser
491
492             try:
493                 new_refs = orig_new_refs = determine_wants(dict(old_refs))
494             except:
495                 proto.write_pkt_line(None)
496                 raise
497
498             if not CAPABILITY_DELETE_REFS in server_capabilities:
499                 # Server does not support deletions. Fail later.
500                 new_refs = dict(orig_new_refs)
501                 for ref, sha in orig_new_refs.items():
502                     if sha == ZERO_SHA:
503                         if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
504                             report_status_parser._ref_statuses.append(
505                                 b'ng ' + sha + b' remote does not support deleting refs')
506                             report_status_parser._ref_status_ok = False
507                         del new_refs[ref]
508
509             if new_refs is None:
510                 proto.write_pkt_line(None)
511                 return old_refs
512
513             if len(new_refs) == 0 and len(orig_new_refs):
514                 # NOOP - Original new refs filtered out by policy
515                 proto.write_pkt_line(None)
516                 if report_status_parser is not None:
517                     report_status_parser.check()
518                 return old_refs
519
520             (have, want) = self._handle_receive_pack_head(
521                 proto, negotiated_capabilities, old_refs, new_refs)
522             if not want and old_refs == new_refs:
523                 return new_refs
524             objects = generate_pack_contents(have, want)
525
526             dowrite = len(objects) > 0
527             dowrite = dowrite or any(old_refs.get(ref) != sha
528                                      for (ref, sha) in new_refs.items()
529                                      if sha != ZERO_SHA)
530             if dowrite:
531                 write_pack(proto.write_file(), objects)
532
533             self._handle_receive_pack_tail(
534                 proto, negotiated_capabilities, progress)
535             return new_refs
536
537     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
538                    progress=None):
539         """Retrieve a pack from a git smart server.
540
541         :param determine_wants: Callback that returns list of commits to fetch
542         :param graph_walker: Object with next() and ack().
543         :param pack_data: Callback called for each bit of data in the pack
544         :param progress: Callback for progress reports (strings)
545         """
546         proto, can_read = self._connect(b'upload-pack', path)
547         with proto:
548             refs, server_capabilities = read_pkt_refs(proto)
549             negotiated_capabilities = (
550                 self._fetch_capabilities & server_capabilities)
551
552             if refs is None:
553                 proto.write_pkt_line(None)
554                 return refs
555
556             try:
557                 wants = determine_wants(refs)
558             except:
559                 proto.write_pkt_line(None)
560                 raise
561             if wants is not None:
562                 wants = [cid for cid in wants if cid != ZERO_SHA]
563             if not wants:
564                 proto.write_pkt_line(None)
565                 return refs
566             self._handle_upload_pack_head(
567                 proto, negotiated_capabilities, graph_walker, wants, can_read)
568             self._handle_upload_pack_tail(
569                 proto, negotiated_capabilities, graph_walker, pack_data, progress)
570             return refs
571
572     def get_refs(self, path):
573         """Retrieve the current refs from a git smart server."""
574         # stock `git ls-remote` uses upload-pack
575         proto, _ = self._connect(b'upload-pack', path)
576         with proto:
577             refs, _ = read_pkt_refs(proto)
578             return refs
579
580     def archive(self, path, committish, write_data, progress=None,
581                 write_error=None):
582         proto, can_read = self._connect(b'upload-archive', path)
583         with proto:
584             proto.write_pkt_line(b"argument " + committish)
585             proto.write_pkt_line(None)
586             pkt = proto.read_pkt_line()
587             if pkt == b"NACK\n":
588                 return
589             elif pkt == b"ACK\n":
590                 pass
591             elif pkt.startswith(b"ERR "):
592                 raise GitProtocolError(pkt[4:].rstrip(b"\n"))
593             else:
594                 raise AssertionError("invalid response %r" % pkt)
595             ret = proto.read_pkt_line()
596             if ret is not None:
597                 raise AssertionError("expected pkt tail")
598             self._read_side_band64k_data(proto, {
599                 SIDE_BAND_CHANNEL_DATA: write_data,
600                 SIDE_BAND_CHANNEL_PROGRESS: progress,
601                 SIDE_BAND_CHANNEL_FATAL: write_error})
602
603
604 class TCPGitClient(TraditionalGitClient):
605     """A Git Client that works over TCP directly (i.e. git://)."""
606
607     def __init__(self, host, port=None, **kwargs):
608         if port is None:
609             port = TCP_GIT_PORT
610         self._host = host
611         self._port = port
612         TraditionalGitClient.__init__(self, **kwargs)
613
614     def _connect(self, cmd, path):
615         if type(cmd) is not bytes:
616             raise TypeError(path)
617         if type(path) is not bytes:
618             raise TypeError(path)
619         sockaddrs = socket.getaddrinfo(
620             self._host, self._port, socket.AF_UNSPEC, socket.SOCK_STREAM)
621         s = None
622         err = socket.error("no address found for %s" % self._host)
623         for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
624             s = socket.socket(family, socktype, proto)
625             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
626             try:
627                 s.connect(sockaddr)
628                 break
629             except socket.error as err:
630                 if s is not None:
631                     s.close()
632                 s = None
633         if s is None:
634             raise err
635         # -1 means system default buffering
636         rfile = s.makefile('rb', -1)
637         # 0 means unbuffered
638         wfile = s.makefile('wb', 0)
639         def close():
640             rfile.close()
641             wfile.close()
642             s.close()
643
644         proto = Protocol(rfile.read, wfile.write, close,
645                          report_activity=self._report_activity)
646         if path.startswith(b"/~"):
647             path = path[1:]
648         # TODO(jelmer): Alternative to ascii?
649         proto.send_cmd(b'git-' + cmd, path, b'host=' + self._host.encode('ascii'))
650         return proto, lambda: _fileno_can_read(s)
651
652
653 class SubprocessWrapper(object):
654     """A socket-like object that talks to a subprocess via pipes."""
655
656     def __init__(self, proc):
657         self.proc = proc
658         if sys.version_info[0] == 2:
659             self.read = proc.stdout.read
660         else:
661             self.read = BufferedReader(proc.stdout).read
662         self.write = proc.stdin.write
663
664     def can_read(self):
665         if subprocess.mswindows:
666             from msvcrt import get_osfhandle
667             from win32pipe import PeekNamedPipe
668             handle = get_osfhandle(self.proc.stdout.fileno())
669             data, total_bytes_avail, msg_bytes_left = PeekNamedPipe(handle, 0)
670             return total_bytes_avail != 0
671         else:
672             return _fileno_can_read(self.proc.stdout.fileno())
673
674     def close(self):
675         self.proc.stdin.close()
676         self.proc.stdout.close()
677         if self.proc.stderr:
678             self.proc.stderr.close()
679         self.proc.wait()
680
681
682 def find_git_command():
683     """Find command to run for system Git (usually C Git).
684     """
685     if sys.platform == 'win32': # support .exe, .bat and .cmd
686         try: # to avoid overhead
687             import win32api
688         except ImportError: # run through cmd.exe with some overhead
689             return ['cmd', '/c', 'git']
690         else:
691             status, git = win32api.FindExecutable('git')
692             return [git]
693     else:
694         return ['git']
695
696
697 class SubprocessGitClient(TraditionalGitClient):
698     """Git client that talks to a server using a subprocess."""
699
700     def __init__(self, **kwargs):
701         self._connection = None
702         self._stderr = None
703         self._stderr = kwargs.get('stderr')
704         if 'stderr' in kwargs:
705             del kwargs['stderr']
706         TraditionalGitClient.__init__(self, **kwargs)
707
708     git_command = None
709
710     def _connect(self, service, path):
711         if type(service) is not bytes:
712             raise TypeError(path)
713         if type(path) is not bytes:
714             raise TypeError(path)
715         import subprocess
716         if self.git_command is None:
717             git_command = find_git_command()
718         argv = git_command + [service.decode('ascii'), path]
719         p = SubprocessWrapper(
720             subprocess.Popen(argv, bufsize=0, stdin=subprocess.PIPE,
721                              stdout=subprocess.PIPE,
722                              stderr=self._stderr))
723         return Protocol(p.read, p.write, p.close,
724                         report_activity=self._report_activity), p.can_read
725
726
727 class LocalGitClient(GitClient):
728     """Git Client that just uses a local Repo."""
729
730     def __init__(self, thin_packs=True, report_activity=None):
731         """Create a new LocalGitClient instance.
732
733         :param thin_packs: Whether or not thin packs should be retrieved
734         :param report_activity: Optional callback for reporting transport
735             activity.
736         """
737         self._report_activity = report_activity
738         # Ignore the thin_packs argument
739
740     def send_pack(self, path, determine_wants, generate_pack_contents,
741                   progress=None, write_pack=write_pack_objects):
742         """Upload a pack to a remote repository.
743
744         :param path: Repository path (as bytestring)
745         :param generate_pack_contents: Function that can return a sequence of
746             the shas of the objects to upload.
747         :param progress: Optional progress function
748         :param write_pack: Function called with (file, iterable of objects) to
749             write the objects returned by generate_pack_contents to the server.
750
751         :raises SendPackError: if server rejects the pack data
752         :raises UpdateRefsError: if the server supports report-status
753                                  and rejects ref updates
754         """
755         from dulwich.repo import Repo
756
757         with closing(Repo(path)) as target:
758             old_refs = target.get_refs()
759             new_refs = determine_wants(dict(old_refs))
760
761             have = [sha1 for sha1 in old_refs.values() if sha1 != ZERO_SHA]
762             want = []
763             all_refs = set(new_refs.keys()).union(set(old_refs.keys()))
764             for refname in all_refs:
765                 old_sha1 = old_refs.get(refname, ZERO_SHA)
766                 new_sha1 = new_refs.get(refname, ZERO_SHA)
767                 if new_sha1 not in have and new_sha1 != ZERO_SHA:
768                     want.append(new_sha1)
769
770             if not want and old_refs == new_refs:
771                 return new_refs
772
773             target.object_store.add_objects(generate_pack_contents(have, want))
774
775             for name, sha in new_refs.items():
776                 target.refs[name] = sha
777
778         return new_refs
779
780     def fetch(self, path, target, determine_wants=None, progress=None):
781         """Fetch into a target repository.
782
783         :param path: Path to fetch from (as bytestring)
784         :param target: Target repository to fetch into
785         :param determine_wants: Optional function to determine what refs
786             to fetch
787         :param progress: Optional progress function
788         :return: remote refs as dictionary
789         """
790         from dulwich.repo import Repo
791         with closing(Repo(path)) as r:
792             return r.fetch(target, determine_wants=determine_wants,
793                            progress=progress)
794
795     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
796                    progress=None):
797         """Retrieve a pack from a git smart server.
798
799         :param determine_wants: Callback that returns list of commits to fetch
800         :param graph_walker: Object with next() and ack().
801         :param pack_data: Callback called for each bit of data in the pack
802         :param progress: Callback for progress reports (strings)
803         """
804         from dulwich.repo import Repo
805         with closing(Repo(path)) as r:
806             objects_iter = r.fetch_objects(determine_wants, graph_walker, progress)
807
808             # Did the process short-circuit (e.g. in a stateless RPC call)? Note
809             # that the client still expects a 0-object pack in most cases.
810             if objects_iter is None:
811                 return
812             write_pack_objects(ProtocolFile(None, pack_data), objects_iter)
813
814     def get_refs(self, path):
815         """Retrieve the current refs from a git smart server."""
816         from dulwich.repo import Repo
817
818         with closing(Repo(path)) as target:
819             return target.get_refs()
820
821
822 # What Git client to use for local access
823 default_local_git_client_cls = LocalGitClient
824
825
826 class SSHVendor(object):
827     """A client side SSH implementation."""
828
829     def connect_ssh(self, host, command, username=None, port=None):
830         import warnings
831         warnings.warn(
832             "SSHVendor.connect_ssh has been renamed to SSHVendor.run_command",
833             DeprecationWarning)
834         return self.run_command(host, command, username=username, port=port)
835
836     def run_command(self, host, command, username=None, port=None):
837         """Connect to an SSH server.
838
839         Run a command remotely and return a file-like object for interaction
840         with the remote command.
841
842         :param host: Host name
843         :param command: Command to run (as argv array)
844         :param username: Optional ame of user to log in as
845         :param port: Optional SSH port to use
846         """
847         raise NotImplementedError(self.run_command)
848
849
850 class SubprocessSSHVendor(SSHVendor):
851     """SSH vendor that shells out to the local 'ssh' command."""
852
853     def run_command(self, host, command, username=None, port=None):
854         if (type(command) is not list or
855             not all([isinstance(b, bytes) for b in command])):
856             raise TypeError(command)
857
858         import subprocess
859         #FIXME: This has no way to deal with passwords..
860         args = ['ssh', '-x']
861         if port is not None:
862             args.extend(['-p', str(port)])
863         if username is not None:
864             host = '%s@%s' % (username, host)
865         args.append(host)
866         proc = subprocess.Popen(args + command,
867                                 stdin=subprocess.PIPE,
868                                 stdout=subprocess.PIPE)
869         return SubprocessWrapper(proc)
870
871
872 def ParamikoSSHVendor(**kwargs):
873     import warnings
874     warnings.warn(
875         "ParamikoSSHVendor has been moved to dulwich.contrib.paramiko.",
876         DeprecationWarning)
877     from dulwich.contrib.paramiko import ParamikoSSHVendor
878     return ParamikoSSHVendor(**kwargs)
879
880
881 # Can be overridden by users
882 get_ssh_vendor = SubprocessSSHVendor
883
884
885 class SSHGitClient(TraditionalGitClient):
886
887     def __init__(self, host, port=None, username=None, vendor=None, **kwargs):
888         self.host = host
889         self.port = port
890         self.username = username
891         TraditionalGitClient.__init__(self, **kwargs)
892         self.alternative_paths = {}
893         if vendor is not None:
894             self.vendor = vendor
895         else:
896             self.vendor = get_ssh_vendor()
897
898     def _get_cmd_path(self, cmd):
899         cmd = self.alternative_paths.get(cmd, b'git-' + cmd)
900         assert isinstance(cmd, bytes)
901         if sys.version_info[:2] <= (2, 6):
902             return shlex.split(cmd)
903         else:
904             # TODO(jelmer): Don't decode/encode here
905             return [x.encode('ascii') for x in shlex.split(cmd.decode('ascii'))]
906
907     def _connect(self, cmd, path):
908         if type(cmd) is not bytes:
909             raise TypeError(path)
910         if type(path) is not bytes:
911             raise TypeError(path)
912         if path.startswith(b"/~"):
913             path = path[1:]
914         argv = self._get_cmd_path(cmd) + [path]
915         con = self.vendor.run_command(
916             self.host, argv, port=self.port, username=self.username)
917         return (Protocol(con.read, con.write, con.close,
918                          report_activity=self._report_activity),
919                 con.can_read)
920
921
922 def default_user_agent_string():
923     return "dulwich/%s" % ".".join([str(x) for x in dulwich.__version__])
924
925
926 def default_urllib2_opener(config):
927     if config is not None:
928         proxy_server = config.get("http", "proxy")
929     else:
930         proxy_server = None
931     handlers = []
932     if proxy_server is not None:
933         handlers.append(urllib2.ProxyHandler({"http": proxy_server}))
934     opener = urllib2.build_opener(*handlers)
935     if config is not None:
936         user_agent = config.get("http", "useragent")
937     else:
938         user_agent = None
939     if user_agent is None:
940         user_agent = default_user_agent_string()
941     opener.addheaders = [('User-agent', user_agent)]
942     return opener
943
944
945 class HttpGitClient(GitClient):
946
947     def __init__(self, base_url, dumb=None, opener=None, config=None, **kwargs):
948         self.base_url = base_url.rstrip("/") + "/"
949         self.dumb = dumb
950         if opener is None:
951             self.opener = default_urllib2_opener(config)
952         else:
953             self.opener = opener
954         GitClient.__init__(self, **kwargs)
955
956     def __repr__(self):
957         return "%s(%r, dumb=%r)" % (type(self).__name__, self.base_url, self.dumb)
958
959     def _get_url(self, path):
960         return urlparse.urljoin(self.base_url, path).rstrip("/") + "/"
961
962     def _http_request(self, url, headers={}, data=None):
963         req = urllib2.Request(url, headers=headers, data=data)
964         try:
965             resp = self.opener.open(req)
966         except urllib2.HTTPError as e:
967             if e.code == 404:
968                 raise NotGitRepository()
969             if e.code != 200:
970                 raise GitProtocolError("unexpected http response %d" % e.code)
971         return resp
972
973     def _discover_references(self, service, url):
974         assert url[-1] == "/"
975         url = urlparse.urljoin(url, "info/refs")
976         headers = {}
977         if self.dumb is not False:
978             url += "?service=%s" % service
979             headers["Content-Type"] = "application/x-%s-request" % service
980         resp = self._http_request(url, headers)
981         try:
982             self.dumb = (not resp.info().gettype().startswith("application/x-git-"))
983             if not self.dumb:
984                 proto = Protocol(resp.read, None)
985                 # The first line should mention the service
986                 pkts = list(proto.read_pkt_seq())
987                 if pkts != [('# service=%s\n' % service)]:
988                     raise GitProtocolError(
989                         "unexpected first line %r from smart server" % pkts)
990                 return read_pkt_refs(proto)
991             else:
992                 return read_info_refs(resp), set()
993         finally:
994             resp.close()
995
996     def _smart_request(self, service, url, data):
997         assert url[-1] == "/"
998         url = urlparse.urljoin(url, service)
999         headers = {"Content-Type": "application/x-%s-request" % service}
1000         resp = self._http_request(url, headers, data)
1001         if resp.info().gettype() != ("application/x-%s-result" % service):
1002             raise GitProtocolError("Invalid content-type from server: %s"
1003                 % resp.info().gettype())
1004         return resp
1005
1006     def send_pack(self, path, determine_wants, generate_pack_contents,
1007                   progress=None, write_pack=write_pack_objects):
1008         """Upload a pack to a remote repository.
1009
1010         :param path: Repository path (as bytestring)
1011         :param generate_pack_contents: Function that can return a sequence of
1012             the shas of the objects to upload.
1013         :param progress: Optional progress function
1014         :param write_pack: Function called with (file, iterable of objects) to
1015             write the objects returned by generate_pack_contents to the server.
1016
1017         :raises SendPackError: if server rejects the pack data
1018         :raises UpdateRefsError: if the server supports report-status
1019                                  and rejects ref updates
1020         """
1021         url = self._get_url(path)
1022         old_refs, server_capabilities = self._discover_references(
1023             b"git-receive-pack", url)
1024         negotiated_capabilities = self._send_capabilities & server_capabilities
1025
1026         if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
1027             self._report_status_parser = ReportStatusParser()
1028
1029         new_refs = determine_wants(dict(old_refs))
1030         if new_refs is None:
1031             return old_refs
1032         if self.dumb:
1033             raise NotImplementedError(self.fetch_pack)
1034         req_data = BytesIO()
1035         req_proto = Protocol(None, req_data.write)
1036         (have, want) = self._handle_receive_pack_head(
1037             req_proto, negotiated_capabilities, old_refs, new_refs)
1038         if not want and old_refs == new_refs:
1039             return new_refs
1040         objects = generate_pack_contents(have, want)
1041         if len(objects) > 0:
1042             write_pack(req_proto.write_file(), objects)
1043         resp = self._smart_request(b"git-receive-pack", url,
1044                                    data=req_data.getvalue())
1045         try:
1046             resp_proto = Protocol(resp.read, None)
1047             self._handle_receive_pack_tail(resp_proto, negotiated_capabilities,
1048                 progress)
1049             return new_refs
1050         finally:
1051             resp.close()
1052
1053
1054     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
1055                    progress=None):
1056         """Retrieve a pack from a git smart server.
1057
1058         :param determine_wants: Callback that returns list of commits to fetch
1059         :param graph_walker: Object with next() and ack().
1060         :param pack_data: Callback called for each bit of data in the pack
1061         :param progress: Callback for progress reports (strings)
1062         :return: Dictionary with the refs of the remote repository
1063         """
1064         url = self._get_url(path)
1065         refs, server_capabilities = self._discover_references(
1066             b"git-upload-pack", url)
1067         negotiated_capabilities = self._fetch_capabilities & server_capabilities
1068         wants = determine_wants(refs)
1069         if wants is not None:
1070             wants = [cid for cid in wants if cid != ZERO_SHA]
1071         if not wants:
1072             return refs
1073         if self.dumb:
1074             raise NotImplementedError(self.send_pack)
1075         req_data = BytesIO()
1076         req_proto = Protocol(None, req_data.write)
1077         self._handle_upload_pack_head(
1078             req_proto, negotiated_capabilities, graph_walker, wants,
1079             lambda: False)
1080         resp = self._smart_request(
1081             b"git-upload-pack", url, data=req_data.getvalue())
1082         try:
1083             resp_proto = Protocol(resp.read, None)
1084             self._handle_upload_pack_tail(resp_proto, negotiated_capabilities,
1085                 graph_walker, pack_data, progress)
1086             return refs
1087         finally:
1088             resp.close()
1089
1090     def get_refs(self, path):
1091         """Retrieve the current refs from a git smart server."""
1092         url = self._get_url(path)
1093         refs, _ = self._discover_references(
1094             b"git-upload-pack", url)
1095         return refs
1096
1097
1098 def get_transport_and_path_from_url(url, config=None, **kwargs):
1099     """Obtain a git client from a URL.
1100
1101     :param url: URL to open (a unicode string)
1102     :param config: Optional config object
1103     :param thin_packs: Whether or not thin packs should be retrieved
1104     :param report_activity: Optional callback for reporting transport
1105         activity.
1106     :return: Tuple with client instance and relative path.
1107     """
1108     parsed = urlparse.urlparse(url)
1109     if parsed.scheme == 'git':
1110         return (TCPGitClient(parsed.hostname, port=parsed.port, **kwargs),
1111                 parsed.path)
1112     elif parsed.scheme == 'git+ssh':
1113         path = parsed.path
1114         if path.startswith('/'):
1115             path = parsed.path[1:]
1116         return SSHGitClient(parsed.hostname, port=parsed.port,
1117                             username=parsed.username, **kwargs), path
1118     elif parsed.scheme in ('http', 'https'):
1119         return HttpGitClient(urlparse.urlunparse(parsed), config=config,
1120                 **kwargs), parsed.path
1121     elif parsed.scheme == 'file':
1122         return default_local_git_client_cls(**kwargs), parsed.path
1123
1124     raise ValueError("unknown scheme '%s'" % parsed.scheme)
1125
1126
1127 def get_transport_and_path(location, **kwargs):
1128     """Obtain a git client from a URL.
1129
1130     :param location: URL or path (a string)
1131     :param config: Optional config object
1132     :param thin_packs: Whether or not thin packs should be retrieved
1133     :param report_activity: Optional callback for reporting transport
1134         activity.
1135     :return: Tuple with client instance and relative path.
1136     """
1137     # First, try to parse it as a URL
1138     try:
1139         return get_transport_and_path_from_url(location, **kwargs)
1140     except ValueError:
1141         pass
1142
1143     if (sys.platform == 'win32' and
1144             location[0].isalpha() and location[1:3] == ':\\'):
1145         # Windows local path
1146         return default_local_git_client_cls(**kwargs), location
1147
1148     if ':' in location and not '@' in location:
1149         # SSH with no user@, zero or one leading slash.
1150         (hostname, path) = location.split(':')
1151         return SSHGitClient(hostname, **kwargs), path
1152     elif '@' in location and ':' in location:
1153         # SSH with user@host:foo.
1154         user_host, path = location.split(':')
1155         user, host = user_host.rsplit('@')
1156         return SSHGitClient(host, username=user, **kwargs), path
1157
1158     # Otherwise, assume it's a local path.
1159     return default_local_git_client_cls(**kwargs), location