1d3cc7e33289cc057a56f15c6c97f2827a41f00c
[jelmer/dulwich.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 # Coprygith (C) 2011-2012 Jelmer Vernooij <jelmer@jelmer.uk>
4 #
5 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
6 # General Public License as public by the Free Software Foundation; version 2.0
7 # or (at your option) any later version. You can redistribute it and/or
8 # modify it under the terms of either of these two licenses.
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 #
16 # You should have received a copy of the licenses; if not, see
17 # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
18 # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
19 # License, Version 2.0.
20 #
21
22 """Git smart network protocol server implementation.
23
24 For more detailed implementation on the network protocol, see the
25 Documentation/technical directory in the cgit distribution, and in particular:
26
27 * Documentation/technical/protocol-capabilities.txt
28 * Documentation/technical/pack-protocol.txt
29
30 Currently supported capabilities:
31
32  * include-tag
33  * thin-pack
34  * multi_ack_detailed
35  * multi_ack
36  * side-band-64k
37  * ofs-delta
38  * no-progress
39  * report-status
40  * delete-refs
41  * shallow
42  * symref
43 """
44
45 import collections
46 import os
47 import socket
48 import sys
49 import zlib
50
51 try:
52     import SocketServer
53 except ImportError:
54     import socketserver as SocketServer
55
56 from dulwich.errors import (
57     ApplyDeltaError,
58     ChecksumMismatch,
59     GitProtocolError,
60     NotGitRepository,
61     UnexpectedCommandError,
62     ObjectFormatException,
63     )
64 from dulwich import log_utils
65 from dulwich.objects import (
66     Commit,
67     valid_hexsha,
68     )
69 from dulwich.pack import (
70     write_pack_objects,
71     )
72 from dulwich.protocol import (  # noqa: F401
73     BufferedPktLineWriter,
74     capability_agent,
75     CAPABILITIES_REF,
76     CAPABILITY_DELETE_REFS,
77     CAPABILITY_INCLUDE_TAG,
78     CAPABILITY_MULTI_ACK_DETAILED,
79     CAPABILITY_MULTI_ACK,
80     CAPABILITY_NO_DONE,
81     CAPABILITY_NO_PROGRESS,
82     CAPABILITY_OFS_DELTA,
83     CAPABILITY_QUIET,
84     CAPABILITY_REPORT_STATUS,
85     CAPABILITY_SHALLOW,
86     CAPABILITY_SIDE_BAND_64K,
87     CAPABILITY_THIN_PACK,
88     COMMAND_DEEPEN,
89     COMMAND_DONE,
90     COMMAND_HAVE,
91     COMMAND_SHALLOW,
92     COMMAND_UNSHALLOW,
93     COMMAND_WANT,
94     MULTI_ACK,
95     MULTI_ACK_DETAILED,
96     Protocol,
97     ProtocolFile,
98     ReceivableProtocol,
99     SIDE_BAND_CHANNEL_DATA,
100     SIDE_BAND_CHANNEL_PROGRESS,
101     SIDE_BAND_CHANNEL_FATAL,
102     SINGLE_ACK,
103     TCP_GIT_PORT,
104     ZERO_SHA,
105     ack_type,
106     extract_capabilities,
107     extract_want_line_capabilities,
108     symref_capabilities,
109     )
110 from dulwich.refs import (
111     ANNOTATED_TAG_SUFFIX,
112     write_info_refs,
113     )
114 from dulwich.repo import (
115     Repo,
116     )
117
118
119 logger = log_utils.getLogger(__name__)
120
121
122 class Backend(object):
123     """A backend for the Git smart server implementation."""
124
125     def open_repository(self, path):
126         """Open the repository at a path.
127
128         :param path: Path to the repository
129         :raise NotGitRepository: no git repository was found at path
130         :return: Instance of BackendRepo
131         """
132         raise NotImplementedError(self.open_repository)
133
134
135 class BackendRepo(object):
136     """Repository abstraction used by the Git server.
137
138     The methods required here are a subset of those provided by
139     dulwich.repo.Repo.
140     """
141
142     object_store = None
143     refs = None
144
145     def get_refs(self):
146         """
147         Get all the refs in the repository
148
149         :return: dict of name -> sha
150         """
151         raise NotImplementedError
152
153     def get_peeled(self, name):
154         """Return the cached peeled value of a ref, if available.
155
156         :param name: Name of the ref to peel
157         :return: The peeled value of the ref. If the ref is known not point to
158             a tag, this will be the SHA the ref refers to. If no cached
159             information about a tag is available, this method may return None,
160             but it should attempt to peel the tag if possible.
161         """
162         return None
163
164     def fetch_objects(self, determine_wants, graph_walker, progress,
165                       get_tagged=None):
166         """
167         Yield the objects required for a list of commits.
168
169         :param progress: is a callback to send progress messages to the client
170         :param get_tagged: Function that returns a dict of pointed-to sha ->
171             tag sha for including tags.
172         """
173         raise NotImplementedError
174
175
176 class DictBackend(Backend):
177     """Trivial backend that looks up Git repositories in a dictionary."""
178
179     def __init__(self, repos):
180         self.repos = repos
181
182     def open_repository(self, path):
183         logger.debug('Opening repository at %s', path)
184         try:
185             return self.repos[path]
186         except KeyError:
187             raise NotGitRepository(
188                 "No git repository was found at %(path)s" % dict(path=path)
189             )
190
191
192 class FileSystemBackend(Backend):
193     """Simple backend looking up Git repositories in the local file system."""
194
195     def __init__(self, root=os.sep):
196         super(FileSystemBackend, self).__init__()
197         self.root = (os.path.abspath(root) + os.sep).replace(
198                 os.sep * 2, os.sep)
199
200     def open_repository(self, path):
201         logger.debug('opening repository at %s', path)
202         abspath = os.path.abspath(os.path.join(self.root, path)) + os.sep
203         normcase_abspath = os.path.normcase(abspath)
204         normcase_root = os.path.normcase(self.root)
205         if not normcase_abspath.startswith(normcase_root):
206             raise NotGitRepository(
207                     "Path %r not inside root %r" %
208                     (path, self.root))
209         return Repo(abspath)
210
211
212 class Handler(object):
213     """Smart protocol command handler base class."""
214
215     def __init__(self, backend, proto, http_req=None):
216         self.backend = backend
217         self.proto = proto
218         self.http_req = http_req
219
220     def handle(self):
221         raise NotImplementedError(self.handle)
222
223
224 class PackHandler(Handler):
225     """Protocol handler for packs."""
226
227     def __init__(self, backend, proto, http_req=None):
228         super(PackHandler, self).__init__(backend, proto, http_req)
229         self._client_capabilities = None
230         # Flags needed for the no-done capability
231         self._done_received = False
232
233     @classmethod
234     def capability_line(cls, capabilities):
235         logger.info('Sending capabilities: %s', capabilities)
236         return b"".join([b" " + c for c in capabilities])
237
238     @classmethod
239     def capabilities(cls):
240         raise NotImplementedError(cls.capabilities)
241
242     @classmethod
243     def innocuous_capabilities(cls):
244         return [CAPABILITY_INCLUDE_TAG, CAPABILITY_THIN_PACK,
245                 CAPABILITY_NO_PROGRESS, CAPABILITY_OFS_DELTA,
246                 capability_agent()]
247
248     @classmethod
249     def required_capabilities(cls):
250         """Return a list of capabilities that we require the client to have."""
251         return []
252
253     def set_client_capabilities(self, caps):
254         allowable_caps = set(self.innocuous_capabilities())
255         allowable_caps.update(self.capabilities())
256         for cap in caps:
257             if cap not in allowable_caps:
258                 raise GitProtocolError('Client asked for capability %s that '
259                                        'was not advertised.' % cap)
260         for cap in self.required_capabilities():
261             if cap not in caps:
262                 raise GitProtocolError('Client does not support required '
263                                        'capability %s.' % cap)
264         self._client_capabilities = set(caps)
265         logger.info('Client capabilities: %s', caps)
266
267     def has_capability(self, cap):
268         if self._client_capabilities is None:
269             raise GitProtocolError('Server attempted to access capability %s '
270                                    'before asking client' % cap)
271         return cap in self._client_capabilities
272
273     def notify_done(self):
274         self._done_received = True
275
276
277 class UploadPackHandler(PackHandler):
278     """Protocol handler for uploading a pack to the client."""
279
280     def __init__(self, backend, args, proto, http_req=None,
281                  advertise_refs=False):
282         super(UploadPackHandler, self).__init__(
283                 backend, proto, http_req=http_req)
284         self.repo = backend.open_repository(args[0])
285         self._graph_walker = None
286         self.advertise_refs = advertise_refs
287         # A state variable for denoting that the have list is still
288         # being processed, and the client is not accepting any other
289         # data (such as side-band, see the progress method here).
290         self._processing_have_lines = False
291
292     @classmethod
293     def capabilities(cls):
294         return [CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_MULTI_ACK,
295                 CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
296                 CAPABILITY_OFS_DELTA, CAPABILITY_NO_PROGRESS,
297                 CAPABILITY_INCLUDE_TAG, CAPABILITY_SHALLOW, CAPABILITY_NO_DONE]
298
299     @classmethod
300     def required_capabilities(cls):
301         return (CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
302                 CAPABILITY_OFS_DELTA)
303
304     def progress(self, message):
305         if (self.has_capability(CAPABILITY_NO_PROGRESS) or
306                 self._processing_have_lines):
307             return
308         self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
309
310     def get_tagged(self, refs=None, repo=None):
311         """Get a dict of peeled values of tags to their original tag shas.
312
313         :param refs: dict of refname -> sha of possible tags; defaults to all
314             of the backend's refs.
315         :param repo: optional Repo instance for getting peeled refs; defaults
316             to the backend's repo, if available
317         :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
318             tag whose peeled value is peeled_sha.
319         """
320         if not self.has_capability(CAPABILITY_INCLUDE_TAG):
321             return {}
322         if refs is None:
323             refs = self.repo.get_refs()
324         if repo is None:
325             repo = getattr(self.repo, "repo", None)
326             if repo is None:
327                 # Bail if we don't have a Repo available; this is ok since
328                 # clients must be able to handle if the server doesn't include
329                 # all relevant tags.
330                 # TODO: fix behavior when missing
331                 return {}
332         tagged = {}
333         for name, sha in refs.items():
334             peeled_sha = repo.get_peeled(name)
335             if peeled_sha != sha:
336                 tagged[peeled_sha] = sha
337         return tagged
338
339     def handle(self):
340         def write(x):
341             return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
342
343         graph_walker = _ProtocolGraphWalker(
344                 self, self.repo.object_store, self.repo.get_peeled,
345                 self.repo.refs.get_symrefs)
346         objects_iter = self.repo.fetch_objects(
347             graph_walker.determine_wants, graph_walker, self.progress,
348             get_tagged=self.get_tagged)
349
350         # Note the fact that client is only processing responses related
351         # to the have lines it sent, and any other data (including side-
352         # band) will be be considered a fatal error.
353         self._processing_have_lines = True
354
355         # Did the process short-circuit (e.g. in a stateless RPC call)? Note
356         # that the client still expects a 0-object pack in most cases.
357         # Also, if it also happens that the object_iter is instantiated
358         # with a graph walker with an implementation that talks over the
359         # wire (which is this instance of this class) this will actually
360         # iterate through everything and write things out to the wire.
361         if len(objects_iter) == 0:
362             return
363
364         # The provided haves are processed, and it is safe to send side-
365         # band data now.
366         self._processing_have_lines = False
367
368         if not graph_walker.handle_done(
369                 not self.has_capability(CAPABILITY_NO_DONE),
370                 self._done_received):
371             return
372
373         self.progress(b"dul-daemon says what\n")
374         self.progress(
375                 ("counting objects: %d, done.\n" % len(objects_iter)).encode(
376                     'ascii'))
377         write_pack_objects(ProtocolFile(None, write), objects_iter)
378         self.progress(b"how was that, then?\n")
379         # we are done
380         self.proto.write_pkt_line(None)
381
382
383 def _split_proto_line(line, allowed):
384     """Split a line read from the wire.
385
386     :param line: The line read from the wire.
387     :param allowed: An iterable of command names that should be allowed.
388         Command names not listed below as possible return values will be
389         ignored.  If None, any commands from the possible return values are
390         allowed.
391     :return: a tuple having one of the following forms:
392         ('want', obj_id)
393         ('have', obj_id)
394         ('done', None)
395         (None, None)  (for a flush-pkt)
396
397     :raise UnexpectedCommandError: if the line cannot be parsed into one of the
398         allowed return values.
399     """
400     if not line:
401         fields = [None]
402     else:
403         fields = line.rstrip(b'\n').split(b' ', 1)
404     command = fields[0]
405     if allowed is not None and command not in allowed:
406         raise UnexpectedCommandError(command)
407     if len(fields) == 1 and command in (COMMAND_DONE, None):
408         return (command, None)
409     elif len(fields) == 2:
410         if command in (COMMAND_WANT, COMMAND_HAVE, COMMAND_SHALLOW,
411                        COMMAND_UNSHALLOW):
412             if not valid_hexsha(fields[1]):
413                 raise GitProtocolError("Invalid sha")
414             return tuple(fields)
415         elif command == COMMAND_DEEPEN:
416             return command, int(fields[1])
417     raise GitProtocolError('Received invalid line from client: %r' % line)
418
419
420 def _find_shallow(store, heads, depth):
421     """Find shallow commits according to a given depth.
422
423     :param store: An ObjectStore for looking up objects.
424     :param heads: Iterable of head SHAs to start walking from.
425     :param depth: The depth of ancestors to include. A depth of one includes
426         only the heads themselves.
427     :return: A tuple of (shallow, not_shallow), sets of SHAs that should be
428         considered shallow and unshallow according to the arguments. Note that
429         these sets may overlap if a commit is reachable along multiple paths.
430     """
431     parents = {}
432
433     def get_parents(sha):
434         result = parents.get(sha, None)
435         if not result:
436             result = store[sha].parents
437             parents[sha] = result
438         return result
439
440     todo = []  # stack of (sha, depth)
441     for head_sha in heads:
442         obj = store.peel_sha(head_sha)
443         if isinstance(obj, Commit):
444             todo.append((obj.id, 1))
445
446     not_shallow = set()
447     shallow = set()
448     while todo:
449         sha, cur_depth = todo.pop()
450         if cur_depth < depth:
451             not_shallow.add(sha)
452             new_depth = cur_depth + 1
453             todo.extend((p, new_depth) for p in get_parents(sha))
454         else:
455             shallow.add(sha)
456
457     return shallow, not_shallow
458
459
460 def _want_satisfied(store, haves, want, earliest):
461     o = store[want]
462     pending = collections.deque([o])
463     known = set([want])
464     while pending:
465         commit = pending.popleft()
466         if commit.id in haves:
467             return True
468         if commit.type_name != b"commit":
469             # non-commit wants are assumed to be satisfied
470             continue
471         for parent in commit.parents:
472             if parent in known:
473                 continue
474             known.add(parent)
475             parent_obj = store[parent]
476             # TODO: handle parents with later commit times than children
477             if parent_obj.commit_time >= earliest:
478                 pending.append(parent_obj)
479     return False
480
481
482 def _all_wants_satisfied(store, haves, wants):
483     """Check whether all the current wants are satisfied by a set of haves.
484
485     :param store: Object store to retrieve objects from
486     :param haves: A set of commits we know the client has.
487     :param wants: A set of commits the client wants
488     :note: Wants are specified with set_wants rather than passed in since
489         in the current interface they are determined outside this class.
490     """
491     haves = set(haves)
492     if haves:
493         earliest = min([store[h].commit_time for h in haves])
494     else:
495         earliest = 0
496     for want in wants:
497         if not _want_satisfied(store, haves, want, earliest):
498             return False
499
500     return True
501
502
503 class _ProtocolGraphWalker(object):
504     """A graph walker that knows the git protocol.
505
506     As a graph walker, this class implements ack(), next(), and reset(). It
507     also contains some base methods for interacting with the wire and walking
508     the commit tree.
509
510     The work of determining which acks to send is passed on to the
511     implementation instance stored in _impl. The reason for this is that we do
512     not know at object creation time what ack level the protocol requires. A
513     call to set_ack_type() is required to set up the implementation, before
514     any calls to next() or ack() are made.
515     """
516     def __init__(self, handler, object_store, get_peeled, get_symrefs):
517         self.handler = handler
518         self.store = object_store
519         self.get_peeled = get_peeled
520         self.get_symrefs = get_symrefs
521         self.proto = handler.proto
522         self.http_req = handler.http_req
523         self.advertise_refs = handler.advertise_refs
524         self._wants = []
525         self.shallow = set()
526         self.client_shallow = set()
527         self.unshallow = set()
528         self._cached = False
529         self._cache = []
530         self._cache_index = 0
531         self._impl = None
532
533     def determine_wants(self, heads):
534         """Determine the wants for a set of heads.
535
536         The given heads are advertised to the client, who then specifies which
537         refs he wants using 'want' lines. This portion of the protocol is the
538         same regardless of ack type, and in fact is used to set the ack type of
539         the ProtocolGraphWalker.
540
541         If the client has the 'shallow' capability, this method also reads and
542         responds to the 'shallow' and 'deepen' lines from the client. These are
543         not part of the wants per se, but they set up necessary state for
544         walking the graph. Additionally, later code depends on this method
545         consuming everything up to the first 'have' line.
546
547         :param heads: a dict of refname->SHA1 to advertise
548         :return: a list of SHA1s requested by the client
549         """
550         symrefs = self.get_symrefs()
551         values = set(heads.values())
552         if self.advertise_refs or not self.http_req:
553             for i, (ref, sha) in enumerate(sorted(heads.items())):
554                 line = sha + b' ' + ref
555                 if not i:
556                     line += (b'\x00' +
557                              self.handler.capability_line(
558                                  self.handler.capabilities() +
559                                  symref_capabilities(symrefs.items())))
560                 self.proto.write_pkt_line(line + b'\n')
561                 peeled_sha = self.get_peeled(ref)
562                 if peeled_sha != sha:
563                     self.proto.write_pkt_line(
564                         peeled_sha + b' ' + ref + ANNOTATED_TAG_SUFFIX + b'\n')
565
566             # i'm done..
567             self.proto.write_pkt_line(None)
568
569             if self.advertise_refs:
570                 return []
571
572         # Now client will sending want want want commands
573         want = self.proto.read_pkt_line()
574         if not want:
575             return []
576         line, caps = extract_want_line_capabilities(want)
577         self.handler.set_client_capabilities(caps)
578         self.set_ack_type(ack_type(caps))
579         allowed = (COMMAND_WANT, COMMAND_SHALLOW, COMMAND_DEEPEN, None)
580         command, sha = _split_proto_line(line, allowed)
581
582         want_revs = []
583         while command == COMMAND_WANT:
584             if sha not in values:
585                 raise GitProtocolError(
586                   'Client wants invalid object %s' % sha)
587             want_revs.append(sha)
588             command, sha = self.read_proto_line(allowed)
589
590         self.set_wants(want_revs)
591         if command in (COMMAND_SHALLOW, COMMAND_DEEPEN):
592             self.unread_proto_line(command, sha)
593             self._handle_shallow_request(want_revs)
594
595         if self.http_req and self.proto.eof():
596             # The client may close the socket at this point, expecting a
597             # flush-pkt from the server. We might be ready to send a packfile
598             # at this point, so we need to explicitly short-circuit in this
599             # case.
600             return []
601
602         return want_revs
603
604     def unread_proto_line(self, command, value):
605         if isinstance(value, int):
606             value = str(value).encode('ascii')
607         self.proto.unread_pkt_line(command + b' ' + value)
608
609     def ack(self, have_ref):
610         if len(have_ref) != 40:
611             raise ValueError("invalid sha %r" % have_ref)
612         return self._impl.ack(have_ref)
613
614     def reset(self):
615         self._cached = True
616         self._cache_index = 0
617
618     def next(self):
619         if not self._cached:
620             if not self._impl and self.http_req:
621                 return None
622             return next(self._impl)
623         self._cache_index += 1
624         if self._cache_index > len(self._cache):
625             return None
626         return self._cache[self._cache_index]
627
628     __next__ = next
629
630     def read_proto_line(self, allowed):
631         """Read a line from the wire.
632
633         :param allowed: An iterable of command names that should be allowed.
634         :return: A tuple of (command, value); see _split_proto_line.
635         :raise UnexpectedCommandError: If an error occurred reading the line.
636         """
637         return _split_proto_line(self.proto.read_pkt_line(), allowed)
638
639     def _handle_shallow_request(self, wants):
640         while True:
641             command, val = self.read_proto_line(
642                     (COMMAND_DEEPEN, COMMAND_SHALLOW))
643             if command == COMMAND_DEEPEN:
644                 depth = val
645                 break
646             self.client_shallow.add(val)
647         self.read_proto_line((None,))  # consume client's flush-pkt
648
649         shallow, not_shallow = _find_shallow(self.store, wants, depth)
650
651         # Update self.shallow instead of reassigning it since we passed a
652         # reference to it before this method was called.
653         self.shallow.update(shallow - not_shallow)
654         new_shallow = self.shallow - self.client_shallow
655         unshallow = self.unshallow = not_shallow & self.client_shallow
656
657         for sha in sorted(new_shallow):
658             self.proto.write_pkt_line(COMMAND_SHALLOW + b' ' + sha)
659         for sha in sorted(unshallow):
660             self.proto.write_pkt_line(COMMAND_UNSHALLOW + b' ' + sha)
661
662         self.proto.write_pkt_line(None)
663
664     def notify_done(self):
665         # relay the message down to the handler.
666         self.handler.notify_done()
667
668     def send_ack(self, sha, ack_type=b''):
669         if ack_type:
670             ack_type = b' ' + ack_type
671         self.proto.write_pkt_line(b'ACK ' + sha + ack_type + b'\n')
672
673     def send_nak(self):
674         self.proto.write_pkt_line(b'NAK\n')
675
676     def handle_done(self, done_required, done_received):
677         # Delegate this to the implementation.
678         return self._impl.handle_done(done_required, done_received)
679
680     def set_wants(self, wants):
681         self._wants = wants
682
683     def all_wants_satisfied(self, haves):
684         """Check whether all the current wants are satisfied by a set of haves.
685
686         :param haves: A set of commits we know the client has.
687         :note: Wants are specified with set_wants rather than passed in since
688             in the current interface they are determined outside this class.
689         """
690         return _all_wants_satisfied(self.store, haves, self._wants)
691
692     def set_ack_type(self, ack_type):
693         impl_classes = {
694           MULTI_ACK: MultiAckGraphWalkerImpl,
695           MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
696           SINGLE_ACK: SingleAckGraphWalkerImpl,
697           }
698         self._impl = impl_classes[ack_type](self)
699
700
701 _GRAPH_WALKER_COMMANDS = (COMMAND_HAVE, COMMAND_DONE, None)
702
703
704 class SingleAckGraphWalkerImpl(object):
705     """Graph walker implementation that speaks the single-ack protocol."""
706
707     def __init__(self, walker):
708         self.walker = walker
709         self._common = []
710
711     def ack(self, have_ref):
712         if not self._common:
713             self.walker.send_ack(have_ref)
714             self._common.append(have_ref)
715
716     def next(self):
717         command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
718         if command in (None, COMMAND_DONE):
719             # defer the handling of done
720             self.walker.notify_done()
721             return None
722         elif command == COMMAND_HAVE:
723             return sha
724
725     __next__ = next
726
727     def handle_done(self, done_required, done_received):
728         if not self._common:
729             self.walker.send_nak()
730
731         if done_required and not done_received:
732             # we are not done, especially when done is required; skip
733             # the pack for this request and especially do not handle
734             # the done.
735             return False
736
737         if not done_received and not self._common:
738             # Okay we are not actually done then since the walker picked
739             # up no haves.  This is usually triggered when client attempts
740             # to pull from a source that has no common base_commit.
741             # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
742             #          test_multi_ack_stateless_nodone
743             return False
744
745         return True
746
747
748 class MultiAckGraphWalkerImpl(object):
749     """Graph walker implementation that speaks the multi-ack protocol."""
750
751     def __init__(self, walker):
752         self.walker = walker
753         self._found_base = False
754         self._common = []
755
756     def ack(self, have_ref):
757         self._common.append(have_ref)
758         if not self._found_base:
759             self.walker.send_ack(have_ref, b'continue')
760             if self.walker.all_wants_satisfied(self._common):
761                 self._found_base = True
762         # else we blind ack within next
763
764     def next(self):
765         while True:
766             command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
767             if command is None:
768                 self.walker.send_nak()
769                 # in multi-ack mode, a flush-pkt indicates the client wants to
770                 # flush but more have lines are still coming
771                 continue
772             elif command == COMMAND_DONE:
773                 self.walker.notify_done()
774                 return None
775             elif command == COMMAND_HAVE:
776                 if self._found_base:
777                     # blind ack
778                     self.walker.send_ack(sha, b'continue')
779                 return sha
780
781     __next__ = next
782
783     def handle_done(self, done_required, done_received):
784         if done_required and not done_received:
785             # we are not done, especially when done is required; skip
786             # the pack for this request and especially do not handle
787             # the done.
788             return False
789
790         if not done_received and not self._common:
791             # Okay we are not actually done then since the walker picked
792             # up no haves.  This is usually triggered when client attempts
793             # to pull from a source that has no common base_commit.
794             # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
795             #          test_multi_ack_stateless_nodone
796             return False
797
798         # don't nak unless no common commits were found, even if not
799         # everything is satisfied
800         if self._common:
801             self.walker.send_ack(self._common[-1])
802         else:
803             self.walker.send_nak()
804         return True
805
806
807 class MultiAckDetailedGraphWalkerImpl(object):
808     """Graph walker implementation speaking the multi-ack-detailed protocol."""
809
810     def __init__(self, walker):
811         self.walker = walker
812         self._common = []
813
814     def ack(self, have_ref):
815         # Should only be called iff have_ref is common
816         self._common.append(have_ref)
817         self.walker.send_ack(have_ref, b'common')
818
819     def next(self):
820         while True:
821             command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
822             if command is None:
823                 if self.walker.all_wants_satisfied(self._common):
824                     self.walker.send_ack(self._common[-1], b'ready')
825                 self.walker.send_nak()
826                 if self.walker.http_req:
827                     # The HTTP version of this request a flush-pkt always
828                     # signifies an end of request, so we also return
829                     # nothing here as if we are done (but not really, as
830                     # it depends on whether no-done capability was
831                     # specified and that's handled in handle_done which
832                     # may or may not call post_nodone_check depending on
833                     # that).
834                     return None
835             elif command == COMMAND_DONE:
836                 # Let the walker know that we got a done.
837                 self.walker.notify_done()
838                 break
839             elif command == COMMAND_HAVE:
840                 # return the sha and let the caller ACK it with the
841                 # above ack method.
842                 return sha
843         # don't nak unless no common commits were found, even if not
844         # everything is satisfied
845
846     __next__ = next
847
848     def handle_done(self, done_required, done_received):
849         if done_required and not done_received:
850             # we are not done, especially when done is required; skip
851             # the pack for this request and especially do not handle
852             # the done.
853             return False
854
855         if not done_received and not self._common:
856             # Okay we are not actually done then since the walker picked
857             # up no haves.  This is usually triggered when client attempts
858             # to pull from a source that has no common base_commit.
859             # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
860             #          test_multi_ack_stateless_nodone
861             return False
862
863         # don't nak unless no common commits were found, even if not
864         # everything is satisfied
865         if self._common:
866             self.walker.send_ack(self._common[-1])
867         else:
868             self.walker.send_nak()
869         return True
870
871
872 class ReceivePackHandler(PackHandler):
873     """Protocol handler for downloading a pack from the client."""
874
875     def __init__(self, backend, args, proto, http_req=None,
876                  advertise_refs=False):
877         super(ReceivePackHandler, self).__init__(
878                 backend, proto, http_req=http_req)
879         self.repo = backend.open_repository(args[0])
880         self.advertise_refs = advertise_refs
881
882     @classmethod
883     def capabilities(cls):
884         return [CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS,
885                 CAPABILITY_QUIET, CAPABILITY_OFS_DELTA,
886                 CAPABILITY_SIDE_BAND_64K, CAPABILITY_NO_DONE]
887
888     def _apply_pack(self, refs):
889         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
890                           AssertionError, socket.error, zlib.error,
891                           ObjectFormatException)
892         status = []
893         will_send_pack = False
894
895         for command in refs:
896             if command[1] != ZERO_SHA:
897                 will_send_pack = True
898
899         if will_send_pack:
900             # TODO: more informative error messages than just the exception
901             # string
902             try:
903                 recv = getattr(self.proto, "recv", None)
904                 self.repo.object_store.add_thin_pack(self.proto.read, recv)
905                 status.append((b'unpack', b'ok'))
906             except all_exceptions as e:
907                 status.append((b'unpack', str(e).replace('\n', '')))
908                 # The pack may still have been moved in, but it may contain
909                 # broken objects. We trust a later GC to clean it up.
910         else:
911             # The git protocol want to find a status entry related to unpack
912             # process even if no pack data has been sent.
913             status.append((b'unpack', b'ok'))
914
915         for oldsha, sha, ref in refs:
916             ref_status = b'ok'
917             try:
918                 if sha == ZERO_SHA:
919                     if CAPABILITY_DELETE_REFS not in self.capabilities():
920                         raise GitProtocolError(
921                           'Attempted to delete refs without delete-refs '
922                           'capability.')
923                     try:
924                         self.repo.refs.remove_if_equals(ref, oldsha)
925                     except all_exceptions:
926                         ref_status = b'failed to delete'
927                 else:
928                     try:
929                         self.repo.refs.set_if_equals(ref, oldsha, sha)
930                     except all_exceptions:
931                         ref_status = b'failed to write'
932             except KeyError as e:
933                 ref_status = b'bad ref'
934             status.append((ref, ref_status))
935
936         return status
937
938     def _report_status(self, status):
939         if self.has_capability(CAPABILITY_SIDE_BAND_64K):
940             writer = BufferedPktLineWriter(
941               lambda d: self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, d))
942             write = writer.write
943
944             def flush():
945                 writer.flush()
946                 self.proto.write_pkt_line(None)
947         else:
948             write = self.proto.write_pkt_line
949
950             def flush():
951                 pass
952
953         for name, msg in status:
954             if name == b'unpack':
955                 write(b'unpack ' + msg + b'\n')
956             elif msg == b'ok':
957                 write(b'ok ' + name + b'\n')
958             else:
959                 write(b'ng ' + name + b' ' + msg + b'\n')
960         write(None)
961         flush()
962
963     def handle(self):
964         if self.advertise_refs or not self.http_req:
965             refs = sorted(self.repo.get_refs().items())
966             symrefs = sorted(self.repo.refs.get_symrefs().items())
967
968             if not refs:
969                 refs = [(CAPABILITIES_REF, ZERO_SHA)]
970             self.proto.write_pkt_line(
971               refs[0][1] + b' ' + refs[0][0] + b'\0' +
972               self.capability_line(
973                   self.capabilities() + symref_capabilities(symrefs)) + b'\n')
974             for i in range(1, len(refs)):
975                 ref = refs[i]
976                 self.proto.write_pkt_line(ref[1] + b' ' + ref[0] + b'\n')
977
978             self.proto.write_pkt_line(None)
979             if self.advertise_refs:
980                 return
981
982         client_refs = []
983         ref = self.proto.read_pkt_line()
984
985         # if ref is none then client doesnt want to send us anything..
986         if ref is None:
987             return
988
989         ref, caps = extract_capabilities(ref)
990         self.set_client_capabilities(caps)
991
992         # client will now send us a list of (oldsha, newsha, ref)
993         while ref:
994             client_refs.append(ref.split())
995             ref = self.proto.read_pkt_line()
996
997         # backend can now deal with this refs and read a pack using self.read
998         status = self._apply_pack(client_refs)
999
1000         # when we have read all the pack from the client, send a status report
1001         # if the client asked for it
1002         if self.has_capability(CAPABILITY_REPORT_STATUS):
1003             self._report_status(status)
1004
1005
1006 class UploadArchiveHandler(Handler):
1007
1008     def __init__(self, backend, proto, http_req=None):
1009         super(UploadArchiveHandler, self).__init__(backend, proto, http_req)
1010
1011     def handle(self):
1012         # TODO(jelmer)
1013         raise NotImplementedError(self.handle)
1014
1015
1016 # Default handler classes for git services.
1017 DEFAULT_HANDLERS = {
1018   b'git-upload-pack': UploadPackHandler,
1019   b'git-receive-pack': ReceivePackHandler,
1020   # b'git-upload-archive': UploadArchiveHandler,
1021 }
1022
1023
1024 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
1025
1026     def __init__(self, handlers, *args, **kwargs):
1027         self.handlers = handlers
1028         SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
1029
1030     def handle(self):
1031         proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
1032         command, args = proto.read_cmd()
1033         logger.info('Handling %s request, args=%s', command, args)
1034
1035         cls = self.handlers.get(command, None)
1036         if not callable(cls):
1037             raise GitProtocolError('Invalid service %s' % command)
1038         h = cls(self.server.backend, args, proto)
1039         h.handle()
1040
1041
1042 class TCPGitServer(SocketServer.TCPServer):
1043
1044     allow_reuse_address = True
1045     serve = SocketServer.TCPServer.serve_forever
1046
1047     def _make_handler(self, *args, **kwargs):
1048         return TCPGitRequestHandler(self.handlers, *args, **kwargs)
1049
1050     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
1051         self.handlers = dict(DEFAULT_HANDLERS)
1052         if handlers is not None:
1053             self.handlers.update(handlers)
1054         self.backend = backend
1055         logger.info('Listening for TCP connections on %s:%d',
1056                     listen_addr, port)
1057         SocketServer.TCPServer.__init__(self, (listen_addr, port),
1058                                         self._make_handler)
1059
1060     def verify_request(self, request, client_address):
1061         logger.info('Handling request from %s', client_address)
1062         return True
1063
1064     def handle_error(self, request, client_address):
1065         logger.exception('Exception happened during processing of request '
1066                          'from %s', client_address)
1067
1068
1069 def main(argv=sys.argv):
1070     """Entry point for starting a TCP git server."""
1071     import optparse
1072     parser = optparse.OptionParser()
1073     parser.add_option("-l", "--listen_address", dest="listen_address",
1074                       default="localhost",
1075                       help="Binding IP address.")
1076     parser.add_option("-p", "--port", dest="port", type=int,
1077                       default=TCP_GIT_PORT,
1078                       help="Binding TCP port.")
1079     options, args = parser.parse_args(argv)
1080
1081     log_utils.default_logging_config()
1082     if len(args) > 1:
1083         gitdir = args[1]
1084     else:
1085         gitdir = '.'
1086     # TODO(jelmer): Support git-daemon-export-ok and --export-all.
1087     backend = FileSystemBackend(gitdir)
1088     server = TCPGitServer(backend, options.listen_address, options.port)
1089     server.serve_forever()
1090
1091
1092 def serve_command(handler_cls, argv=sys.argv, backend=None, inf=sys.stdin,
1093                   outf=sys.stdout):
1094     """Serve a single command.
1095
1096     This is mostly useful for the implementation of commands used by e.g.
1097     git+ssh.
1098
1099     :param handler_cls: `Handler` class to use for the request
1100     :param argv: execv-style command-line arguments. Defaults to sys.argv.
1101     :param backend: `Backend` to use
1102     :param inf: File-like object to read from, defaults to standard input.
1103     :param outf: File-like object to write to, defaults to standard output.
1104     :return: Exit code for use with sys.exit. 0 on success, 1 on failure.
1105     """
1106     if backend is None:
1107         backend = FileSystemBackend()
1108
1109     def send_fn(data):
1110         outf.write(data)
1111         outf.flush()
1112     proto = Protocol(inf.read, send_fn)
1113     handler = handler_cls(backend, argv[1:], proto)
1114     # FIXME: Catch exceptions and write a single-line summary to outf.
1115     handler.handle()
1116     return 0
1117
1118
1119 def generate_info_refs(repo):
1120     """Generate an info refs file."""
1121     refs = repo.get_refs()
1122     return write_info_refs(refs, repo.object_store)
1123
1124
1125 def generate_objects_info_packs(repo):
1126     """Generate an index for for packs."""
1127     for pack in repo.object_store.packs:
1128         yield (
1129             b'P ' + pack.data.filename.encode(sys.getfilesystemencoding()) +
1130             b'\n')
1131
1132
1133 def update_server_info(repo):
1134     """Generate server info for dumb file access.
1135
1136     This generates info/refs and objects/info/packs,
1137     similar to "git update-server-info".
1138     """
1139     repo._put_named_file(
1140         os.path.join('info', 'refs'),
1141         b"".join(generate_info_refs(repo)))
1142
1143     repo._put_named_file(
1144         os.path.join('objects', 'info', 'packs'),
1145         b"".join(generate_objects_info_packs(repo)))
1146
1147
1148 if __name__ == '__main__':
1149     main()