Add convenience function for parsing info/refs file.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_web.py
1 # test_web.py -- Tests for the git HTTP server
2 # Copryight (C) 2010 Google, Inc.
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 """Tests for the Git HTTP server."""
20
21 from cStringIO import StringIO
22 import re
23 from unittest import TestCase
24
25 from dulwich.objects import (
26     type_map,
27     Tag,
28     Blob,
29     )
30 from dulwich.web import (
31     HTTP_OK,
32     HTTP_NOT_FOUND,
33     HTTP_FORBIDDEN,
34     send_file,
35     get_info_refs,
36     handle_service_request,
37     _LengthLimitedFile,
38     HTTPGitRequest,
39     HTTPGitApplication,
40     )
41
42
43 class WebTestCase(TestCase):
44     """Base TestCase that sets up some useful instance vars."""
45     def setUp(self):
46         self._environ = {}
47         self._req = HTTPGitRequest(self._environ, self._start_response)
48         self._status = None
49         self._headers = []
50
51     def _start_response(self, status, headers):
52         self._status = status
53         self._headers = list(headers)
54
55
56 class DumbHandlersTestCase(WebTestCase):
57
58     def test_send_file_not_found(self):
59         list(send_file(self._req, None, 'text/plain'))
60         self.assertEquals(HTTP_NOT_FOUND, self._status)
61
62     def test_send_file(self):
63         f = StringIO('foobar')
64         output = ''.join(send_file(self._req, f, 'text/plain'))
65         self.assertEquals('foobar', output)
66         self.assertEquals(HTTP_OK, self._status)
67         self.assertTrue(('Content-Type', 'text/plain') in self._headers)
68         self.assertTrue(f.closed)
69
70     def test_send_file_buffered(self):
71         bufsize = 10240
72         xs = 'x' * bufsize
73         f = StringIO(2 * xs)
74         self.assertEquals([xs, xs],
75                           list(send_file(self._req, f, 'text/plain')))
76         self.assertEquals(HTTP_OK, self._status)
77         self.assertTrue(('Content-Type', 'text/plain') in self._headers)
78         self.assertTrue(f.closed)
79
80     def test_send_file_error(self):
81         class TestFile(object):
82             def __init__(self):
83                 self.closed = False
84
85             def read(self, size=-1):
86                 raise IOError
87
88             def close(self):
89                 self.closed = True
90
91         f = TestFile()
92         list(send_file(self._req, f, 'text/plain'))
93         self.assertEquals(HTTP_NOT_FOUND, self._status)
94         self.assertTrue(f.closed)
95
96     def test_get_info_refs(self):
97         self._environ['QUERY_STRING'] = ''
98
99         class TestTag(object):
100             type = Tag().type
101
102             def __init__(self, sha, obj_type, obj_sha):
103                 self.sha = lambda: sha
104                 self.object = (obj_type, obj_sha)
105
106         class TestBlob(object):
107             type = Blob().type
108
109             def __init__(self, sha):
110                 self.sha = lambda: sha
111
112         blob1 = TestBlob('111')
113         blob2 = TestBlob('222')
114         blob3 = TestBlob('333')
115
116         tag1 = TestTag('aaa', TestTag.type, 'bbb')
117         tag2 = TestTag('bbb', TestBlob.type, '222')
118
119         class TestBackend(object):
120             def __init__(self):
121                 objects = [blob1, blob2, blob3, tag1, tag2]
122                 self.repo = dict((o.sha(), o) for o in objects)
123
124             def get_refs(self):
125                 return {
126                     'HEAD': '000',
127                     'refs/heads/master': blob1.sha(),
128                     'refs/tags/tag-tag': tag1.sha(),
129                     'refs/tags/blob-tag': blob3.sha(),
130                     }
131
132         self.assertEquals(['111\trefs/heads/master\n',
133                            '333\trefs/tags/blob-tag\n',
134                            'aaa\trefs/tags/tag-tag\n',
135                            '222\trefs/tags/tag-tag^{}\n'],
136                           list(get_info_refs(self._req, TestBackend(), None)))
137
138
139 class SmartHandlersTestCase(WebTestCase):
140
141     class TestProtocol(object):
142         def __init__(self, handler):
143             self._handler = handler
144
145         def write_pkt_line(self, line):
146             if line is None:
147                 self._handler.write('flush-pkt\n')
148             else:
149                 self._handler.write('pkt-line: %s' % line)
150
151     class _TestUploadPackHandler(object):
152         def __init__(self, backend, read, write, stateless_rpc=False,
153                      advertise_refs=False):
154             self.read = read
155             self.write = write
156             self.proto = SmartHandlersTestCase.TestProtocol(self)
157             self.stateless_rpc = stateless_rpc
158             self.advertise_refs = advertise_refs
159
160         def handle(self):
161             self.write('handled input: %s' % self.read())
162
163     def _MakeHandler(self, *args, **kwargs):
164         self._handler = self._TestUploadPackHandler(*args, **kwargs)
165         return self._handler
166
167     def services(self):
168         return {'git-upload-pack': self._MakeHandler}
169
170     def test_handle_service_request_unknown(self):
171         mat = re.search('.*', '/git-evil-handler')
172         list(handle_service_request(self._req, 'backend', mat))
173         self.assertEquals(HTTP_FORBIDDEN, self._status)
174
175     def test_handle_service_request(self):
176         self._environ['wsgi.input'] = StringIO('foo')
177         mat = re.search('.*', '/git-upload-pack')
178         output = ''.join(handle_service_request(self._req, 'backend', mat,
179                                                 services=self.services()))
180         self.assertEqual('handled input: foo', output)
181         response_type = 'application/x-git-upload-pack-response'
182         self.assertTrue(('Content-Type', response_type) in self._headers)
183         self.assertFalse(self._handler.advertise_refs)
184         self.assertTrue(self._handler.stateless_rpc)
185
186     def test_handle_service_request_with_length(self):
187         self._environ['wsgi.input'] = StringIO('foobar')
188         self._environ['CONTENT_LENGTH'] = 3
189         mat = re.search('.*', '/git-upload-pack')
190         output = ''.join(handle_service_request(self._req, 'backend', mat,
191                                                 services=self.services()))
192         self.assertEqual('handled input: foo', output)
193         response_type = 'application/x-git-upload-pack-response'
194         self.assertTrue(('Content-Type', response_type) in self._headers)
195
196     def test_get_info_refs_unknown(self):
197         self._environ['QUERY_STRING'] = 'service=git-evil-handler'
198         list(get_info_refs(self._req, 'backend', None,
199                            services=self.services()))
200         self.assertEquals(HTTP_FORBIDDEN, self._status)
201
202     def test_get_info_refs(self):
203         self._environ['wsgi.input'] = StringIO('foo')
204         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
205
206         output = ''.join(get_info_refs(self._req, 'backend', None,
207                                        services=self.services()))
208         self.assertEquals(('pkt-line: # service=git-upload-pack\n'
209                            'flush-pkt\n'
210                            # input is ignored by the handler
211                            'handled input: '), output)
212         self.assertTrue(self._handler.advertise_refs)
213         self.assertTrue(self._handler.stateless_rpc)
214
215
216 class LengthLimitedFileTestCase(TestCase):
217     def test_no_cutoff(self):
218         f = _LengthLimitedFile(StringIO('foobar'), 1024)
219         self.assertEquals('foobar', f.read())
220
221     def test_cutoff(self):
222         f = _LengthLimitedFile(StringIO('foobar'), 3)
223         self.assertEquals('foo', f.read())
224         self.assertEquals('', f.read())
225
226     def test_multiple_reads(self):
227         f = _LengthLimitedFile(StringIO('foobar'), 3)
228         self.assertEquals('fo', f.read(2))
229         self.assertEquals('o', f.read(2))
230         self.assertEquals('', f.read())
231
232
233 class HTTPGitRequestTestCase(WebTestCase):
234     def test_not_found(self):
235         self._req.cache_forever()  # cache headers should be discarded
236         message = 'Something not found'
237         self.assertEquals(message, self._req.not_found(message))
238         self.assertEquals(HTTP_NOT_FOUND, self._status)
239         self.assertEquals(set([('Content-Type', 'text/plain')]),
240                           set(self._headers))
241
242     def test_forbidden(self):
243         self._req.cache_forever()  # cache headers should be discarded
244         message = 'Something not found'
245         self.assertEquals(message, self._req.forbidden(message))
246         self.assertEquals(HTTP_FORBIDDEN, self._status)
247         self.assertEquals(set([('Content-Type', 'text/plain')]),
248                           set(self._headers))
249
250     def test_respond_ok(self):
251         self._req.respond()
252         self.assertEquals([], self._headers)
253         self.assertEquals(HTTP_OK, self._status)
254
255     def test_respond(self):
256         self._req.nocache()
257         self._req.respond(status=402, content_type='some/type',
258                           headers=[('X-Foo', 'foo'), ('X-Bar', 'bar')])
259         self.assertEquals(set([
260             ('X-Foo', 'foo'),
261             ('X-Bar', 'bar'),
262             ('Content-Type', 'some/type'),
263             ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
264             ('Pragma', 'no-cache'),
265             ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
266             ]), set(self._headers))
267         self.assertEquals(402, self._status)
268
269
270 class HTTPGitApplicationTestCase(TestCase):
271     def setUp(self):
272         self._app = HTTPGitApplication('backend')
273
274     def test_call(self):
275         def test_handler(req, backend, mat):
276             # tests interface used by all handlers
277             self.assertEquals(environ, req.environ)
278             self.assertEquals('backend', backend)
279             self.assertEquals('/foo', mat.group(0))
280             return 'output'
281
282         self._app.services = {
283             ('GET', re.compile('/foo$')): test_handler,
284         }
285         environ = {
286             'PATH_INFO': '/foo',
287             'REQUEST_METHOD': 'GET',
288             }
289         self.assertEquals('output', self._app(environ, None))