Clean up file headers.
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19
20 """Git smart network protocol server implementation.
21
22 For more detailed implementation on the network protocol, see the
23 Documentation/technical directory in the cgit distribution, and in particular:
24     Documentation/technical/protocol-capabilities.txt
25     Documentation/technical/pack-protocol.txt
26 """
27
28
29 import collections
30 import socket
31 import zlib
32 import SocketServer
33
34 from dulwich.errors import (
35     ApplyDeltaError,
36     ChecksumMismatch,
37     GitProtocolError,
38     ObjectFormatException,
39     )
40 from dulwich.objects import (
41     hex_to_sha,
42     )
43 from dulwich.pack import (
44     PackStreamReader,
45     write_pack_data,
46     )
47 from dulwich.protocol import (
48     MULTI_ACK,
49     MULTI_ACK_DETAILED,
50     ProtocolFile,
51     ReceivableProtocol,
52     SINGLE_ACK,
53     TCP_GIT_PORT,
54     ZERO_SHA,
55     ack_type,
56     extract_capabilities,
57     extract_want_line_capabilities,
58     )
59
60
61
62 class Backend(object):
63     """A backend for the Git smart server implementation."""
64
65     def open_repository(self, path):
66         """Open the repository at a path."""
67         raise NotImplementedError(self.open_repository)
68
69
70 class BackendRepo(object):
71     """Repository abstraction used by the Git server.
72     
73     Please note that the methods required here are a 
74     subset of those provided by dulwich.repo.Repo.
75     """
76
77     object_store = None
78     refs = None
79
80     def get_refs(self):
81         """
82         Get all the refs in the repository
83
84         :return: dict of name -> sha
85         """
86         raise NotImplementedError
87
88     def get_peeled(self, name):
89         """Return the cached peeled value of a ref, if available.
90
91         :param name: Name of the ref to peel
92         :return: The peeled value of the ref. If the ref is known not point to
93             a tag, this will be the SHA the ref refers to. If no cached
94             information about a tag is available, this method may return None,
95             but it should attempt to peel the tag if possible.
96         """
97         return None
98
99     def fetch_objects(self, determine_wants, graph_walker, progress,
100                       get_tagged=None):
101         """
102         Yield the objects required for a list of commits.
103
104         :param progress: is a callback to send progress messages to the client
105         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
106             sha for including tags.
107         """
108         raise NotImplementedError
109
110
111 class PackStreamCopier(PackStreamReader):
112     """Class to verify a pack stream as it is being read.
113
114     The pack is read from a ReceivableProtocol using read() or recv() as
115     appropriate and written out to the given file-like object.
116     """
117
118     def __init__(self, read_all, read_some, outfile):
119         super(PackStreamCopier, self).__init__(read_all, read_some)
120         self.outfile = outfile
121
122     def _read(self, read, size):
123         data = super(PackStreamCopier, self)._read(read, size)
124         self.outfile.write(data)
125         return data
126
127     def verify(self):
128         """Verify a pack stream and write it to the output file.
129
130         See PackStreamReader.iterobjects for a list of exceptions this may
131         throw.
132         """
133         for _, _, _ in self.read_objects():
134             pass
135
136
137 class DictBackend(Backend):
138     """Trivial backend that looks up Git repositories in a dictionary."""
139
140     def __init__(self, repos):
141         self.repos = repos
142
143     def open_repository(self, path):
144         # FIXME: What to do in case there is no repo ?
145         return self.repos[path]
146
147
148 class Handler(object):
149     """Smart protocol command handler base class."""
150
151     def __init__(self, backend, proto):
152         self.backend = backend
153         self.proto = proto
154         self._client_capabilities = None
155
156     def capability_line(self):
157         return " ".join(self.capabilities())
158
159     def capabilities(self):
160         raise NotImplementedError(self.capabilities)
161
162     def innocuous_capabilities(self):
163         return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
164
165     def required_capabilities(self):
166         """Return a list of capabilities that we require the client to have."""
167         return []
168
169     def set_client_capabilities(self, caps):
170         allowable_caps = set(self.innocuous_capabilities())
171         allowable_caps.update(self.capabilities())
172         for cap in caps:
173             if cap not in allowable_caps:
174                 raise GitProtocolError('Client asked for capability %s that '
175                                        'was not advertised.' % cap)
176         for cap in self.required_capabilities():
177             if cap not in caps:
178                 raise GitProtocolError('Client does not support required '
179                                        'capability %s.' % cap)
180         self._client_capabilities = set(caps)
181
182     def has_capability(self, cap):
183         if self._client_capabilities is None:
184             raise GitProtocolError('Server attempted to access capability %s '
185                                    'before asking client' % cap)
186         return cap in self._client_capabilities
187
188
189 class UploadPackHandler(Handler):
190     """Protocol handler for uploading a pack to the server."""
191
192     def __init__(self, backend, args, proto,
193                  stateless_rpc=False, advertise_refs=False):
194         Handler.__init__(self, backend, proto)
195         self.repo = backend.open_repository(args[0])
196         self._graph_walker = None
197         self.stateless_rpc = stateless_rpc
198         self.advertise_refs = advertise_refs
199
200     def capabilities(self):
201         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
202                 "ofs-delta", "no-progress", "include-tag")
203
204     def required_capabilities(self):
205         return ("side-band-64k", "thin-pack", "ofs-delta")
206
207     def progress(self, message):
208         if self.has_capability("no-progress"):
209             return
210         self.proto.write_sideband(2, message)
211
212     def get_tagged(self, refs=None, repo=None):
213         """Get a dict of peeled values of tags to their original tag shas.
214
215         :param refs: dict of refname -> sha of possible tags; defaults to all of
216             the backend's refs.
217         :param repo: optional Repo instance for getting peeled refs; defaults to
218             the backend's repo, if available
219         :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
220             tag whose peeled value is peeled_sha.
221         """
222         if not self.has_capability("include-tag"):
223             return {}
224         if refs is None:
225             refs = self.repo.get_refs()
226         if repo is None:
227             repo = getattr(self.repo, "repo", None)
228             if repo is None:
229                 # Bail if we don't have a Repo available; this is ok since
230                 # clients must be able to handle if the server doesn't include
231                 # all relevant tags.
232                 # TODO: fix behavior when missing
233                 return {}
234         tagged = {}
235         for name, sha in refs.iteritems():
236             peeled_sha = repo.get_peeled(name)
237             if peeled_sha != sha:
238                 tagged[peeled_sha] = sha
239         return tagged
240
241     def handle(self):
242         write = lambda x: self.proto.write_sideband(1, x)
243
244         graph_walker = ProtocolGraphWalker(self, self.repo.object_store,
245             self.repo.get_peeled)
246         objects_iter = self.repo.fetch_objects(
247           graph_walker.determine_wants, graph_walker, self.progress,
248           get_tagged=self.get_tagged)
249
250         # Do they want any objects?
251         if len(objects_iter) == 0:
252             return
253
254         self.progress("dul-daemon says what\n")
255         self.progress("counting objects: %d, done.\n" % len(objects_iter))
256         write_pack_data(ProtocolFile(None, write), objects_iter, 
257                         len(objects_iter))
258         self.progress("how was that, then?\n")
259         # we are done
260         self.proto.write("0000")
261
262
263 class ProtocolGraphWalker(object):
264     """A graph walker that knows the git protocol.
265
266     As a graph walker, this class implements ack(), next(), and reset(). It
267     also contains some base methods for interacting with the wire and walking
268     the commit tree.
269
270     The work of determining which acks to send is passed on to the
271     implementation instance stored in _impl. The reason for this is that we do
272     not know at object creation time what ack level the protocol requires. A
273     call to set_ack_level() is required to set up the implementation, before any
274     calls to next() or ack() are made.
275     """
276     def __init__(self, handler, object_store, get_peeled):
277         self.handler = handler
278         self.store = object_store
279         self.get_peeled = get_peeled
280         self.proto = handler.proto
281         self.stateless_rpc = handler.stateless_rpc
282         self.advertise_refs = handler.advertise_refs
283         self._wants = []
284         self._cached = False
285         self._cache = []
286         self._cache_index = 0
287         self._impl = None
288
289     def determine_wants(self, heads):
290         """Determine the wants for a set of heads.
291
292         The given heads are advertised to the client, who then specifies which
293         refs he wants using 'want' lines. This portion of the protocol is the
294         same regardless of ack type, and in fact is used to set the ack type of
295         the ProtocolGraphWalker.
296
297         :param heads: a dict of refname->SHA1 to advertise
298         :return: a list of SHA1s requested by the client
299         """
300         if not heads:
301             raise GitProtocolError('No heads found')
302         values = set(heads.itervalues())
303         if self.advertise_refs or not self.stateless_rpc:
304             for i, (ref, sha) in enumerate(heads.iteritems()):
305                 line = "%s %s" % (sha, ref)
306                 if not i:
307                     line = "%s\x00%s" % (line, self.handler.capability_line())
308                 self.proto.write_pkt_line("%s\n" % line)
309                 peeled_sha = self.get_peeled(ref)
310                 if peeled_sha != sha:
311                     self.proto.write_pkt_line('%s %s^{}\n' %
312                                               (peeled_sha, ref))
313
314             # i'm done..
315             self.proto.write_pkt_line(None)
316
317             if self.advertise_refs:
318                 return []
319
320         # Now client will sending want want want commands
321         want = self.proto.read_pkt_line()
322         if not want:
323             return []
324         line, caps = extract_want_line_capabilities(want)
325         self.handler.set_client_capabilities(caps)
326         self.set_ack_type(ack_type(caps))
327         command, sha = self._split_proto_line(line)
328
329         want_revs = []
330         while command != None:
331             if command != 'want':
332                 raise GitProtocolError(
333                   'Protocol got unexpected command %s' % command)
334             if sha not in values:
335                 raise GitProtocolError(
336                   'Client wants invalid object %s' % sha)
337             want_revs.append(sha)
338             command, sha = self.read_proto_line()
339
340         self.set_wants(want_revs)
341         return want_revs
342
343     def ack(self, have_ref):
344         return self._impl.ack(have_ref)
345
346     def reset(self):
347         self._cached = True
348         self._cache_index = 0
349
350     def next(self):
351         if not self._cached:
352             if not self._impl and self.stateless_rpc:
353                 return None
354             return self._impl.next()
355         self._cache_index += 1
356         if self._cache_index > len(self._cache):
357             return None
358         return self._cache[self._cache_index]
359
360     def _split_proto_line(self, line):
361         fields = line.rstrip('\n').split(' ', 1)
362         if len(fields) == 1 and fields[0] == 'done':
363             return ('done', None)
364         elif len(fields) == 2 and fields[0] in ('want', 'have'):
365             try:
366                 hex_to_sha(fields[1])
367                 return tuple(fields)
368             except (TypeError, AssertionError), e:
369                 raise GitProtocolError(e)
370         raise GitProtocolError('Received invalid line from client:\n%s' % line)
371
372     def read_proto_line(self):
373         """Read a line from the wire.
374
375         :return: a tuple having one of the following forms:
376             ('want', obj_id)
377             ('have', obj_id)
378             ('done', None)
379             (None, None)  (for a flush-pkt)
380
381         :raise GitProtocolError: if the line cannot be parsed into one of the
382             possible return values.
383         """
384         line = self.proto.read_pkt_line()
385         if not line:
386             return (None, None)
387         return self._split_proto_line(line)
388
389     def send_ack(self, sha, ack_type=''):
390         if ack_type:
391             ack_type = ' %s' % ack_type
392         self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
393
394     def send_nak(self):
395         self.proto.write_pkt_line('NAK\n')
396
397     def set_wants(self, wants):
398         self._wants = wants
399
400     def _is_satisfied(self, haves, want, earliest):
401         """Check whether a want is satisfied by a set of haves.
402
403         A want, typically a branch tip, is "satisfied" only if there exists a
404         path back from that want to one of the haves.
405
406         :param haves: A set of commits we know the client has.
407         :param want: The want to check satisfaction for.
408         :param earliest: A timestamp beyond which the search for haves will be
409             terminated, presumably because we're searching too far down the
410             wrong branch.
411         """
412         o = self.store[want]
413         pending = collections.deque([o])
414         while pending:
415             commit = pending.popleft()
416             if commit.id in haves:
417                 return True
418             if commit.type_name != "commit":
419                 # non-commit wants are assumed to be satisfied
420                 continue
421             for parent in commit.parents:
422                 parent_obj = self.store[parent]
423                 # TODO: handle parents with later commit times than children
424                 if parent_obj.commit_time >= earliest:
425                     pending.append(parent_obj)
426         return False
427
428     def all_wants_satisfied(self, haves):
429         """Check whether all the current wants are satisfied by a set of haves.
430
431         :param haves: A set of commits we know the client has.
432         :note: Wants are specified with set_wants rather than passed in since
433             in the current interface they are determined outside this class.
434         """
435         haves = set(haves)
436         earliest = min([self.store[h].commit_time for h in haves])
437         for want in self._wants:
438             if not self._is_satisfied(haves, want, earliest):
439                 return False
440         return True
441
442     def set_ack_type(self, ack_type):
443         impl_classes = {
444           MULTI_ACK: MultiAckGraphWalkerImpl,
445           MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
446           SINGLE_ACK: SingleAckGraphWalkerImpl,
447           }
448         self._impl = impl_classes[ack_type](self)
449
450
451 class SingleAckGraphWalkerImpl(object):
452     """Graph walker implementation that speaks the single-ack protocol."""
453
454     def __init__(self, walker):
455         self.walker = walker
456         self._sent_ack = False
457
458     def ack(self, have_ref):
459         if not self._sent_ack:
460             self.walker.send_ack(have_ref)
461             self._sent_ack = True
462
463     def next(self):
464         command, sha = self.walker.read_proto_line()
465         if command in (None, 'done'):
466             if not self._sent_ack:
467                 self.walker.send_nak()
468             return None
469         elif command == 'have':
470             return sha
471
472
473 class MultiAckGraphWalkerImpl(object):
474     """Graph walker implementation that speaks the multi-ack protocol."""
475
476     def __init__(self, walker):
477         self.walker = walker
478         self._found_base = False
479         self._common = []
480
481     def ack(self, have_ref):
482         self._common.append(have_ref)
483         if not self._found_base:
484             self.walker.send_ack(have_ref, 'continue')
485             if self.walker.all_wants_satisfied(self._common):
486                 self._found_base = True
487         # else we blind ack within next
488
489     def next(self):
490         while True:
491             command, sha = self.walker.read_proto_line()
492             if command is None:
493                 self.walker.send_nak()
494                 # in multi-ack mode, a flush-pkt indicates the client wants to
495                 # flush but more have lines are still coming
496                 continue
497             elif command == 'done':
498                 # don't nak unless no common commits were found, even if not
499                 # everything is satisfied
500                 if self._common:
501                     self.walker.send_ack(self._common[-1])
502                 else:
503                     self.walker.send_nak()
504                 return None
505             elif command == 'have':
506                 if self._found_base:
507                     # blind ack
508                     self.walker.send_ack(sha, 'continue')
509                 return sha
510
511
512 class MultiAckDetailedGraphWalkerImpl(object):
513     """Graph walker implementation speaking the multi-ack-detailed protocol."""
514
515     def __init__(self, walker):
516         self.walker = walker
517         self._found_base = False
518         self._common = []
519
520     def ack(self, have_ref):
521         self._common.append(have_ref)
522         if not self._found_base:
523             self.walker.send_ack(have_ref, 'common')
524             if self.walker.all_wants_satisfied(self._common):
525                 self._found_base = True
526                 self.walker.send_ack(have_ref, 'ready')
527         # else we blind ack within next
528
529     def next(self):
530         while True:
531             command, sha = self.walker.read_proto_line()
532             if command is None:
533                 self.walker.send_nak()
534                 if self.walker.stateless_rpc:
535                     return None
536                 continue
537             elif command == 'done':
538                 # don't nak unless no common commits were found, even if not
539                 # everything is satisfied
540                 if self._common:
541                     self.walker.send_ack(self._common[-1])
542                 else:
543                     self.walker.send_nak()
544                 return None
545             elif command == 'have':
546                 if self._found_base:
547                     # blind ack; can happen if the client has more requests
548                     # inflight
549                     self.walker.send_ack(sha, 'ready')
550                 return sha
551
552
553 class ReceivePackHandler(Handler):
554     """Protocol handler for downloading a pack from the client."""
555
556     def __init__(self, backend, args, proto,
557                  stateless_rpc=False, advertise_refs=False):
558         Handler.__init__(self, backend, proto)
559         self.repo = backend.open_repository(args[0])
560         self.stateless_rpc = stateless_rpc
561         self.advertise_refs = advertise_refs
562
563     def capabilities(self):
564         return ("report-status", "delete-refs")
565
566     def _apply_pack(self, refs):
567         f, commit = self.repo.object_store.add_thin_pack()
568         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
569                           AssertionError, socket.error, zlib.error,
570                           ObjectFormatException)
571         status = []
572         # TODO: more informative error messages than just the exception string
573         try:
574             PackStreamCopier(self.proto.read, self.proto.recv, f).verify()
575             p = commit()
576             if not p:
577                 raise IOError('Failed to write pack')
578             p.check()
579             status.append(('unpack', 'ok'))
580         except all_exceptions, e:
581             status.append(('unpack', str(e).replace('\n', '')))
582             # The pack may still have been moved in, but it may contain broken
583             # objects. We trust a later GC to clean it up.
584
585         for oldsha, sha, ref in refs:
586             ref_status = 'ok'
587             try:
588                 if sha == ZERO_SHA:
589                     if not 'delete-refs' in self.capabilities():
590                         raise GitProtocolError(
591                           'Attempted to delete refs without delete-refs '
592                           'capability.')
593                     try:
594                         del self.repo.refs[ref]
595                     except all_exceptions:
596                         ref_status = 'failed to delete'
597                 else:
598                     try:
599                         self.repo.refs[ref] = sha
600                     except all_exceptions:
601                         ref_status = 'failed to write'
602             except KeyError, e:
603                 ref_status = 'bad ref'
604             status.append((ref, ref_status))
605
606         return status
607
608     def handle(self):
609         refs = self.repo.get_refs().items()
610
611         if self.advertise_refs or not self.stateless_rpc:
612             if refs:
613                 self.proto.write_pkt_line(
614                   "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
615                                      self.capability_line()))
616                 for i in range(1, len(refs)):
617                     ref = refs[i]
618                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
619             else:
620                 self.proto.write_pkt_line("%s capabilities^{} %s" % (
621                   ZERO_SHA, self.capability_line()))
622
623             self.proto.write("0000")
624             if self.advertise_refs:
625                 return
626
627         client_refs = []
628         ref = self.proto.read_pkt_line()
629
630         # if ref is none then client doesnt want to send us anything..
631         if ref is None:
632             return
633
634         ref, caps = extract_capabilities(ref)
635         self.set_client_capabilities(caps)
636
637         # client will now send us a list of (oldsha, newsha, ref)
638         while ref:
639             client_refs.append(ref.split())
640             ref = self.proto.read_pkt_line()
641
642         # backend can now deal with this refs and read a pack using self.read
643         status = self._apply_pack(client_refs)
644
645         # when we have read all the pack from the client, send a status report
646         # if the client asked for it
647         if self.has_capability('report-status'):
648             for name, msg in status:
649                 if name == 'unpack':
650                     self.proto.write_pkt_line('unpack %s\n' % msg)
651                 elif msg == 'ok':
652                     self.proto.write_pkt_line('ok %s\n' % name)
653                 else:
654                     self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
655             self.proto.write_pkt_line(None)
656
657
658 # Default handler classes for git services.
659 DEFAULT_HANDLERS = {
660   'git-upload-pack': UploadPackHandler,
661   'git-receive-pack': ReceivePackHandler,
662   }
663
664
665 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
666
667     def __init__(self, handlers, *args, **kwargs):
668         self.handlers = handlers and handlers or DEFAULT_HANDLERS
669         SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
670
671     def handle(self):
672         proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
673         command, args = proto.read_cmd()
674
675         cls = self.handlers.get(command, None)
676         if not callable(cls):
677             raise GitProtocolError('Invalid service %s' % command)
678         h = cls(self.server.backend, args, proto)
679         h.handle()
680
681
682 class TCPGitServer(SocketServer.TCPServer):
683
684     allow_reuse_address = True
685     serve = SocketServer.TCPServer.serve_forever
686
687     def _make_handler(self, *args, **kwargs):
688         return TCPGitRequestHandler(self.handlers, *args, **kwargs)
689
690     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
691         self.backend = backend
692         self.handlers = handlers
693         SocketServer.TCPServer.__init__(self, (listen_addr, port),
694                                         self._make_handler)