28331e66f1d53de058b2e4e609863a8d0ff5f1c9
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 """Git smart network protocol server implementation.
20
21 For more detailed implementation on the network protocol, see the
22 Documentation/technical directory in the cgit distribution, and in particular:
23
24 * Documentation/technical/protocol-capabilities.txt
25 * Documentation/technical/pack-protocol.txt
26 """
27
28
29 import collections
30 import socket
31 import SocketServer
32 import sys
33 import zlib
34
35 from dulwich.errors import (
36     ApplyDeltaError,
37     ChecksumMismatch,
38     GitProtocolError,
39     ObjectFormatException,
40     )
41 from dulwich import log_utils
42 from dulwich.objects import (
43     hex_to_sha,
44     )
45 from dulwich.pack import (
46     PackStreamReader,
47     write_pack_data,
48     )
49 from dulwich.protocol import (
50     MULTI_ACK,
51     MULTI_ACK_DETAILED,
52     ProtocolFile,
53     ReceivableProtocol,
54     SINGLE_ACK,
55     TCP_GIT_PORT,
56     ZERO_SHA,
57     ack_type,
58     extract_capabilities,
59     extract_want_line_capabilities,
60     BufferedPktLineWriter,
61     )
62 from dulwich.repo import (
63     Repo,
64     )
65
66
67 logger = log_utils.getLogger(__name__)
68
69
70 class Backend(object):
71     """A backend for the Git smart server implementation."""
72
73     def open_repository(self, path):
74         """Open the repository at a path."""
75         raise NotImplementedError(self.open_repository)
76
77
78 class BackendRepo(object):
79     """Repository abstraction used by the Git server.
80     
81     Please note that the methods required here are a 
82     subset of those provided by dulwich.repo.Repo.
83     """
84
85     object_store = None
86     refs = None
87
88     def get_refs(self):
89         """
90         Get all the refs in the repository
91
92         :return: dict of name -> sha
93         """
94         raise NotImplementedError
95
96     def get_peeled(self, name):
97         """Return the cached peeled value of a ref, if available.
98
99         :param name: Name of the ref to peel
100         :return: The peeled value of the ref. If the ref is known not point to
101             a tag, this will be the SHA the ref refers to. If no cached
102             information about a tag is available, this method may return None,
103             but it should attempt to peel the tag if possible.
104         """
105         return None
106
107     def fetch_objects(self, determine_wants, graph_walker, progress,
108                       get_tagged=None):
109         """
110         Yield the objects required for a list of commits.
111
112         :param progress: is a callback to send progress messages to the client
113         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
114             sha for including tags.
115         """
116         raise NotImplementedError
117
118
119 class PackStreamCopier(PackStreamReader):
120     """Class to verify a pack stream as it is being read.
121
122     The pack is read from a ReceivableProtocol using read() or recv() as
123     appropriate and written out to the given file-like object.
124     """
125
126     def __init__(self, read_all, read_some, outfile):
127         super(PackStreamCopier, self).__init__(read_all, read_some)
128         self.outfile = outfile
129
130     def _read(self, read, size):
131         data = super(PackStreamCopier, self)._read(read, size)
132         self.outfile.write(data)
133         return data
134
135     def verify(self):
136         """Verify a pack stream and write it to the output file.
137
138         See PackStreamReader.iterobjects for a list of exceptions this may
139         throw.
140         """
141         for _, _, _ in self.read_objects():
142             pass
143
144
145 class DictBackend(Backend):
146     """Trivial backend that looks up Git repositories in a dictionary."""
147
148     def __init__(self, repos):
149         self.repos = repos
150
151     def open_repository(self, path):
152         logger.debug('Opening repository at %s', path)
153         # FIXME: What to do in case there is no repo ?
154         return self.repos[path]
155
156
157 class Handler(object):
158     """Smart protocol command handler base class."""
159
160     def __init__(self, backend, proto):
161         self.backend = backend
162         self.proto = proto
163         self._client_capabilities = None
164
165     @classmethod
166     def capability_line(cls):
167         return " ".join(cls.capabilities())
168
169     @classmethod
170     def capabilities(cls):
171         raise NotImplementedError(cls.capabilities)
172
173     @classmethod
174     def innocuous_capabilities(cls):
175         return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
176
177     @classmethod
178     def required_capabilities(cls):
179         """Return a list of capabilities that we require the client to have."""
180         return []
181
182     def set_client_capabilities(self, caps):
183         allowable_caps = set(self.innocuous_capabilities())
184         allowable_caps.update(self.capabilities())
185         for cap in caps:
186             if cap not in allowable_caps:
187                 raise GitProtocolError('Client asked for capability %s that '
188                                        'was not advertised.' % cap)
189         for cap in self.required_capabilities():
190             if cap not in caps:
191                 raise GitProtocolError('Client does not support required '
192                                        'capability %s.' % cap)
193         self._client_capabilities = set(caps)
194         logger.info('Client capabilities: %s', caps)
195
196     def has_capability(self, cap):
197         if self._client_capabilities is None:
198             raise GitProtocolError('Server attempted to access capability %s '
199                                    'before asking client' % cap)
200         return cap in self._client_capabilities
201
202
203 class UploadPackHandler(Handler):
204     """Protocol handler for uploading a pack to the server."""
205
206     def __init__(self, backend, args, proto,
207                  stateless_rpc=False, advertise_refs=False):
208         Handler.__init__(self, backend, proto)
209         self.repo = backend.open_repository(args[0])
210         self._graph_walker = None
211         self.stateless_rpc = stateless_rpc
212         self.advertise_refs = advertise_refs
213
214     @classmethod
215     def capabilities(cls):
216         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
217                 "ofs-delta", "no-progress", "include-tag")
218
219     @classmethod
220     def required_capabilities(cls):
221         return ("side-band-64k", "thin-pack", "ofs-delta")
222
223     def progress(self, message):
224         if self.has_capability("no-progress"):
225             return
226         self.proto.write_sideband(2, message)
227
228     def get_tagged(self, refs=None, repo=None):
229         """Get a dict of peeled values of tags to their original tag shas.
230
231         :param refs: dict of refname -> sha of possible tags; defaults to all of
232             the backend's refs.
233         :param repo: optional Repo instance for getting peeled refs; defaults to
234             the backend's repo, if available
235         :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
236             tag whose peeled value is peeled_sha.
237         """
238         if not self.has_capability("include-tag"):
239             return {}
240         if refs is None:
241             refs = self.repo.get_refs()
242         if repo is None:
243             repo = getattr(self.repo, "repo", None)
244             if repo is None:
245                 # Bail if we don't have a Repo available; this is ok since
246                 # clients must be able to handle if the server doesn't include
247                 # all relevant tags.
248                 # TODO: fix behavior when missing
249                 return {}
250         tagged = {}
251         for name, sha in refs.iteritems():
252             peeled_sha = repo.get_peeled(name)
253             if peeled_sha != sha:
254                 tagged[peeled_sha] = sha
255         return tagged
256
257     def handle(self):
258         write = lambda x: self.proto.write_sideband(1, x)
259
260         graph_walker = ProtocolGraphWalker(self, self.repo.object_store,
261             self.repo.get_peeled)
262         objects_iter = self.repo.fetch_objects(
263           graph_walker.determine_wants, graph_walker, self.progress,
264           get_tagged=self.get_tagged)
265
266         # Do they want any objects?
267         if len(objects_iter) == 0:
268             return
269
270         self.progress("dul-daemon says what\n")
271         self.progress("counting objects: %d, done.\n" % len(objects_iter))
272         write_pack_data(ProtocolFile(None, write), objects_iter, 
273                         len(objects_iter))
274         self.progress("how was that, then?\n")
275         # we are done
276         self.proto.write("0000")
277
278
279 def _split_proto_line(line):
280     """Split a line read from the wire.
281
282     :return: a tuple having one of the following forms:
283         ('want', obj_id)
284         ('have', obj_id)
285         ('done', None)
286         (None, None)  (for a flush-pkt)
287
288     :raise GitProtocolError: if the line cannot be parsed into one of the
289         possible return values.
290     """
291     if not line:
292         fields = [None]
293     else:
294         fields = line.rstrip('\n').split(' ', 1)
295     if len(fields) == 1 and fields[0] in ('done', None):
296         return (fields[0], None)
297     elif len(fields) == 2 and fields[0] in ('want', 'have'):
298         try:
299             hex_to_sha(fields[1])
300             return tuple(fields)
301         except (TypeError, AssertionError), e:
302             raise GitProtocolError(e)
303     raise GitProtocolError('Received invalid line from client:\n%s' % line)
304
305
306 class ProtocolGraphWalker(object):
307     """A graph walker that knows the git protocol.
308
309     As a graph walker, this class implements ack(), next(), and reset(). It
310     also contains some base methods for interacting with the wire and walking
311     the commit tree.
312
313     The work of determining which acks to send is passed on to the
314     implementation instance stored in _impl. The reason for this is that we do
315     not know at object creation time what ack level the protocol requires. A
316     call to set_ack_level() is required to set up the implementation, before any
317     calls to next() or ack() are made.
318     """
319     def __init__(self, handler, object_store, get_peeled):
320         self.handler = handler
321         self.store = object_store
322         self.get_peeled = get_peeled
323         self.proto = handler.proto
324         self.stateless_rpc = handler.stateless_rpc
325         self.advertise_refs = handler.advertise_refs
326         self._wants = []
327         self._cached = False
328         self._cache = []
329         self._cache_index = 0
330         self._impl = None
331
332     def determine_wants(self, heads):
333         """Determine the wants for a set of heads.
334
335         The given heads are advertised to the client, who then specifies which
336         refs he wants using 'want' lines. This portion of the protocol is the
337         same regardless of ack type, and in fact is used to set the ack type of
338         the ProtocolGraphWalker.
339
340         :param heads: a dict of refname->SHA1 to advertise
341         :return: a list of SHA1s requested by the client
342         """
343         if not heads:
344             raise GitProtocolError('No heads found')
345         values = set(heads.itervalues())
346         if self.advertise_refs or not self.stateless_rpc:
347             for i, (ref, sha) in enumerate(heads.iteritems()):
348                 line = "%s %s" % (sha, ref)
349                 if not i:
350                     line = "%s\x00%s" % (line, self.handler.capability_line())
351                 self.proto.write_pkt_line("%s\n" % line)
352                 peeled_sha = self.get_peeled(ref)
353                 if peeled_sha != sha:
354                     self.proto.write_pkt_line('%s %s^{}\n' %
355                                               (peeled_sha, ref))
356
357             # i'm done..
358             self.proto.write_pkt_line(None)
359
360             if self.advertise_refs:
361                 return []
362
363         # Now client will sending want want want commands
364         want = self.proto.read_pkt_line()
365         if not want:
366             return []
367         line, caps = extract_want_line_capabilities(want)
368         self.handler.set_client_capabilities(caps)
369         self.set_ack_type(ack_type(caps))
370         command, sha = _split_proto_line(line)
371
372         want_revs = []
373         while command != None:
374             if command != 'want':
375                 raise GitProtocolError(
376                   'Protocol got unexpected command %s' % command)
377             if sha not in values:
378                 raise GitProtocolError(
379                   'Client wants invalid object %s' % sha)
380             want_revs.append(sha)
381             command, sha = self.read_proto_line()
382
383         self.set_wants(want_revs)
384         return want_revs
385
386     def ack(self, have_ref):
387         return self._impl.ack(have_ref)
388
389     def reset(self):
390         self._cached = True
391         self._cache_index = 0
392
393     def next(self):
394         if not self._cached:
395             if not self._impl and self.stateless_rpc:
396                 return None
397             return self._impl.next()
398         self._cache_index += 1
399         if self._cache_index > len(self._cache):
400             return None
401         return self._cache[self._cache_index]
402
403     def read_proto_line(self):
404         """Read and split a line from the wire.
405
406         :return: A tuple of (command, value); see _split_proto_line.
407         :raise GitProtocolError: If an error occurred reading the line.
408         """
409         return _split_proto_line(self.proto.read_pkt_line())
410
411     def send_ack(self, sha, ack_type=''):
412         if ack_type:
413             ack_type = ' %s' % ack_type
414         self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
415
416     def send_nak(self):
417         self.proto.write_pkt_line('NAK\n')
418
419     def set_wants(self, wants):
420         self._wants = wants
421
422     def _is_satisfied(self, haves, want, earliest):
423         """Check whether a want is satisfied by a set of haves.
424
425         A want, typically a branch tip, is "satisfied" only if there exists a
426         path back from that want to one of the haves.
427
428         :param haves: A set of commits we know the client has.
429         :param want: The want to check satisfaction for.
430         :param earliest: A timestamp beyond which the search for haves will be
431             terminated, presumably because we're searching too far down the
432             wrong branch.
433         """
434         o = self.store[want]
435         pending = collections.deque([o])
436         while pending:
437             commit = pending.popleft()
438             if commit.id in haves:
439                 return True
440             if commit.type_name != "commit":
441                 # non-commit wants are assumed to be satisfied
442                 continue
443             for parent in commit.parents:
444                 parent_obj = self.store[parent]
445                 # TODO: handle parents with later commit times than children
446                 if parent_obj.commit_time >= earliest:
447                     pending.append(parent_obj)
448         return False
449
450     def all_wants_satisfied(self, haves):
451         """Check whether all the current wants are satisfied by a set of haves.
452
453         :param haves: A set of commits we know the client has.
454         :note: Wants are specified with set_wants rather than passed in since
455             in the current interface they are determined outside this class.
456         """
457         haves = set(haves)
458         earliest = min([self.store[h].commit_time for h in haves])
459         for want in self._wants:
460             if not self._is_satisfied(haves, want, earliest):
461                 return False
462         return True
463
464     def set_ack_type(self, ack_type):
465         impl_classes = {
466           MULTI_ACK: MultiAckGraphWalkerImpl,
467           MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
468           SINGLE_ACK: SingleAckGraphWalkerImpl,
469           }
470         self._impl = impl_classes[ack_type](self)
471
472
473 class SingleAckGraphWalkerImpl(object):
474     """Graph walker implementation that speaks the single-ack protocol."""
475
476     def __init__(self, walker):
477         self.walker = walker
478         self._sent_ack = False
479
480     def ack(self, have_ref):
481         if not self._sent_ack:
482             self.walker.send_ack(have_ref)
483             self._sent_ack = True
484
485     def next(self):
486         command, sha = self.walker.read_proto_line()
487         if command in (None, 'done'):
488             if not self._sent_ack:
489                 self.walker.send_nak()
490             return None
491         elif command == 'have':
492             return sha
493
494
495 class MultiAckGraphWalkerImpl(object):
496     """Graph walker implementation that speaks the multi-ack protocol."""
497
498     def __init__(self, walker):
499         self.walker = walker
500         self._found_base = False
501         self._common = []
502
503     def ack(self, have_ref):
504         self._common.append(have_ref)
505         if not self._found_base:
506             self.walker.send_ack(have_ref, 'continue')
507             if self.walker.all_wants_satisfied(self._common):
508                 self._found_base = True
509         # else we blind ack within next
510
511     def next(self):
512         while True:
513             command, sha = self.walker.read_proto_line()
514             if command is None:
515                 self.walker.send_nak()
516                 # in multi-ack mode, a flush-pkt indicates the client wants to
517                 # flush but more have lines are still coming
518                 continue
519             elif command == 'done':
520                 # don't nak unless no common commits were found, even if not
521                 # everything is satisfied
522                 if self._common:
523                     self.walker.send_ack(self._common[-1])
524                 else:
525                     self.walker.send_nak()
526                 return None
527             elif command == 'have':
528                 if self._found_base:
529                     # blind ack
530                     self.walker.send_ack(sha, 'continue')
531                 return sha
532
533
534 class MultiAckDetailedGraphWalkerImpl(object):
535     """Graph walker implementation speaking the multi-ack-detailed protocol."""
536
537     def __init__(self, walker):
538         self.walker = walker
539         self._found_base = False
540         self._common = []
541
542     def ack(self, have_ref):
543         self._common.append(have_ref)
544         if not self._found_base:
545             self.walker.send_ack(have_ref, 'common')
546             if self.walker.all_wants_satisfied(self._common):
547                 self._found_base = True
548                 self.walker.send_ack(have_ref, 'ready')
549         # else we blind ack within next
550
551     def next(self):
552         while True:
553             command, sha = self.walker.read_proto_line()
554             if command is None:
555                 self.walker.send_nak()
556                 if self.walker.stateless_rpc:
557                     return None
558                 continue
559             elif command == 'done':
560                 # don't nak unless no common commits were found, even if not
561                 # everything is satisfied
562                 if self._common:
563                     self.walker.send_ack(self._common[-1])
564                 else:
565                     self.walker.send_nak()
566                 return None
567             elif command == 'have':
568                 if self._found_base:
569                     # blind ack; can happen if the client has more requests
570                     # inflight
571                     self.walker.send_ack(sha, 'ready')
572                 return sha
573
574
575 class ReceivePackHandler(Handler):
576     """Protocol handler for downloading a pack from the client."""
577
578     def __init__(self, backend, args, proto,
579                  stateless_rpc=False, advertise_refs=False):
580         Handler.__init__(self, backend, proto)
581         self.repo = backend.open_repository(args[0])
582         self.stateless_rpc = stateless_rpc
583         self.advertise_refs = advertise_refs
584
585     @classmethod
586     def capabilities(cls):
587         return ("report-status", "delete-refs", "side-band-64k")
588
589     def _apply_pack(self, refs):
590         f, commit = self.repo.object_store.add_thin_pack()
591         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
592                           AssertionError, socket.error, zlib.error,
593                           ObjectFormatException)
594         status = []
595         # TODO: more informative error messages than just the exception string
596         try:
597             PackStreamCopier(self.proto.read, self.proto.recv, f).verify()
598             p = commit()
599             if not p:
600                 raise IOError('Failed to write pack')
601             p.check()
602             status.append(('unpack', 'ok'))
603         except all_exceptions, e:
604             status.append(('unpack', str(e).replace('\n', '')))
605             # The pack may still have been moved in, but it may contain broken
606             # objects. We trust a later GC to clean it up.
607
608         for oldsha, sha, ref in refs:
609             ref_status = 'ok'
610             try:
611                 if sha == ZERO_SHA:
612                     if not 'delete-refs' in self.capabilities():
613                         raise GitProtocolError(
614                           'Attempted to delete refs without delete-refs '
615                           'capability.')
616                     try:
617                         del self.repo.refs[ref]
618                     except all_exceptions:
619                         ref_status = 'failed to delete'
620                 else:
621                     try:
622                         self.repo.refs[ref] = sha
623                     except all_exceptions:
624                         ref_status = 'failed to write'
625             except KeyError, e:
626                 ref_status = 'bad ref'
627             status.append((ref, ref_status))
628
629         return status
630
631     def _report_status(self, status):
632         if self.has_capability('side-band-64k'):
633             writer = BufferedPktLineWriter(
634               lambda d: self.proto.write_sideband(1, d))
635             write = writer.write
636
637             def flush():
638                 writer.flush()
639                 self.proto.write_pkt_line(None)
640         else:
641             write = self.proto.write_pkt_line
642             flush = lambda: None
643
644         for name, msg in status:
645             if name == 'unpack':
646                 write('unpack %s\n' % msg)
647             elif msg == 'ok':
648                 write('ok %s\n' % name)
649             else:
650                 write('ng %s %s\n' % (name, msg))
651         write(None)
652         flush()
653
654     def handle(self):
655         refs = self.repo.get_refs().items()
656
657         if self.advertise_refs or not self.stateless_rpc:
658             if refs:
659                 self.proto.write_pkt_line(
660                   "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
661                                      self.capability_line()))
662                 for i in range(1, len(refs)):
663                     ref = refs[i]
664                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
665             else:
666                 self.proto.write_pkt_line("%s capabilities^{} %s" % (
667                   ZERO_SHA, self.capability_line()))
668
669             self.proto.write("0000")
670             if self.advertise_refs:
671                 return
672
673         client_refs = []
674         ref = self.proto.read_pkt_line()
675
676         # if ref is none then client doesnt want to send us anything..
677         if ref is None:
678             return
679
680         ref, caps = extract_capabilities(ref)
681         self.set_client_capabilities(caps)
682
683         # client will now send us a list of (oldsha, newsha, ref)
684         while ref:
685             client_refs.append(ref.split())
686             ref = self.proto.read_pkt_line()
687
688         # backend can now deal with this refs and read a pack using self.read
689         status = self._apply_pack(client_refs)
690
691         # when we have read all the pack from the client, send a status report
692         # if the client asked for it
693         if self.has_capability('report-status'):
694             self._report_status(status)
695
696
697 # Default handler classes for git services.
698 DEFAULT_HANDLERS = {
699   'git-upload-pack': UploadPackHandler,
700   'git-receive-pack': ReceivePackHandler,
701   }
702
703
704 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
705
706     def __init__(self, handlers, *args, **kwargs):
707         self.handlers = handlers
708         SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
709
710     def handle(self):
711         proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
712         command, args = proto.read_cmd()
713         logger.info('Handling %s request, args=%s', command, args)
714
715         cls = self.handlers.get(command, None)
716         if not callable(cls):
717             raise GitProtocolError('Invalid service %s' % command)
718         h = cls(self.server.backend, args, proto)
719         h.handle()
720
721
722 class TCPGitServer(SocketServer.TCPServer):
723
724     allow_reuse_address = True
725     serve = SocketServer.TCPServer.serve_forever
726
727     def _make_handler(self, *args, **kwargs):
728         return TCPGitRequestHandler(self.handlers, *args, **kwargs)
729
730     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
731         self.handlers = dict(DEFAULT_HANDLERS)
732         if handlers is not None:
733             self.handlers.update(handlers)
734         self.backend = backend
735         logger.info('Listening for TCP connections on %s:%d', listen_addr, port)
736         SocketServer.TCPServer.__init__(self, (listen_addr, port),
737                                         self._make_handler)
738
739     def verify_request(self, request, client_address):
740         logger.info('Handling request from %s', client_address)
741         return True
742
743     def handle_error(self, request, client_address):
744         logger.exception('Exception happened during processing of request '
745                          'from %s', client_address)
746
747
748 def main(argv=sys.argv):
749     """Entry point for starting a TCP git server."""
750     if len(argv) > 1:
751         gitdir = argv[1]
752     else:
753         gitdir = '.'
754
755     log_utils.default_logging_config()
756     backend = DictBackend({'/': Repo(gitdir)})
757     server = TCPGitServer(backend, 'localhost')
758     server.serve_forever()