1 # test_protocol.py -- Tests for the git protocol
2 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
19 """Tests for the smart protocol utility functions."""
22 from StringIO import StringIO
23 from unittest import TestCase
25 from dulwich.protocol import (
29 extract_want_line_capabilities,
36 class BaseProtocolTests(object):
38 def test_write_pkt_line_none(self):
39 self.proto.write_pkt_line(None)
40 self.assertEquals(self.rout.getvalue(), "0000")
42 def test_write_pkt_line(self):
43 self.proto.write_pkt_line("bla")
44 self.assertEquals(self.rout.getvalue(), "0007bla")
46 def test_read_pkt_line(self):
47 self.rin.write("0008cmd ")
49 self.assertEquals("cmd ", self.proto.read_pkt_line())
51 def test_read_pkt_seq(self):
52 self.rin.write("0008cmd 0005l0000")
54 self.assertEquals(["cmd ", "l"], list(self.proto.read_pkt_seq()))
56 def test_read_pkt_line_none(self):
57 self.rin.write("0000")
59 self.assertEquals(None, self.proto.read_pkt_line())
61 def test_write_sideband(self):
62 self.proto.write_sideband(3, "bloe")
63 self.assertEquals(self.rout.getvalue(), "0009\x03bloe")
65 def test_send_cmd(self):
66 self.proto.send_cmd("fetch", "a", "b")
67 self.assertEquals(self.rout.getvalue(), "000efetch a\x00b\x00")
69 def test_read_cmd(self):
70 self.rin.write("0012cmd arg1\x00arg2\x00")
72 self.assertEquals(("cmd", ["arg1", "arg2"]), self.proto.read_cmd())
74 def test_read_cmd_noend0(self):
75 self.rin.write("0011cmd arg1\x00arg2")
77 self.assertRaises(AssertionError, self.proto.read_cmd)
80 class ProtocolTests(BaseProtocolTests, TestCase):
84 self.rout = StringIO()
86 self.proto = Protocol(self.rin.read, self.rout.write)
89 class ReceivableStringIO(StringIO):
90 """StringIO with socket-like recv semantics for testing."""
93 # fail fast if no bytes are available; in a real socket, this would
95 if self.tell() == len(self.getvalue()):
96 raise AssertionError("Blocking read past end of socket")
99 # calls shouldn't return quite as much as asked for
100 return self.read(size - 1)
103 class ReceivableProtocolTests(BaseProtocolTests, TestCase):
107 self.rout = StringIO()
108 self.rin = ReceivableStringIO()
109 self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
110 self.proto._rbufsize = 8
113 all_data = "1234567" * 10 # not a multiple of bufsize
114 self.rin.write(all_data)
117 # We ask for 8 bytes each time and actually read 7, so it should take
118 # exactly 10 iterations.
120 data += self.proto.recv(10)
121 # any more reads would block
122 self.assertRaises(AssertionError, self.proto.recv, 10)
123 self.assertEquals(all_data, data)
125 def test_recv_read(self):
126 all_data = "1234567" # recv exactly in one call
127 self.rin.write(all_data)
129 self.assertEquals("1234", self.proto.recv(4))
130 self.assertEquals("567", self.proto.read(3))
131 self.assertRaises(AssertionError, self.proto.recv, 10)
133 def test_read_recv(self):
134 all_data = "12345678abcdefg"
135 self.rin.write(all_data)
137 self.assertEquals("1234", self.proto.read(4))
138 self.assertEquals("5678abc", self.proto.recv(8))
139 self.assertEquals("defg", self.proto.read(4))
140 self.assertRaises(AssertionError, self.proto.recv, 10)
142 def test_mixed(self):
143 # arbitrary non-repeating string
144 all_data = ",".join(str(i) for i in xrange(100))
145 self.rin.write(all_data)
149 for i in xrange(1, 100):
150 data += self.proto.recv(i)
151 # if we get to the end, do a non-blocking read instead of blocking
152 if len(data) + i > len(all_data):
153 data += self.proto.recv(i)
154 # ReceivableStringIO leaves off the last byte unless we ask
156 data += self.proto.recv(1)
159 data += self.proto.read(i)
161 # didn't break, something must have gone wrong
164 self.assertEquals(all_data, data)
167 class CapabilitiesTestCase(TestCase):
169 def test_plain(self):
170 self.assertEquals(("bla", []), extract_capabilities("bla"))
173 self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la"))
174 self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la\n"))
175 self.assertEquals(("bla", ["la", "la"]), extract_capabilities("bla\0la la"))
177 def test_plain_want_line(self):
178 self.assertEquals(("want bla", []), extract_want_line_capabilities("want bla"))
180 def test_caps_want_line(self):
181 self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la"))
182 self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la\n"))
183 self.assertEquals(("want bla", ["la", "la"]), extract_want_line_capabilities("want bla la la"))
185 def test_ack_type(self):
186 self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
187 self.assertEquals(MULTI_ACK, ack_type(['foo', 'bar', 'multi_ack']))
188 self.assertEquals(MULTI_ACK_DETAILED,
189 ack_type(['foo', 'bar', 'multi_ack_detailed']))
190 # choose detailed when both present
191 self.assertEquals(MULTI_ACK_DETAILED,
192 ack_type(['foo', 'bar', 'multi_ack',
193 'multi_ack_detailed']))