Use common, shared, code
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 import SocketServer
20
21 class Backend(object):
22
23     def get_refs(self):
24         """
25         Get all the refs in the repository
26
27         :return: list of tuple(name, sha)
28         """
29         raise NotImplementedError
30
31     def has_revision(self, sha):
32         """
33         Is a given sha in this repository?
34
35         :return: True or False
36         """
37         raise NotImplementedError
38
39     def apply_pack(self, refs, read):
40         """ Import a set of changes into a repository and update the refs
41
42         :param refs: list of tuple(name, sha)
43         :param read: callback to read from the incoming pack
44         """
45         raise NotImplementedError
46
47     def generate_pack(self, want, have, write, progress):
48         """
49         Generate a pack containing all commits a client is missing
50
51         :param want: is a list of sha's the client desires
52         :param have: is a list of sha's the client has (allowing us to send the minimal pack)
53         :param write: is a callback to write pack data to the client
54         :param progress: is a callback to send progress messages to the client
55         """
56         raise NotImplementedError
57
58 from dulwich.repo import Repo
59 from dulwich.pack import PackData, Pack
60 import sha, tempfile, os
61 from dulwich.pack import write_pack_object
62
63 class PackWriteWrapper(object):
64
65     def __init__(self, write):
66         self.writefn = write
67         self.sha = sha.sha()
68
69     def write(self, blob):
70         self.sha.update(blob)
71         self.writefn(blob)
72
73     def tell(self):
74         pass
75
76     @property
77     def digest(self):
78         return self.sha.digest()
79
80 class GitBackend(Backend):
81
82     def __init__(self, gitdir=None):
83         self.gitdir = gitdir
84
85         if not self.gitdir:
86             self.gitdir = tempfile.mkdtemp()
87             Repo.create(self.gitdir)
88
89         self.repo = Repo(self.gitdir)
90
91     def get_refs(self):
92         refs = []
93         if self.repo.head():
94             refs.append(('HEAD', self.repo.head()))
95         for ref, sha in self.repo.heads().items():
96             refs.append(('refs/heads/'+ref,sha))
97         return refs
98
99     def has_revision(self, sha):
100         return self.repo.get_object(sha) != None
101
102     def apply_pack(self, refs, read):
103         # store the incoming pack in the repository
104         fd, name = tempfile.mkstemp(suffix='.pack', prefix='', dir=self.repo.pack_dir())
105         os.write(fd, read())
106         os.close(fd)
107
108         # strip '.pack' off our filename
109         basename = name[:-5]
110
111         # generate an index for it
112         pd = PackData(name)
113         pd.create_index_v2(basename+".idx")
114
115         for oldsha, sha, ref in refs:
116             if ref == "0" * 40:
117                 self.repo.remove_ref(ref)
118             else:
119                 self.repo.set_ref(ref, sha)
120
121         print "pack applied"
122
123     def generate_pack(self, want, have, write, progress):
124         progress("dul-daemon says what\n")
125
126         sha_queue = []
127
128         commits_to_send = want[:]
129         for sha in commits_to_send:
130             if sha in sha_queue:
131                 continue
132
133             sha_queue.append((1,sha))
134
135             c = self.repo.commit(sha)
136             for p in c.parents():
137                 if not p in commits_to_send:
138                     commits_to_send.append(p)
139
140             def parse_tree(tree, sha_queue):
141                 for mode, name, x in tree.entries():
142                     if not x in sha_queue:
143                         try:
144                             t = self.repo.get_tree(x)
145                             sha_queue.append((2, x))
146                             parse_tree(t, sha_queue)
147                         except:
148                             sha_queue.append((3, x))
149
150             treesha = c.tree()
151             if treesha not in sha_queue:
152                 sha_queue.append((2, treesha))
153                 t = self.repo.get_tree(treesha)
154                 parse_tree(t, sha_queue)
155
156             progress("counting objects: %d\r" % len(sha_queue))
157
158         progress("counting objects: %d, done.\n" % len(sha_queue))
159
160         write_pack_data(write, (self.repo.get_object(sha).as_raw_string() for sha in sha_queue))
161
162         progress("how was that, then?\n")
163
164
165 class Handler(object):
166
167     def __init__(self, backend, read, write):
168         self.backend = backend
169         self.proto = Protocol(read, write)
170
171     def capabilities(self):
172         return " ".join(self.default_capabilities())
173
174     def handshake(self, blob):
175         """
176         Compare remote capabilites with our own and alter protocol accordingly
177
178         :param blob: space seperated list of capabilities (i.e. wire format)
179         """
180         if not "\x00" in blob:
181             return blob
182         blob, caps = blob.split("\x00")
183
184         # FIXME: Do something with this..
185         caps = caps.split()
186
187         return blob
188
189
190 class UploadPackHandler(Handler):
191
192     def default_capabilities(self):
193         return ("multi_ack", "side-band-64k", "thin-pack", "ofs-delta")
194
195     def handle(self):
196         refs = self.backend.get_refs()
197
198         if refs:
199             self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
200             for i in range(1, len(refs)):
201                 ref = refs[i]
202                 self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
203
204         # i'm done..
205         self.proto.write("0000")
206
207         # Now client will either send "0000", meaning that it doesnt want to pull.
208         # or it will start sending want want want commands
209         want = self.proto.read_pkt_line()
210         if want == None:
211             return
212
213         want = self.handshake(want)
214
215         # Keep reading the list of demands until we hit another "0000" 
216         want_revs = []
217         while want and want[:4] == 'want':
218             want_rev = want[5:45]
219             # FIXME: This check probably isnt needed?
220             if self.backend.has_revision(want_rev):
221                want_revs.append(want_rev)
222             want = self.proto.read_pkt_line()
223         
224         # Client will now tell us which commits it already has - if we have them we ACK them
225         # this allows client to stop looking at that commits parents (main reason why git pull is fast)
226         last_sha = None
227         have_revs = []
228         have = self.proto.read_pkt_line()
229         while have and have[:4] == 'have':
230             have_ref = have[6:46]
231             if self.backend.has_revision(have_ref):
232                 self.proto.write_pkt_line("ACK %s continue\n" % have_ref)
233                 last_sha = have_ref
234                 have_revs.append(have_ref)
235             have = self.proto.read_pkt_line()
236
237         # At some point client will stop sending commits and will tell us it is done
238         assert(have[:4] == "done")
239
240         # Oddness: Git seems to resend the last ACK, without the "continue" statement
241         if last_sha:
242             self.proto.write_pkt_line("ACK %s\n" % last_sha)
243
244         # The exchange finishes with a NAK
245         self.proto.write_pkt_line("NAK\n")
246       
247         self.backend.generate_pack(want_revs, have_revs, lambda x: self.proto.write_sideband(1, x), lambda x: self.proto.write_sideband(2, x))
248
249         # we are done
250         self.proto.write("0000")
251
252
253 class ReceivePackHandler(Handler):
254
255     def default_capabilities(self):
256         return ("report-status", "delete-refs")
257
258     def handle(self):
259         refs = self.backend.get_refs()
260
261         if refs:
262             self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
263             for i in range(1, len(refs)):
264                 ref = refs[i]
265                 self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
266         else:
267             self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
268
269         self.proto.write("0000")
270
271         client_refs = []
272         ref = self.proto.read_pkt_line()
273
274         # if ref is none then client doesnt want to send us anything..
275         if ref is None:
276             return
277
278         ref = self.handshake(ref)
279
280         # client will now send us a list of (oldsha, newsha, ref)
281         while ref:
282             client_refs.append(ref.split())
283             ref = self.proto.read_pkt_line()
284
285         # backend can now deal with this refs and read a pack using self.read
286         self.backend.apply_pack(client_refs, self.proto.read)
287
288         # when we have read all the pack from the client, it assumes everything worked OK
289         # there is NO ack from the server before it reports victory.
290
291
292 class TCPGitRequestHandler(SocketServer.StreamRequestHandler, Handler):
293
294     def __init__(self, request, client_address, server):
295         SocketServer.StreamRequestHandler.__init__(self, request, client_address, server)
296
297     def handle(self):
298         #FIXME: StreamRequestHandler seems to be the thing that calls handle(),
299         #so we can't call this in a sane place??
300         Handler.__init__(self, self.server.backend, self.rfile.read, self.wfile.write)
301
302         request = self.proto.read_pkt_line()
303
304         # up until the space is the command to run, everything after is parameters
305         splice_point = request.find(' ')
306         command, params = request[:splice_point], request[splice_point+1:]
307
308         # params are null seperated
309         params = params.split(chr(0))
310
311         # switch case to handle the specific git command
312         if command == 'git-upload-pack':
313             cls = UploadPackHandler
314         elif command == 'git-receive-pack':
315             cls = ReceivePackHandler
316         else:
317             return
318
319         h = cls(self.backend, self.proto.read, self.proto.write)
320         h.handle()
321
322
323 class TCPGitServer(SocketServer.TCPServer):
324
325     allow_reuse_address = True
326     serve = SocketServer.TCPServer.serve_forever
327
328     def __init__(self, backend, addr):
329         self.backend = backend
330         SocketServer.TCPServer.__init__(self, addr, TCPGitRequestHandler)
331
332