Remove GitBackendRepo, use plain Repo's instead.
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19
20 """Git smart network protocol server implementation.
21
22 For more detailed implementation on the network protocol, see the
23 Documentation/technical directory in the cgit distribution, and in particular:
24     Documentation/technical/protocol-capabilities.txt
25     Documentation/technical/pack-protocol.txt
26 """
27
28
29 import collections
30 import SocketServer
31
32 from dulwich.errors import (
33     ApplyDeltaError,
34     ChecksumMismatch,
35     GitProtocolError,
36     )
37 from dulwich.objects import (
38     hex_to_sha,
39     )
40 from dulwich.protocol import (
41     Protocol,
42     ProtocolFile,
43     TCP_GIT_PORT,
44     ZERO_SHA,
45     extract_capabilities,
46     extract_want_line_capabilities,
47     SINGLE_ACK,
48     MULTI_ACK,
49     MULTI_ACK_DETAILED,
50     ack_type,
51     )
52 from dulwich.pack import (
53     write_pack_data,
54     )
55
56 class Backend(object):
57     """A backend for the Git smart server implementation."""
58
59     def open_repository(self, path):
60         """Open the repository at a path."""
61         raise NotImplementedError(self.open_repository)
62
63
64 class BackendRepo(object):
65     """Repository abstraction used by the Git server.
66     
67     Please note that the methods required here are a 
68     subset of those provided by dulwich.repo.Repo.
69     """
70
71     object_store = None
72     refs = None
73
74     def get_refs(self):
75         """
76         Get all the refs in the repository
77
78         :return: dict of name -> sha
79         """
80         raise NotImplementedError
81
82     def get_peeled(self, name):
83         """Return the cached peeled value of a ref, if available.
84
85         :param name: Name of the ref to peel
86         :return: The peeled value of the ref. If the ref is known not point to
87             a tag, this will be the SHA the ref refers to. If the ref may 
88             point to a tag, but no cached information is available, None is 
89             returned.
90         """
91         return None
92
93     def fetch_objects(self, determine_wants, graph_walker, progress,
94                       get_tagged=None):
95         """
96         Yield the objects required for a list of commits.
97
98         :param progress: is a callback to send progress messages to the client
99         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
100             sha for including tags.
101         """
102         raise NotImplementedError
103
104
105 class DictBackend(Backend):
106     """Trivial backend that looks up Git repositories in a dictionary."""
107
108     def __init__(self, repos):
109         self.repos = repos
110
111     def open_repository(self, path):
112         # FIXME: What to do in case there is no repo ?
113         return self.repos[path]
114
115
116 class Handler(object):
117     """Smart protocol command handler base class."""
118
119     def __init__(self, backend, read, write):
120         self.backend = backend
121         self.proto = Protocol(read, write)
122         self._client_capabilities = None
123
124     def capability_line(self):
125         return " ".join(self.capabilities())
126
127     def capabilities(self):
128         raise NotImplementedError(self.capabilities)
129
130     def innocuous_capabilities(self):
131         return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
132
133     def required_capabilities(self):
134         """Return a list of capabilities that we require the client to have."""
135         return []
136
137     def set_client_capabilities(self, caps):
138         allowable_caps = set(self.innocuous_capabilities())
139         allowable_caps.update(self.capabilities())
140         for cap in caps:
141             if cap not in allowable_caps:
142                 raise GitProtocolError('Client asked for capability %s that '
143                                        'was not advertised.' % cap)
144         for cap in self.required_capabilities():
145             if cap not in caps:
146                 raise GitProtocolError('Client does not support required '
147                                        'capability %s.' % cap)
148         self._client_capabilities = set(caps)
149
150     def has_capability(self, cap):
151         if self._client_capabilities is None:
152             raise GitProtocolError('Server attempted to access capability %s '
153                                    'before asking client' % cap)
154         return cap in self._client_capabilities
155
156
157 class UploadPackHandler(Handler):
158     """Protocol handler for uploading a pack to the server."""
159
160     def __init__(self, backend, args, read, write,
161                  stateless_rpc=False, advertise_refs=False):
162         Handler.__init__(self, backend, read, write)
163         self.repo = backend.open_repository(args[0])
164         self._graph_walker = None
165         self.stateless_rpc = stateless_rpc
166         self.advertise_refs = advertise_refs
167
168     def capabilities(self):
169         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
170                 "ofs-delta", "no-progress", "include-tag")
171
172     def required_capabilities(self):
173         return ("side-band-64k", "thin-pack", "ofs-delta")
174
175     def progress(self, message):
176         if self.has_capability("no-progress"):
177             return
178         self.proto.write_sideband(2, message)
179
180     def get_tagged(self, refs=None, repo=None):
181         """Get a dict of peeled values of tags to their original tag shas.
182
183         :param refs: dict of refname -> sha of possible tags; defaults to all of
184             the backend's refs.
185         :param repo: optional Repo instance for getting peeled refs; defaults to
186             the backend's repo, if available
187         :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
188             tag whose peeled value is peeled_sha.
189         """
190         if not self.has_capability("include-tag"):
191             return {}
192         if refs is None:
193             refs = self.repo.get_refs()
194         if repo is None:
195             repo = getattr(self.repo, "repo", None)
196             if repo is None:
197                 # Bail if we don't have a Repo available; this is ok since
198                 # clients must be able to handle if the server doesn't include
199                 # all relevant tags.
200                 # TODO: fix behavior when missing
201                 return {}
202         tagged = {}
203         for name, sha in refs.iteritems():
204             peeled_sha = repo.get_peeled(name)
205             if peeled_sha != sha:
206                 tagged[peeled_sha] = sha
207         return tagged
208
209     def handle(self):
210         write = lambda x: self.proto.write_sideband(1, x)
211
212         graph_walker = ProtocolGraphWalker(self, self.repo.object_store,
213             self.repo.get_peeled)
214         objects_iter = self.repo.fetch_objects(
215           graph_walker.determine_wants, graph_walker, self.progress,
216           get_tagged=self.get_tagged)
217
218         # Do they want any objects?
219         if len(objects_iter) == 0:
220             return
221
222         self.progress("dul-daemon says what\n")
223         self.progress("counting objects: %d, done.\n" % len(objects_iter))
224         write_pack_data(ProtocolFile(None, write), objects_iter, 
225                         len(objects_iter))
226         self.progress("how was that, then?\n")
227         # we are done
228         self.proto.write("0000")
229
230
231 class ProtocolGraphWalker(object):
232     """A graph walker that knows the git protocol.
233
234     As a graph walker, this class implements ack(), next(), and reset(). It
235     also contains some base methods for interacting with the wire and walking
236     the commit tree.
237
238     The work of determining which acks to send is passed on to the
239     implementation instance stored in _impl. The reason for this is that we do
240     not know at object creation time what ack level the protocol requires. A
241     call to set_ack_level() is required to set up the implementation, before any
242     calls to next() or ack() are made.
243     """
244     def __init__(self, handler, object_store, get_peeled):
245         self.handler = handler
246         self.store = object_store
247         self.get_peeled = get_peeled
248         self.proto = handler.proto
249         self.stateless_rpc = handler.stateless_rpc
250         self.advertise_refs = handler.advertise_refs
251         self._wants = []
252         self._cached = False
253         self._cache = []
254         self._cache_index = 0
255         self._impl = None
256
257     def determine_wants(self, heads):
258         """Determine the wants for a set of heads.
259
260         The given heads are advertised to the client, who then specifies which
261         refs he wants using 'want' lines. This portion of the protocol is the
262         same regardless of ack type, and in fact is used to set the ack type of
263         the ProtocolGraphWalker.
264
265         :param heads: a dict of refname->SHA1 to advertise
266         :return: a list of SHA1s requested by the client
267         """
268         if not heads:
269             raise GitProtocolError('No heads found')
270         values = set(heads.itervalues())
271         if self.advertise_refs or not self.stateless_rpc:
272             for i, (ref, sha) in enumerate(heads.iteritems()):
273                 line = "%s %s" % (sha, ref)
274                 if not i:
275                     line = "%s\x00%s" % (line, self.handler.capability_line())
276                 self.proto.write_pkt_line("%s\n" % line)
277                 peeled_sha = self.get_peeled(ref)
278                 if peeled_sha != sha:
279                     self.proto.write_pkt_line('%s %s^{}\n' %
280                                               (peeled_sha, ref))
281
282             # i'm done..
283             self.proto.write_pkt_line(None)
284
285             if self.advertise_refs:
286                 return []
287
288         # Now client will sending want want want commands
289         want = self.proto.read_pkt_line()
290         if not want:
291             return []
292         line, caps = extract_want_line_capabilities(want)
293         self.handler.set_client_capabilities(caps)
294         self.set_ack_type(ack_type(caps))
295         command, sha = self._split_proto_line(line)
296
297         want_revs = []
298         while command != None:
299             if command != 'want':
300                 raise GitProtocolError(
301                     'Protocol got unexpected command %s' % command)
302             if sha not in values:
303                 raise GitProtocolError(
304                     'Client wants invalid object %s' % sha)
305             want_revs.append(sha)
306             command, sha = self.read_proto_line()
307
308         self.set_wants(want_revs)
309         return want_revs
310
311     def ack(self, have_ref):
312         return self._impl.ack(have_ref)
313
314     def reset(self):
315         self._cached = True
316         self._cache_index = 0
317
318     def next(self):
319         if not self._cached:
320             if not self._impl and self.stateless_rpc:
321                 return None
322             return self._impl.next()
323         self._cache_index += 1
324         if self._cache_index > len(self._cache):
325             return None
326         return self._cache[self._cache_index]
327
328     def _split_proto_line(self, line):
329         fields = line.rstrip('\n').split(' ', 1)
330         if len(fields) == 1 and fields[0] == 'done':
331             return ('done', None)
332         elif len(fields) == 2 and fields[0] in ('want', 'have'):
333             try:
334                 hex_to_sha(fields[1])
335                 return tuple(fields)
336             except (TypeError, AssertionError), e:
337                 raise GitProtocolError(e)
338         raise GitProtocolError('Received invalid line from client:\n%s' % line)
339
340     def read_proto_line(self):
341         """Read a line from the wire.
342
343         :return: a tuple having one of the following forms:
344             ('want', obj_id)
345             ('have', obj_id)
346             ('done', None)
347             (None, None)  (for a flush-pkt)
348
349         :raise GitProtocolError: if the line cannot be parsed into one of the
350             possible return values.
351         """
352         line = self.proto.read_pkt_line()
353         if not line:
354             return (None, None)
355         return self._split_proto_line(line)
356
357     def send_ack(self, sha, ack_type=''):
358         if ack_type:
359             ack_type = ' %s' % ack_type
360         self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
361
362     def send_nak(self):
363         self.proto.write_pkt_line('NAK\n')
364
365     def set_wants(self, wants):
366         self._wants = wants
367
368     def _is_satisfied(self, haves, want, earliest):
369         """Check whether a want is satisfied by a set of haves.
370
371         A want, typically a branch tip, is "satisfied" only if there exists a
372         path back from that want to one of the haves.
373
374         :param haves: A set of commits we know the client has.
375         :param want: The want to check satisfaction for.
376         :param earliest: A timestamp beyond which the search for haves will be
377             terminated, presumably because we're searching too far down the
378             wrong branch.
379         """
380         o = self.store[want]
381         pending = collections.deque([o])
382         while pending:
383             commit = pending.popleft()
384             if commit.id in haves:
385                 return True
386             if commit.type_name != "commit":
387                 # non-commit wants are assumed to be satisfied
388                 continue
389             for parent in commit.parents:
390                 parent_obj = self.store[parent]
391                 # TODO: handle parents with later commit times than children
392                 if parent_obj.commit_time >= earliest:
393                     pending.append(parent_obj)
394         return False
395
396     def all_wants_satisfied(self, haves):
397         """Check whether all the current wants are satisfied by a set of haves.
398
399         :param haves: A set of commits we know the client has.
400         :note: Wants are specified with set_wants rather than passed in since
401             in the current interface they are determined outside this class.
402         """
403         haves = set(haves)
404         earliest = min([self.store[h].commit_time for h in haves])
405         for want in self._wants:
406             if not self._is_satisfied(haves, want, earliest):
407                 return False
408         return True
409
410     def set_ack_type(self, ack_type):
411         impl_classes = {
412             MULTI_ACK: MultiAckGraphWalkerImpl,
413             MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
414             SINGLE_ACK: SingleAckGraphWalkerImpl,
415             }
416         self._impl = impl_classes[ack_type](self)
417
418
419 class SingleAckGraphWalkerImpl(object):
420     """Graph walker implementation that speaks the single-ack protocol."""
421
422     def __init__(self, walker):
423         self.walker = walker
424         self._sent_ack = False
425
426     def ack(self, have_ref):
427         if not self._sent_ack:
428             self.walker.send_ack(have_ref)
429             self._sent_ack = True
430
431     def next(self):
432         command, sha = self.walker.read_proto_line()
433         if command in (None, 'done'):
434             if not self._sent_ack:
435                 self.walker.send_nak()
436             return None
437         elif command == 'have':
438             return sha
439
440
441 class MultiAckGraphWalkerImpl(object):
442     """Graph walker implementation that speaks the multi-ack protocol."""
443
444     def __init__(self, walker):
445         self.walker = walker
446         self._found_base = False
447         self._common = []
448
449     def ack(self, have_ref):
450         self._common.append(have_ref)
451         if not self._found_base:
452             self.walker.send_ack(have_ref, 'continue')
453             if self.walker.all_wants_satisfied(self._common):
454                 self._found_base = True
455         # else we blind ack within next
456
457     def next(self):
458         while True:
459             command, sha = self.walker.read_proto_line()
460             if command is None:
461                 self.walker.send_nak()
462                 # in multi-ack mode, a flush-pkt indicates the client wants to
463                 # flush but more have lines are still coming
464                 continue
465             elif command == 'done':
466                 # don't nak unless no common commits were found, even if not
467                 # everything is satisfied
468                 if self._common:
469                     self.walker.send_ack(self._common[-1])
470                 else:
471                     self.walker.send_nak()
472                 return None
473             elif command == 'have':
474                 if self._found_base:
475                     # blind ack
476                     self.walker.send_ack(sha, 'continue')
477                 return sha
478
479
480 class MultiAckDetailedGraphWalkerImpl(object):
481     """Graph walker implementation speaking the multi-ack-detailed protocol."""
482
483     def __init__(self, walker):
484         self.walker = walker
485         self._found_base = False
486         self._common = []
487
488     def ack(self, have_ref):
489         self._common.append(have_ref)
490         if not self._found_base:
491             self.walker.send_ack(have_ref, 'common')
492             if self.walker.all_wants_satisfied(self._common):
493                 self._found_base = True
494                 self.walker.send_ack(have_ref, 'ready')
495         # else we blind ack within next
496
497     def next(self):
498         while True:
499             command, sha = self.walker.read_proto_line()
500             if command is None:
501                 self.walker.send_nak()
502                 if self.walker.stateless_rpc:
503                     return None
504                 continue
505             elif command == 'done':
506                 # don't nak unless no common commits were found, even if not
507                 # everything is satisfied
508                 if self._common:
509                     self.walker.send_ack(self._common[-1])
510                 else:
511                     self.walker.send_nak()
512                 return None
513             elif command == 'have':
514                 if self._found_base:
515                     # blind ack; can happen if the client has more requests
516                     # inflight
517                     self.walker.send_ack(sha, 'ready')
518                 return sha
519
520
521 class ReceivePackHandler(Handler):
522     """Protocol handler for downloading a pack from the client."""
523
524     def __init__(self, backend, args, read, write,
525                  stateless_rpc=False, advertise_refs=False):
526         Handler.__init__(self, backend, read, write)
527         self.repo = backend.open_repository(args[0])
528         self.stateless_rpc = stateless_rpc
529         self.advertise_refs = advertise_refs
530
531     def capabilities(self):
532         return ("report-status", "delete-refs")
533
534     def _apply_pack(self, refs, read, delete_refs=True):
535         f, commit = self.repo.object_store.add_thin_pack()
536         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
537         status = []
538         unpack_error = None
539         # TODO: more informative error messages than just the exception string
540         try:
541             # TODO: decode the pack as we stream to avoid blocking reads beyond
542             # the end of data (when using HTTP/1.1 chunked encoding)
543             while True:
544                 data = read(10240)
545                 if not data:
546                     break
547                 f.write(data)
548         except all_exceptions, e:
549             unpack_error = str(e).replace('\n', '')
550         try:
551             commit()
552         except all_exceptions, e:
553             if not unpack_error:
554                 unpack_error = str(e).replace('\n', '')
555
556         if unpack_error:
557             status.append(('unpack', unpack_error))
558         else:
559             status.append(('unpack', 'ok'))
560
561         for oldsha, sha, ref in refs:
562             ref_error = None
563             try:
564                 if sha == ZERO_SHA:
565                     if not delete_refs:
566                         raise GitProtocolError(
567                           'Attempted to delete refs without delete-refs '
568                           'capability.')
569                     try:
570                         del self.repo.refs[ref]
571                     except all_exceptions:
572                         ref_error = 'failed to delete'
573                 else:
574                     try:
575                         self.repo.refs[ref] = sha
576                     except all_exceptions:
577                         ref_error = 'failed to write'
578             except KeyError, e:
579                 ref_error = 'bad ref'
580             if ref_error:
581                 status.append((ref, ref_error))
582             else:
583                 status.append((ref, 'ok'))
584
585         print "pack applied"
586         return status
587
588     def handle(self):
589         refs = self.repo.get_refs().items()
590
591         if self.advertise_refs or not self.stateless_rpc:
592             if refs:
593                 self.proto.write_pkt_line(
594                   "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
595                                      self.capability_line()))
596                 for i in range(1, len(refs)):
597                     ref = refs[i]
598                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
599             else:
600                 self.proto.write_pkt_line("%s capabilities^{} %s" % (
601                   ZERO_SHA, self.capability_line()))
602
603             self.proto.write("0000")
604             if self.advertise_refs:
605                 return
606
607         client_refs = []
608         ref = self.proto.read_pkt_line()
609
610         # if ref is none then client doesnt want to send us anything..
611         if ref is None:
612             return
613
614         ref, caps = extract_capabilities(ref)
615         self.set_client_capabilities(caps)
616
617         # client will now send us a list of (oldsha, newsha, ref)
618         while ref:
619             client_refs.append(ref.split())
620             ref = self.proto.read_pkt_line()
621
622         # backend can now deal with this refs and read a pack using self.read
623         status = self.repo._apply_pack(client_refs, self.proto.read,
624             self.has_capability('delete-refs'))
625
626         # when we have read all the pack from the client, send a status report
627         # if the client asked for it
628         if self.has_capability('report-status'):
629             for name, msg in status:
630                 if name == 'unpack':
631                     self.proto.write_pkt_line('unpack %s\n' % msg)
632                 elif msg == 'ok':
633                     self.proto.write_pkt_line('ok %s\n' % name)
634                 else:
635                     self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
636             self.proto.write_pkt_line(None)
637
638
639 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
640
641     def handle(self):
642         proto = Protocol(self.rfile.read, self.wfile.write)
643         command, args = proto.read_cmd()
644
645         # switch case to handle the specific git command
646         if command == 'git-upload-pack':
647             cls = UploadPackHandler
648         elif command == 'git-receive-pack':
649             cls = ReceivePackHandler
650         else:
651             return
652
653         h = cls(self.server.backend, args, self.rfile.read, self.wfile.write)
654         h.handle()
655
656
657 class TCPGitServer(SocketServer.TCPServer):
658
659     allow_reuse_address = True
660     serve = SocketServer.TCPServer.serve_forever
661
662     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
663         self.backend = backend
664         SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)