server: Explicitly specify allowed protocol commands.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_server.py
1 # test_server.py -- Tests for the git server
2 # Copyright (C) 2010 Google, Inc.
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 """Tests for the smart protocol server."""
20
21
22 from dulwich.errors import (
23     GitProtocolError,
24     UnexpectedCommandError,
25     )
26 from dulwich.server import (
27     Backend,
28     DictBackend,
29     BackendRepo,
30     Handler,
31     MultiAckGraphWalkerImpl,
32     MultiAckDetailedGraphWalkerImpl,
33     _split_proto_line,
34     ProtocolGraphWalker,
35     SingleAckGraphWalkerImpl,
36     UploadPackHandler,
37     )
38 from dulwich.tests import TestCase
39
40
41
42 ONE = '1' * 40
43 TWO = '2' * 40
44 THREE = '3' * 40
45 FOUR = '4' * 40
46 FIVE = '5' * 40
47 SIX = '6' * 40
48
49
50 class TestProto(object):
51
52     def __init__(self):
53         self._output = []
54         self._received = {0: [], 1: [], 2: [], 3: []}
55
56     def set_output(self, output_lines):
57         self._output = ['%s\n' % line.rstrip() for line in output_lines]
58
59     def read_pkt_line(self):
60         if self._output:
61             return self._output.pop(0)
62         else:
63             return None
64
65     def write_sideband(self, band, data):
66         self._received[band].append(data)
67
68     def write_pkt_line(self, data):
69         if data is None:
70             data = 'None'
71         self._received[0].append(data)
72
73     def get_received_line(self, band=0):
74         lines = self._received[band]
75         if lines:
76             return lines.pop(0)
77         else:
78             return None
79
80
81 class TestGenericHandler(Handler):
82
83     def __init__(self):
84         Handler.__init__(self, Backend(), None)
85
86     @classmethod
87     def capabilities(cls):
88         return ('cap1', 'cap2', 'cap3')
89
90     @classmethod
91     def required_capabilities(cls):
92         return ('cap2',)
93
94
95 class HandlerTestCase(TestCase):
96
97     def setUp(self):
98         super(HandlerTestCase, self).setUp()
99         self._handler = TestGenericHandler()
100
101     def assertSucceeds(self, func, *args, **kwargs):
102         try:
103             func(*args, **kwargs)
104         except GitProtocolError, e:
105             self.fail(e)
106
107     def test_capability_line(self):
108         self.assertEquals('cap1 cap2 cap3', self._handler.capability_line())
109
110     def test_set_client_capabilities(self):
111         set_caps = self._handler.set_client_capabilities
112         self.assertSucceeds(set_caps, ['cap2'])
113         self.assertSucceeds(set_caps, ['cap1', 'cap2'])
114
115         # different order
116         self.assertSucceeds(set_caps, ['cap3', 'cap1', 'cap2'])
117
118         # error cases
119         self.assertRaises(GitProtocolError, set_caps, ['capxxx', 'cap2'])
120         self.assertRaises(GitProtocolError, set_caps, ['cap1', 'cap3'])
121
122         # ignore innocuous but unknown capabilities
123         self.assertRaises(GitProtocolError, set_caps, ['cap2', 'ignoreme'])
124         self.assertFalse('ignoreme' in self._handler.capabilities())
125         self._handler.innocuous_capabilities = lambda: ('ignoreme',)
126         self.assertSucceeds(set_caps, ['cap2', 'ignoreme'])
127
128     def test_has_capability(self):
129         self.assertRaises(GitProtocolError, self._handler.has_capability, 'cap')
130         caps = self._handler.capabilities()
131         self._handler.set_client_capabilities(caps)
132         for cap in caps:
133             self.assertTrue(self._handler.has_capability(cap))
134         self.assertFalse(self._handler.has_capability('capxxx'))
135
136
137 class UploadPackHandlerTestCase(TestCase):
138
139     def setUp(self):
140         super(UploadPackHandlerTestCase, self).setUp()
141         self._backend = DictBackend({"/": BackendRepo()})
142         self._handler = UploadPackHandler(self._backend,
143                 ["/", "host=lolcathost"], None, None)
144         self._handler.proto = TestProto()
145
146     def test_progress(self):
147         caps = self._handler.required_capabilities()
148         self._handler.set_client_capabilities(caps)
149         self._handler.progress('first message')
150         self._handler.progress('second message')
151         self.assertEqual('first message',
152                          self._handler.proto.get_received_line(2))
153         self.assertEqual('second message',
154                          self._handler.proto.get_received_line(2))
155         self.assertEqual(None, self._handler.proto.get_received_line(2))
156
157     def test_no_progress(self):
158         caps = list(self._handler.required_capabilities()) + ['no-progress']
159         self._handler.set_client_capabilities(caps)
160         self._handler.progress('first message')
161         self._handler.progress('second message')
162         self.assertEqual(None, self._handler.proto.get_received_line(2))
163
164     def test_get_tagged(self):
165         refs = {
166             'refs/tags/tag1': ONE,
167             'refs/tags/tag2': TWO,
168             'refs/heads/master': FOUR,  # not a tag, no peeled value
169             }
170         peeled = {
171             'refs/tags/tag1': '1234',
172             'refs/tags/tag2': '5678',
173             }
174
175         class TestRepo(object):
176             def get_peeled(self, ref):
177                 return peeled.get(ref, refs[ref])
178
179         caps = list(self._handler.required_capabilities()) + ['include-tag']
180         self._handler.set_client_capabilities(caps)
181         self.assertEquals({'1234': ONE, '5678': TWO},
182                           self._handler.get_tagged(refs, repo=TestRepo()))
183
184         # non-include-tag case
185         caps = self._handler.required_capabilities()
186         self._handler.set_client_capabilities(caps)
187         self.assertEquals({}, self._handler.get_tagged(refs, repo=TestRepo()))
188
189
190 class TestCommit(object):
191
192     def __init__(self, sha, parents, commit_time):
193         self.id = sha
194         self.parents = parents
195         self.commit_time = commit_time
196         self.type_name = "commit"
197
198     def __repr__(self):
199         return '%s(%s)' % (self.__class__.__name__, self._sha)
200
201
202 class TestRepo(object):
203     def __init__(self):
204         self.peeled = {}
205
206     def get_peeled(self, name):
207         return self.peeled[name]
208
209
210 class TestBackend(object):
211
212     def __init__(self, repo, objects):
213         self.repo = repo
214         self.object_store = objects
215
216
217 class TestUploadPackHandler(Handler):
218
219     def __init__(self, objects, proto):
220         self.backend = TestBackend(TestRepo(), objects)
221         self.proto = proto
222         self.stateless_rpc = False
223         self.advertise_refs = False
224
225     @classmethod
226     def capabilities(cls):
227         return ('multi_ack',)
228
229
230 class ProtocolGraphWalkerTestCase(TestCase):
231
232     def setUp(self):
233         super(ProtocolGraphWalkerTestCase, self).setUp()
234         # Create the following commit tree:
235         #   3---5
236         #  /
237         # 1---2---4
238         self._objects = {
239           ONE: TestCommit(ONE, [], 111),
240           TWO: TestCommit(TWO, [ONE], 222),
241           THREE: TestCommit(THREE, [ONE], 333),
242           FOUR: TestCommit(FOUR, [TWO], 444),
243           FIVE: TestCommit(FIVE, [THREE], 555),
244           }
245
246         self._walker = ProtocolGraphWalker(
247             TestUploadPackHandler(self._objects, TestProto()),
248             self._objects, None)
249
250     def test_is_satisfied_no_haves(self):
251         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
252         self.assertFalse(self._walker._is_satisfied([], TWO, 0))
253         self.assertFalse(self._walker._is_satisfied([], THREE, 0))
254
255     def test_is_satisfied_have_root(self):
256         self.assertTrue(self._walker._is_satisfied([ONE], ONE, 0))
257         self.assertTrue(self._walker._is_satisfied([ONE], TWO, 0))
258         self.assertTrue(self._walker._is_satisfied([ONE], THREE, 0))
259
260     def test_is_satisfied_have_branch(self):
261         self.assertTrue(self._walker._is_satisfied([TWO], TWO, 0))
262         # wrong branch
263         self.assertFalse(self._walker._is_satisfied([TWO], THREE, 0))
264
265     def test_all_wants_satisfied(self):
266         self._walker.set_wants([FOUR, FIVE])
267         # trivial case: wants == haves
268         self.assertTrue(self._walker.all_wants_satisfied([FOUR, FIVE]))
269         # cases that require walking the commit tree
270         self.assertTrue(self._walker.all_wants_satisfied([ONE]))
271         self.assertFalse(self._walker.all_wants_satisfied([TWO]))
272         self.assertFalse(self._walker.all_wants_satisfied([THREE]))
273         self.assertTrue(self._walker.all_wants_satisfied([TWO, THREE]))
274
275     def test_split_proto_line(self):
276         allowed = ('want', 'done', None)
277         self.assertEquals(('want', ONE),
278                           _split_proto_line('want %s\n' % ONE, allowed))
279         self.assertEquals(('want', TWO),
280                           _split_proto_line('want %s\n' % TWO, allowed))
281         self.assertRaises(GitProtocolError, _split_proto_line,
282                           'want xxxx\n', allowed)
283         self.assertRaises(UnexpectedCommandError, _split_proto_line,
284                           'have %s\n' % THREE, allowed)
285         self.assertRaises(GitProtocolError, _split_proto_line,
286                           'foo %s\n' % FOUR, allowed)
287         self.assertRaises(GitProtocolError, _split_proto_line, 'bar', allowed)
288         self.assertEquals(('done', None), _split_proto_line('done\n', allowed))
289         self.assertEquals((None, None), _split_proto_line('', allowed))
290
291     def test_determine_wants(self):
292         self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
293
294         self._walker.proto.set_output([
295           'want %s multi_ack' % ONE,
296           'want %s' % TWO,
297           ])
298         heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
299         self._walker.get_peeled = heads.get
300         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
301
302         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
303         self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
304
305         self._walker.proto.set_output([])
306         self.assertEquals([], self._walker.determine_wants(heads))
307
308         self._walker.proto.set_output(['want %s multi_ack' % ONE, 'foo'])
309         self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
310
311         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
312         self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
313
314     def test_determine_wants_advertisement(self):
315         self._walker.proto.set_output([])
316         # advertise branch tips plus tag
317         heads = {'ref4': FOUR, 'ref5': FIVE, 'tag6': SIX}
318         peeled = {'ref4': FOUR, 'ref5': FIVE, 'tag6': FIVE}
319         self._walker.get_peeled = peeled.get
320         self._walker.determine_wants(heads)
321         lines = []
322         while True:
323             line = self._walker.proto.get_received_line()
324             if line == 'None':
325                 break
326             # strip capabilities list if present
327             if '\x00' in line:
328                 line = line[:line.index('\x00')]
329             lines.append(line.rstrip())
330
331         self.assertEquals([
332           '%s ref4' % FOUR,
333           '%s ref5' % FIVE,
334           '%s tag6^{}' % FIVE,
335           '%s tag6' % SIX,
336           ], sorted(lines))
337
338         # ensure peeled tag was advertised immediately following tag
339         for i, line in enumerate(lines):
340             if line.endswith(' tag6'):
341                 self.assertEquals('%s tag6^{}' % FIVE, lines[i+1])
342
343     # TODO: test commit time cutoff
344
345
346 class TestProtocolGraphWalker(object):
347
348     def __init__(self):
349         self.acks = []
350         self.lines = []
351         self.done = False
352         self.stateless_rpc = False
353         self.advertise_refs = False
354
355     def read_proto_line(self, allowed):
356         command, sha = self.lines.pop(0)
357         if allowed is not None:
358             assert command in allowed
359         return command, sha
360
361     def send_ack(self, sha, ack_type=''):
362         self.acks.append((sha, ack_type))
363
364     def send_nak(self):
365         self.acks.append((None, 'nak'))
366
367     def all_wants_satisfied(self, haves):
368         return self.done
369
370     def pop_ack(self):
371         if not self.acks:
372             return None
373         return self.acks.pop(0)
374
375
376 class AckGraphWalkerImplTestCase(TestCase):
377     """Base setup and asserts for AckGraphWalker tests."""
378
379     def setUp(self):
380         super(AckGraphWalkerImplTestCase, self).setUp()
381         self._walker = TestProtocolGraphWalker()
382         self._walker.lines = [
383           ('have', TWO),
384           ('have', ONE),
385           ('have', THREE),
386           ('done', None),
387           ]
388         self._impl = self.impl_cls(self._walker)
389
390     def assertNoAck(self):
391         self.assertEquals(None, self._walker.pop_ack())
392
393     def assertAcks(self, acks):
394         for sha, ack_type in acks:
395             self.assertEquals((sha, ack_type), self._walker.pop_ack())
396         self.assertNoAck()
397
398     def assertAck(self, sha, ack_type=''):
399         self.assertAcks([(sha, ack_type)])
400
401     def assertNak(self):
402         self.assertAck(None, 'nak')
403
404     def assertNextEquals(self, sha):
405         self.assertEquals(sha, self._impl.next())
406
407
408 class SingleAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
409
410     impl_cls = SingleAckGraphWalkerImpl
411
412     def test_single_ack(self):
413         self.assertNextEquals(TWO)
414         self.assertNoAck()
415
416         self.assertNextEquals(ONE)
417         self._walker.done = True
418         self._impl.ack(ONE)
419         self.assertAck(ONE)
420
421         self.assertNextEquals(THREE)
422         self._impl.ack(THREE)
423         self.assertNoAck()
424
425         self.assertNextEquals(None)
426         self.assertNoAck()
427
428     def test_single_ack_flush(self):
429         # same as ack test but ends with a flush-pkt instead of done
430         self._walker.lines[-1] = (None, None)
431
432         self.assertNextEquals(TWO)
433         self.assertNoAck()
434
435         self.assertNextEquals(ONE)
436         self._walker.done = True
437         self._impl.ack(ONE)
438         self.assertAck(ONE)
439
440         self.assertNextEquals(THREE)
441         self.assertNoAck()
442
443         self.assertNextEquals(None)
444         self.assertNoAck()
445
446     def test_single_ack_nak(self):
447         self.assertNextEquals(TWO)
448         self.assertNoAck()
449
450         self.assertNextEquals(ONE)
451         self.assertNoAck()
452
453         self.assertNextEquals(THREE)
454         self.assertNoAck()
455
456         self.assertNextEquals(None)
457         self.assertNak()
458
459     def test_single_ack_nak_flush(self):
460         # same as nak test but ends with a flush-pkt instead of done
461         self._walker.lines[-1] = (None, None)
462
463         self.assertNextEquals(TWO)
464         self.assertNoAck()
465
466         self.assertNextEquals(ONE)
467         self.assertNoAck()
468
469         self.assertNextEquals(THREE)
470         self.assertNoAck()
471
472         self.assertNextEquals(None)
473         self.assertNak()
474
475
476 class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
477
478     impl_cls = MultiAckGraphWalkerImpl
479
480     def test_multi_ack(self):
481         self.assertNextEquals(TWO)
482         self.assertNoAck()
483
484         self.assertNextEquals(ONE)
485         self._walker.done = True
486         self._impl.ack(ONE)
487         self.assertAck(ONE, 'continue')
488
489         self.assertNextEquals(THREE)
490         self._impl.ack(THREE)
491         self.assertAck(THREE, 'continue')
492
493         self.assertNextEquals(None)
494         self.assertAck(THREE)
495
496     def test_multi_ack_partial(self):
497         self.assertNextEquals(TWO)
498         self.assertNoAck()
499
500         self.assertNextEquals(ONE)
501         self._impl.ack(ONE)
502         self.assertAck(ONE, 'continue')
503
504         self.assertNextEquals(THREE)
505         self.assertNoAck()
506
507         self.assertNextEquals(None)
508         # done, re-send ack of last common
509         self.assertAck(ONE)
510
511     def test_multi_ack_flush(self):
512         self._walker.lines = [
513           ('have', TWO),
514           (None, None),
515           ('have', ONE),
516           ('have', THREE),
517           ('done', None),
518           ]
519         self.assertNextEquals(TWO)
520         self.assertNoAck()
521
522         self.assertNextEquals(ONE)
523         self.assertNak()  # nak the flush-pkt
524
525         self._walker.done = True
526         self._impl.ack(ONE)
527         self.assertAck(ONE, 'continue')
528
529         self.assertNextEquals(THREE)
530         self._impl.ack(THREE)
531         self.assertAck(THREE, 'continue')
532
533         self.assertNextEquals(None)
534         self.assertAck(THREE)
535
536     def test_multi_ack_nak(self):
537         self.assertNextEquals(TWO)
538         self.assertNoAck()
539
540         self.assertNextEquals(ONE)
541         self.assertNoAck()
542
543         self.assertNextEquals(THREE)
544         self.assertNoAck()
545
546         self.assertNextEquals(None)
547         self.assertNak()
548
549
550 class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
551
552     impl_cls = MultiAckDetailedGraphWalkerImpl
553
554     def test_multi_ack(self):
555         self.assertNextEquals(TWO)
556         self.assertNoAck()
557
558         self.assertNextEquals(ONE)
559         self._walker.done = True
560         self._impl.ack(ONE)
561         self.assertAcks([(ONE, 'common'), (ONE, 'ready')])
562
563         self.assertNextEquals(THREE)
564         self._impl.ack(THREE)
565         self.assertAck(THREE, 'ready')
566
567         self.assertNextEquals(None)
568         self.assertAck(THREE)
569
570     def test_multi_ack_partial(self):
571         self.assertNextEquals(TWO)
572         self.assertNoAck()
573
574         self.assertNextEquals(ONE)
575         self._impl.ack(ONE)
576         self.assertAck(ONE, 'common')
577
578         self.assertNextEquals(THREE)
579         self.assertNoAck()
580
581         self.assertNextEquals(None)
582         # done, re-send ack of last common
583         self.assertAck(ONE)
584
585     def test_multi_ack_flush(self):
586         # same as ack test but contains a flush-pkt in the middle
587         self._walker.lines = [
588           ('have', TWO),
589           (None, None),
590           ('have', ONE),
591           ('have', THREE),
592           ('done', None),
593           ]
594         self.assertNextEquals(TWO)
595         self.assertNoAck()
596
597         self.assertNextEquals(ONE)
598         self.assertNak()  # nak the flush-pkt
599
600         self._walker.done = True
601         self._impl.ack(ONE)
602         self.assertAcks([(ONE, 'common'), (ONE, 'ready')])
603
604         self.assertNextEquals(THREE)
605         self._impl.ack(THREE)
606         self.assertAck(THREE, 'ready')
607
608         self.assertNextEquals(None)
609         self.assertAck(THREE)
610
611     def test_multi_ack_nak(self):
612         self.assertNextEquals(TWO)
613         self.assertNoAck()
614
615         self.assertNextEquals(ONE)
616         self.assertNoAck()
617
618         self.assertNextEquals(THREE)
619         self.assertNoAck()
620
621         self.assertNextEquals(None)
622         self.assertNak()
623
624     def test_multi_ack_nak_flush(self):
625         # same as nak test but contains a flush-pkt in the middle
626         self._walker.lines = [
627           ('have', TWO),
628           (None, None),
629           ('have', ONE),
630           ('have', THREE),
631           ('done', None),
632           ]
633         self.assertNextEquals(TWO)
634         self.assertNoAck()
635
636         self.assertNextEquals(ONE)
637         self.assertNak()
638
639         self.assertNextEquals(THREE)
640         self.assertNoAck()
641
642         self.assertNextEquals(None)
643         self.assertNak()
644
645     def test_multi_ack_stateless(self):
646         # transmission ends with a flush-pkt
647         self._walker.lines[-1] = (None, None)
648         self._walker.stateless_rpc = True
649
650         self.assertNextEquals(TWO)
651         self.assertNoAck()
652
653         self.assertNextEquals(ONE)
654         self.assertNoAck()
655
656         self.assertNextEquals(THREE)
657         self.assertNoAck()
658
659         self.assertNextEquals(None)
660         self.assertNak()