Standardize quote delimiters in test_protocol.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_protocol.py
1 # test_protocol.py -- Tests for the git protocol
2 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 """Tests for the smart protocol utility functions."""
20
21
22 from StringIO import StringIO
23
24 from dulwich.protocol import (
25     Protocol,
26     ReceivableProtocol,
27     extract_capabilities,
28     extract_want_line_capabilities,
29     ack_type,
30     SINGLE_ACK,
31     MULTI_ACK,
32     MULTI_ACK_DETAILED,
33     BufferedPktLineWriter,
34     )
35 from dulwich.tests import TestCase
36
37
38 class BaseProtocolTests(object):
39
40     def test_write_pkt_line_none(self):
41         self.proto.write_pkt_line(None)
42         self.assertEquals(self.rout.getvalue(), '0000')
43
44     def test_write_pkt_line(self):
45         self.proto.write_pkt_line('bla')
46         self.assertEquals(self.rout.getvalue(), '0007bla')
47
48     def test_read_pkt_line(self):
49         self.rin.write('0008cmd ')
50         self.rin.seek(0)
51         self.assertEquals('cmd ', self.proto.read_pkt_line())
52
53     def test_read_pkt_seq(self):
54         self.rin.write('0008cmd 0005l0000')
55         self.rin.seek(0)
56         self.assertEquals(['cmd ', 'l'], list(self.proto.read_pkt_seq()))
57
58     def test_read_pkt_line_none(self):
59         self.rin.write('0000')
60         self.rin.seek(0)
61         self.assertEquals(None, self.proto.read_pkt_line())
62
63     def test_write_sideband(self):
64         self.proto.write_sideband(3, 'bloe')
65         self.assertEquals(self.rout.getvalue(), '0009\x03bloe')
66
67     def test_send_cmd(self):
68         self.proto.send_cmd('fetch', 'a', 'b')
69         self.assertEquals(self.rout.getvalue(), '000efetch a\x00b\x00')
70
71     def test_read_cmd(self):
72         self.rin.write('0012cmd arg1\x00arg2\x00')
73         self.rin.seek(0)
74         self.assertEquals(('cmd', ['arg1', 'arg2']), self.proto.read_cmd())
75
76     def test_read_cmd_noend0(self):
77         self.rin.write('0011cmd arg1\x00arg2')
78         self.rin.seek(0)
79         self.assertRaises(AssertionError, self.proto.read_cmd)
80
81
82 class ProtocolTests(BaseProtocolTests, TestCase):
83
84     def setUp(self):
85         TestCase.setUp(self)
86         self.rout = StringIO()
87         self.rin = StringIO()
88         self.proto = Protocol(self.rin.read, self.rout.write)
89
90
91 class ReceivableStringIO(StringIO):
92     """StringIO with socket-like recv semantics for testing."""
93
94     def recv(self, size):
95         # fail fast if no bytes are available; in a real socket, this would
96         # block forever
97         if self.tell() == len(self.getvalue()):
98             raise AssertionError('Blocking read past end of socket')
99         if size == 1:
100             return self.read(1)
101         # calls shouldn't return quite as much as asked for
102         return self.read(size - 1)
103
104
105 class ReceivableProtocolTests(BaseProtocolTests, TestCase):
106
107     def setUp(self):
108         TestCase.setUp(self)
109         self.rout = StringIO()
110         self.rin = ReceivableStringIO()
111         self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
112         self.proto._rbufsize = 8
113
114     def test_recv(self):
115         all_data = '1234567' * 10  # not a multiple of bufsize
116         self.rin.write(all_data)
117         self.rin.seek(0)
118         data = ''
119         # We ask for 8 bytes each time and actually read 7, so it should take
120         # exactly 10 iterations.
121         for _ in xrange(10):
122             data += self.proto.recv(10)
123         # any more reads would block
124         self.assertRaises(AssertionError, self.proto.recv, 10)
125         self.assertEquals(all_data, data)
126
127     def test_recv_read(self):
128         all_data = '1234567'  # recv exactly in one call
129         self.rin.write(all_data)
130         self.rin.seek(0)
131         self.assertEquals('1234', self.proto.recv(4))
132         self.assertEquals('567', self.proto.read(3))
133         self.assertRaises(AssertionError, self.proto.recv, 10)
134
135     def test_read_recv(self):
136         all_data = '12345678abcdefg'
137         self.rin.write(all_data)
138         self.rin.seek(0)
139         self.assertEquals('1234', self.proto.read(4))
140         self.assertEquals('5678abc', self.proto.recv(8))
141         self.assertEquals('defg', self.proto.read(4))
142         self.assertRaises(AssertionError, self.proto.recv, 10)
143
144     def test_mixed(self):
145         # arbitrary non-repeating string
146         all_data = ','.join(str(i) for i in xrange(100))
147         self.rin.write(all_data)
148         self.rin.seek(0)
149         data = ''
150
151         for i in xrange(1, 100):
152             data += self.proto.recv(i)
153             # if we get to the end, do a non-blocking read instead of blocking
154             if len(data) + i > len(all_data):
155                 data += self.proto.recv(i)
156                 # ReceivableStringIO leaves off the last byte unless we ask
157                 # nicely
158                 data += self.proto.recv(1)
159                 break
160             else:
161                 data += self.proto.read(i)
162         else:
163             # didn't break, something must have gone wrong
164             self.fail()
165
166         self.assertEquals(all_data, data)
167
168
169 class CapabilitiesTestCase(TestCase):
170
171     def test_plain(self):
172         self.assertEquals(('bla', []), extract_capabilities('bla'))
173
174     def test_caps(self):
175         self.assertEquals(('bla', ['la']), extract_capabilities('bla\0la'))
176         self.assertEquals(('bla', ['la']), extract_capabilities('bla\0la\n'))
177         self.assertEquals(('bla', ['la', 'la']), extract_capabilities('bla\0la la'))
178
179     def test_plain_want_line(self):
180         self.assertEquals(('want bla', []), extract_want_line_capabilities('want bla'))
181
182     def test_caps_want_line(self):
183         self.assertEquals(('want bla', ['la']), extract_want_line_capabilities('want bla la'))
184         self.assertEquals(('want bla', ['la']), extract_want_line_capabilities('want bla la\n'))
185         self.assertEquals(('want bla', ['la', 'la']), extract_want_line_capabilities('want bla la la'))
186
187     def test_ack_type(self):
188         self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
189         self.assertEquals(MULTI_ACK, ack_type(['foo', 'bar', 'multi_ack']))
190         self.assertEquals(MULTI_ACK_DETAILED,
191                           ack_type(['foo', 'bar', 'multi_ack_detailed']))
192         # choose detailed when both present
193         self.assertEquals(MULTI_ACK_DETAILED,
194                           ack_type(['foo', 'bar', 'multi_ack',
195                                     'multi_ack_detailed']))
196
197
198 class BufferedPktLineWriterTests(TestCase):
199
200     def setUp(self):
201         TestCase.setUp(self)
202         self._output = StringIO()
203         self._writer = BufferedPktLineWriter(self._output.write, bufsize=16)
204
205     def assertOutputEquals(self, expected):
206         self.assertEquals(expected, self._output.getvalue())
207
208     def _truncate(self):
209         self._output.seek(0)
210         self._output.truncate()
211
212     def test_write(self):
213         self._writer.write('foo')
214         self.assertOutputEquals('')
215         self._writer.flush()
216         self.assertOutputEquals('0007foo')
217
218     def test_write_none(self):
219         self._writer.write(None)
220         self.assertOutputEquals('')
221         self._writer.flush()
222         self.assertOutputEquals('0000')
223
224     def test_flush_empty(self):
225         self._writer.flush()
226         self.assertOutputEquals('')
227
228     def test_write_multiple(self):
229         self._writer.write('foo')
230         self._writer.write('bar')
231         self.assertOutputEquals('')
232         self._writer.flush()
233         self.assertOutputEquals('0007foo0007bar')
234
235     def test_write_across_boundary(self):
236         self._writer.write('foo')
237         self._writer.write('barbaz')
238         self.assertOutputEquals('0007foo000abarba')
239         self._truncate()
240         self._writer.flush()
241         self.assertOutputEquals('z')
242
243     def test_write_to_boundary(self):
244         self._writer.write('foo')
245         self._writer.write('barba')
246         self.assertOutputEquals('0007foo0009barba')
247         self._truncate()
248         self._writer.write('z')
249         self._writer.flush()
250         self.assertOutputEquals('0005z')