Merge server capability refactoring from Dave.
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19
20 """Git smart network protocol server implementation.
21
22 For more detailed implementation on the network protocol, see the
23 Documentation/technical directory in the cgit distribution, and in particular:
24     Documentation/technical/protocol-capabilities.txt
25     Documentation/technical/pack-protocol.txt
26 """
27
28
29 import collections
30 import SocketServer
31 import tempfile
32
33 from dulwich.errors import (
34     ApplyDeltaError,
35     ChecksumMismatch,
36     GitProtocolError,
37     )
38 from dulwich.objects import (
39     hex_to_sha,
40     )
41 from dulwich.protocol import (
42     Protocol,
43     ProtocolFile,
44     TCP_GIT_PORT,
45     extract_capabilities,
46     extract_want_line_capabilities,
47     SINGLE_ACK,
48     MULTI_ACK,
49     MULTI_ACK_DETAILED,
50     ack_type,
51     )
52 from dulwich.repo import (
53     Repo,
54     )
55 from dulwich.pack import (
56     write_pack_data,
57     )
58
59 class Backend(object):
60
61     def get_refs(self):
62         """
63         Get all the refs in the repository
64
65         :return: dict of name -> sha
66         """
67         raise NotImplementedError
68
69     def apply_pack(self, refs, read):
70         """ Import a set of changes into a repository and update the refs
71
72         :param refs: list of tuple(name, sha)
73         :param read: callback to read from the incoming pack
74         """
75         raise NotImplementedError
76
77     def fetch_objects(self, determine_wants, graph_walker, progress):
78         """
79         Yield the objects required for a list of commits.
80
81         :param progress: is a callback to send progress messages to the client
82         """
83         raise NotImplementedError
84
85
86 class GitBackend(Backend):
87
88     def __init__(self, repo=None):
89         if repo is None:
90             repo = Repo(tmpfile.mkdtemp())
91         self.repo = repo
92         self.object_store = self.repo.object_store
93         self.fetch_objects = self.repo.fetch_objects
94         self.get_refs = self.repo.get_refs
95
96     def apply_pack(self, refs, read):
97         f, commit = self.repo.object_store.add_thin_pack()
98         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
99         status = []
100         unpack_error = None
101         # TODO: more informative error messages than just the exception string
102         try:
103             # TODO: decode the pack as we stream to avoid blocking reads beyond
104             # the end of data (when using HTTP/1.1 chunked encoding)
105             while True:
106                 data = read(10240)
107                 if not data:
108                     break
109                 f.write(data)
110         except all_exceptions, e:
111             unpack_error = str(e).replace('\n', '')
112         try:
113             commit()
114         except all_exceptions, e:
115             if not unpack_error:
116                 unpack_error = str(e).replace('\n', '')
117
118         if unpack_error:
119             status.append(('unpack', unpack_error))
120         else:
121             status.append(('unpack', 'ok'))
122
123         for oldsha, sha, ref in refs:
124             # TODO: check refname
125             ref_error = None
126             try:
127                 if ref == "0" * 40:
128                     try:
129                         del self.repo.refs[ref]
130                     except all_exceptions:
131                         ref_error = 'failed to delete'
132                 else:
133                     try:
134                         self.repo.refs[ref] = sha
135                     except all_exceptions:
136                         ref_error = 'failed to write'
137             except KeyError, e:
138                 ref_error = 'bad ref'
139             if ref_error:
140                 status.append((ref, ref_error))
141             else:
142                 status.append((ref, 'ok'))
143
144
145         print "pack applied"
146         return status
147
148
149 class Handler(object):
150     """Smart protocol command handler base class."""
151
152     def __init__(self, backend, read, write):
153         self.backend = backend
154         self.proto = Protocol(read, write)
155         self._client_capabilities = None
156
157     def capability_line(self):
158         return " ".join(self.capabilities())
159
160     def capabilities(self):
161         raise NotImplementedError(self.capabilities)
162
163     def set_client_capabilities(self, caps):
164         my_caps = self.capabilities()
165         for cap in caps:
166             if cap not in my_caps:
167                 raise GitProtocolError('Client asked for capability %s that '
168                                        'was not advertised.' % cap)
169         self._client_capabilities = caps
170
171     def has_capability(self, cap):
172         if self._client_capabilities is None:
173             raise GitProtocolError('Server attempted to access capability %s '
174                                    'before asking client' % cap)
175         return cap in self._client_capabilities
176
177
178 class UploadPackHandler(Handler):
179     """Protocol handler for uploading a pack to the server."""
180
181     def __init__(self, backend, read, write,
182                  stateless_rpc=False, advertise_refs=False):
183         Handler.__init__(self, backend, read, write)
184         self._graph_walker = None
185         self.stateless_rpc = stateless_rpc
186         self.advertise_refs = advertise_refs
187
188     def capabilities(self):
189         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
190                 "ofs-delta")
191
192     def handle(self):
193
194         progress = lambda x: self.proto.write_sideband(2, x)
195         write = lambda x: self.proto.write_sideband(1, x)
196
197         graph_walker = ProtocolGraphWalker(self)
198         objects_iter = self.backend.fetch_objects(
199           graph_walker.determine_wants, graph_walker, progress)
200
201         # Do they want any objects?
202         if len(objects_iter) == 0:
203             return
204
205         progress("dul-daemon says what\n")
206         progress("counting objects: %d, done.\n" % len(objects_iter))
207         write_pack_data(ProtocolFile(None, write), objects_iter, 
208                         len(objects_iter))
209         progress("how was that, then?\n")
210         # we are done
211         self.proto.write("0000")
212
213
214 class ProtocolGraphWalker(object):
215     """A graph walker that knows the git protocol.
216
217     As a graph walker, this class implements ack(), next(), and reset(). It also
218     contains some base methods for interacting with the wire and walking the
219     commit tree.
220
221     The work of determining which acks to send is passed on to the
222     implementation instance stored in _impl. The reason for this is that we do
223     not know at object creation time what ack level the protocol requires. A
224     call to set_ack_level() is required to set up the implementation, before any
225     calls to next() or ack() are made.
226     """
227     def __init__(self, handler):
228         self.handler = handler
229         self.store = handler.backend.object_store
230         self.proto = handler.proto
231         self.stateless_rpc = handler.stateless_rpc
232         self.advertise_refs = handler.advertise_refs
233         self._wants = []
234         self._cached = False
235         self._cache = []
236         self._cache_index = 0
237         self._impl = None
238
239     def determine_wants(self, heads):
240         """Determine the wants for a set of heads.
241
242         The given heads are advertised to the client, who then specifies which
243         refs he wants using 'want' lines. This portion of the protocol is the
244         same regardless of ack type, and in fact is used to set the ack type of
245         the ProtocolGraphWalker.
246
247         :param heads: a dict of refname->SHA1 to advertise
248         :return: a list of SHA1s requested by the client
249         """
250         if not heads:
251             raise GitProtocolError('No heads found')
252         values = set(heads.itervalues())
253         if self.advertise_refs or not self.stateless_rpc:
254             for i, (ref, sha) in enumerate(heads.iteritems()):
255                 line = "%s %s" % (sha, ref)
256                 if not i:
257                     line = "%s\x00%s" % (line, self.handler.capability_line())
258                 self.proto.write_pkt_line("%s\n" % line)
259                 # TODO: include peeled value of any tags
260
261             # i'm done..
262             self.proto.write_pkt_line(None)
263
264             if self.advertise_refs:
265                 return []
266
267         # Now client will sending want want want commands
268         want = self.proto.read_pkt_line()
269         if not want:
270             return []
271         line, caps = extract_want_line_capabilities(want)
272         self.handler.set_client_capabilities(caps)
273         self.set_ack_type(ack_type(caps))
274         command, sha = self._split_proto_line(line)
275
276         want_revs = []
277         while command != None:
278             if command != 'want':
279                 raise GitProtocolError(
280                     'Protocol got unexpected command %s' % command)
281             if sha not in values:
282                 raise GitProtocolError(
283                     'Client wants invalid object %s' % sha)
284             want_revs.append(sha)
285             command, sha = self.read_proto_line()
286
287         self.set_wants(want_revs)
288         return want_revs
289
290     def ack(self, have_ref):
291         return self._impl.ack(have_ref)
292
293     def reset(self):
294         self._cached = True
295         self._cache_index = 0
296
297     def next(self):
298         if not self._cached:
299             if not self._impl and self.stateless_rpc:
300                 return None
301             return self._impl.next()
302         self._cache_index += 1
303         if self._cache_index > len(self._cache):
304             return None
305         return self._cache[self._cache_index]
306
307     def _split_proto_line(self, line):
308         fields = line.rstrip('\n').split(' ', 1)
309         if len(fields) == 1 and fields[0] == 'done':
310             return ('done', None)
311         elif len(fields) == 2 and fields[0] in ('want', 'have'):
312             try:
313                 hex_to_sha(fields[1])
314                 return tuple(fields)
315             except (TypeError, AssertionError), e:
316                 raise GitProtocolError(e)
317         raise GitProtocolError('Received invalid line from client:\n%s' % line)
318
319     def read_proto_line(self):
320         """Read a line from the wire.
321
322         :return: a tuple having one of the following forms:
323             ('want', obj_id)
324             ('have', obj_id)
325             ('done', None)
326             (None, None)  (for a flush-pkt)
327
328         :raise GitProtocolError: if the line cannot be parsed into one of the
329             possible return values.
330         """
331         line = self.proto.read_pkt_line()
332         if not line:
333             return (None, None)
334         return self._split_proto_line(line)
335
336     def send_ack(self, sha, ack_type=''):
337         if ack_type:
338             ack_type = ' %s' % ack_type
339         self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
340
341     def send_nak(self):
342         self.proto.write_pkt_line('NAK\n')
343
344     def set_wants(self, wants):
345         self._wants = wants
346
347     def _is_satisfied(self, haves, want, earliest):
348         """Check whether a want is satisfied by a set of haves.
349
350         A want, typically a branch tip, is "satisfied" only if there exists a
351         path back from that want to one of the haves.
352
353         :param haves: A set of commits we know the client has.
354         :param want: The want to check satisfaction for.
355         :param earliest: A timestamp beyond which the search for haves will be
356             terminated, presumably because we're searching too far down the
357             wrong branch.
358         """
359         o = self.store[want]
360         pending = collections.deque([o])
361         while pending:
362             commit = pending.popleft()
363             if commit.id in haves:
364                 return True
365             if not getattr(commit, 'get_parents', None):
366                 # non-commit wants are assumed to be satisfied
367                 continue
368             for parent in commit.get_parents():
369                 parent_obj = self.store[parent]
370                 # TODO: handle parents with later commit times than children
371                 if parent_obj.commit_time >= earliest:
372                     pending.append(parent_obj)
373         return False
374
375     def all_wants_satisfied(self, haves):
376         """Check whether all the current wants are satisfied by a set of haves.
377
378         :param haves: A set of commits we know the client has.
379         :note: Wants are specified with set_wants rather than passed in since
380             in the current interface they are determined outside this class.
381         """
382         haves = set(haves)
383         earliest = min([self.store[h].commit_time for h in haves])
384         for want in self._wants:
385             if not self._is_satisfied(haves, want, earliest):
386                 return False
387         return True
388
389     def set_ack_type(self, ack_type):
390         impl_classes = {
391             MULTI_ACK: MultiAckGraphWalkerImpl,
392             MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
393             SINGLE_ACK: SingleAckGraphWalkerImpl,
394             }
395         self._impl = impl_classes[ack_type](self)
396
397
398 class SingleAckGraphWalkerImpl(object):
399     """Graph walker implementation that speaks the single-ack protocol."""
400
401     def __init__(self, walker):
402         self.walker = walker
403         self._sent_ack = False
404
405     def ack(self, have_ref):
406         if not self._sent_ack:
407             self.walker.send_ack(have_ref)
408             self._sent_ack = True
409
410     def next(self):
411         command, sha = self.walker.read_proto_line()
412         if command in (None, 'done'):
413             if not self._sent_ack:
414                 self.walker.send_nak()
415             return None
416         elif command == 'have':
417             return sha
418
419
420 class MultiAckGraphWalkerImpl(object):
421     """Graph walker implementation that speaks the multi-ack protocol."""
422
423     def __init__(self, walker):
424         self.walker = walker
425         self._found_base = False
426         self._common = []
427
428     def ack(self, have_ref):
429         self._common.append(have_ref)
430         if not self._found_base:
431             self.walker.send_ack(have_ref, 'continue')
432             if self.walker.all_wants_satisfied(self._common):
433                 self._found_base = True
434         # else we blind ack within next
435
436     def next(self):
437         while True:
438             command, sha = self.walker.read_proto_line()
439             if command is None:
440                 self.walker.send_nak()
441                 # in multi-ack mode, a flush-pkt indicates the client wants to
442                 # flush but more have lines are still coming
443                 continue
444             elif command == 'done':
445                 # don't nak unless no common commits were found, even if not
446                 # everything is satisfied
447                 if self._common:
448                     self.walker.send_ack(self._common[-1])
449                 else:
450                     self.walker.send_nak()
451                 return None
452             elif command == 'have':
453                 if self._found_base:
454                     # blind ack
455                     self.walker.send_ack(sha, 'continue')
456                 return sha
457
458
459 class MultiAckDetailedGraphWalkerImpl(object):
460     """Graph walker implementation speaking the multi-ack-detailed protocol."""
461
462     def __init__(self, walker):
463         self.walker = walker
464         self._found_base = False
465         self._common = []
466
467     def ack(self, have_ref):
468         self._common.append(have_ref)
469         if not self._found_base:
470             self.walker.send_ack(have_ref, 'common')
471             if self.walker.all_wants_satisfied(self._common):
472                 self._found_base = True
473                 self.walker.send_ack(have_ref, 'ready')
474         # else we blind ack within next
475
476     def next(self):
477         while True:
478             command, sha = self.walker.read_proto_line()
479             if command is None:
480                 self.walker.send_nak()
481                 if self.walker.stateless_rpc:
482                     return None
483                 continue
484             elif command == 'done':
485                 # don't nak unless no common commits were found, even if not
486                 # everything is satisfied
487                 if self._common:
488                     self.walker.send_ack(self._common[-1])
489                 else:
490                     self.walker.send_nak()
491                 return None
492             elif command == 'have':
493                 if self._found_base:
494                     # blind ack; can happen if the client has more requests
495                     # inflight
496                     self.walker.send_ack(sha, 'ready')
497                 return sha
498
499
500 class ReceivePackHandler(Handler):
501     """Protocol handler for downloading a pack from the client."""
502
503     def __init__(self, backend, read, write,
504                  stateless_rpc=False, advertise_refs=False):
505         Handler.__init__(self, backend, read, write)
506         self.stateless_rpc = stateless_rpc
507         self.advertise_refs = advertise_refs
508
509     def __init__(self, backend, read, write,
510                  stateless_rpc=False, advertise_refs=False):
511         Handler.__init__(self, backend, read, write)
512         self._stateless_rpc = stateless_rpc
513         self._advertise_refs = advertise_refs
514
515     def capabilities(self):
516         return ("report-status", "delete-refs")
517
518     def handle(self):
519         refs = self.backend.get_refs().items()
520
521         if self.advertise_refs or not self.stateless_rpc:
522             if refs:
523                 self.proto.write_pkt_line(
524                     "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
525                                        self.capability_line()))
526                 for i in range(1, len(refs)):
527                     ref = refs[i]
528                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
529             else:
530                 self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capability_line())
531
532             self.proto.write("0000")
533             if self.advertise_refs:
534                 return
535
536         client_refs = []
537         ref = self.proto.read_pkt_line()
538
539         # if ref is none then client doesnt want to send us anything..
540         if ref is None:
541             return
542
543         ref, caps = extract_capabilities(ref)
544         self.set_client_capabilities(caps)
545
546         # client will now send us a list of (oldsha, newsha, ref)
547         while ref:
548             client_refs.append(ref.split())
549             ref = self.proto.read_pkt_line()
550
551         # backend can now deal with this refs and read a pack using self.read
552         status = self.backend.apply_pack(client_refs, self.proto.read)
553
554         # when we have read all the pack from the client, send a status report
555         # if the client asked for it
556         if self.has_capability('report-status'):
557             for name, msg in status:
558                 if name == 'unpack':
559                     self.proto.write_pkt_line('unpack %s\n' % msg)
560                 elif msg == 'ok':
561                     self.proto.write_pkt_line('ok %s\n' % name)
562                 else:
563                     self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
564             self.proto.write_pkt_line(None)
565
566
567 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
568
569     def handle(self):
570         proto = Protocol(self.rfile.read, self.wfile.write)
571         command, args = proto.read_cmd()
572
573         # switch case to handle the specific git command
574         if command == 'git-upload-pack':
575             cls = UploadPackHandler
576         elif command == 'git-receive-pack':
577             cls = ReceivePackHandler
578         else:
579             return
580
581         h = cls(self.server.backend, self.rfile.read, self.wfile.write)
582         h.handle()
583
584
585 class TCPGitServer(SocketServer.TCPServer):
586
587     allow_reuse_address = True
588     serve = SocketServer.TCPServer.serve_forever
589
590     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
591         self.backend = backend
592         SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)