Standardize quote delimiters in test_protocol.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_protocol.py
index a182b35df3e848b4482b01e42819914cb5d68ae7..78011e410509b2f7f8c002e7914f2ffcfb3d600b 100644 (file)
@@ -20,7 +20,6 @@
 
 
 from StringIO import StringIO
-from unittest import TestCase
 
 from dulwich.protocol import (
     Protocol,
@@ -31,48 +30,51 @@ from dulwich.protocol import (
     SINGLE_ACK,
     MULTI_ACK,
     MULTI_ACK_DETAILED,
+    BufferedPktLineWriter,
     )
+from dulwich.tests import TestCase
+
 
 class BaseProtocolTests(object):
 
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)
-        self.assertEquals(self.rout.getvalue(), "0000")
+        self.assertEquals(self.rout.getvalue(), '0000')
 
     def test_write_pkt_line(self):
-        self.proto.write_pkt_line("bla")
-        self.assertEquals(self.rout.getvalue(), "0007bla")
+        self.proto.write_pkt_line('bla')
+        self.assertEquals(self.rout.getvalue(), '0007bla')
 
     def test_read_pkt_line(self):
-        self.rin.write("0008cmd ")
+        self.rin.write('0008cmd ')
         self.rin.seek(0)
-        self.assertEquals("cmd ", self.proto.read_pkt_line())
+        self.assertEquals('cmd ', self.proto.read_pkt_line())
 
     def test_read_pkt_seq(self):
-        self.rin.write("0008cmd 0005l0000")
+        self.rin.write('0008cmd 0005l0000')
         self.rin.seek(0)
-        self.assertEquals(["cmd ", "l"], list(self.proto.read_pkt_seq()))
+        self.assertEquals(['cmd ', 'l'], list(self.proto.read_pkt_seq()))
 
     def test_read_pkt_line_none(self):
-        self.rin.write("0000")
+        self.rin.write('0000')
         self.rin.seek(0)
         self.assertEquals(None, self.proto.read_pkt_line())
 
     def test_write_sideband(self):
-        self.proto.write_sideband(3, "bloe")
-        self.assertEquals(self.rout.getvalue(), "0009\x03bloe")
+        self.proto.write_sideband(3, 'bloe')
+        self.assertEquals(self.rout.getvalue(), '0009\x03bloe')
 
     def test_send_cmd(self):
-        self.proto.send_cmd("fetch", "a", "b")
-        self.assertEquals(self.rout.getvalue(), "000efetch a\x00b\x00")
+        self.proto.send_cmd('fetch', 'a', 'b')
+        self.assertEquals(self.rout.getvalue(), '000efetch a\x00b\x00')
 
     def test_read_cmd(self):
-        self.rin.write("0012cmd arg1\x00arg2\x00")
+        self.rin.write('0012cmd arg1\x00arg2\x00')
         self.rin.seek(0)
-        self.assertEquals(("cmd", ["arg1", "arg2"]), self.proto.read_cmd())
+        self.assertEquals(('cmd', ['arg1', 'arg2']), self.proto.read_cmd())
 
     def test_read_cmd_noend0(self):
-        self.rin.write("0011cmd arg1\x00arg2")
+        self.rin.write('0011cmd arg1\x00arg2')
         self.rin.seek(0)
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
@@ -93,7 +95,7 @@ class ReceivableStringIO(StringIO):
         # fail fast if no bytes are available; in a real socket, this would
         # block forever
         if self.tell() == len(self.getvalue()):
-            raise AssertionError("Blocking read past end of socket")
+            raise AssertionError('Blocking read past end of socket')
         if size == 1:
             return self.read(1)
         # calls shouldn't return quite as much as asked for
@@ -110,10 +112,10 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.proto._rbufsize = 8
 
     def test_recv(self):
-        all_data = "1234567" * 10  # not a multiple of bufsize
+        all_data = '1234567' * 10  # not a multiple of bufsize
         self.rin.write(all_data)
         self.rin.seek(0)
-        data = ""
+        data = ''
         # We ask for 8 bytes each time and actually read 7, so it should take
         # exactly 10 iterations.
         for _ in xrange(10):
@@ -123,28 +125,28 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.assertEquals(all_data, data)
 
     def test_recv_read(self):
-        all_data = "1234567"  # recv exactly in one call
+        all_data = '1234567'  # recv exactly in one call
         self.rin.write(all_data)
         self.rin.seek(0)
