Clean up file headers.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_protocol.py
1 # test_protocol.py -- Tests for the git protocol
2 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 """Tests for the smart protocol utility functions."""
20
21
22 from StringIO import StringIO
23 from unittest import TestCase
24
25 from dulwich.protocol import (
26     Protocol,
27     ReceivableProtocol,
28     extract_capabilities,
29     extract_want_line_capabilities,
30     ack_type,
31     SINGLE_ACK,
32     MULTI_ACK,
33     MULTI_ACK_DETAILED,
34     )
35
36 class BaseProtocolTests(object):
37
38     def test_write_pkt_line_none(self):
39         self.proto.write_pkt_line(None)
40         self.assertEquals(self.rout.getvalue(), "0000")
41
42     def test_write_pkt_line(self):
43         self.proto.write_pkt_line("bla")
44         self.assertEquals(self.rout.getvalue(), "0007bla")
45
46     def test_read_pkt_line(self):
47         self.rin.write("0008cmd ")
48         self.rin.seek(0)
49         self.assertEquals("cmd ", self.proto.read_pkt_line())
50
51     def test_read_pkt_seq(self):
52         self.rin.write("0008cmd 0005l0000")
53         self.rin.seek(0)
54         self.assertEquals(["cmd ", "l"], list(self.proto.read_pkt_seq()))
55
56     def test_read_pkt_line_none(self):
57         self.rin.write("0000")
58         self.rin.seek(0)
59         self.assertEquals(None, self.proto.read_pkt_line())
60
61     def test_write_sideband(self):
62         self.proto.write_sideband(3, "bloe")
63         self.assertEquals(self.rout.getvalue(), "0009\x03bloe")
64
65     def test_send_cmd(self):
66         self.proto.send_cmd("fetch", "a", "b")
67         self.assertEquals(self.rout.getvalue(), "000efetch a\x00b\x00")
68
69     def test_read_cmd(self):
70         self.rin.write("0012cmd arg1\x00arg2\x00")
71         self.rin.seek(0)
72         self.assertEquals(("cmd", ["arg1", "arg2"]), self.proto.read_cmd())
73
74     def test_read_cmd_noend0(self):
75         self.rin.write("0011cmd arg1\x00arg2")
76         self.rin.seek(0)
77         self.assertRaises(AssertionError, self.proto.read_cmd)
78
79
80 class ProtocolTests(BaseProtocolTests, TestCase):
81
82     def setUp(self):
83         TestCase.setUp(self)
84         self.rout = StringIO()
85         self.rin = StringIO()
86         self.proto = Protocol(self.rin.read, self.rout.write)
87
88
89 class ReceivableStringIO(StringIO):
90     """StringIO with socket-like recv semantics for testing."""
91
92     def recv(self, size):
93         # fail fast if no bytes are available; in a real socket, this would
94         # block forever
95         if self.tell() == len(self.getvalue()):
96             raise AssertionError("Blocking read past end of socket")
97         if size == 1:
98             return self.read(1)
99         # calls shouldn't return quite as much as asked for
100         return self.read(size - 1)
101
102
103 class ReceivableProtocolTests(BaseProtocolTests, TestCase):
104
105     def setUp(self):
106         TestCase.setUp(self)
107         self.rout = StringIO()
108         self.rin = ReceivableStringIO()
109         self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
110         self.proto._rbufsize = 8
111
112     def test_recv(self):
113         all_data = "1234567" * 10  # not a multiple of bufsize
114         self.rin.write(all_data)
115         self.rin.seek(0)
116         data = ""
117         # We ask for 8 bytes each time and actually read 7, so it should take
118         # exactly 10 iterations.
119         for _ in xrange(10):
120             data += self.proto.recv(10)
121         # any more reads would block
122         self.assertRaises(AssertionError, self.proto.recv, 10)
123         self.assertEquals(all_data, data)
124
125     def test_recv_read(self):
126         all_data = "1234567"  # recv exactly in one call
127         self.rin.write(all_data)
128         self.rin.seek(0)
129         self.assertEquals("1234", self.proto.recv(4))
130         self.assertEquals("567", self.proto.read(3))
131         self.assertRaises(AssertionError, self.proto.recv, 10)
132
133     def test_read_recv(self):
134         all_data = "12345678abcdefg"
135         self.rin.write(all_data)
136         self.rin.seek(0)
137         self.assertEquals("1234", self.proto.read(4))
138         self.assertEquals("5678abc", self.proto.recv(8))
139         self.assertEquals("defg", self.proto.read(4))
140         self.assertRaises(AssertionError, self.proto.recv, 10)
141
142     def test_mixed(self):
143         # arbitrary non-repeating string
144         all_data = ",".join(str(i) for i in xrange(100))
145         self.rin.write(all_data)
146         self.rin.seek(0)
147         data = ""
148
149         for i in xrange(1, 100):
150             data += self.proto.recv(i)
151             # if we get to the end, do a non-blocking read instead of blocking
152             if len(data) + i > len(all_data):
153                 data += self.proto.recv(i)
154                 # ReceivableStringIO leaves off the last byte unless we ask
155                 # nicely
156                 data += self.proto.recv(1)
157                 break
158             else:
159                 data += self.proto.read(i)
160         else:
161             # didn't break, something must have gone wrong
162             self.fail()
163
164         self.assertEquals(all_data, data)
165
166
167 class CapabilitiesTestCase(TestCase):
168
169     def test_plain(self):
170         self.assertEquals(("bla", []), extract_capabilities("bla"))
171
172     def test_caps(self):
173         self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la"))
174         self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la\n"))
175         self.assertEquals(("bla", ["la", "la"]), extract_capabilities("bla\0la la"))
176
177     def test_plain_want_line(self):
178         self.assertEquals(("want bla", []), extract_want_line_capabilities("want bla"))
179
180     def test_caps_want_line(self):
181         self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la"))
182         self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la\n"))
183         self.assertEquals(("want bla", ["la", "la"]), extract_want_line_capabilities("want bla la la"))
184
185     def test_ack_type(self):
186         self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
187         self.assertEquals(MULTI_ACK, ack_type(['foo', 'bar', 'multi_ack']))
188         self.assertEquals(MULTI_ACK_DETAILED,
189                           ack_type(['foo', 'bar', 'multi_ack_detailed']))
190         # choose detailed when both present
191         self.assertEquals(MULTI_ACK_DETAILED,
192                           ack_type(['foo', 'bar', 'multi_ack',
193                                     'multi_ack_detailed']))