Fix import.
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19
20 """Git smart network protocol server implementation.
21
22 For more detailed implementation on the network protocol, see the
23 Documentation/technical directory in the cgit distribution, and in particular:
24     Documentation/technical/protocol-capabilities.txt
25     Documentation/technical/pack-protocol.txt
26 """
27
28
29 import collections
30 import SocketServer
31 import tempfile
32
33 from dulwich.errors import (
34     ApplyDeltaError,
35     ChecksumMismatch,
36     GitProtocolError,
37     )
38 from dulwich.objects import (
39     hex_to_sha,
40     )
41 from dulwich.protocol import (
42     Protocol,
43     ProtocolFile,
44     TCP_GIT_PORT,
45     ZERO_SHA,
46     extract_capabilities,
47     extract_want_line_capabilities,
48     SINGLE_ACK,
49     MULTI_ACK,
50     MULTI_ACK_DETAILED,
51     ack_type,
52     )
53 from dulwich.repo import (
54     Repo,
55     )
56 from dulwich.pack import (
57     write_pack_data,
58     )
59
60 class Backend(object):
61
62     def get_refs(self):
63         """
64         Get all the refs in the repository
65
66         :return: dict of name -> sha
67         """
68         raise NotImplementedError
69
70     def apply_pack(self, refs, read, delete_refs=True):
71         """ Import a set of changes into a repository and update the refs
72
73         :param refs: list of tuple(name, sha)
74         :param read: callback to read from the incoming pack
75         :param delete_refs: whether to allow deleting refs
76         """
77         raise NotImplementedError
78
79     def fetch_objects(self, determine_wants, graph_walker, progress,
80                       get_tagged=None):
81         """
82         Yield the objects required for a list of commits.
83
84         :param progress: is a callback to send progress messages to the client
85         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
86             sha for including tags.
87         """
88         raise NotImplementedError
89
90
91 class GitBackend(Backend):
92
93     def __init__(self, repo=None):
94         if repo is None:
95             repo = Repo(tempfile.mkdtemp())
96         self.repo = repo
97         self.refs = self.repo.refs
98         self.object_store = self.repo.object_store
99         self.fetch_objects = self.repo.fetch_objects
100         self.get_refs = self.repo.get_refs
101
102     def apply_pack(self, refs, read, delete_refs=True):
103         f, commit = self.repo.object_store.add_thin_pack()
104         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
105         status = []
106         unpack_error = None
107         # TODO: more informative error messages than just the exception string
108         try:
109             # TODO: decode the pack as we stream to avoid blocking reads beyond
110             # the end of data (when using HTTP/1.1 chunked encoding)
111             while True:
112                 data = read(10240)
113                 if not data:
114                     break
115                 f.write(data)
116         except all_exceptions, e:
117             unpack_error = str(e).replace('\n', '')
118         try:
119             commit()
120         except all_exceptions, e:
121             if not unpack_error:
122                 unpack_error = str(e).replace('\n', '')
123
124         if unpack_error:
125             status.append(('unpack', unpack_error))
126         else:
127             status.append(('unpack', 'ok'))
128
129         for oldsha, sha, ref in refs:
130             ref_error = None
131             try:
132                 if sha == ZERO_SHA:
133                     if not delete_refs:
134                         raise GitProtocolError(
135                           'Attempted to delete refs without delete-refs '
136                           'capability.')
137                     try:
138                         del self.repo.refs[ref]
139                     except all_exceptions:
140                         ref_error = 'failed to delete'
141                 else:
142                     try:
143                         self.repo.refs[ref] = sha
144                     except all_exceptions:
145                         ref_error = 'failed to write'
146             except KeyError, e:
147                 ref_error = 'bad ref'
148             if ref_error:
149                 status.append((ref, ref_error))
150             else:
151                 status.append((ref, 'ok'))
152
153
154         print "pack applied"
155         return status
156
157
158 class Handler(object):
159     """Smart protocol command handler base class."""
160
161     def __init__(self, backend, read, write):
162         self.backend = backend
163         self.proto = Protocol(read, write)
164         self._client_capabilities = None
165
166     def capability_line(self):
167         return " ".join(self.capabilities())
168
169     def capabilities(self):
170         raise NotImplementedError(self.capabilities)
171
172     def innocuous_capabilities(self):
173         return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
174
175     def required_capabilities(self):
176         """Return a list of capabilities that we require the client to have."""
177         return []
178
179     def set_client_capabilities(self, caps):
180         allowable_caps = set(self.innocuous_capabilities())
181         allowable_caps.update(self.capabilities())
182         for cap in caps:
183             if cap not in allowable_caps:
184                 raise GitProtocolError('Client asked for capability %s that '
185                                        'was not advertised.' % cap)
186         for cap in self.required_capabilities():
187             if cap not in caps:
188                 raise GitProtocolError('Client does not support required '
189                                        'capability %s.' % cap)
190         self._client_capabilities = set(caps)
191
192     def has_capability(self, cap):
193         if self._client_capabilities is None:
194             raise GitProtocolError('Server attempted to access capability %s '
195                                    'before asking client' % cap)
196         return cap in self._client_capabilities
197
198
199 class UploadPackHandler(Handler):
200     """Protocol handler for uploading a pack to the server."""
201
202     def __init__(self, backend, read, write,
203                  stateless_rpc=False, advertise_refs=False):
204         Handler.__init__(self, backend, read, write)
205         self._graph_walker = None
206         self.stateless_rpc = stateless_rpc
207         self.advertise_refs = advertise_refs
208
209     def capabilities(self):
210         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
211                 "ofs-delta", "no-progress", "include-tag")
212
213     def required_capabilities(self):
214         return ("side-band-64k", "thin-pack", "ofs-delta")
215
216     def progress(self, message):
217         if self.has_capability("no-progress"):
218             return
219         self.proto.write_sideband(2, message)
220
221     def get_tagged(self, refs=None, repo=None):
222         """Get a dict of peeled values of tags to their original tag shas.
223
224         :param refs: dict of refname -> sha of possible tags; defaults to all of
225             the backend's refs.
226         :param repo: optional Repo instance for getting peeled refs; defaults to
227             the backend's repo, if available
228         :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
229             tag whose peeled value is peeled_sha.
230         """
231         if not self.has_capability("include-tag"):
232             return {}
233         if refs is None:
234             refs = self.backend.get_refs()
235         if repo is None:
236             repo = getattr(self.backend, "repo", None)
237             if repo is None:
238                 # Bail if we don't have a Repo available; this is ok since
239                 # clients must be able to handle if the server doesn't include
240                 # all relevant tags.
241                 # TODO: either guarantee a Repo, or fix behavior when missing
242                 return {}
243         tagged = {}
244         for name, sha in refs.iteritems():
245             peeled_sha = repo.get_peeled(name)
246             if peeled_sha != sha:
247                 tagged[peeled_sha] = sha
248         return tagged
249
250     def handle(self):
251         write = lambda x: self.proto.write_sideband(1, x)
252
253         graph_walker = ProtocolGraphWalker(self)
254         objects_iter = self.backend.fetch_objects(
255           graph_walker.determine_wants, graph_walker, self.progress,
256           get_tagged=self.get_tagged)
257
258         # Do they want any objects?
259         if len(objects_iter) == 0:
260             return
261
262         self.progress("dul-daemon says what\n")
263         self.progress("counting objects: %d, done.\n" % len(objects_iter))
264         write_pack_data(ProtocolFile(None, write), objects_iter, 
265                         len(objects_iter))
266         self.progress("how was that, then?\n")
267         # we are done
268         self.proto.write("0000")
269
270
271 class ProtocolGraphWalker(object):
272     """A graph walker that knows the git protocol.
273
274     As a graph walker, this class implements ack(), next(), and reset(). It also
275     contains some base methods for interacting with the wire and walking the
276     commit tree.
277
278     The work of determining which acks to send is passed on to the
279     implementation instance stored in _impl. The reason for this is that we do
280     not know at object creation time what ack level the protocol requires. A
281     call to set_ack_level() is required to set up the implementation, before any
282     calls to next() or ack() are made.
283     """
284     def __init__(self, handler):
285         self.handler = handler
286         self.store = handler.backend.object_store
287         self.proto = handler.proto
288         self.stateless_rpc = handler.stateless_rpc
289         self.advertise_refs = handler.advertise_refs
290         self._wants = []
291         self._cached = False
292         self._cache = []
293         self._cache_index = 0
294         self._impl = None
295
296     def determine_wants(self, heads):
297         """Determine the wants for a set of heads.
298
299         The given heads are advertised to the client, who then specifies which
300         refs he wants using 'want' lines. This portion of the protocol is the
301         same regardless of ack type, and in fact is used to set the ack type of
302         the ProtocolGraphWalker.
303
304         :param heads: a dict of refname->SHA1 to advertise
305         :return: a list of SHA1s requested by the client
306         """
307         if not heads:
308             raise GitProtocolError('No heads found')
309         values = set(heads.itervalues())
310         if self.advertise_refs or not self.stateless_rpc:
311             for i, (ref, sha) in enumerate(heads.iteritems()):
312                 line = "%s %s" % (sha, ref)
313                 if not i:
314                     line = "%s\x00%s" % (line, self.handler.capability_line())
315                 self.proto.write_pkt_line("%s\n" % line)
316                 peeled_sha = self.handler.backend.repo.get_peeled(ref)
317                 if peeled_sha != sha:
318                     self.proto.write_pkt_line('%s %s^{}\n' %
319                                               (peeled_sha, ref))
320
321             # i'm done..
322             self.proto.write_pkt_line(None)
323
324             if self.advertise_refs:
325                 return []
326
327         # Now client will sending want want want commands
328         want = self.proto.read_pkt_line()
329         if not want:
330             return []
331         line, caps = extract_want_line_capabilities(want)
332         self.handler.set_client_capabilities(caps)
333         self.set_ack_type(ack_type(caps))
334         command, sha = self._split_proto_line(line)
335
336         want_revs = []
337         while command != None:
338             if command != 'want':
339                 raise GitProtocolError(
340                     'Protocol got unexpected command %s' % command)
341             if sha not in values:
342                 raise GitProtocolError(
343                     'Client wants invalid object %s' % sha)
344             want_revs.append(sha)
345             command, sha = self.read_proto_line()
346
347         self.set_wants(want_revs)
348         return want_revs
349
350     def ack(self, have_ref):
351         return self._impl.ack(have_ref)
352
353     def reset(self):
354         self._cached = True
355         self._cache_index = 0
356
357     def next(self):
358         if not self._cached:
359             if not self._impl and self.stateless_rpc:
360                 return None
361             return self._impl.next()
362         self._cache_index += 1
363         if self._cache_index > len(self._cache):
364             return None
365         return self._cache[self._cache_index]
366
367     def _split_proto_line(self, line):
368         fields = line.rstrip('\n').split(' ', 1)
369         if len(fields) == 1 and fields[0] == 'done':
370             return ('done', None)
371         elif len(fields) == 2 and fields[0] in ('want', 'have'):
372             try:
373                 hex_to_sha(fields[1])
374                 return tuple(fields)
375             except (TypeError, AssertionError), e:
376                 raise GitProtocolError(e)
377         raise GitProtocolError('Received invalid line from client:\n%s' % line)
378
379     def read_proto_line(self):
380         """Read a line from the wire.
381
382         :return: a tuple having one of the following forms:
383             ('want', obj_id)
384             ('have', obj_id)
385             ('done', None)
386             (None, None)  (for a flush-pkt)
387
388         :raise GitProtocolError: if the line cannot be parsed into one of the
389             possible return values.
390         """
391         line = self.proto.read_pkt_line()
392         if not line:
393             return (None, None)
394         return self._split_proto_line(line)
395
396     def send_ack(self, sha, ack_type=''):
397         if ack_type:
398             ack_type = ' %s' % ack_type
399         self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
400
401     def send_nak(self):
402         self.proto.write_pkt_line('NAK\n')
403
404     def set_wants(self, wants):
405         self._wants = wants
406
407     def _is_satisfied(self, haves, want, earliest):
408         """Check whether a want is satisfied by a set of haves.
409
410         A want, typically a branch tip, is "satisfied" only if there exists a
411         path back from that want to one of the haves.
412
413         :param haves: A set of commits we know the client has.
414         :param want: The want to check satisfaction for.
415         :param earliest: A timestamp beyond which the search for haves will be
416             terminated, presumably because we're searching too far down the
417             wrong branch.
418         """
419         o = self.store[want]
420         pending = collections.deque([o])
421         while pending:
422             commit = pending.popleft()
423             if commit.id in haves:
424                 return True
425             if not getattr(commit, 'get_parents', None):
426                 # non-commit wants are assumed to be satisfied
427                 continue
428             for parent in commit.get_parents():
429                 parent_obj = self.store[parent]
430                 # TODO: handle parents with later commit times than children
431                 if parent_obj.commit_time >= earliest:
432                     pending.append(parent_obj)
433         return False
434
435     def all_wants_satisfied(self, haves):
436         """Check whether all the current wants are satisfied by a set of haves.
437
438         :param haves: A set of commits we know the client has.
439         :note: Wants are specified with set_wants rather than passed in since
440             in the current interface they are determined outside this class.
441         """
442         haves = set(haves)
443         earliest = min([self.store[h].commit_time for h in haves])
444         for want in self._wants:
445             if not self._is_satisfied(haves, want, earliest):
446                 return False
447         return True
448
449     def set_ack_type(self, ack_type):
450         impl_classes = {
451             MULTI_ACK: MultiAckGraphWalkerImpl,
452             MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
453             SINGLE_ACK: SingleAckGraphWalkerImpl,
454             }
455         self._impl = impl_classes[ack_type](self)
456
457
458 class SingleAckGraphWalkerImpl(object):
459     """Graph walker implementation that speaks the single-ack protocol."""
460
461     def __init__(self, walker):
462         self.walker = walker
463         self._sent_ack = False
464
465     def ack(self, have_ref):
466         if not self._sent_ack:
467             self.walker.send_ack(have_ref)
468             self._sent_ack = True
469
470     def next(self):
471         command, sha = self.walker.read_proto_line()
472         if command in (None, 'done'):
473             if not self._sent_ack:
474                 self.walker.send_nak()
475             return None
476         elif command == 'have':
477             return sha
478
479
480 class MultiAckGraphWalkerImpl(object):
481     """Graph walker implementation that speaks the multi-ack protocol."""
482
483     def __init__(self, walker):
484         self.walker = walker
485         self._found_base = False
486         self._common = []
487
488     def ack(self, have_ref):
489         self._common.append(have_ref)
490         if not self._found_base:
491             self.walker.send_ack(have_ref, 'continue')
492             if self.walker.all_wants_satisfied(self._common):
493                 self._found_base = True
494         # else we blind ack within next
495
496     def next(self):
497         while True:
498             command, sha = self.walker.read_proto_line()
499             if command is None:
500                 self.walker.send_nak()
501                 # in multi-ack mode, a flush-pkt indicates the client wants to
502                 # flush but more have lines are still coming
503                 continue
504             elif command == 'done':
505                 # don't nak unless no common commits were found, even if not
506                 # everything is satisfied
507                 if self._common:
508                     self.walker.send_ack(self._common[-1])
509                 else:
510                     self.walker.send_nak()
511                 return None
512             elif command == 'have':
513                 if self._found_base:
514                     # blind ack
515                     self.walker.send_ack(sha, 'continue')
516                 return sha
517
518
519 class MultiAckDetailedGraphWalkerImpl(object):
520     """Graph walker implementation speaking the multi-ack-detailed protocol."""
521
522     def __init__(self, walker):
523         self.walker = walker
524         self._found_base = False
525         self._common = []
526
527     def ack(self, have_ref):
528         self._common.append(have_ref)
529         if not self._found_base:
530             self.walker.send_ack(have_ref, 'common')
531             if self.walker.all_wants_satisfied(self._common):
532                 self._found_base = True
533                 self.walker.send_ack(have_ref, 'ready')
534         # else we blind ack within next
535
536     def next(self):
537         while True:
538             command, sha = self.walker.read_proto_line()
539             if command is None:
540                 self.walker.send_nak()
541                 if self.walker.stateless_rpc:
542                     return None
543                 continue
544             elif command == 'done':
545                 # don't nak unless no common commits were found, even if not
546                 # everything is satisfied
547                 if self._common:
548                     self.walker.send_ack(self._common[-1])
549                 else:
550                     self.walker.send_nak()
551                 return None
552             elif command == 'have':
553                 if self._found_base:
554                     # blind ack; can happen if the client has more requests
555                     # inflight
556                     self.walker.send_ack(sha, 'ready')
557                 return sha
558
559
560 class ReceivePackHandler(Handler):
561     """Protocol handler for downloading a pack from the client."""
562
563     def __init__(self, backend, read, write,
564                  stateless_rpc=False, advertise_refs=False):
565         Handler.__init__(self, backend, read, write)
566         self.stateless_rpc = stateless_rpc
567         self.advertise_refs = advertise_refs
568
569     def capabilities(self):
570         return ("report-status", "delete-refs")
571
572     def handle(self):
573         refs = self.backend.get_refs().items()
574
575         if self.advertise_refs or not self.stateless_rpc:
576             if refs:
577                 self.proto.write_pkt_line(
578                   "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
579                                      self.capability_line()))
580                 for i in range(1, len(refs)):
581                     ref = refs[i]
582                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
583             else:
584                 self.proto.write_pkt_line("%s capabilities^{} %s" % (
585                   ZERO_SHA, self.capability_line()))
586
587             self.proto.write("0000")
588             if self.advertise_refs:
589                 return
590
591         client_refs = []
592         ref = self.proto.read_pkt_line()
593
594         # if ref is none then client doesnt want to send us anything..
595         if ref is None:
596             return
597
598         ref, caps = extract_capabilities(ref)
599         self.set_client_capabilities(caps)
600
601         # client will now send us a list of (oldsha, newsha, ref)
602         while ref:
603             client_refs.append(ref.split())
604             ref = self.proto.read_pkt_line()
605
606         # backend can now deal with this refs and read a pack using self.read
607         status = self.backend.apply_pack(client_refs, self.proto.read,
608                                          self.has_capability('delete-refs'))
609
610         # when we have read all the pack from the client, send a status report
611         # if the client asked for it
612         if self.has_capability('report-status'):
613             for name, msg in status:
614                 if name == 'unpack':
615                     self.proto.write_pkt_line('unpack %s\n' % msg)
616                 elif msg == 'ok':
617                     self.proto.write_pkt_line('ok %s\n' % name)
618                 else:
619                     self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
620             self.proto.write_pkt_line(None)
621
622
623 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
624
625     def handle(self):
626         proto = Protocol(self.rfile.read, self.wfile.write)
627         command, args = proto.read_cmd()
628
629         # switch case to handle the specific git command
630         if command == 'git-upload-pack':
631             cls = UploadPackHandler
632         elif command == 'git-receive-pack':
633             cls = ReceivePackHandler
634         else:
635             return
636
637         h = cls(self.server.backend, self.rfile.read, self.wfile.write)
638         h.handle()
639
640
641 class TCPGitServer(SocketServer.TCPServer):
642
643     allow_reuse_address = True
644     serve = SocketServer.TCPServer.serve_forever
645
646     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
647         self.backend = backend
648         SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)