-        self.assertEquals("1234", self.proto.recv(4))
-        self.assertEquals("567", self.proto.read(3))
+        self.assertEquals('1234', self.proto.recv(4))
+        self.assertEquals('567', self.proto.read(3))
         self.assertRaises(AssertionError, self.proto.recv, 10)
 
     def test_read_recv(self):
-        all_data = "12345678abcdefg"
+        all_data = '12345678abcdefg'
         self.rin.write(all_data)
         self.rin.seek(0)
-        self.assertEquals("1234", self.proto.read(4))
-        self.assertEquals("5678abc", self.proto.recv(8))
-        self.assertEquals("defg", self.proto.read(4))
+        self.assertEquals('1234', self.proto.read(4))
+        self.assertEquals('5678abc', self.proto.recv(8))
+        self.assertEquals('defg', self.proto.read(4))
         self.assertRaises(AssertionError, self.proto.recv, 10)
 
     def test_mixed(self):
         # arbitrary non-repeating string
-        all_data = ",".join(str(i) for i in xrange(100))
+        all_data = ','.join(str(i) for i in xrange(100))
         self.rin.write(all_data)
         self.rin.seek(0)
-        data = ""
+        data = ''
 
         for i in xrange(1, 100):
             data += self.proto.recv(i)
@@ -167,20 +169,20 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
 class CapabilitiesTestCase(TestCase):
 
     def test_plain(self):
-        self.assertEquals(("bla", []), extract_capabilities("bla"))
+        self.assertEquals(('bla', []), extract_capabilities('bla'))
 
     def test_caps(self):
-        self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la"))
-        self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la\n"))
-        self.assertEquals(("bla", ["la", "la"]), extract_capabilities("bla\0la la"))
+        self.assertEquals(('bla', ['la']), extract_capabilities('bla\0la'))
+        self.assertEquals(('bla', ['la']), extract_capabilities('bla\0la\n'))
+        self.assertEquals(('bla', ['la', 'la']), extract_capabilities('bla\0la la'))
 
     def test_plain_want_line(self):
-        self.assertEquals(("want bla", []), extract_want_line_capabilities("want bla"))
+        self.assertEquals(('want bla', []), extract_want_line_capabilities('want bla'))
 
     def test_caps_want_line(self):
-        self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la"))
-        self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la\n"))
-        self.assertEquals(("want bla", ["la", "la"]), extract_want_line_capabilities("want bla la la"))
+        self.assertEquals(('want bla', ['la']), extract_want_line_capabilities('want bla la'))
+        self.assertEquals(('want bla', ['la']), extract_want_line_capabilities('want bla la\n'))
+        self.assertEquals(('want bla', ['la', 'la']), extract_want_line_capabilities('want bla la la'))
 
     def test_ack_type(self):
         self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
@@ -191,3 +193,58 @@ class CapabilitiesTestCase(TestCase):
         self.assertEquals(MULTI_ACK_DETAILED,
                           ack_type(['foo', 'bar', 'multi_ack',
                                     'multi_ack_detailed']))
+
+
+class BufferedPktLineWriterTests(TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self._output = StringIO()
+        self._writer = BufferedPktLineWriter(self._output.write, bufsize=16)
+
+    def assertOutputEquals(self, expected):
+        self.assertEquals(expected, self._output.getvalue())
+
+    def _truncate(self):
+        self._output.seek(0)
+        self._output.truncate()
+
+    def test_write(self):
+        self._writer.write('foo')
+        self.assertOutputEquals('')
+        self._writer.flush()
+        self.assertOutputEquals('0007foo')
+
+    def test_write_none(self):
+        self._writer.write(None)
+        self.assertOutputEquals('')
+        self._writer.flush()
+        self.assertOutputEquals('0000')
+
+    def test_flush_empty(self):
+        self._writer.flush()
+        self.assertOutputEquals('')
+
+    def test_write_multiple(self):
+        self._writer.write('foo')
+        self._writer.write('bar')
+        self.assertOutputEquals('')
+        self._writer.flush()
+        self.assertOutputEquals('0007foo0007bar')
+
+    def test_write_across_boundary(self):
+        self._writer.write('foo')
+        self._writer.write('barbaz')
+        self.assertOutputEquals('0007foo000abarba')
+        self._truncate()
+        self._writer.flush()
+        self.assertOutputEquals('z')
+
+    def test_write_to_boundary(self):
+        self._writer.write('foo')
+        self._writer.write('barba')
+        self.assertOutputEquals('0007foo0009barba')
+        self._truncate()
+        self._writer.write('z')
+        self._writer.flush()
+        self.assertOutputEquals('0005z')