Work towards making Dulwich less dependent on the filesystem.
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19
20 """Git smart network protocol server implementation."""
21
22
23 import collections
24 import SocketServer
25 import tempfile
26
27 from dulwich.errors import (
28     GitProtocolError,
29     )
30 from dulwich.objects import (
31     hex_to_sha,
32     )
33 from dulwich.protocol import (
34     Protocol,
35     ProtocolFile,
36     TCP_GIT_PORT,
37     extract_capabilities,
38     extract_want_line_capabilities,
39     SINGLE_ACK,
40     MULTI_ACK,
41     ack_type,
42     )
43 from dulwich.repo import (
44     Repo,
45     )
46 from dulwich.pack import (
47     write_pack_data,
48     )
49
50 class Backend(object):
51
52     def get_refs(self):
53         """
54         Get all the refs in the repository
55
56         :return: dict of name -> sha
57         """
58         raise NotImplementedError
59
60     def apply_pack(self, refs, read):
61         """ Import a set of changes into a repository and update the refs
62
63         :param refs: list of tuple(name, sha)
64         :param read: callback to read from the incoming pack
65         """
66         raise NotImplementedError
67
68     def fetch_objects(self, determine_wants, graph_walker, progress):
69         """
70         Yield the objects required for a list of commits.
71
72         :param progress: is a callback to send progress messages to the client
73         """
74         raise NotImplementedError
75
76
77 class GitBackend(Backend):
78
79     def __init__(self, repo=None):
80         if repo is None:
81             repo = Repo(tmpfile.mkdtemp())
82         self.repo = repo
83         self.object_store = self.repo.object_store
84         self.fetch_objects = self.repo.fetch_objects
85         self.get_refs = self.repo.get_refs
86
87     def apply_pack(self, refs, read):
88         f, commit = self.repo.object_store.add_thin_pack()
89         try:
90             f.write(read())
91         finally:
92             commit()
93
94         for oldsha, sha, ref in refs:
95             if ref == "0" * 40:
96                 del self.repo.refs[ref]
97             else:
98                 self.repo.refs[ref] = sha
99
100         print "pack applied"
101
102
103 class Handler(object):
104     """Smart protocol command handler base class."""
105
106     def __init__(self, backend, read, write):
107         self.backend = backend
108         self.proto = Protocol(read, write)
109
110     def capabilities(self):
111         return " ".join(self.default_capabilities())
112
113
114 class UploadPackHandler(Handler):
115     """Protocol handler for uploading a pack to the server."""
116
117     def __init__(self, backend, read, write):
118         Handler.__init__(self, backend, read, write)
119         self._client_capabilities = None
120         self._graph_walker = None
121
122     def default_capabilities(self):
123         return ("multi_ack", "side-band-64k", "thin-pack", "ofs-delta")
124
125     def set_client_capabilities(self, caps):
126         my_caps = self.default_capabilities()
127         for cap in caps:
128             if '_ack' in cap and cap not in my_caps:
129                 raise GitProtocolError('Client asked for capability %s that '
130                                        'was not advertised.' % cap)
131         self._client_capabilities = caps
132
133     def get_client_capabilities(self):
134         return self._client_capabilities
135
136     client_capabilities = property(get_client_capabilities,
137                                    set_client_capabilities)
138
139     def handle(self):
140         def determine_wants(heads):
141             keys = heads.keys()
142             values = set(heads.itervalues())
143             if keys:
144                 self.proto.write_pkt_line("%s %s\x00%s\n" % ( heads[keys[0]], keys[0], self.capabilities()))
145                 for k in keys[1:]:
146                     self.proto.write_pkt_line("%s %s\n" % (heads[k], k))
147
148             # i'm done..
149             self.proto.write("0000")
150
151             # Now client will either send "0000", meaning that it doesnt want to pull.
152             # or it will start sending want want want commands
153             want = self.proto.read_pkt_line()
154             if want == None:
155                 return []
156
157             want, self.client_capabilities = extract_want_line_capabilities(want)
158             graph_walker.set_ack_type(ack_type(self.client_capabilities))
159
160             want_revs = []
161             while want and want[:4] == 'want':
162                 sha = want[5:45]
163                 try:
164                     hex_to_sha(sha)
165                 except (TypeError, AssertionError), e:
166                     raise GitProtocolError(e)
167
168                 if sha not in values:
169                     raise GitProtocolError(
170                         'Client wants invalid object %s' % sha)
171                 want_revs.append(sha)
172                 want = self.proto.read_pkt_line()
173             graph_walker.set_wants(want_revs)
174             return want_revs
175
176         progress = lambda x: self.proto.write_sideband(2, x)
177         write = lambda x: self.proto.write_sideband(1, x)
178
179         graph_walker = ProtocolGraphWalker(self.backend.object_store, self.proto)
180         objects_iter = self.backend.fetch_objects(determine_wants, graph_walker, progress)
181
182         # Do they want any objects?
183         if len(objects_iter) == 0:
184             return
185
186         progress("dul-daemon says what\n")
187         progress("counting objects: %d, done.\n" % len(objects_iter))
188         write_pack_data(ProtocolFile(None, write), objects_iter, 
189                         len(objects_iter))
190         progress("how was that, then?\n")
191         # we are done
192         self.proto.write("0000")
193
194
195 class ProtocolGraphWalker(object):
196     """A graph walker that knows the git protocol.
197
198     As a graph walker, this class implements ack(), next(), and reset(). It also
199     contains some base methods for interacting with the wire and walking the
200     commit tree.
201
202     The work of determining which acks to send is passed on to the
203     implementation instance stored in _impl. The reason for this is that we do
204     not know at object creation time what ack level the protocol requires. A
205     call to set_ack_level() is required to set up the implementation, before any
206     calls to next() or ack() are made.
207     """
208     def __init__(self, object_store, proto):
209         self.store = object_store
210         self.proto = proto
211         self._wants = []
212         self._cached = False
213         self._cache = []
214         self._cache_index = 0
215         self._impl = None
216
217     def ack(self, have_ref):
218         return self._impl.ack(have_ref)
219
220     def reset(self):
221         self._cached = True
222         self._cache_index = 0
223
224     def next(self):
225         if not self._cached:
226             return self._impl.next()
227         self._cache_index += 1
228         if self._cache_index > len(self._cache):
229             return None
230         return self._cache[self._cache_index]
231
232     def read_proto_line(self):
233         """Read a line from the wire.
234
235         :return: a tuple having one of the following forms:
236             ('have', obj_id)
237             ('done', None)
238             (None, None)  (for a flush-pkt)
239         """
240         line = self.proto.read_pkt_line()
241         if not line:
242             return (None, None)
243         fields = line.rstrip('\n').split(' ', 1)
244         if len(fields) == 1 and fields[0] == 'done':
245             return ('done', None)
246         if len(fields) == 2 and fields[0] == 'have':
247             try:
248                 hex_to_sha(fields[1])
249                 return fields
250             except (TypeError, AssertionError), e:
251                 raise GitProtocolError(e)
252         raise GitProtocolError('Received invalid line from client:\n%s' % line)
253
254     def send_ack(self, sha, ack_type=''):
255         if ack_type:
256             ack_type = ' %s' % ack_type
257         self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
258
259     def send_nak(self):
260         self.proto.write_pkt_line('NAK\n')
261
262     def set_wants(self, wants):
263         self._wants = wants
264
265     def _is_satisfied(self, haves, want, earliest):
266         """Check whether a want is satisfied by a set of haves.
267
268         A want, typically a branch tip, is "satisfied" only if there exists a
269         path back from that want to one of the haves.
270
271         :param haves: A set of commits we know the client has.
272         :param want: The want to check satisfaction for.
273         :param earliest: A timestamp beyond which the search for haves will be
274             terminated, presumably because we're searching too far down the
275             wrong branch.
276         """
277         o = self.store[want]
278         pending = collections.deque([o])
279         while pending:
280             commit = pending.popleft()
281             if commit.id in haves:
282                 return True
283             if not getattr(commit, 'get_parents', None):
284                 # non-commit wants are assumed to be satisfied
285                 continue
286             for parent in commit.get_parents():
287                 parent_obj = self.store[parent]
288                 # TODO: handle parents with later commit times than children
289                 if parent_obj.commit_time >= earliest:
290                     pending.append(parent_obj)
291         return False
292
293     def all_wants_satisfied(self, haves):
294         """Check whether all the current wants are satisfied by a set of haves.
295
296         :param haves: A set of commits we know the client has.
297         :note: Wants are specified with set_wants rather than passed in since
298             in the current interface they are determined outside this class.
299         """
300         haves = set(haves)
301         earliest = min([self.store[h].commit_time for h in haves])
302         for want in self._wants:
303             if not self._is_satisfied(haves, want, earliest):
304                 return False
305         return True
306
307     def set_ack_type(self, ack_type):
308         impl_classes = {
309             MULTI_ACK: MultiAckGraphWalkerImpl,
310             SINGLE_ACK: SingleAckGraphWalkerImpl,
311             }
312         self._impl = impl_classes[ack_type](self)
313
314
315 class SingleAckGraphWalkerImpl(object):
316     """Graph walker implementation that speaks the single-ack protocol."""
317
318     def __init__(self, walker):
319         self.walker = walker
320         self._sent_ack = False
321
322     def ack(self, have_ref):
323         if not self._sent_ack:
324             self.walker.send_ack(have_ref)
325             self._sent_ack = True
326
327     def next(self):
328         command, sha = self.walker.read_proto_line()
329         if command in (None, 'done'):
330             if not self._sent_ack:
331                 self.walker.send_nak()
332             return None
333         elif command == 'have':
334             return sha
335
336
337 class MultiAckGraphWalkerImpl(object):
338     """Graph walker implementation that speaks the multi-ack protocol."""
339
340     def __init__(self, walker):
341         self.walker = walker
342         self._found_base = False
343         self._common = []
344
345     def ack(self, have_ref):
346         self._common.append(have_ref)
347         if not self._found_base:
348             self.walker.send_ack(have_ref, 'continue')
349             if self.walker.all_wants_satisfied(self._common):
350                 self._found_base = True
351         # else we blind ack within next
352
353     def next(self):
354         command, sha = self.walker.read_proto_line()
355         if command is None:
356             self.walker.send_nak()
357             return None
358         elif command == 'done':
359             # don't nak unless no common commits were found, even if not
360             # everything is satisfied
361             if self._common:
362                 self.walker.send_ack(self._common[-1])
363             else:
364                 self.walker.send_nak()
365             return None
366         elif command == 'have':
367             if self._found_base:
368                 # blind ack
369                 self.walker.send_ack(sha, 'continue')
370             return sha
371
372
373 class ReceivePackHandler(Handler):
374     """Protocol handler for downloading a pack to the client."""
375
376     def default_capabilities(self):
377         return ("report-status", "delete-refs")
378
379     def handle(self):
380         refs = self.backend.get_refs().items()
381
382         if refs:
383             self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
384             for i in range(1, len(refs)):
385                 ref = refs[i]
386                 self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
387         else:
388             self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
389
390         self.proto.write("0000")
391
392         client_refs = []
393         ref = self.proto.read_pkt_line()
394
395         # if ref is none then client doesnt want to send us anything..
396         if ref is None:
397             return
398
399         ref, client_capabilities = extract_capabilities(ref)
400
401         # client will now send us a list of (oldsha, newsha, ref)
402         while ref:
403             client_refs.append(ref.split())
404             ref = self.proto.read_pkt_line()
405
406         # backend can now deal with this refs and read a pack using self.read
407         self.backend.apply_pack(client_refs, self.proto.read)
408
409         # when we have read all the pack from the client, it assumes 
410         # everything worked OK.
411         # there is NO ack from the server before it reports victory.
412
413
414 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
415
416     def handle(self):
417         proto = Protocol(self.rfile.read, self.wfile.write)
418         command, args = proto.read_cmd()
419
420         # switch case to handle the specific git command
421         if command == 'git-upload-pack':
422             cls = UploadPackHandler
423         elif command == 'git-receive-pack':
424             cls = ReceivePackHandler
425         else:
426             return
427
428         h = cls(self.server.backend, self.rfile.read, self.wfile.write)
429         h.handle()
430
431
432 class TCPGitServer(SocketServer.TCPServer):
433
434     allow_reuse_address = True
435     serve = SocketServer.TCPServer.serve_forever
436
437     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
438         self.backend = backend
439         SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)