2e19838d402790a49b86682b7a1525ac075b2b27
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19
20 """Git smart network protocol server implementation.
21
22 For more detailed implementation on the network protocol, see the
23 Documentation/technical directory in the cgit distribution, and in particular:
24     Documentation/technical/protocol-capabilities.txt
25     Documentation/technical/pack-protocol.txt
26 """
27
28
29 import collections
30 import SocketServer
31 import tempfile
32
33 from dulwich.errors import (
34     ApplyDeltaError,
35     ChecksumMismatch,
36     GitProtocolError,
37     )
38 from dulwich.objects import (
39     hex_to_sha,
40     )
41 from dulwich.protocol import (
42     Protocol,
43     ProtocolFile,
44     TCP_GIT_PORT,
45     extract_capabilities,
46     extract_want_line_capabilities,
47     SINGLE_ACK,
48     MULTI_ACK,
49     MULTI_ACK_DETAILED,
50     ack_type,
51     )
52 from dulwich.repo import (
53     Repo,
54     )
55 from dulwich.pack import (
56     write_pack_data,
57     )
58
59 class Backend(object):
60
61     def get_refs(self):
62         """
63         Get all the refs in the repository
64
65         :return: dict of name -> sha
66         """
67         raise NotImplementedError
68
69     def apply_pack(self, refs, read):
70         """ Import a set of changes into a repository and update the refs
71
72         :param refs: list of tuple(name, sha)
73         :param read: callback to read from the incoming pack
74         """
75         raise NotImplementedError
76
77     def fetch_objects(self, determine_wants, graph_walker, progress):
78         """
79         Yield the objects required for a list of commits.
80
81         :param progress: is a callback to send progress messages to the client
82         """
83         raise NotImplementedError
84
85
86 class GitBackend(Backend):
87
88     def __init__(self, repo=None):
89         if repo is None:
90             repo = Repo(tmpfile.mkdtemp())
91         self.repo = repo
92         self.object_store = self.repo.object_store
93         self.fetch_objects = self.repo.fetch_objects
94         self.get_refs = self.repo.get_refs
95
96     def apply_pack(self, refs, read):
97         f, commit = self.repo.object_store.add_thin_pack()
98         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
99         status = []
100         unpack_error = None
101         # TODO: more informative error messages than just the exception string
102         try:
103             # TODO: decode the pack as we stream to avoid blocking reads beyond
104             # the end of data (when using HTTP/1.1 chunked encoding)
105             while True:
106                 data = read(10240)
107                 if not data:
108                     break
109                 f.write(data)
110         except all_exceptions, e:
111             unpack_error = str(e).replace('\n', '')
112         try:
113             commit()
114         except all_exceptions, e:
115             if not unpack_error:
116                 unpack_error = str(e).replace('\n', '')
117
118         if unpack_error:
119             status.append(('unpack', unpack_error))
120         else:
121             status.append(('unpack', 'ok'))
122
123         for oldsha, sha, ref in refs:
124             # TODO: check refname
125             ref_error = None
126             try:
127                 if ref == "0" * 40:
128                     try:
129                         del self.repo.refs[ref]
130                     except all_exceptions:
131                         ref_error = 'failed to delete'
132                 else:
133                     try:
134                         self.repo.refs[ref] = sha
135                     except all_exceptions:
136                         ref_error = 'failed to write'
137             except KeyError, e:
138                 ref_error = 'bad ref'
139             if ref_error:
140                 status.append((ref, ref_error))
141             else:
142                 status.append((ref, 'ok'))
143
144
145         print "pack applied"
146         return status
147
148
149 class Handler(object):
150     """Smart protocol command handler base class."""
151
152     def __init__(self, backend, read, write):
153         self.backend = backend
154         self.proto = Protocol(read, write)
155
156     def capabilities(self):
157         return " ".join(self.default_capabilities())
158
159
160 class UploadPackHandler(Handler):
161     """Protocol handler for uploading a pack to the server."""
162
163     def __init__(self, backend, read, write,
164                  stateless_rpc=False, advertise_refs=False):
165         Handler.__init__(self, backend, read, write)
166         self._client_capabilities = None
167         self._graph_walker = None
168         self.stateless_rpc = stateless_rpc
169         self.advertise_refs = advertise_refs
170
171     def default_capabilities(self):
172         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
173                 "ofs-delta")
174
175     def set_client_capabilities(self, caps):
176         my_caps = self.default_capabilities()
177         for cap in caps:
178             if '_ack' in cap and cap not in my_caps:
179                 raise GitProtocolError('Client asked for capability %s that '
180                                        'was not advertised.' % cap)
181         self._client_capabilities = caps
182
183     def get_client_capabilities(self):
184         return self._client_capabilities
185
186     client_capabilities = property(get_client_capabilities,
187                                    set_client_capabilities)
188
189     def handle(self):
190
191         progress = lambda x: self.proto.write_sideband(2, x)
192         write = lambda x: self.proto.write_sideband(1, x)
193
194         graph_walker = ProtocolGraphWalker(self)
195         objects_iter = self.backend.fetch_objects(
196           graph_walker.determine_wants, graph_walker, progress)
197
198         # Do they want any objects?
199         if len(objects_iter) == 0:
200             return
201
202         progress("dul-daemon says what\n")
203         progress("counting objects: %d, done.\n" % len(objects_iter))
204         write_pack_data(ProtocolFile(None, write), objects_iter, 
205                         len(objects_iter))
206         progress("how was that, then?\n")
207         # we are done
208         self.proto.write("0000")
209
210
211 class ProtocolGraphWalker(object):
212     """A graph walker that knows the git protocol.
213
214     As a graph walker, this class implements ack(), next(), and reset(). It also
215     contains some base methods for interacting with the wire and walking the
216     commit tree.
217
218     The work of determining which acks to send is passed on to the
219     implementation instance stored in _impl. The reason for this is that we do
220     not know at object creation time what ack level the protocol requires. A
221     call to set_ack_level() is required to set up the implementation, before any
222     calls to next() or ack() are made.
223     """
224     def __init__(self, handler):
225         self.handler = handler
226         self.store = handler.backend.object_store
227         self.proto = handler.proto
228         self.stateless_rpc = handler.stateless_rpc
229         self.advertise_refs = handler.advertise_refs
230         self._wants = []
231         self._cached = False
232         self._cache = []
233         self._cache_index = 0
234         self._impl = None
235
236     def determine_wants(self, heads):
237         """Determine the wants for a set of heads.
238
239         The given heads are advertised to the client, who then specifies which
240         refs he wants using 'want' lines. This portion of the protocol is the
241         same regardless of ack type, and in fact is used to set the ack type of
242         the ProtocolGraphWalker.
243
244         :param heads: a dict of refname->SHA1 to advertise
245         :return: a list of SHA1s requested by the client
246         """
247         if not heads:
248             raise GitProtocolError('No heads found')
249         values = set(heads.itervalues())
250         if self.advertise_refs or not self.stateless_rpc:
251             for i, (ref, sha) in enumerate(heads.iteritems()):
252                 line = "%s %s" % (sha, ref)
253                 if not i:
254                     line = "%s\x00%s" % (line, self.handler.capabilities())
255                 self.proto.write_pkt_line("%s\n" % line)
256                 # TODO: include peeled value of any tags
257
258             # i'm done..
259             self.proto.write_pkt_line(None)
260
261             if self.advertise_refs:
262                 return []
263
264         # Now client will sending want want want commands
265         want = self.proto.read_pkt_line()
266         if not want:
267             return []
268         line, caps = extract_want_line_capabilities(want)
269         self.handler.client_capabilities = caps
270         self.set_ack_type(ack_type(caps))
271         command, sha = self._split_proto_line(line)
272
273         want_revs = []
274         while command != None:
275             if command != 'want':
276                 raise GitProtocolError(
277                     'Protocol got unexpected command %s' % command)
278             if sha not in values:
279                 raise GitProtocolError(
280                     'Client wants invalid object %s' % sha)
281             want_revs.append(sha)
282             command, sha = self.read_proto_line()
283
284         self.set_wants(want_revs)
285         return want_revs
286
287     def ack(self, have_ref):
288         return self._impl.ack(have_ref)
289
290     def reset(self):
291         self._cached = True
292         self._cache_index = 0
293
294     def next(self):
295         if not self._cached:
296             if not self._impl and self.stateless_rpc:
297                 return None
298             return self._impl.next()
299         self._cache_index += 1
300         if self._cache_index > len(self._cache):
301             return None
302         return self._cache[self._cache_index]
303
304     def _split_proto_line(self, line):
305         fields = line.rstrip('\n').split(' ', 1)
306         if len(fields) == 1 and fields[0] == 'done':
307             return ('done', None)
308         elif len(fields) == 2 and fields[0] in ('want', 'have'):
309             try:
310                 hex_to_sha(fields[1])
311                 return tuple(fields)
312             except (TypeError, AssertionError), e:
313                 raise GitProtocolError(e)
314         raise GitProtocolError('Received invalid line from client:\n%s' % line)
315
316     def read_proto_line(self):
317         """Read a line from the wire.
318
319         :return: a tuple having one of the following forms:
320             ('want', obj_id)
321             ('have', obj_id)
322             ('done', None)
323             (None, None)  (for a flush-pkt)
324
325         :raise GitProtocolError: if the line cannot be parsed into one of the
326             possible return values.
327         """
328         line = self.proto.read_pkt_line()
329         if not line:
330             return (None, None)
331         return self._split_proto_line(line)
332
333     def send_ack(self, sha, ack_type=''):
334         if ack_type:
335             ack_type = ' %s' % ack_type
336         self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
337
338     def send_nak(self):
339         self.proto.write_pkt_line('NAK\n')
340
341     def set_wants(self, wants):
342         self._wants = wants
343
344     def _is_satisfied(self, haves, want, earliest):
345         """Check whether a want is satisfied by a set of haves.
346
347         A want, typically a branch tip, is "satisfied" only if there exists a
348         path back from that want to one of the haves.
349
350         :param haves: A set of commits we know the client has.
351         :param want: The want to check satisfaction for.
352         :param earliest: A timestamp beyond which the search for haves will be
353             terminated, presumably because we're searching too far down the
354             wrong branch.
355         """
356         o = self.store[want]
357         pending = collections.deque([o])
358         while pending:
359             commit = pending.popleft()
360             if commit.id in haves:
361                 return True
362             if not getattr(commit, 'get_parents', None):
363                 # non-commit wants are assumed to be satisfied
364                 continue
365             for parent in commit.get_parents():
366                 parent_obj = self.store[parent]
367                 # TODO: handle parents with later commit times than children
368                 if parent_obj.commit_time >= earliest:
369                     pending.append(parent_obj)
370         return False
371
372     def all_wants_satisfied(self, haves):
373         """Check whether all the current wants are satisfied by a set of haves.
374
375         :param haves: A set of commits we know the client has.
376         :note: Wants are specified with set_wants rather than passed in since
377             in the current interface they are determined outside this class.
378         """
379         haves = set(haves)
380         earliest = min([self.store[h].commit_time for h in haves])
381         for want in self._wants:
382             if not self._is_satisfied(haves, want, earliest):
383                 return False
384         return True
385
386     def set_ack_type(self, ack_type):
387         impl_classes = {
388             MULTI_ACK: MultiAckGraphWalkerImpl,
389             MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
390             SINGLE_ACK: SingleAckGraphWalkerImpl,
391             }
392         self._impl = impl_classes[ack_type](self)
393
394
395 class SingleAckGraphWalkerImpl(object):
396     """Graph walker implementation that speaks the single-ack protocol."""
397
398     def __init__(self, walker):
399         self.walker = walker
400         self._sent_ack = False
401
402     def ack(self, have_ref):
403         if not self._sent_ack:
404             self.walker.send_ack(have_ref)
405             self._sent_ack = True
406
407     def next(self):
408         command, sha = self.walker.read_proto_line()
409         if command in (None, 'done'):
410             if not self._sent_ack:
411                 self.walker.send_nak()
412             return None
413         elif command == 'have':
414             return sha
415
416
417 class MultiAckGraphWalkerImpl(object):
418     """Graph walker implementation that speaks the multi-ack protocol."""
419
420     def __init__(self, walker):
421         self.walker = walker
422         self._found_base = False
423         self._common = []
424
425     def ack(self, have_ref):
426         self._common.append(have_ref)
427         if not self._found_base:
428             self.walker.send_ack(have_ref, 'continue')
429             if self.walker.all_wants_satisfied(self._common):
430                 self._found_base = True
431         # else we blind ack within next
432
433     def next(self):
434         while True:
435             command, sha = self.walker.read_proto_line()
436             if command is None:
437                 self.walker.send_nak()
438                 # in multi-ack mode, a flush-pkt indicates the client wants to
439                 # flush but more have lines are still coming
440                 continue
441             elif command == 'done':
442                 # don't nak unless no common commits were found, even if not
443                 # everything is satisfied
444                 if self._common:
445                     self.walker.send_ack(self._common[-1])
446                 else:
447                     self.walker.send_nak()
448                 return None
449             elif command == 'have':
450                 if self._found_base:
451                     # blind ack
452                     self.walker.send_ack(sha, 'continue')
453                 return sha
454
455
456 class MultiAckDetailedGraphWalkerImpl(object):
457     """Graph walker implementation speaking the multi-ack-detailed protocol."""
458
459     def __init__(self, walker):
460         self.walker = walker
461         self._found_base = False
462         self._common = []
463
464     def ack(self, have_ref):
465         self._common.append(have_ref)
466         if not self._found_base:
467             self.walker.send_ack(have_ref, 'common')
468             if self.walker.all_wants_satisfied(self._common):
469                 self._found_base = True
470                 self.walker.send_ack(have_ref, 'ready')
471         # else we blind ack within next
472
473     def next(self):
474         while True:
475             command, sha = self.walker.read_proto_line()
476             if command is None:
477                 self.walker.send_nak()
478                 if self.walker.stateless_rpc:
479                     return None
480                 continue
481             elif command == 'done':
482                 # don't nak unless no common commits were found, even if not
483                 # everything is satisfied
484                 if self._common:
485                     self.walker.send_ack(self._common[-1])
486                 else:
487                     self.walker.send_nak()
488                 return None
489             elif command == 'have':
490                 if self._found_base:
491                     # blind ack; can happen if the client has more requests
492                     # inflight
493                     self.walker.send_ack(sha, 'ready')
494                 return sha
495
496
497 class ReceivePackHandler(Handler):
498     """Protocol handler for downloading a pack from the client."""
499
500     def __init__(self, backend, read, write,
501                  stateless_rpc=False, advertise_refs=False):
502         Handler.__init__(self, backend, read, write)
503         self.stateless_rpc = stateless_rpc
504         self.advertise_refs = advertise_refs
505
506     def __init__(self, backend, read, write,
507                  stateless_rpc=False, advertise_refs=False):
508         Handler.__init__(self, backend, read, write)
509         self._stateless_rpc = stateless_rpc
510         self._advertise_refs = advertise_refs
511
512     def default_capabilities(self):
513         return ("report-status", "delete-refs")
514
515     def handle(self):
516         refs = self.backend.get_refs().items()
517
518         if self.advertise_refs or not self.stateless_rpc:
519             if refs:
520                 self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
521                 for i in range(1, len(refs)):
522                     ref = refs[i]
523                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
524             else:
525                 self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
526
527             self.proto.write("0000")
528             if self.advertise_refs:
529                 return
530
531         client_refs = []
532         ref = self.proto.read_pkt_line()
533
534         # if ref is none then client doesnt want to send us anything..
535         if ref is None:
536             return
537
538         ref, client_capabilities = extract_capabilities(ref)
539
540         # client will now send us a list of (oldsha, newsha, ref)
541         while ref:
542             client_refs.append(ref.split())
543             ref = self.proto.read_pkt_line()
544
545         # backend can now deal with this refs and read a pack using self.read
546         status = self.backend.apply_pack(client_refs, self.proto.read)
547
548         # when we have read all the pack from the client, send a status report
549         # if the client asked for it
550         if 'report-status' in client_capabilities:
551             for name, msg in status:
552                 if name == 'unpack':
553                     self.proto.write_pkt_line('unpack %s\n' % msg)
554                 elif msg == 'ok':
555                     self.proto.write_pkt_line('ok %s\n' % name)
556                 else:
557                     self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
558             self.proto.write_pkt_line(None)
559
560
561 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
562
563     def handle(self):
564         proto = Protocol(self.rfile.read, self.wfile.write)
565         command, args = proto.read_cmd()
566
567         # switch case to handle the specific git command
568         if command == 'git-upload-pack':
569             cls = UploadPackHandler
570         elif command == 'git-receive-pack':
571             cls = ReceivePackHandler
572         else:
573             return
574
575         h = cls(self.server.backend, self.rfile.read, self.wfile.write)
576         h.handle()
577
578
579 class TCPGitServer(SocketServer.TCPServer):
580
581     allow_reuse_address = True
582     serve = SocketServer.TCPServer.serve_forever
583
584     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
585         self.backend = backend
586         SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)