0a4817fa8e8e6b4be8cf8a89feccbad0c18529fc
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_web.py
1 # test_web.py -- Tests for the git HTTP server
2 # Copryight (C) 2010 Google, Inc.
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # or (at your option) any later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 """Tests for the Git HTTP server."""
20
21 from cStringIO import StringIO
22 import re
23 from unittest import TestCase
24
25 from dulwich.objects import (
26     Blob,
27     )
28 from dulwich.web import (
29     HTTP_OK,
30     HTTP_NOT_FOUND,
31     HTTP_FORBIDDEN,
32     send_file,
33     get_info_refs,
34     handle_service_request,
35     _LengthLimitedFile,
36     HTTPGitRequest,
37     HTTPGitApplication,
38     )
39
40
41 class WebTestCase(TestCase):
42     """Base TestCase that sets up some useful instance vars."""
43     def setUp(self):
44         self._environ = {}
45         self._req = HTTPGitRequest(self._environ, self._start_response)
46         self._status = None
47         self._headers = []
48
49     def _start_response(self, status, headers):
50         self._status = status
51         self._headers = list(headers)
52
53
54 class DumbHandlersTestCase(WebTestCase):
55
56     def test_send_file_not_found(self):
57         list(send_file(self._req, None, 'text/plain'))
58         self.assertEquals(HTTP_NOT_FOUND, self._status)
59
60     def test_send_file(self):
61         f = StringIO('foobar')
62         output = ''.join(send_file(self._req, f, 'text/plain'))
63         self.assertEquals('foobar', output)
64         self.assertEquals(HTTP_OK, self._status)
65         self.assertTrue(('Content-Type', 'text/plain') in self._headers)
66         self.assertTrue(f.closed)
67
68     def test_send_file_buffered(self):
69         bufsize = 10240
70         xs = 'x' * bufsize
71         f = StringIO(2 * xs)
72         self.assertEquals([xs, xs],
73                           list(send_file(self._req, f, 'text/plain')))
74         self.assertEquals(HTTP_OK, self._status)
75         self.assertTrue(('Content-Type', 'text/plain') in self._headers)
76         self.assertTrue(f.closed)
77
78     def test_send_file_error(self):
79         class TestFile(object):
80             def __init__(self):
81                 self.closed = False
82
83             def read(self, size=-1):
84                 raise IOError
85
86             def close(self):
87                 self.closed = True
88
89         f = TestFile()
90         list(send_file(self._req, f, 'text/plain'))
91         self.assertEquals(HTTP_NOT_FOUND, self._status)
92         self.assertTrue(f.closed)
93
94     def test_get_info_refs(self):
95         self._environ['QUERY_STRING'] = ''
96
97         class TestTag(object):
98             def __init__(self, sha, obj_class, obj_sha):
99                 self.sha = lambda: sha
100                 self.object = (obj_class, obj_sha)
101
102         class TestBlob(object):
103             def __init__(self, sha):
104                 self.sha = lambda: sha
105
106         blob1 = TestBlob('111')
107         blob2 = TestBlob('222')
108         blob3 = TestBlob('333')
109
110         tag1 = TestTag('aaa', Blob, '222')
111
112         class TestBackend(object):
113
114             def __init__(self):
115                 objects = [blob1, blob2, blob3, tag1]
116                 self._objects = dict((o.sha(), o) for o in objects)
117                 self._peeled = {
118                     'HEAD': '000',
119                     'refs/heads/master': blob1.sha(),
120                     'refs/tags/tag-tag': blob2.sha(),
121                     'refs/tags/blob-tag': blob3.sha(),
122                     }
123
124             def __getitem__(self, sha):
125                 return self._objects[sha]
126
127             def get_peeled(self, sha):
128                 return self._peeled[sha]
129
130             def get_refs(self):
131                 return {
132                     'HEAD': '000',
133                     'refs/heads/master': blob1.sha(),
134                     'refs/tags/tag-tag': tag1.sha(),
135                     'refs/tags/blob-tag': blob3.sha(),
136                     }
137
138         self.assertEquals(['111\trefs/heads/master\n',
139                            '333\trefs/tags/blob-tag\n',
140                            'aaa\trefs/tags/tag-tag\n',
141                            '222\trefs/tags/tag-tag^{}\n'],
142                           list(get_info_refs(self._req, TestBackend(), None)))
143
144
145 class SmartHandlersTestCase(WebTestCase):
146
147     class TestProtocol(object):
148         def __init__(self, handler):
149             self._handler = handler
150
151         def write_pkt_line(self, line):
152             if line is None:
153                 self._handler.write('flush-pkt\n')
154             else:
155                 self._handler.write('pkt-line: %s' % line)
156
157     class _TestUploadPackHandler(object):
158         def __init__(self, backend, read, write, stateless_rpc=False,
159                      advertise_refs=False):
160             self.read = read
161             self.write = write
162             self.proto = SmartHandlersTestCase.TestProtocol(self)
163             self.stateless_rpc = stateless_rpc
164             self.advertise_refs = advertise_refs
165
166         def handle(self):
167             self.write('handled input: %s' % self.read())
168
169     def _MakeHandler(self, *args, **kwargs):
170         self._handler = self._TestUploadPackHandler(*args, **kwargs)
171         return self._handler
172
173     def services(self):
174         return {'git-upload-pack': self._MakeHandler}
175
176     def test_handle_service_request_unknown(self):
177         mat = re.search('.*', '/git-evil-handler')
178         list(handle_service_request(self._req, 'backend', mat))
179         self.assertEquals(HTTP_FORBIDDEN, self._status)
180
181     def test_handle_service_request(self):
182         self._environ['wsgi.input'] = StringIO('foo')
183         mat = re.search('.*', '/git-upload-pack')
184         output = ''.join(handle_service_request(self._req, 'backend', mat,
185                                                 services=self.services()))
186         self.assertEqual('handled input: foo', output)
187         response_type = 'application/x-git-upload-pack-response'
188         self.assertTrue(('Content-Type', response_type) in self._headers)
189         self.assertFalse(self._handler.advertise_refs)
190         self.assertTrue(self._handler.stateless_rpc)
191
192     def test_handle_service_request_with_length(self):
193         self._environ['wsgi.input'] = StringIO('foobar')
194         self._environ['CONTENT_LENGTH'] = 3
195         mat = re.search('.*', '/git-upload-pack')
196         output = ''.join(handle_service_request(self._req, 'backend', mat,
197                                                 services=self.services()))
198         self.assertEqual('handled input: foo', output)
199         response_type = 'application/x-git-upload-pack-response'
200         self.assertTrue(('Content-Type', response_type) in self._headers)
201
202     def test_get_info_refs_unknown(self):
203         self._environ['QUERY_STRING'] = 'service=git-evil-handler'
204         list(get_info_refs(self._req, 'backend', None,
205                            services=self.services()))
206         self.assertEquals(HTTP_FORBIDDEN, self._status)
207
208     def test_get_info_refs(self):
209         self._environ['wsgi.input'] = StringIO('foo')
210         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
211
212         output = ''.join(get_info_refs(self._req, 'backend', None,
213                                        services=self.services()))
214         self.assertEquals(('pkt-line: # service=git-upload-pack\n'
215                            'flush-pkt\n'
216                            # input is ignored by the handler
217                            'handled input: '), output)
218         self.assertTrue(self._handler.advertise_refs)
219         self.assertTrue(self._handler.stateless_rpc)
220
221
222 class LengthLimitedFileTestCase(TestCase):
223     def test_no_cutoff(self):
224         f = _LengthLimitedFile(StringIO('foobar'), 1024)
225         self.assertEquals('foobar', f.read())
226
227     def test_cutoff(self):
228         f = _LengthLimitedFile(StringIO('foobar'), 3)
229         self.assertEquals('foo', f.read())
230         self.assertEquals('', f.read())
231
232     def test_multiple_reads(self):
233         f = _LengthLimitedFile(StringIO('foobar'), 3)
234         self.assertEquals('fo', f.read(2))
235         self.assertEquals('o', f.read(2))
236         self.assertEquals('', f.read())
237
238
239 class HTTPGitRequestTestCase(WebTestCase):
240     def test_not_found(self):
241         self._req.cache_forever()  # cache headers should be discarded
242         message = 'Something not found'
243         self.assertEquals(message, self._req.not_found(message))
244         self.assertEquals(HTTP_NOT_FOUND, self._status)
245         self.assertEquals(set([('Content-Type', 'text/plain')]),
246                           set(self._headers))
247
248     def test_forbidden(self):
249         self._req.cache_forever()  # cache headers should be discarded
250         message = 'Something not found'
251         self.assertEquals(message, self._req.forbidden(message))
252         self.assertEquals(HTTP_FORBIDDEN, self._status)
253         self.assertEquals(set([('Content-Type', 'text/plain')]),
254                           set(self._headers))
255
256     def test_respond_ok(self):
257         self._req.respond()
258         self.assertEquals([], self._headers)
259         self.assertEquals(HTTP_OK, self._status)
260
261     def test_respond(self):
262         self._req.nocache()
263         self._req.respond(status=402, content_type='some/type',
264                           headers=[('X-Foo', 'foo'), ('X-Bar', 'bar')])
265         self.assertEquals(set([
266             ('X-Foo', 'foo'),
267             ('X-Bar', 'bar'),
268             ('Content-Type', 'some/type'),
269             ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
270             ('Pragma', 'no-cache'),
271             ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
272             ]), set(self._headers))
273         self.assertEquals(402, self._status)
274
275
276 class HTTPGitApplicationTestCase(TestCase):
277     def setUp(self):
278         self._app = HTTPGitApplication('backend')
279
280     def test_call(self):
281         def test_handler(req, backend, mat):
282             # tests interface used by all handlers
283             self.assertEquals(environ, req.environ)
284             self.assertEquals('backend', backend)
285             self.assertEquals('/foo', mat.group(0))
286             return 'output'
287
288         self._app.services = {
289             ('GET', re.compile('/foo$')): test_handler,
290         }
291         environ = {
292             'PATH_INFO': '/foo',
293             'REQUEST_METHOD': 'GET',
294             }
295         self.assertEquals('output', self._app(environ, None))