More typo. I suck at refactoring :(
[jelmer/dulwich-libgit2.git] / dulwich / server.py
1 # server.py -- Implementation of the server side git protocols
2 # Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 import SocketServer
20
21 class Backend(object):
22
23     def get_refs(self):
24         """
25         Get all the refs in the repository
26
27         :return: list of tuple(name, sha)
28         """
29         raise NotImplementedError
30
31     def has_revision(self, sha):
32         """
33         Is a given sha in this repository?
34
35         :return: True or False
36         """
37         raise NotImplementedError
38
39     def apply_pack(self, refs, read):
40         """ Import a set of changes into a repository and update the refs
41
42         :param refs: list of tuple(name, sha)
43         :param read: callback to read from the incoming pack
44         """
45         raise NotImplementedError
46
47     def generate_pack(self, want, have, write, progress):
48         """
49         Generate a pack containing all commits a client is missing
50
51         :param want: is a list of sha's the client desires
52         :param have: is a list of sha's the client has (allowing us to send the minimal pack)
53         :param write: is a callback to write pack data to the client
54         :param progress: is a callback to send progress messages to the client
55         """
56         raise NotImplementedError
57
58 from dulwich.repo import Repo
59 from dulwich.pack import PackData, Pack
60 import sha, tempfile, os
61 from dulwich.pack import write_pack_object
62
63 class PackWriteWrapper(object):
64
65     def __init__(self, write):
66         self.writefn = write
67         self.sha = sha.sha()
68
69     def write(self, blob):
70         self.sha.update(blob)
71         self.writefn(blob)
72
73     def tell(self):
74         pass
75
76     @property
77     def digest(self):
78         return self.sha.digest()
79
80 class GitBackend(Backend):
81
82     def __init__(self, gitdir=None):
83         self.gitdir = gitdir
84
85         if not self.gitdir:
86             self.gitdir = tempfile.mkdtemp()
87             Repo.create(self.gitdir)
88
89         self.repo = Repo(self.gitdir)
90
91     def get_refs(self):
92         refs = []
93         if self.repo.head():
94             refs.append(('HEAD', self.repo.head()))
95         for ref, sha in self.repo.heads().items():
96             refs.append(('refs/heads/'+ref,sha))
97         return refs
98
99     def has_revision(self, sha):
100         return self.repo.get_object(sha) != None
101
102     def apply_pack(self, refs, read):
103         # store the incoming pack in the repository
104         fd, name = tempfile.mkstemp(suffix='.pack', prefix='', dir=self.repo.pack_dir())
105         os.write(fd, read())
106         os.close(fd)
107
108         # strip '.pack' off our filename
109         basename = name[:-5]
110
111         # generate an index for it
112         pd = PackData(name)
113         pd.create_index_v2(basename+".idx")
114
115         for oldsha, sha, ref in refs:
116             if ref == "0" * 40:
117                 self.repo.remove_ref(ref)
118             else:
119                 self.repo.set_ref(ref, sha)
120
121         print "pack applied"
122
123     def generate_pack(self, want, have, write, progress):
124         progress("dul-daemon says what\n")
125
126         sha_queue = []
127
128         commits_to_send = want[:]
129         for sha in commits_to_send:
130             if sha in sha_queue:
131                 continue
132
133             sha_queue.append((1,sha))
134
135             c = self.repo.commit(sha)
136             for p in c.parents():
137                 if not p in commits_to_send:
138                     commits_to_send.append(p)
139
140             def parse_tree(tree, sha_queue):
141                 for mode, name, x in tree.entries():
142                     if not x in sha_queue:
143                         try:
144                             t = self.repo.get_tree(x)
145                             sha_queue.append((2, x))
146                             parse_tree(t, sha_queue)
147                         except:
148                             sha_queue.append((3, x))
149
150             treesha = c.tree()
151             if treesha not in sha_queue:
152                 sha_queue.append((2, treesha))
153                 t = self.repo.get_tree(treesha)
154                 parse_tree(t, sha_queue)
155
156             progress("counting objects: %d\r" % len(sha_queue))
157
158         progress("counting objects: %d, done.\n" % len(sha_queue))
159
160         write_pack_data(write, (self.repo.get_object(sha).as_raw_string() for sha in sha_queue))
161
162         progress("how was that, then?\n")
163
164
165 class Handler(object):
166
167     def __init__(self, backend, read, write):
168         self.backend = backend
169         self.read = read
170         self.write = write
171
172     def read_pkt_line(self):
173         """
174         Reads a 'pkt line' from the remote git process
175
176         :return: The next string from the stream
177         """
178         sizestr = self.read(4)
179         if not sizestr:
180             return None
181         size = int(sizestr, 16)
182         if size == 0:
183             return None
184         return self.read(size-4)
185
186     def write_pkt_line(self, line):
187         """
188         Sends a 'pkt line' to the remote git process
189
190         :param line: A string containing the data to send
191         """
192         self.write("%04x%s" % (len(line)+4, line))
193
194     def write_sideband(self, channel, blob):
195         """
196         Write data to the sideband (a git multiplexing method)
197
198         :param channel: int specifying which channel to write to
199         :param blob: a blob of data (as a string) to send on this channel
200         """
201         # a pktline can be a max of 65535. a sideband line can therefore be
202         # 65535-5 = 65530
203         # WTF: Why have the len in ASCII, but the channel in binary.
204         while blob:
205             self.write_pkt_line("%s%s" % (chr(channel), blob[:65530]))
206             blob = blob[65530:]
207
208     def capabilities(self):
209         return " ".join(self.default_capabilities())
210
211     def handshake(self, blob):
212         """
213         Compare remote capabilites with our own and alter protocol accordingly
214
215         :param blob: space seperated list of capabilities (i.e. wire format)
216         """
217         if not "\x00" in blob:
218             return blob
219         blob, caps = blob.split("\x00")
220
221         # FIXME: Do something with this..
222         caps = caps.split()
223
224         return blob
225
226     def handle(self):
227         """
228         Deal with the request
229         """
230         raise NotImplementedError
231
232
233 class UploadPackHandler(Handler):
234
235     def default_capabilities(self):
236         return ("multi_ack", "side-band-64k", "thin-pack", "ofs-delta")
237
238     def handle(self):
239         refs = self.backend.get_refs()
240
241         if refs:
242             self.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
243             for i in range(1, len(refs)):
244                 ref = refs[i]
245                 self.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
246
247         # i'm done...
248         self.write("0000")
249
250         # Now client will either send "0000", meaning that it doesnt want to pull.
251         # or it will start sending want want want commands
252         want = self.read_pkt_line()
253         if want == None:
254             return
255
256         want = self.handshake(want)
257
258         # Keep reading the list of demands until we hit another "0000" 
259         want_revs = []
260         while want and want[:4] == 'want':
261             want_rev = want[5:45]
262             # FIXME: This check probably isnt needed?
263             if self.backend.has_revision(want_rev):
264                want_revs.append(want_rev)
265             want = self.read_pkt_line()
266         
267         # Client will now tell us which commits it already has - if we have them we ACK them
268         # this allows client to stop looking at that commits parents (main reason why git pull is fast)
269         last_sha = None
270         have_revs = []
271         have = self.read_pkt_line()
272         while have and have[:4] == 'have':
273             have_ref = have[6:46]
274             if self.backend.has_revision(have_ref):
275                 self.write_pkt_line("ACK %s continue\n" % have_ref)
276                 last_sha = have_ref
277                 have_revs.append(have_ref)
278             have = self.read_pkt_line()
279
280         # At some point client will stop sending commits and will tell us it is done
281         assert(have[:4] == "done")
282
283         # Oddness: Git seems to resend the last ACK, without the "continue" statement
284         if last_sha:
285             self.write_pkt_line("ACK %s\n" % last_sha)
286
287         # The exchange finishes with a NAK
288         self.write_pkt_line("NAK\n")
289       
290         self.backend.generate_pack(want_revs, have_revs, lambda x: self.write_sideband(1, x), lambda x: self.write_sideband(2, x))
291
292         # we are done
293         self.write("0000")
294
295
296 class ReceivePackHandler(Handler):
297
298     def default_capabilities(self):
299         return ("report-status", "delete-refs")
300
301     def handle(self):
302         refs = self.backend.get_refs()
303
304         if refs:
305             self.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
306             for i in range(1, len(refs)):
307                 ref = refs[i]
308                 self.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
309         else:
310             self.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
311
312         self.write("0000")
313
314         client_refs = []
315         ref = self.read_pkt_line()
316
317         # if ref is none then client doesnt want to send us anything..
318         if ref is None:
319             return
320
321         ref = self.handshake(ref)
322
323         # client will now send us a list of (oldsha, newsha, ref)
324         while ref:
325             client_refs.append(ref.split())
326             ref = self.read_pkt_line()
327
328         # backend can now deal with this refs and read a pack using self.read
329         self.backend.apply_pack(client_refs, self.read)
330
331         # when we have read all the pack from the client, it assumes everything worked OK
332         # there is NO ack from the server before it reports victory.
333
334
335 class TCPGitRequestHandler(SocketServer.StreamRequestHandler, Handler):
336
337     def __init__(self, request, client_address, server):
338         SocketServer.StreamRequestHandler.__init__(self, request, client_address, server)
339
340     def handle(self):
341         #FIXME: StreamRequestHandler seems to be the thing that calls handle(),
342         #so we can't call this in a sane place??
343         Handler.__init__(self, self.server.backend, self.rfile.read, self.wfile.write)
344
345         request = self.read_pkt_line()
346
347         # up until the space is the command to run, everything after is parameters
348         splice_point = request.find(' ')
349         command, params = request[:splice_point], request[splice_point+1:]
350
351         # params are null seperated
352         params = params.split(chr(0))
353
354         # switch case to handle the specific git command
355         if command == 'git-upload-pack':
356             cls = UploadPackHandler
357         elif command == 'git-receive-pack':
358             cls = ReceivePackHandler
359         else:
360             return
361
362         h = cls(self.backend, self.read, self.write)
363         h.handle()
364
365
366 class TCPGitServer(SocketServer.TCPServer):
367
368     allow_reuse_address = True
369     serve = SocketServer.TCPServer.serve_forever
370
371     def __init__(self, backend, addr):
372         self.backend = backend
373         SocketServer.TCPServer.__init__(self, addr, TCPGitRequestHandler)
374
375