Add no-progress capability support to UploadPackHandler.
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19
20 """Git smart network protocol server implementation.
21
22 For more detailed implementation on the network protocol, see the
23 Documentation/technical directory in the cgit distribution, and in particular:
24     Documentation/technical/protocol-capabilities.txt
25     Documentation/technical/pack-protocol.txt
26 """
27
28
29 import collections
30 import SocketServer
31 import tempfile
32
33 from dulwich.errors import (
34     ApplyDeltaError,
35     ChecksumMismatch,
36     GitProtocolError,
37     )
38 from dulwich.objects import (
39     hex_to_sha,
40     )
41 from dulwich.protocol import (
42     Protocol,
43     ProtocolFile,
44     TCP_GIT_PORT,
45     extract_capabilities,
46     extract_want_line_capabilities,
47     SINGLE_ACK,
48     MULTI_ACK,
49     MULTI_ACK_DETAILED,
50     ack_type,
51     )
52 from dulwich.repo import (
53     Repo,
54     )
55 from dulwich.pack import (
56     write_pack_data,
57     )
58
59 class Backend(object):
60
61     def get_refs(self):
62         """
63         Get all the refs in the repository
64
65         :return: dict of name -> sha
66         """
67         raise NotImplementedError
68
69     def apply_pack(self, refs, read):
70         """ Import a set of changes into a repository and update the refs
71
72         :param refs: list of tuple(name, sha)
73         :param read: callback to read from the incoming pack
74         """
75         raise NotImplementedError
76
77     def fetch_objects(self, determine_wants, graph_walker, progress):
78         """
79         Yield the objects required for a list of commits.
80
81         :param progress: is a callback to send progress messages to the client
82         """
83         raise NotImplementedError
84
85
86 class GitBackend(Backend):
87
88     def __init__(self, repo=None):
89         if repo is None:
90             repo = Repo(tmpfile.mkdtemp())
91         self.repo = repo
92         self.object_store = self.repo.object_store
93         self.fetch_objects = self.repo.fetch_objects
94         self.get_refs = self.repo.get_refs
95
96     def apply_pack(self, refs, read):
97         f, commit = self.repo.object_store.add_thin_pack()
98         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
99         status = []
100         unpack_error = None
101         # TODO: more informative error messages than just the exception string
102         try:
103             # TODO: decode the pack as we stream to avoid blocking reads beyond
104             # the end of data (when using HTTP/1.1 chunked encoding)
105             while True:
106                 data = read(10240)
107                 if not data:
108                     break
109                 f.write(data)
110         except all_exceptions, e:
111             unpack_error = str(e).replace('\n', '')
112         try:
113             commit()
114         except all_exceptions, e:
115             if not unpack_error:
116                 unpack_error = str(e).replace('\n', '')
117
118         if unpack_error:
119             status.append(('unpack', unpack_error))
120         else:
121             status.append(('unpack', 'ok'))
122
123         for oldsha, sha, ref in refs:
124             # TODO: check refname
125             ref_error = None
126             try:
127                 if ref == "0" * 40:
128                     try:
129                         del self.repo.refs[ref]
130                     except all_exceptions:
131                         ref_error = 'failed to delete'
132                 else:
133                     try:
134                         self.repo.refs[ref] = sha
135                     except all_exceptions:
136                         ref_error = 'failed to write'
137             except KeyError, e:
138                 ref_error = 'bad ref'
139             if ref_error:
140                 status.append((ref, ref_error))
141             else:
142                 status.append((ref, 'ok'))
143
144
145         print "pack applied"
146         return status
147
148
149 class Handler(object):
150     """Smart protocol command handler base class."""
151
152     def __init__(self, backend, read, write):
153         self.backend = backend
154         self.proto = Protocol(read, write)
155         self._client_capabilities = None
156
157     def capability_line(self):
158         return " ".join(self.capabilities())
159
160     def capabilities(self):
161         raise NotImplementedError(self.capabilities)
162
163     def set_client_capabilities(self, caps):
164         my_caps = self.capabilities()
165         for cap in caps:
166             if cap not in my_caps:
167                 raise GitProtocolError('Client asked for capability %s that '
168                                        'was not advertised.' % cap)
169         self._client_capabilities = set(caps)
170
171     def has_capability(self, cap):
172         if self._client_capabilities is None:
173             raise GitProtocolError('Server attempted to access capability %s '
174                                    'before asking client' % cap)
175         return cap in self._client_capabilities
176
177
178 class UploadPackHandler(Handler):
179     """Protocol handler for uploading a pack to the server."""
180
181     def __init__(self, backend, read, write,
182                  stateless_rpc=False, advertise_refs=False):
183         Handler.__init__(self, backend, read, write)
184         self._graph_walker = None
185         self.stateless_rpc = stateless_rpc
186         self.advertise_refs = advertise_refs
187
188     def capabilities(self):
189         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
190                 "ofs-delta", "no-progress")
191
192     def progress(self, message):
193         if self.has_capability("no-progress"):
194             return
195         self.proto.write_sideband(2, message)
196
197     def handle(self):
198         write = lambda x: self.proto.write_sideband(1, x)
199
200         graph_walker = ProtocolGraphWalker(self)
201         objects_iter = self.backend.fetch_objects(
202           graph_walker.determine_wants, graph_walker, self.progress)
203
204         # Do they want any objects?
205         if len(objects_iter) == 0:
206             return
207
208         self.progress("dul-daemon says what\n")
209         self.progress("counting objects: %d, done.\n" % len(objects_iter))
210         write_pack_data(ProtocolFile(None, write), objects_iter, 
211                         len(objects_iter))
212         self.progress("how was that, then?\n")
213         # we are done
214         self.proto.write("0000")
215
216
217 class ProtocolGraphWalker(object):
218     """A graph walker that knows the git protocol.
219
220     As a graph walker, this class implements ack(), next(), and reset(). It also
221     contains some base methods for interacting with the wire and walking the
222     commit tree.
223
224     The work of determining which acks to send is passed on to the
225     implementation instance stored in _impl. The reason for this is that we do
226     not know at object creation time what ack level the protocol requires. A
227     call to set_ack_level() is required to set up the implementation, before any
228     calls to next() or ack() are made.
229     """
230     def __init__(self, handler):
231         self.handler = handler
232         self.store = handler.backend.object_store
233         self.proto = handler.proto
234         self.stateless_rpc = handler.stateless_rpc
235         self.advertise_refs = handler.advertise_refs
236         self._wants = []
237         self._cached = False
238         self._cache = []
239         self._cache_index = 0
240         self._impl = None
241
242     def determine_wants(self, heads):
243         """Determine the wants for a set of heads.
244
245         The given heads are advertised to the client, who then specifies which
246         refs he wants using 'want' lines. This portion of the protocol is the
247         same regardless of ack type, and in fact is used to set the ack type of
248         the ProtocolGraphWalker.
249
250         :param heads: a dict of refname->SHA1 to advertise
251         :return: a list of SHA1s requested by the client
252         """
253         if not heads:
254             raise GitProtocolError('No heads found')
255         values = set(heads.itervalues())
256         if self.advertise_refs or not self.stateless_rpc:
257             for i, (ref, sha) in enumerate(heads.iteritems()):
258                 line = "%s %s" % (sha, ref)
259                 if not i:
260                     line = "%s\x00%s" % (line, self.handler.capability_line())
261                 self.proto.write_pkt_line("%s\n" % line)
262                 # TODO: include peeled value of any tags
263
264             # i'm done..
265             self.proto.write_pkt_line(None)
266
267             if self.advertise_refs:
268                 return []
269
270         # Now client will sending want want want commands
271         want = self.proto.read_pkt_line()
272         if not want:
273             return []
274         line, caps = extract_want_line_capabilities(want)
275         self.handler.set_client_capabilities(caps)
276         self.set_ack_type(ack_type(caps))
277         command, sha = self._split_proto_line(line)
278
279         want_revs = []
280         while command != None:
281             if command != 'want':
282                 raise GitProtocolError(
283                     'Protocol got unexpected command %s' % command)
284             if sha not in values:
285                 raise GitProtocolError(
286                     'Client wants invalid object %s' % sha)
287             want_revs.append(sha)
288             command, sha = self.read_proto_line()
289
290         self.set_wants(want_revs)
291         return want_revs
292
293     def ack(self, have_ref):
294         return self._impl.ack(have_ref)
295
296     def reset(self):
297         self._cached = True
298         self._cache_index = 0
299
300     def next(self):
301         if not self._cached:
302             if not self._impl and self.stateless_rpc:
303                 return None
304             return self._impl.next()
305         self._cache_index += 1
306         if self._cache_index > len(self._cache):
307             return None
308         return self._cache[self._cache_index]
309
310     def _split_proto_line(self, line):
311         fields = line.rstrip('\n').split(' ', 1)
312         if len(fields) == 1 and fields[0] == 'done':
313             return ('done', None)
314         elif len(fields) == 2 and fields[0] in ('want', 'have'):
315             try:
316                 hex_to_sha(fields[1])
317                 return tuple(fields)
318             except (TypeError, AssertionError), e:
319                 raise GitProtocolError(e)
320         raise GitProtocolError('Received invalid line from client:\n%s' % line)
321
322     def read_proto_line(self):
323         """Read a line from the wire.
324
325         :return: a tuple having one of the following forms:
326             ('want', obj_id)
327             ('have', obj_id)
328             ('done', None)
329             (None, None)  (for a flush-pkt)
330
331         :raise GitProtocolError: if the line cannot be parsed into one of the
332             possible return values.
333         """
334         line = self.proto.read_pkt_line()
335         if not line:
336             return (None, None)
337         return self._split_proto_line(line)
338
339     def send_ack(self, sha, ack_type=''):
340         if ack_type:
341             ack_type = ' %s' % ack_type
342         self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
343
344     def send_nak(self):
345         self.proto.write_pkt_line('NAK\n')
346
347     def set_wants(self, wants):
348         self._wants = wants
349
350     def _is_satisfied(self, haves, want, earliest):
351         """Check whether a want is satisfied by a set of haves.
352
353         A want, typically a branch tip, is "satisfied" only if there exists a
354         path back from that want to one of the haves.
355
356         :param haves: A set of commits we know the client has.
357         :param want: The want to check satisfaction for.
358         :param earliest: A timestamp beyond which the search for haves will be
359             terminated, presumably because we're searching too far down the
360             wrong branch.
361         """
362         o = self.store[want]
363         pending = collections.deque([o])
364         while pending:
365             commit = pending.popleft()
366             if commit.id in haves:
367                 return True
368             if not getattr(commit, 'get_parents', None):
369                 # non-commit wants are assumed to be satisfied
370                 continue
371             for parent in commit.get_parents():
372                 parent_obj = self.store[parent]
373                 # TODO: handle parents with later commit times than children
374                 if parent_obj.commit_time >= earliest:
375                     pending.append(parent_obj)
376         return False
377
378     def all_wants_satisfied(self, haves):
379         """Check whether all the current wants are satisfied by a set of haves.
380
381         :param haves: A set of commits we know the client has.
382         :note: Wants are specified with set_wants rather than passed in since
383             in the current interface they are determined outside this class.
384         """
385         haves = set(haves)
386         earliest = min([self.store[h].commit_time for h in haves])
387         for want in self._wants:
388             if not self._is_satisfied(haves, want, earliest):
389                 return False
390         return True
391
392     def set_ack_type(self, ack_type):
393         impl_classes = {
394             MULTI_ACK: MultiAckGraphWalkerImpl,
395             MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
396             SINGLE_ACK: SingleAckGraphWalkerImpl,
397             }
398         self._impl = impl_classes[ack_type](self)
399
400
401 class SingleAckGraphWalkerImpl(object):
402     """Graph walker implementation that speaks the single-ack protocol."""
403
404     def __init__(self, walker):
405         self.walker = walker
406         self._sent_ack = False
407
408     def ack(self, have_ref):
409         if not self._sent_ack:
410             self.walker.send_ack(have_ref)
411             self._sent_ack = True
412
413     def next(self):
414         command, sha = self.walker.read_proto_line()
415         if command in (None, 'done'):
416             if not self._sent_ack:
417                 self.walker.send_nak()
418             return None
419         elif command == 'have':
420             return sha
421
422
423 class MultiAckGraphWalkerImpl(object):
424     """Graph walker implementation that speaks the multi-ack protocol."""
425
426     def __init__(self, walker):
427         self.walker = walker
428         self._found_base = False
429         self._common = []
430
431     def ack(self, have_ref):
432         self._common.append(have_ref)
433         if not self._found_base:
434             self.walker.send_ack(have_ref, 'continue')
435             if self.walker.all_wants_satisfied(self._common):
436                 self._found_base = True
437         # else we blind ack within next
438
439     def next(self):
440         while True:
441             command, sha = self.walker.read_proto_line()
442             if command is None:
443                 self.walker.send_nak()
444                 # in multi-ack mode, a flush-pkt indicates the client wants to
445                 # flush but more have lines are still coming
446                 continue
447             elif command == 'done':
448                 # don't nak unless no common commits were found, even if not
449                 # everything is satisfied
450                 if self._common:
451                     self.walker.send_ack(self._common[-1])
452                 else:
453                     self.walker.send_nak()
454                 return None
455             elif command == 'have':
456                 if self._found_base:
457                     # blind ack
458                     self.walker.send_ack(sha, 'continue')
459                 return sha
460
461
462 class MultiAckDetailedGraphWalkerImpl(object):
463     """Graph walker implementation speaking the multi-ack-detailed protocol."""
464
465     def __init__(self, walker):
466         self.walker = walker
467         self._found_base = False
468         self._common = []
469
470     def ack(self, have_ref):
471         self._common.append(have_ref)
472         if not self._found_base:
473             self.walker.send_ack(have_ref, 'common')
474             if self.walker.all_wants_satisfied(self._common):
475                 self._found_base = True
476                 self.walker.send_ack(have_ref, 'ready')
477         # else we blind ack within next
478
479     def next(self):
480         while True:
481             command, sha = self.walker.read_proto_line()
482             if command is None:
483                 self.walker.send_nak()
484                 if self.walker.stateless_rpc:
485                     return None
486                 continue
487             elif command == 'done':
488                 # don't nak unless no common commits were found, even if not
489                 # everything is satisfied
490                 if self._common:
491                     self.walker.send_ack(self._common[-1])
492                 else:
493                     self.walker.send_nak()
494                 return None
495             elif command == 'have':
496                 if self._found_base:
497                     # blind ack; can happen if the client has more requests
498                     # inflight
499                     self.walker.send_ack(sha, 'ready')
500                 return sha
501
502
503 class ReceivePackHandler(Handler):
504     """Protocol handler for downloading a pack from the client."""
505
506     def __init__(self, backend, read, write,
507                  stateless_rpc=False, advertise_refs=False):
508         Handler.__init__(self, backend, read, write)
509         self.stateless_rpc = stateless_rpc
510         self.advertise_refs = advertise_refs
511
512     def __init__(self, backend, read, write,
513                  stateless_rpc=False, advertise_refs=False):
514         Handler.__init__(self, backend, read, write)
515         self._stateless_rpc = stateless_rpc
516         self._advertise_refs = advertise_refs
517
518     def capabilities(self):
519         return ("report-status", "delete-refs")
520
521     def handle(self):
522         refs = self.backend.get_refs().items()
523
524         if self.advertise_refs or not self.stateless_rpc:
525             if refs:
526                 self.proto.write_pkt_line(
527                     "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
528                                        self.capability_line()))
529                 for i in range(1, len(refs)):
530                     ref = refs[i]
531                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
532             else:
533                 self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capability_line())
534
535             self.proto.write("0000")
536             if self.advertise_refs:
537                 return
538
539         client_refs = []
540         ref = self.proto.read_pkt_line()
541
542         # if ref is none then client doesnt want to send us anything..
543         if ref is None:
544             return
545
546         ref, caps = extract_capabilities(ref)
547         self.set_client_capabilities(caps)
548
549         # client will now send us a list of (oldsha, newsha, ref)
550         while ref:
551             client_refs.append(ref.split())
552             ref = self.proto.read_pkt_line()
553
554         # backend can now deal with this refs and read a pack using self.read
555         status = self.backend.apply_pack(client_refs, self.proto.read)
556
557         # when we have read all the pack from the client, send a status report
558         # if the client asked for it
559         if self.has_capability('report-status'):
560             for name, msg in status:
561                 if name == 'unpack':
562                     self.proto.write_pkt_line('unpack %s\n' % msg)
563                 elif msg == 'ok':
564                     self.proto.write_pkt_line('ok %s\n' % name)
565                 else:
566                     self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
567             self.proto.write_pkt_line(None)
568
569
570 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
571
572     def handle(self):
573         proto = Protocol(self.rfile.read, self.wfile.write)
574         command, args = proto.read_cmd()
575
576         # switch case to handle the specific git command
577         if command == 'git-upload-pack':
578             cls = UploadPackHandler
579         elif command == 'git-receive-pack':
580             cls = ReceivePackHandler
581         else:
582             return
583
584         h = cls(self.server.backend, self.rfile.read, self.wfile.write)
585         h.handle()
586
587
588 class TCPGitServer(SocketServer.TCPServer):
589
590     allow_reuse_address = True
591     serve = SocketServer.TCPServer.serve_forever
592
593     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
594         self.backend = backend
595         SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)