server: Explicitly specify allowed protocol commands.
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 """Git smart network protocol server implementation.
20
21 For more detailed implementation on the network protocol, see the
22 Documentation/technical directory in the cgit distribution, and in particular:
23
24 * Documentation/technical/protocol-capabilities.txt
25 * Documentation/technical/pack-protocol.txt
26 """
27
28
29 import collections
30 import socket
31 import SocketServer
32 import sys
33 import zlib
34
35 from dulwich.errors import (
36     ApplyDeltaError,
37     ChecksumMismatch,
38     GitProtocolError,
39     UnexpectedCommandError,
40     ObjectFormatException,
41     )
42 from dulwich import log_utils
43 from dulwich.objects import (
44     hex_to_sha,
45     )
46 from dulwich.pack import (
47     PackStreamReader,
48     write_pack_data,
49     )
50 from dulwich.protocol import (
51     MULTI_ACK,
52     MULTI_ACK_DETAILED,
53     ProtocolFile,
54     ReceivableProtocol,
55     SINGLE_ACK,
56     TCP_GIT_PORT,
57     ZERO_SHA,
58     ack_type,
59     extract_capabilities,
60     extract_want_line_capabilities,
61     BufferedPktLineWriter,
62     )
63 from dulwich.repo import (
64     Repo,
65     )
66
67
68 logger = log_utils.getLogger(__name__)
69
70
71 class Backend(object):
72     """A backend for the Git smart server implementation."""
73
74     def open_repository(self, path):
75         """Open the repository at a path."""
76         raise NotImplementedError(self.open_repository)
77
78
79 class BackendRepo(object):
80     """Repository abstraction used by the Git server.
81     
82     Please note that the methods required here are a 
83     subset of those provided by dulwich.repo.Repo.
84     """
85
86     object_store = None
87     refs = None
88
89     def get_refs(self):
90         """
91         Get all the refs in the repository
92
93         :return: dict of name -> sha
94         """
95         raise NotImplementedError
96
97     def get_peeled(self, name):
98         """Return the cached peeled value of a ref, if available.
99
100         :param name: Name of the ref to peel
101         :return: The peeled value of the ref. If the ref is known not point to
102             a tag, this will be the SHA the ref refers to. If no cached
103             information about a tag is available, this method may return None,
104             but it should attempt to peel the tag if possible.
105         """
106         return None
107
108     def fetch_objects(self, determine_wants, graph_walker, progress,
109                       get_tagged=None):
110         """
111         Yield the objects required for a list of commits.
112
113         :param progress: is a callback to send progress messages to the client
114         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
115             sha for including tags.
116         """
117         raise NotImplementedError
118
119
120 class PackStreamCopier(PackStreamReader):
121     """Class to verify a pack stream as it is being read.
122
123     The pack is read from a ReceivableProtocol using read() or recv() as
124     appropriate and written out to the given file-like object.
125     """
126
127     def __init__(self, read_all, read_some, outfile):
128         super(PackStreamCopier, self).__init__(read_all, read_some)
129         self.outfile = outfile
130
131     def _read(self, read, size):
132         data = super(PackStreamCopier, self)._read(read, size)
133         self.outfile.write(data)
134         return data
135
136     def verify(self):
137         """Verify a pack stream and write it to the output file.
138
139         See PackStreamReader.iterobjects for a list of exceptions this may
140         throw.
141         """
142         for _, _, _ in self.read_objects():
143             pass
144
145
146 class DictBackend(Backend):
147     """Trivial backend that looks up Git repositories in a dictionary."""
148
149     def __init__(self, repos):
150         self.repos = repos
151
152     def open_repository(self, path):
153         logger.debug('Opening repository at %s', path)
154         # FIXME: What to do in case there is no repo ?
155         return self.repos[path]
156
157
158 class Handler(object):
159     """Smart protocol command handler base class."""
160
161     def __init__(self, backend, proto):
162         self.backend = backend
163         self.proto = proto
164         self._client_capabilities = None
165
166     @classmethod
167     def capability_line(cls):
168         return " ".join(cls.capabilities())
169
170     @classmethod
171     def capabilities(cls):
172         raise NotImplementedError(cls.capabilities)
173
174     @classmethod
175     def innocuous_capabilities(cls):
176         return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
177
178     @classmethod
179     def required_capabilities(cls):
180         """Return a list of capabilities that we require the client to have."""
181         return []
182
183     def set_client_capabilities(self, caps):
184         allowable_caps = set(self.innocuous_capabilities())
185         allowable_caps.update(self.capabilities())
186         for cap in caps:
187             if cap not in allowable_caps:
188                 raise GitProtocolError('Client asked for capability %s that '
189                                        'was not advertised.' % cap)
190         for cap in self.required_capabilities():
191             if cap not in caps:
192                 raise GitProtocolError('Client does not support required '
193                                        'capability %s.' % cap)
194         self._client_capabilities = set(caps)
195         logger.info('Client capabilities: %s', caps)
196
197     def has_capability(self, cap):
198         if self._client_capabilities is None:
199             raise GitProtocolError('Server attempted to access capability %s '
200                                    'before asking client' % cap)
201         return cap in self._client_capabilities
202
203
204 class UploadPackHandler(Handler):
205     """Protocol handler for uploading a pack to the server."""
206
207     def __init__(self, backend, args, proto,
208                  stateless_rpc=False, advertise_refs=False):
209         Handler.__init__(self, backend, proto)
210         self.repo = backend.open_repository(args[0])
211         self._graph_walker = None
212         self.stateless_rpc = stateless_rpc
213         self.advertise_refs = advertise_refs
214
215     @classmethod
216     def capabilities(cls):
217         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
218                 "ofs-delta", "no-progress", "include-tag")
219
220     @classmethod
221     def required_capabilities(cls):
222         return ("side-band-64k", "thin-pack", "ofs-delta")
223
224     def progress(self, message):
225         if self.has_capability("no-progress"):
226             return
227         self.proto.write_sideband(2, message)
228
229     def get_tagged(self, refs=None, repo=None):
230         """Get a dict of peeled values of tags to their original tag shas.
231
232         :param refs: dict of refname -> sha of possible tags; defaults to all of
233             the backend's refs.
234         :param repo: optional Repo instance for getting peeled refs; defaults to
235             the backend's repo, if available
236         :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
237             tag whose peeled value is peeled_sha.
238         """
239         if not self.has_capability("include-tag"):
240             return {}
241         if refs is None:
242             refs = self.repo.get_refs()
243         if repo is None:
244             repo = getattr(self.repo, "repo", None)
245             if repo is None:
246                 # Bail if we don't have a Repo available; this is ok since
247                 # clients must be able to handle if the server doesn't include
248                 # all relevant tags.
249                 # TODO: fix behavior when missing
250                 return {}
251         tagged = {}
252         for name, sha in refs.iteritems():
253             peeled_sha = repo.get_peeled(name)
254             if peeled_sha != sha:
255                 tagged[peeled_sha] = sha
256         return tagged
257
258     def handle(self):
259         write = lambda x: self.proto.write_sideband(1, x)
260
261         graph_walker = ProtocolGraphWalker(self, self.repo.object_store,
262             self.repo.get_peeled)
263         objects_iter = self.repo.fetch_objects(
264           graph_walker.determine_wants, graph_walker, self.progress,
265           get_tagged=self.get_tagged)
266
267         # Do they want any objects?
268         if len(objects_iter) == 0:
269             return
270
271         self.progress("dul-daemon says what\n")
272         self.progress("counting objects: %d, done.\n" % len(objects_iter))
273         write_pack_data(ProtocolFile(None, write), objects_iter, 
274                         len(objects_iter))
275         self.progress("how was that, then?\n")
276         # we are done
277         self.proto.write("0000")
278
279
280 def _split_proto_line(line, allowed):
281     """Split a line read from the wire.
282
283     :param line: The line read from the wire.
284     :param allowed: An iterable of command names that should be allowed.
285         Command names not listed below as possible return values will be
286         ignored.  If None, any commands from the possible return values are
287         allowed.
288     :return: a tuple having one of the following forms:
289         ('want', obj_id)
290         ('have', obj_id)
291         ('done', None)
292         (None, None)  (for a flush-pkt)
293
294     :raise UnexpectedCommandError: if the line cannot be parsed into one of the
295         allowed return values.
296     """
297     if not line:
298         fields = [None]
299     else:
300         fields = line.rstrip('\n').split(' ', 1)
301     command = fields[0]
302     if allowed is not None and command not in allowed:
303         raise UnexpectedCommandError(command)
304     try:
305         if len(fields) == 1 and command in ('done', None):
306             return (command, None)
307         elif len(fields) == 2 and command in ('want', 'have'):
308             hex_to_sha(fields[1])
309             return tuple(fields)
310     except (TypeError, AssertionError), e:
311         raise GitProtocolError(e)
312     raise GitProtocolError('Received invalid line from client: %s' % line)
313
314
315 class ProtocolGraphWalker(object):
316     """A graph walker that knows the git protocol.
317
318     As a graph walker, this class implements ack(), next(), and reset(). It
319     also contains some base methods for interacting with the wire and walking
320     the commit tree.
321
322     The work of determining which acks to send is passed on to the
323     implementation instance stored in _impl. The reason for this is that we do
324     not know at object creation time what ack level the protocol requires. A
325     call to set_ack_level() is required to set up the implementation, before any
326     calls to next() or ack() are made.
327     """
328     def __init__(self, handler, object_store, get_peeled):
329         self.handler = handler
330         self.store = object_store
331         self.get_peeled = get_peeled
332         self.proto = handler.proto
333         self.stateless_rpc = handler.stateless_rpc
334         self.advertise_refs = handler.advertise_refs
335         self._wants = []
336         self._cached = False
337         self._cache = []
338         self._cache_index = 0
339         self._impl = None
340
341     def determine_wants(self, heads):
342         """Determine the wants for a set of heads.
343
344         The given heads are advertised to the client, who then specifies which
345         refs he wants using 'want' lines. This portion of the protocol is the
346         same regardless of ack type, and in fact is used to set the ack type of
347         the ProtocolGraphWalker.
348
349         :param heads: a dict of refname->SHA1 to advertise
350         :return: a list of SHA1s requested by the client
351         """
352         if not heads:
353             raise GitProtocolError('No heads found')
354         values = set(heads.itervalues())
355         if self.advertise_refs or not self.stateless_rpc:
356             for i, (ref, sha) in enumerate(heads.iteritems()):
357                 line = "%s %s" % (sha, ref)
358                 if not i:
359                     line = "%s\x00%s" % (line, self.handler.capability_line())
360                 self.proto.write_pkt_line("%s\n" % line)
361                 peeled_sha = self.get_peeled(ref)
362                 if peeled_sha != sha:
363                     self.proto.write_pkt_line('%s %s^{}\n' %
364                                               (peeled_sha, ref))
365
366             # i'm done..
367             self.proto.write_pkt_line(None)
368
369             if self.advertise_refs:
370                 return []
371
372         # Now client will sending want want want commands
373         want = self.proto.read_pkt_line()
374         if not want:
375             return []
376         line, caps = extract_want_line_capabilities(want)
377         self.handler.set_client_capabilities(caps)
378         self.set_ack_type(ack_type(caps))
379         allowed = ('want', None)
380         command, sha = _split_proto_line(line, allowed)
381
382         want_revs = []
383         while command != None:
384             if sha not in values:
385                 raise GitProtocolError(
386                   'Client wants invalid object %s' % sha)
387             want_revs.append(sha)
388             command, sha = self.read_proto_line(allowed)
389
390         self.set_wants(want_revs)
391         return want_revs
392
393     def ack(self, have_ref):
394         return self._impl.ack(have_ref)
395
396     def reset(self):
397         self._cached = True
398         self._cache_index = 0
399
400     def next(self):
401         if not self._cached:
402             if not self._impl and self.stateless_rpc:
403                 return None
404             return self._impl.next()
405         self._cache_index += 1
406         if self._cache_index > len(self._cache):
407             return None
408         return self._cache[self._cache_index]
409
410     def read_proto_line(self, allowed):
411         """Read a line from the wire.
412
413         :param allowed: An iterable of command names that should be allowed.
414         :return: A tuple of (command, value); see _split_proto_line.
415         :raise GitProtocolError: If an error occurred reading the line.
416         """
417         return _split_proto_line(self.proto.read_pkt_line(), allowed)
418
419     def send_ack(self, sha, ack_type=''):
420         if ack_type:
421             ack_type = ' %s' % ack_type
422         self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
423
424     def send_nak(self):
425         self.proto.write_pkt_line('NAK\n')
426
427     def set_wants(self, wants):
428         self._wants = wants
429
430     def _is_satisfied(self, haves, want, earliest):
431         """Check whether a want is satisfied by a set of haves.
432
433         A want, typically a branch tip, is "satisfied" only if there exists a
434         path back from that want to one of the haves.
435
436         :param haves: A set of commits we know the client has.
437         :param want: The want to check satisfaction for.
438         :param earliest: A timestamp beyond which the search for haves will be
439             terminated, presumably because we're searching too far down the
440             wrong branch.
441         """
442         o = self.store[want]
443         pending = collections.deque([o])
444         while pending:
445             commit = pending.popleft()
446             if commit.id in haves:
447                 return True
448             if commit.type_name != "commit":
449                 # non-commit wants are assumed to be satisfied
450                 continue
451             for parent in commit.parents:
452                 parent_obj = self.store[parent]
453                 # TODO: handle parents with later commit times than children
454                 if parent_obj.commit_time >= earliest:
455                     pending.append(parent_obj)
456         return False
457
458     def all_wants_satisfied(self, haves):
459         """Check whether all the current wants are satisfied by a set of haves.
460
461         :param haves: A set of commits we know the client has.
462         :note: Wants are specified with set_wants rather than passed in since
463             in the current interface they are determined outside this class.
464         """
465         haves = set(haves)
466         earliest = min([self.store[h].commit_time for h in haves])
467         for want in self._wants:
468             if not self._is_satisfied(haves, want, earliest):
469                 return False
470         return True
471
472     def set_ack_type(self, ack_type):
473         impl_classes = {
474           MULTI_ACK: MultiAckGraphWalkerImpl,
475           MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
476           SINGLE_ACK: SingleAckGraphWalkerImpl,
477           }
478         self._impl = impl_classes[ack_type](self)
479
480
481 _GRAPH_WALKER_COMMANDS = ('have', 'done', None)
482
483
484 class SingleAckGraphWalkerImpl(object):
485     """Graph walker implementation that speaks the single-ack protocol."""
486
487     def __init__(self, walker):
488         self.walker = walker
489         self._sent_ack = False
490
491     def ack(self, have_ref):
492         if not self._sent_ack:
493             self.walker.send_ack(have_ref)
494             self._sent_ack = True
495
496     def next(self):
497         command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
498         if command in (None, 'done'):
499             if not self._sent_ack:
500                 self.walker.send_nak()
501             return None
502         elif command == 'have':
503             return sha
504
505
506 class MultiAckGraphWalkerImpl(object):
507     """Graph walker implementation that speaks the multi-ack protocol."""
508
509     def __init__(self, walker):
510         self.walker = walker
511         self._found_base = False
512         self._common = []
513
514     def ack(self, have_ref):
515         self._common.append(have_ref)
516         if not self._found_base:
517             self.walker.send_ack(have_ref, 'continue')
518             if self.walker.all_wants_satisfied(self._common):
519                 self._found_base = True
520         # else we blind ack within next
521
522     def next(self):
523         while True:
524             command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
525             if command is None:
526                 self.walker.send_nak()
527                 # in multi-ack mode, a flush-pkt indicates the client wants to
528                 # flush but more have lines are still coming
529                 continue
530             elif command == 'done':
531                 # don't nak unless no common commits were found, even if not
532                 # everything is satisfied
533                 if self._common:
534                     self.walker.send_ack(self._common[-1])
535                 else:
536                     self.walker.send_nak()
537                 return None
538             elif command == 'have':
539                 if self._found_base:
540                     # blind ack
541                     self.walker.send_ack(sha, 'continue')
542                 return sha
543
544
545 class MultiAckDetailedGraphWalkerImpl(object):
546     """Graph walker implementation speaking the multi-ack-detailed protocol."""
547
548     def __init__(self, walker):
549         self.walker = walker
550         self._found_base = False
551         self._common = []
552
553     def ack(self, have_ref):
554         self._common.append(have_ref)
555         if not self._found_base:
556             self.walker.send_ack(have_ref, 'common')
557             if self.walker.all_wants_satisfied(self._common):
558                 self._found_base = True
559                 self.walker.send_ack(have_ref, 'ready')
560         # else we blind ack within next
561
562     def next(self):
563         while True:
564             command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
565             if command is None:
566                 self.walker.send_nak()
567                 if self.walker.stateless_rpc:
568                     return None
569                 continue
570             elif command == 'done':
571                 # don't nak unless no common commits were found, even if not
572                 # everything is satisfied
573                 if self._common:
574                     self.walker.send_ack(self._common[-1])
575                 else:
576                     self.walker.send_nak()
577                 return None
578             elif command == 'have':
579                 if self._found_base:
580                     # blind ack; can happen if the client has more requests
581                     # inflight
582                     self.walker.send_ack(sha, 'ready')
583                 return sha
584
585
586 class ReceivePackHandler(Handler):
587     """Protocol handler for downloading a pack from the client."""
588
589     def __init__(self, backend, args, proto,
590                  stateless_rpc=False, advertise_refs=False):
591         Handler.__init__(self, backend, proto)
592         self.repo = backend.open_repository(args[0])
593         self.stateless_rpc = stateless_rpc
594         self.advertise_refs = advertise_refs
595
596     @classmethod
597     def capabilities(cls):
598         return ("report-status", "delete-refs", "side-band-64k")
599
600     def _apply_pack(self, refs):
601         f, commit = self.repo.object_store.add_thin_pack()
602         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
603                           AssertionError, socket.error, zlib.error,
604                           ObjectFormatException)
605         status = []
606         # TODO: more informative error messages than just the exception string
607         try:
608             PackStreamCopier(self.proto.read, self.proto.recv, f).verify()
609             p = commit()
610             if not p:
611                 raise IOError('Failed to write pack')
612             p.check()
613             status.append(('unpack', 'ok'))
614         except all_exceptions, e:
615             status.append(('unpack', str(e).replace('\n', '')))
616             # The pack may still have been moved in, but it may contain broken
617             # objects. We trust a later GC to clean it up.
618
619         for oldsha, sha, ref in refs:
620             ref_status = 'ok'
621             try:
622                 if sha == ZERO_SHA:
623                     if not 'delete-refs' in self.capabilities():
624                         raise GitProtocolError(
625                           'Attempted to delete refs without delete-refs '
626                           'capability.')
627                     try:
628                         del self.repo.refs[ref]
629                     except all_exceptions:
630                         ref_status = 'failed to delete'
631                 else:
632                     try:
633                         self.repo.refs[ref] = sha
634                     except all_exceptions:
635                         ref_status = 'failed to write'
636             except KeyError, e:
637                 ref_status = 'bad ref'
638             status.append((ref, ref_status))
639
640         return status
641
642     def _report_status(self, status):
643         if self.has_capability('side-band-64k'):
644             writer = BufferedPktLineWriter(
645               lambda d: self.proto.write_sideband(1, d))
646             write = writer.write
647
648             def flush():
649                 writer.flush()
650                 self.proto.write_pkt_line(None)
651         else:
652             write = self.proto.write_pkt_line
653             flush = lambda: None
654
655         for name, msg in status:
656             if name == 'unpack':
657                 write('unpack %s\n' % msg)
658             elif msg == 'ok':
659                 write('ok %s\n' % name)
660             else:
661                 write('ng %s %s\n' % (name, msg))
662         write(None)
663         flush()
664
665     def handle(self):
666         refs = self.repo.get_refs().items()
667
668         if self.advertise_refs or not self.stateless_rpc:
669             if refs:
670                 self.proto.write_pkt_line(
671                   "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
672                                      self.capability_line()))
673                 for i in range(1, len(refs)):
674                     ref = refs[i]
675                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
676             else:
677                 self.proto.write_pkt_line("%s capabilities^{} %s" % (
678                   ZERO_SHA, self.capability_line()))
679
680             self.proto.write("0000")
681             if self.advertise_refs:
682                 return
683
684         client_refs = []
685         ref = self.proto.read_pkt_line()
686
687         # if ref is none then client doesnt want to send us anything..
688         if ref is None:
689             return
690
691         ref, caps = extract_capabilities(ref)
692         self.set_client_capabilities(caps)
693
694         # client will now send us a list of (oldsha, newsha, ref)
695         while ref:
696             client_refs.append(ref.split())
697             ref = self.proto.read_pkt_line()
698
699         # backend can now deal with this refs and read a pack using self.read
700         status = self._apply_pack(client_refs)
701
702         # when we have read all the pack from the client, send a status report
703         # if the client asked for it
704         if self.has_capability('report-status'):
705             self._report_status(status)
706
707
708 # Default handler classes for git services.
709 DEFAULT_HANDLERS = {
710   'git-upload-pack': UploadPackHandler,
711   'git-receive-pack': ReceivePackHandler,
712   }
713
714
715 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
716
717     def __init__(self, handlers, *args, **kwargs):
718         self.handlers = handlers
719         SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
720
721     def handle(self):
722         proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
723         command, args = proto.read_cmd()
724         logger.info('Handling %s request, args=%s', command, args)
725
726         cls = self.handlers.get(command, None)
727         if not callable(cls):
728             raise GitProtocolError('Invalid service %s' % command)
729         h = cls(self.server.backend, args, proto)
730         h.handle()
731
732
733 class TCPGitServer(SocketServer.TCPServer):
734
735     allow_reuse_address = True
736     serve = SocketServer.TCPServer.serve_forever
737
738     def _make_handler(self, *args, **kwargs):
739         return TCPGitRequestHandler(self.handlers, *args, **kwargs)
740
741     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
742         self.handlers = dict(DEFAULT_HANDLERS)
743         if handlers is not None:
744             self.handlers.update(handlers)
745         self.backend = backend
746         logger.info('Listening for TCP connections on %s:%d', listen_addr, port)
747         SocketServer.TCPServer.__init__(self, (listen_addr, port),
748                                         self._make_handler)
749
750     def verify_request(self, request, client_address):
751         logger.info('Handling request from %s', client_address)
752         return True
753
754     def handle_error(self, request, client_address):
755         logger.exception('Exception happened during processing of request '
756                          'from %s', client_address)
757
758
759 def main(argv=sys.argv):
760     """Entry point for starting a TCP git server."""
761     if len(argv) > 1:
762         gitdir = argv[1]
763     else:
764         gitdir = '.'
765
766     log_utils.default_logging_config()
767     backend = DictBackend({'/': Repo(gitdir)})
768     server = TCPGitServer(backend, 'localhost')
769     server.serve_forever()