1824cfd82b89dcbb2a4e2ddf5e33d0cf02544fe9
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19
20 """Git smart network protocol server implementation.
21
22 For more detailed implementation on the network protocol, see the
23 Documentation/technical directory in the cgit distribution, and in particular:
24     Documentation/technical/protocol-capabilities.txt
25     Documentation/technical/pack-protocol.txt
26 """
27
28
29 import collections
30 import SocketServer
31 import tempfile
32
33 from dulwich.errors import (
34     GitProtocolError,
35     )
36 from dulwich.objects import (
37     hex_to_sha,
38     )
39 from dulwich.protocol import (
40     Protocol,
41     ProtocolFile,
42     TCP_GIT_PORT,
43     extract_capabilities,
44     extract_want_line_capabilities,
45     SINGLE_ACK,
46     MULTI_ACK,
47     ack_type,
48     )
49 from dulwich.repo import (
50     Repo,
51     )
52 from dulwich.pack import (
53     write_pack_data,
54     )
55
56 class Backend(object):
57
58     def get_refs(self):
59         """
60         Get all the refs in the repository
61
62         :return: dict of name -> sha
63         """
64         raise NotImplementedError
65
66     def apply_pack(self, refs, read):
67         """ Import a set of changes into a repository and update the refs
68
69         :param refs: list of tuple(name, sha)
70         :param read: callback to read from the incoming pack
71         """
72         raise NotImplementedError
73
74     def fetch_objects(self, determine_wants, graph_walker, progress):
75         """
76         Yield the objects required for a list of commits.
77
78         :param progress: is a callback to send progress messages to the client
79         """
80         raise NotImplementedError
81
82
83 class GitBackend(Backend):
84
85     def __init__(self, repo=None):
86         if repo is None:
87             repo = Repo(tmpfile.mkdtemp())
88         self.repo = repo
89         self.object_store = self.repo.object_store
90         self.fetch_objects = self.repo.fetch_objects
91         self.get_refs = self.repo.get_refs
92
93     def apply_pack(self, refs, read):
94         f, commit = self.repo.object_store.add_thin_pack()
95         try:
96             f.write(read())
97         finally:
98             commit()
99
100         for oldsha, sha, ref in refs:
101             if ref == "0" * 40:
102                 del self.repo.refs[ref]
103             else:
104                 self.repo.refs[ref] = sha
105
106         print "pack applied"
107
108
109 class Handler(object):
110     """Smart protocol command handler base class."""
111
112     def __init__(self, backend, read, write):
113         self.backend = backend
114         self.proto = Protocol(read, write)
115
116     def capabilities(self):
117         return " ".join(self.default_capabilities())
118
119
120 class UploadPackHandler(Handler):
121     """Protocol handler for uploading a pack to the server."""
122
123     def __init__(self, backend, read, write,
124                  stateless_rpc=False, advertise_refs=False):
125         Handler.__init__(self, backend, read, write)
126         self._client_capabilities = None
127         self._graph_walker = None
128         self._stateless_rpc = stateless_rpc
129         self._advertise_refs = advertise_refs
130
131     def default_capabilities(self):
132         return ("multi_ack", "side-band-64k", "thin-pack", "ofs-delta")
133
134     def set_client_capabilities(self, caps):
135         my_caps = self.default_capabilities()
136         for cap in caps:
137             if '_ack' in cap and cap not in my_caps:
138                 raise GitProtocolError('Client asked for capability %s that '
139                                        'was not advertised.' % cap)
140         self._client_capabilities = caps
141
142     def get_client_capabilities(self):
143         return self._client_capabilities
144
145     client_capabilities = property(get_client_capabilities,
146                                    set_client_capabilities)
147
148     def handle(self):
149
150         progress = lambda x: self.proto.write_sideband(2, x)
151         write = lambda x: self.proto.write_sideband(1, x)
152
153         graph_walker = ProtocolGraphWalker(self)
154         objects_iter = self.backend.fetch_objects(
155           graph_walker.determine_wants, graph_walker, progress)
156
157         # Do they want any objects?
158         if len(objects_iter) == 0:
159             return
160
161         progress("dul-daemon says what\n")
162         progress("counting objects: %d, done.\n" % len(objects_iter))
163         write_pack_data(ProtocolFile(None, write), objects_iter, 
164                         len(objects_iter))
165         progress("how was that, then?\n")
166         # we are done
167         self.proto.write("0000")
168
169
170 class ProtocolGraphWalker(object):
171     """A graph walker that knows the git protocol.
172
173     As a graph walker, this class implements ack(), next(), and reset(). It also
174     contains some base methods for interacting with the wire and walking the
175     commit tree.
176
177     The work of determining which acks to send is passed on to the
178     implementation instance stored in _impl. The reason for this is that we do
179     not know at object creation time what ack level the protocol requires. A
180     call to set_ack_level() is required to set up the implementation, before any
181     calls to next() or ack() are made.
182     """
183     def __init__(self, handler):
184         self.handler = handler
185         self.store = handler.backend.object_store
186         self.proto = handler.proto
187         self._wants = []
188         self._cached = False
189         self._cache = []
190         self._cache_index = 0
191         self._impl = None
192
193     def determine_wants(self, heads):
194         """Determine the wants for a set of heads.
195
196         The given heads are advertised to the client, who then specifies which
197         refs he wants using 'want' lines. This portion of the protocol is the
198         same regardless of ack type, and in fact is used to set the ack type of
199         the ProtocolGraphWalker.
200
201         :param heads: a dict of refname->SHA1 to advertise
202         :return: a list of SHA1s requested by the client
203         """
204         if not heads:
205             raise GitProtocolError('No heads found')
206         values = set(heads.itervalues())
207         for i, (ref, sha) in enumerate(heads.iteritems()):
208             line = "%s %s" % (sha, ref)
209             if not i:
210                 line = "%s\x00%s" % (line, self.handler.capabilities())
211             self.proto.write_pkt_line("%s\n" % line)
212             # TODO: include peeled value of any tags
213
214         # i'm done..
215         self.proto.write_pkt_line(None)
216
217         # Now client will sending want want want commands
218         want = self.proto.read_pkt_line()
219         if not want:
220             return []
221         line, caps = extract_want_line_capabilities(want)
222         self.handler.client_capabilities = caps
223         self.set_ack_type(ack_type(caps))
224         command, sha = self._split_proto_line(line)
225
226         want_revs = []
227         while command != None:
228             if command != 'want':
229                 raise GitProtocolError(
230                     'Protocol got unexpected command %s' % command)
231             if sha not in values:
232                 raise GitProtocolError(
233                     'Client wants invalid object %s' % sha)
234             want_revs.append(sha)
235             command, sha = self.read_proto_line()
236
237         self.set_wants(want_revs)
238         return want_revs
239
240     def ack(self, have_ref):
241         return self._impl.ack(have_ref)
242
243     def reset(self):
244         self._cached = True
245         self._cache_index = 0
246
247     def next(self):
248         if not self._cached:
249             if not self._impl:
250                 return None
251             return self._impl.next()
252         self._cache_index += 1
253         if self._cache_index > len(self._cache):
254             return None
255         return self._cache[self._cache_index]
256
257     def _split_proto_line(self, line):
258         fields = line.rstrip('\n').split(' ', 1)
259         if len(fields) == 1 and fields[0] == 'done':
260             return ('done', None)
261         elif len(fields) == 2 and fields[0] in ('want', 'have'):
262             try:
263                 hex_to_sha(fields[1])
264                 return tuple(fields)
265             except (TypeError, AssertionError), e:
266                 raise GitProtocolError(e)
267         raise GitProtocolError('Received invalid line from client:\n%s' % line)
268
269     def read_proto_line(self):
270         """Read a line from the wire.
271
272         :return: a tuple having one of the following forms:
273             ('want', obj_id)
274             ('have', obj_id)
275             ('done', None)
276             (None, None)  (for a flush-pkt)
277         """
278         line = self.proto.read_pkt_line()
279         if not line:
280             return (None, None)
281         return self._split_proto_line(line)
282
283     def send_ack(self, sha, ack_type=''):
284         if ack_type:
285             ack_type = ' %s' % ack_type
286         self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
287
288     def send_nak(self):
289         self.proto.write_pkt_line('NAK\n')
290
291     def set_wants(self, wants):
292         self._wants = wants
293
294     def _is_satisfied(self, haves, want, earliest):
295         """Check whether a want is satisfied by a set of haves.
296
297         A want, typically a branch tip, is "satisfied" only if there exists a
298         path back from that want to one of the haves.
299
300         :param haves: A set of commits we know the client has.
301         :param want: The want to check satisfaction for.
302         :param earliest: A timestamp beyond which the search for haves will be
303             terminated, presumably because we're searching too far down the
304             wrong branch.
305         """
306         o = self.store[want]
307         pending = collections.deque([o])
308         while pending:
309             commit = pending.popleft()
310             if commit.id in haves:
311                 return True
312             if not getattr(commit, 'get_parents', None):
313                 # non-commit wants are assumed to be satisfied
314                 continue
315             for parent in commit.get_parents():
316                 parent_obj = self.store[parent]
317                 # TODO: handle parents with later commit times than children
318                 if parent_obj.commit_time >= earliest:
319                     pending.append(parent_obj)
320         return False
321
322     def all_wants_satisfied(self, haves):
323         """Check whether all the current wants are satisfied by a set of haves.
324
325         :param haves: A set of commits we know the client has.
326         :note: Wants are specified with set_wants rather than passed in since
327             in the current interface they are determined outside this class.
328         """
329         haves = set(haves)
330         earliest = min([self.store[h].commit_time for h in haves])
331         for want in self._wants:
332             if not self._is_satisfied(haves, want, earliest):
333                 return False
334         return True
335
336     def set_ack_type(self, ack_type):
337         impl_classes = {
338             MULTI_ACK: MultiAckGraphWalkerImpl,
339             SINGLE_ACK: SingleAckGraphWalkerImpl,
340             }
341         self._impl = impl_classes[ack_type](self)
342
343
344 class SingleAckGraphWalkerImpl(object):
345     """Graph walker implementation that speaks the single-ack protocol."""
346
347     def __init__(self, walker):
348         self.walker = walker
349         self._sent_ack = False
350
351     def ack(self, have_ref):
352         if not self._sent_ack:
353             self.walker.send_ack(have_ref)
354             self._sent_ack = True
355
356     def next(self):
357         command, sha = self.walker.read_proto_line()
358         if command in (None, 'done'):
359             if not self._sent_ack:
360                 self.walker.send_nak()
361             return None
362         elif command == 'have':
363             return sha
364
365
366 class MultiAckGraphWalkerImpl(object):
367     """Graph walker implementation that speaks the multi-ack protocol."""
368
369     def __init__(self, walker):
370         self.walker = walker
371         self._found_base = False
372         self._common = []
373
374     def ack(self, have_ref):
375         self._common.append(have_ref)
376         if not self._found_base:
377             self.walker.send_ack(have_ref, 'continue')
378             if self.walker.all_wants_satisfied(self._common):
379                 self._found_base = True
380         # else we blind ack within next
381
382     def next(self):
383         while True:
384             command, sha = self.walker.read_proto_line()
385             if command is None:
386                 self.walker.send_nak()
387                 # in multi-ack mode, a flush-pkt indicates the client wants to
388                 # flush but more have lines are still coming
389                 continue
390             elif command == 'done':
391                 # don't nak unless no common commits were found, even if not
392                 # everything is satisfied
393                 if self._common:
394                     self.walker.send_ack(self._common[-1])
395                 else:
396                     self.walker.send_nak()
397                 return None
398             elif command == 'have':
399                 if self._found_base:
400                     # blind ack
401                     self.walker.send_ack(sha, 'continue')
402                 return sha
403
404
405 class ReceivePackHandler(Handler):
406     """Protocol handler for downloading a pack to the client."""
407
408     def __init__(self, backend, read, write,
409                  stateless_rpc=False, advertise_refs=False):
410         Handler.__init__(self, backend, read, write)
411         self._stateless_rpc = stateless_rpc
412         self._advertise_refs = advertise_refs
413
414     def default_capabilities(self):
415         return ("report-status", "delete-refs")
416
417     def handle(self):
418         refs = self.backend.get_refs().items()
419
420         if refs:
421             self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
422             for i in range(1, len(refs)):
423                 ref = refs[i]
424                 self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
425         else:
426             self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
427
428         self.proto.write("0000")
429
430         client_refs = []
431         ref = self.proto.read_pkt_line()
432
433         # if ref is none then client doesnt want to send us anything..
434         if ref is None:
435             return
436
437         ref, client_capabilities = extract_capabilities(ref)
438
439         # client will now send us a list of (oldsha, newsha, ref)
440         while ref:
441             client_refs.append(ref.split())
442             ref = self.proto.read_pkt_line()
443
444         # backend can now deal with this refs and read a pack using self.read
445         self.backend.apply_pack(client_refs, self.proto.read)
446
447         # when we have read all the pack from the client, it assumes 
448         # everything worked OK.
449         # there is NO ack from the server before it reports victory.
450
451
452 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
453
454     def handle(self):
455         proto = Protocol(self.rfile.read, self.wfile.write)
456         command, args = proto.read_cmd()
457
458         # switch case to handle the specific git command
459         if command == 'git-upload-pack':
460             cls = UploadPackHandler
461         elif command == 'git-receive-pack':
462             cls = ReceivePackHandler
463         else:
464             return
465
466         h = cls(self.server.backend, self.rfile.read, self.wfile.write)
467         h.handle()
468
469
470 class TCPGitServer(SocketServer.TCPServer):
471
472     allow_reuse_address = True
473     serve = SocketServer.TCPServer.serve_forever
474
475     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
476         self.backend = backend
477         SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)