3ac9a326586cd93106f065350c1d90b77c6f2c16
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_repository.py
1 # test_repository.py -- tests for repository.py
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # of the License or (at your option) any later version of 
8 # the License.
9
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20
21 """Tests for the repository."""
22
23 from cStringIO import StringIO
24 import os
25 import shutil
26 import tempfile
27 import unittest
28
29 from dulwich import errors
30 from dulwich.repo import (
31     check_ref_format,
32     Repo,
33     read_packed_refs,
34     read_packed_refs_with_peeled,
35     write_packed_refs,
36     _split_ref_line,
37     )
38
39 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
40
41
42 def open_repo(name):
43     """Open a copy of a repo in a temporary directory.
44
45     Use this function for accessing repos in dulwich/tests/data/repos to avoid
46     accidentally or intentionally modifying those repos in place. Use
47     tear_down_repo to delete any temp files created.
48
49     :param name: The name of the repository, relative to
50         dulwich/tests/data/repos
51     :returns: An initialized Repo object that lives in a temporary directory.
52     """
53     temp_dir = tempfile.mkdtemp()
54     repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
55     temp_repo_dir = os.path.join(temp_dir, name)
56     shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
57     return Repo(temp_repo_dir)
58
59 def tear_down_repo(repo):
60     """Tear down a test repository."""
61     temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
62     shutil.rmtree(temp_dir)
63
64
65
66 class CreateRepositoryTests(unittest.TestCase):
67
68     def test_create(self):
69         tmp_dir = tempfile.mkdtemp()
70         try:
71             repo = Repo.init_bare(tmp_dir)
72             self.assertEquals(tmp_dir, repo._controldir)
73         finally:
74             shutil.rmtree(tmp_dir)
75
76
77 class RepositoryTests(unittest.TestCase):
78
79     def setUp(self):
80         self._repo = None
81
82     def tearDown(self):
83         if self._repo is not None:
84             tear_down_repo(self._repo)
85
86     def test_simple_props(self):
87         r = self._repo = open_repo('a.git')
88         self.assertEqual(r.controldir(), r.path)
89   
90     def test_ref(self):
91         r = self._repo = open_repo('a.git')
92         self.assertEqual(r.ref('refs/heads/master'),
93                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
94   
95     def test_get_refs(self):
96         r = self._repo = open_repo('a.git')
97         self.assertEqual({
98             'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097', 
99             'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
100             }, r.get_refs())
101   
102     def test_head(self):
103         r = self._repo = open_repo('a.git')
104         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
105   
106     def test_get_object(self):
107         r = self._repo = open_repo('a.git')
108         obj = r.get_object(r.head())
109         self.assertEqual(obj._type, 'commit')
110   
111     def test_get_object_non_existant(self):
112         r = self._repo = open_repo('a.git')
113         self.assertRaises(KeyError, r.get_object, missing_sha)
114   
115     def test_commit(self):
116         r = self._repo = open_repo('a.git')
117         obj = r.commit(r.head())
118         self.assertEqual(obj._type, 'commit')
119   
120     def test_commit_not_commit(self):
121         r = self._repo = open_repo('a.git')
122         self.assertRaises(errors.NotCommitError,
123                           r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
124   
125     def test_tree(self):
126         r = self._repo = open_repo('a.git')
127         commit = r.commit(r.head())
128         tree = r.tree(commit.tree)
129         self.assertEqual(tree._type, 'tree')
130         self.assertEqual(tree.sha().hexdigest(), commit.tree)
131   
132     def test_tree_not_tree(self):
133         r = self._repo = open_repo('a.git')
134         self.assertRaises(errors.NotTreeError, r.tree, r.head())
135   
136     def test_get_blob(self):
137         r = self._repo = open_repo('a.git')
138         commit = r.commit(r.head())
139         tree = r.tree(commit.tree)
140         blob_sha = tree.entries()[0][2]
141         blob = r.get_blob(blob_sha)
142         self.assertEqual(blob._type, 'blob')
143         self.assertEqual(blob.sha().hexdigest(), blob_sha)
144   
145     def test_get_blob_notblob(self):
146         r = self._repo = open_repo('a.git')
147         self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
148     
149     def test_linear_history(self):
150         r = self._repo = open_repo('a.git')
151         history = r.revision_history(r.head())
152         shas = [c.sha().hexdigest() for c in history]
153         self.assertEqual(shas, [r.head(),
154                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
155   
156     def test_merge_history(self):
157         r = self._repo = open_repo('simple_merge.git')
158         history = r.revision_history(r.head())
159         shas = [c.sha().hexdigest() for c in history]
160         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
161                                 'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
162                                 '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
163                                 '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
164                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
165   
166     def test_revision_history_missing_commit(self):
167         r = self._repo = open_repo('simple_merge.git')
168         self.assertRaises(errors.MissingCommitError, r.revision_history,
169                           missing_sha)
170   
171     def test_out_of_order_merge(self):
172         """Test that revision history is ordered by date, not parent order."""
173         r = self._repo = open_repo('ooo_merge.git')
174         history = r.revision_history(r.head())
175         shas = [c.sha().hexdigest() for c in history]
176         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
177                                 'f507291b64138b875c28e03469025b1ea20bc614',
178                                 'fb5b0425c7ce46959bec94d54b9a157645e114f5',
179                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
180   
181     def test_get_tags_empty(self):
182         r = self._repo = open_repo('ooo_merge.git')
183         self.assertEqual({}, r.refs.as_dict('refs/tags'))
184
185     def test_get_config(self):
186         r = self._repo = open_repo('ooo_merge.git')
187         self.assertEquals({}, r.get_config())
188
189
190 class CheckRefFormatTests(unittest.TestCase):
191     """Tests for the check_ref_format function.
192
193     These are the same tests as in the git test suite.
194     """
195
196     def test_valid(self):
197         self.assertTrue(check_ref_format('heads/foo'))
198         self.assertTrue(check_ref_format('foo/bar/baz'))
199         self.assertTrue(check_ref_format('refs///heads/foo'))
200         self.assertTrue(check_ref_format('foo./bar'))
201         self.assertTrue(check_ref_format('heads/foo@bar'))
202         self.assertTrue(check_ref_format('heads/fix.lock.error'))
203
204     def test_invalid(self):
205         self.assertFalse(check_ref_format('foo'))
206         self.assertFalse(check_ref_format('heads/foo/'))
207         self.assertFalse(check_ref_format('./foo'))
208         self.assertFalse(check_ref_format('.refs/foo'))
209         self.assertFalse(check_ref_format('heads/foo..bar'))
210         self.assertFalse(check_ref_format('heads/foo?bar'))
211         self.assertFalse(check_ref_format('heads/foo.lock'))
212         self.assertFalse(check_ref_format('heads/v@{ation'))
213         self.assertFalse(check_ref_format('heads/foo\bar'))
214
215
216 ONES = "1" * 40
217 TWOS = "2" * 40
218 THREES = "3" * 40
219 FOURS = "4" * 40
220
221 class PackedRefsFileTests(unittest.TestCase):
222
223     def test_split_ref_line_errors(self):
224         self.assertRaises(errors.PackedRefsException, _split_ref_line,
225                           'singlefield')
226         self.assertRaises(errors.PackedRefsException, _split_ref_line,
227                           'badsha name')
228         self.assertRaises(errors.PackedRefsException, _split_ref_line,
229                           '%s bad/../refname' % ONES)
230
231     def test_read_without_peeled(self):
232         f = StringIO('# comment\n%s ref/1\n%s ref/2' % (ONES, TWOS))
233         self.assertEqual([(ONES, 'ref/1'), (TWOS, 'ref/2')],
234                          list(read_packed_refs(f)))
235
236     def test_read_without_peeled_errors(self):
237         f = StringIO('%s ref/1\n^%s' % (ONES, TWOS))
238         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
239
240     def test_read_with_peeled(self):
241         f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
242             ONES, TWOS, THREES, FOURS))
243         self.assertEqual([
244             (ONES, 'ref/1', None),
245             (TWOS, 'ref/2', THREES),
246             (FOURS, 'ref/4', None),
247             ], list(read_packed_refs_with_peeled(f)))
248
249     def test_read_with_peeled_errors(self):
250         f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
251         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
252
253         f = StringIO('%s ref/1\n^%s\n^%s' % (ONES, TWOS, THREES))
254         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
255
256     def test_write_with_peeled(self):
257         f = StringIO()
258         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
259                           {'ref/1': THREES})
260         self.assertEqual(
261             "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
262             ONES, THREES, TWOS), f.getvalue())
263
264     def test_write_without_peeled(self):
265         f = StringIO()
266         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS})
267         self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
268
269
270 class RefsContainerTests(unittest.TestCase):
271
272     def setUp(self):
273         self._repo = open_repo('refs.git')
274         self._refs = self._repo.refs
275
276     def tearDown(self):
277         tear_down_repo(self._repo)
278
279     def test_get_packed_refs(self):
280         self.assertEqual(
281             {'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c'},
282             self._refs.get_packed_refs())
283
284     def test_keys(self):
285         self.assertEqual([
286             'HEAD',
287             'refs/heads/loop',
288             'refs/heads/master',
289             'refs/tags/refs-0.1',
290             ], sorted(list(self._refs.keys())))
291         self.assertEqual(['loop', 'master'],
292                          sorted(self._refs.keys('refs/heads')))
293         self.assertEqual(['refs-0.1'], list(self._refs.keys('refs/tags')))
294
295     def test_as_dict(self):
296         # refs/heads/loop does not show up
297         self.assertEqual({
298             'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
299             'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
300             'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
301             }, self._refs.as_dict())
302
303     def test_setitem(self):
304         self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
305         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
306                          self._refs['refs/some/ref'])
307         f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
308         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
309                           f.read()[:40])
310         f.close()
311
312     def test_setitem_symbolic(self):
313         ones = '1' * 40
314         self._refs['HEAD'] = ones
315         self.assertEqual(ones, self._refs['HEAD'])
316
317         # ensure HEAD was not modified
318         f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
319         self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
320         f.close()
321
322         # ensure the symbolic link was written through
323         f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
324         self.assertEqual(ones, f.read()[:40])
325         f.close()
326
327     def test_set_if_equals(self):
328         nines = '9' * 40
329         self.assertFalse(self._refs.set_if_equals('HEAD', 'c0ffee', nines))
330         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
331                          self._refs['HEAD'])
332
333         self.assertTrue(self._refs.set_if_equals(
334             'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
335         self.assertEqual(nines, self._refs['HEAD'])
336
337         # ensure symref was followed
338         self.assertEqual(nines, self._refs['refs/heads/master'])
339
340         self.assertFalse(os.path.exists(
341             os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
342         self.assertFalse(os.path.exists(
343             os.path.join(self._refs.path, 'HEAD.lock')))
344
345     def test_add_if_new(self):
346         nines = '9' * 40
347         self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
348         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
349                          self._refs['refs/heads/master'])
350
351         self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
352         self.assertEqual(nines, self._refs['refs/some/ref'])
353
354         # don't overwrite packed ref
355         self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', nines))
356         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
357                          self._refs['refs/tags/refs-0.1'])
358
359     def test_check_refname(self):
360         try:
361             self._refs._check_refname('HEAD')
362         except KeyError:
363             self.fail()
364
365         try:
366             self._refs._check_refname('refs/heads/foo')
367         except KeyError:
368             self.fail()
369
370         self.assertRaises(KeyError, self._refs._check_refname, 'refs')
371         self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
372
373     def test_follow(self):
374         self.assertEquals(
375             ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
376             self._refs._follow('HEAD'))
377         self.assertEquals(
378             ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
379             self._refs._follow('refs/heads/master'))
380         self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
381         self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
382
383     def test_delitem(self):
384         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
385                           self._refs['refs/heads/master'])
386         del self._refs['refs/heads/master']
387         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
388         ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
389         self.assertFalse(os.path.exists(ref_file))
390         self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
391
392     def test_delitem_symbolic(self):
393         self.assertEqual('ref: refs/heads/master',
394                           self._refs.read_loose_ref('HEAD'))
395         del self._refs['HEAD']
396         self.assertRaises(KeyError, lambda: self._refs['HEAD'])
397         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
398                          self._refs['refs/heads/master'])
399         self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
400
401     def test_remove_if_equals(self):
402         nines = '9' * 40
403         self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
404         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
405                          self._refs['HEAD'])
406
407         # HEAD is a symref, so shouldn't equal its dereferenced value
408         self.assertFalse(self._refs.remove_if_equals(
409             'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
410         self.assertTrue(self._refs.remove_if_equals(
411             'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
412         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
413
414         # HEAD is now a broken symref
415         self.assertRaises(KeyError, lambda: self._refs['HEAD'])
416         self.assertEqual('ref: refs/heads/master',
417                           self._refs.read_loose_ref('HEAD'))
418
419         self.assertFalse(os.path.exists(
420             os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
421         self.assertFalse(os.path.exists(
422             os.path.join(self._refs.path, 'HEAD.lock')))
423
424         # test removing ref that is only packed
425         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
426                          self._refs['refs/tags/refs-0.1'])
427         self.assertTrue(
428             self._refs.remove_if_equals('refs/tags/refs-0.1',
429             'df6800012397fb85c56e7418dd4eb9405dee075c'))
430         self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])