Add non-bare repository tests.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_repository.py
1 # test_repository.py -- tests for repository.py
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # of the License or (at your option) any later version of 
8 # the License.
9
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20
21 """Tests for the repository."""
22
23 from cStringIO import StringIO
24 import os
25 import shutil
26 import tempfile
27 import unittest
28 import warnings
29
30 from dulwich import errors
31 from dulwich.object_store import (
32     tree_lookup_path,
33     )
34 from dulwich import objects
35 from dulwich.repo import (
36     check_ref_format,
37     DictRefsContainer,
38     Repo,
39     read_packed_refs,
40     read_packed_refs_with_peeled,
41     write_packed_refs,
42     _split_ref_line,
43     )
44 from dulwich.tests.utils import (
45     open_repo,
46     tear_down_repo,
47     )
48
49 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
50
51
52 class CreateRepositoryTests(unittest.TestCase):
53
54     def test_create(self):
55         tmp_dir = tempfile.mkdtemp()
56         try:
57             repo = Repo.init_bare(tmp_dir)
58             self.assertEquals(tmp_dir, repo._controldir)
59         finally:
60             shutil.rmtree(tmp_dir)
61
62
63 class RepositoryTests(unittest.TestCase):
64
65     def setUp(self):
66         self._repo = None
67
68     def tearDown(self):
69         if self._repo is not None:
70             tear_down_repo(self._repo)
71   
72     def test_simple_props(self):
73         r = self._repo = open_repo('a.git')
74         self.assertEqual(r.controldir(), r.path)
75   
76     def test_ref(self):
77         r = self._repo = open_repo('a.git')
78         self.assertEqual(r.ref('refs/heads/master'),
79                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
80
81     def test_setitem(self):
82         r = self._repo = open_repo('a.git')
83         r["refs/tags/foo"] = 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
84         self.assertEquals('a90fa2d900a17e99b433217e988c4eb4a2e9a097',
85                           r["refs/tags/foo"].id)
86   
87     def test_get_refs(self):
88         r = self._repo = open_repo('a.git')
89         self.assertEqual({
90             'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
91             'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
92             'refs/tags/mytag': '28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
93             'refs/tags/mytag-packed': 'b0931cadc54336e78a1d980420e3268903b57a50',
94             }, r.get_refs())
95   
96     def test_head(self):
97         r = self._repo = open_repo('a.git')
98         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
99
100     def test_get_object(self):
101         r = self._repo = open_repo('a.git')
102         obj = r.get_object(r.head())
103         self.assertEqual(obj.type_name, 'commit')
104
105     def test_get_object_non_existant(self):
106         r = self._repo = open_repo('a.git')
107         self.assertRaises(KeyError, r.get_object, missing_sha)
108
109     def test_contains_object(self):
110         r = self._repo = open_repo('a.git')
111         self.assertTrue(r.head() in r)
112
113     def test_contains_ref(self):
114         r = self._repo = open_repo('a.git')
115         self.assertTrue("HEAD" in r)
116
117     def test_contains_missing(self):
118         r = self._repo = open_repo('a.git')
119         self.assertFalse("bar" in r)
120
121     def test_commit(self):
122         r = self._repo = open_repo('a.git')
123         warnings.simplefilter("ignore", DeprecationWarning)
124         try:
125             obj = r.commit(r.head())
126         finally:
127             warnings.resetwarnings()
128         self.assertEqual(obj.type_name, 'commit')
129
130     def test_commit_not_commit(self):
131         r = self._repo = open_repo('a.git')
132         warnings.simplefilter("ignore", DeprecationWarning)
133         try:
134             self.assertRaises(errors.NotCommitError,
135                 r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
136         finally:
137             warnings.resetwarnings()
138
139     def test_tree(self):
140         r = self._repo = open_repo('a.git')
141         commit = r[r.head()]
142         warnings.simplefilter("ignore", DeprecationWarning)
143         try:
144             tree = r.tree(commit.tree)
145         finally:
146             warnings.resetwarnings()
147         self.assertEqual(tree.type_name, 'tree')
148         self.assertEqual(tree.sha().hexdigest(), commit.tree)
149
150     def test_tree_not_tree(self):
151         r = self._repo = open_repo('a.git')
152         warnings.simplefilter("ignore", DeprecationWarning)
153         try:
154             self.assertRaises(errors.NotTreeError, r.tree, r.head())
155         finally:
156             warnings.resetwarnings()
157
158     def test_tag(self):
159         r = self._repo = open_repo('a.git')
160         tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
161         warnings.simplefilter("ignore", DeprecationWarning)
162         try:
163             tag = r.tag(tag_sha)
164         finally:
165             warnings.resetwarnings()
166         self.assertEqual(tag.type_name, 'tag')
167         self.assertEqual(tag.sha().hexdigest(), tag_sha)
168         obj_class, obj_sha = tag.object
169         self.assertEqual(obj_class, objects.Commit)
170         self.assertEqual(obj_sha, r.head())
171
172     def test_tag_not_tag(self):
173         r = self._repo = open_repo('a.git')
174         warnings.simplefilter("ignore", DeprecationWarning)
175         try:
176             self.assertRaises(errors.NotTagError, r.tag, r.head())
177         finally:
178             warnings.resetwarnings()
179
180     def test_get_peeled(self):
181         # unpacked ref
182         r = self._repo = open_repo('a.git')
183         tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
184         self.assertNotEqual(r[tag_sha].sha().hexdigest(), r.head())
185         self.assertEqual(r.get_peeled('refs/tags/mytag'), r.head())
186
187         # packed ref with cached peeled value
188         packed_tag_sha = 'b0931cadc54336e78a1d980420e3268903b57a50'
189         parent_sha = r[r.head()].parents[0]
190         self.assertNotEqual(r[packed_tag_sha].sha().hexdigest(), parent_sha)
191         self.assertEqual(r.get_peeled('refs/tags/mytag-packed'), parent_sha)
192
193         # TODO: add more corner cases to test repo
194
195     def test_get_peeled_not_tag(self):
196         r = self._repo = open_repo('a.git')
197         self.assertEqual(r.get_peeled('HEAD'), r.head())
198
199     def test_get_blob(self):
200         r = self._repo = open_repo('a.git')
201         commit = r[r.head()]
202         tree = r[commit.tree]
203         blob_sha = tree.entries()[0][2]
204         warnings.simplefilter("ignore", DeprecationWarning)
205         try:
206             blob = r.get_blob(blob_sha)
207         finally:
208             warnings.resetwarnings()
209         self.assertEqual(blob.type_name, 'blob')
210         self.assertEqual(blob.sha().hexdigest(), blob_sha)
211
212     def test_get_blob_notblob(self):
213         r = self._repo = open_repo('a.git')
214         warnings.simplefilter("ignore", DeprecationWarning)
215         try:
216             self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
217         finally:
218             warnings.resetwarnings()
219     
220     def test_linear_history(self):
221         r = self._repo = open_repo('a.git')
222         history = r.revision_history(r.head())
223         shas = [c.sha().hexdigest() for c in history]
224         self.assertEqual(shas, [r.head(),
225                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
226   
227     def test_merge_history(self):
228         r = self._repo = open_repo('simple_merge.git')
229         history = r.revision_history(r.head())
230         shas = [c.sha().hexdigest() for c in history]
231         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
232                                 'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
233                                 '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
234                                 '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
235                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
236   
237     def test_revision_history_missing_commit(self):
238         r = self._repo = open_repo('simple_merge.git')
239         self.assertRaises(errors.MissingCommitError, r.revision_history,
240                           missing_sha)
241   
242     def test_out_of_order_merge(self):
243         """Test that revision history is ordered by date, not parent order."""
244         r = self._repo = open_repo('ooo_merge.git')
245         history = r.revision_history(r.head())
246         shas = [c.sha().hexdigest() for c in history]
247         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
248                                 'f507291b64138b875c28e03469025b1ea20bc614',
249                                 'fb5b0425c7ce46959bec94d54b9a157645e114f5',
250                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
251   
252     def test_get_tags_empty(self):
253         r = self._repo = open_repo('ooo_merge.git')
254         self.assertEqual({}, r.refs.as_dict('refs/tags'))
255
256     def test_get_config(self):
257         r = self._repo = open_repo('ooo_merge.git')
258         self.assertEquals({}, r.get_config())
259
260     def test_common_revisions(self):
261         """
262         This test demonstrates that ``find_common_revisions()`` actually returns
263         common heads, not revisions; dulwich already uses
264         ``find_common_revisions()`` in such a manner (see
265         ``Repo.fetch_objects()``).
266         """
267
268         expected_shas = set(['60dacdc733de308bb77bb76ce0fb0f9b44c9769e'])
269
270         # Source for objects.
271         r_base = open_repo('simple_merge.git')
272
273         # Re-create each-side of the merge in simple_merge.git.
274         #
275         # Since the trees and blobs are missing, the repository created is
276         # corrupted, but we're only checking for commits for the purpose of this
277         # test, so it's immaterial.
278         r1_dir = tempfile.mkdtemp()
279         r1_commits = ['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', # HEAD
280                       '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
281                       '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
282
283         r2_dir = tempfile.mkdtemp()
284         r2_commits = ['4cffe90e0a41ad3f5190079d7c8f036bde29cbe6', # HEAD
285                       '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
286                       '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
287
288         try:
289             r1 = Repo.init_bare(r1_dir)
290             map(lambda c: r1.object_store.add_object(r_base.get_object(c)), \
291                 r1_commits)
292             r1.refs['HEAD'] = r1_commits[0]
293
294             r2 = Repo.init_bare(r2_dir)
295             map(lambda c: r2.object_store.add_object(r_base.get_object(c)), \
296                 r2_commits)
297             r2.refs['HEAD'] = r2_commits[0]
298
299             # Finally, the 'real' testing!
300             shas = r2.object_store.find_common_revisions(r1.get_graph_walker())
301             self.assertEqual(set(shas), expected_shas)
302
303             shas = r1.object_store.find_common_revisions(r2.get_graph_walker())
304             self.assertEqual(set(shas), expected_shas)
305         finally:
306             shutil.rmtree(r1_dir)
307             shutil.rmtree(r2_dir)
308
309     def _build_initial_repo(self):
310         repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
311         os.makedirs(repo_dir)
312         r = self._repo = Repo.init(repo_dir)
313         self.assertFalse(r.bare)
314         self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
315         self.assertRaises(KeyError, lambda: r.refs['refs/heads/master'])
316
317         f = open(os.path.join(r.path, 'a'), 'wb')
318         try:
319             f.write('file contents')
320         finally:
321             f.close()
322         r.stage(['a'])
323         commit_sha = r.do_commit('msg',
324                                  committer='Test Committer <test@nodomain.com>',
325                                  author='Test Author <test@nodomain.com>',
326                                  commit_timestamp=12345, commit_timezone=0,
327                                  author_timestamp=12345, author_timezone=0)
328         self.assertEqual([], r[commit_sha].parents)
329         return commit_sha
330
331     def test_build_repo(self):
332         commit_sha = self._build_initial_repo()
333         r = self._repo
334         self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
335         self.assertEqual(commit_sha, r.refs['refs/heads/master'])
336         expected_blob = objects.Blob.from_string('file contents')
337         self.assertEqual(expected_blob.data, r[expected_blob.id].data)
338         actual_commit = r[commit_sha]
339         self.assertEqual('msg', actual_commit.message)
340
341     def test_commit_modified(self):
342         parent_sha = self._build_initial_repo()
343         r = self._repo
344         f = open(os.path.join(r.path, 'a'), 'wb')
345         try:
346             f.write('new contents')
347         finally:
348             f.close()
349         r.stage(['a'])
350         commit_sha = r.do_commit('modified a',
351                                  committer='Test Committer <test@nodomain.com>',
352                                  author='Test Author <test@nodomain.com>',
353                                  commit_timestamp=12395, commit_timezone=0,
354                                  author_timestamp=12395, author_timezone=0)
355         self.assertEqual([parent_sha], r[commit_sha].parents)
356         _, blob_id = tree_lookup_path(r.get_object, r[commit_sha].tree, 'a')
357         self.assertEqual('new contents', r[blob_id].data)
358
359     def test_commit_deleted(self):
360         parent_sha = self._build_initial_repo()
361         r = self._repo
362         os.remove(os.path.join(r.path, 'a'))
363         r.stage(['a'])
364         commit_sha = r.do_commit('deleted a',
365                                  committer='Test Committer <test@nodomain.com>',
366                                  author='Test Author <test@nodomain.com>',
367                                  commit_timestamp=12395, commit_timezone=0,
368                                  author_timestamp=12395, author_timezone=0)
369         self.assertEqual([parent_sha], r[commit_sha].parents)
370         self.assertEqual([], list(r.open_index()))
371         tree = r[r[commit_sha].tree]
372         self.assertEqual([], tree.iteritems())
373
374     def test_commit_fail_ref(self):
375         repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
376         os.makedirs(repo_dir)
377         r = self._repo = Repo.init(repo_dir)
378
379         def set_if_equals(name, old_ref, new_ref):
380             self.fail('Unexpected call to set_if_equals')
381         r.refs.set_if_equals = set_if_equals
382
383         def add_if_new(name, new_ref):
384             return False
385         r.refs.add_if_new = add_if_new
386
387         self.assertRaises(errors.CommitError, r.do_commit, 'failed commit',
388                           committer='Test Committer <test@nodomain.com>',
389                           author='Test Author <test@nodomain.com>',
390                           commit_timestamp=12345, commit_timezone=0,
391                           author_timestamp=12345, author_timezone=0)
392         shas = list(r.object_store)
393         self.assertEqual(2, len(shas))
394         for sha in shas:
395             obj = r[sha]
396             if isinstance(obj, objects.Commit):
397                 commit = obj
398             elif isinstance(obj, objects.Tree):
399                 tree = obj
400             else:
401                 self.fail('Unexpected object found: %s' % sha)
402         self.assertEqual(tree.id, commit.tree)
403
404
405 class CheckRefFormatTests(unittest.TestCase):
406     """Tests for the check_ref_format function.
407
408     These are the same tests as in the git test suite.
409     """
410
411     def test_valid(self):
412         self.assertTrue(check_ref_format('heads/foo'))
413         self.assertTrue(check_ref_format('foo/bar/baz'))
414         self.assertTrue(check_ref_format('refs///heads/foo'))
415         self.assertTrue(check_ref_format('foo./bar'))
416         self.assertTrue(check_ref_format('heads/foo@bar'))
417         self.assertTrue(check_ref_format('heads/fix.lock.error'))
418
419     def test_invalid(self):
420         self.assertFalse(check_ref_format('foo'))
421         self.assertFalse(check_ref_format('heads/foo/'))
422         self.assertFalse(check_ref_format('./foo'))
423         self.assertFalse(check_ref_format('.refs/foo'))
424         self.assertFalse(check_ref_format('heads/foo..bar'))
425         self.assertFalse(check_ref_format('heads/foo?bar'))
426         self.assertFalse(check_ref_format('heads/foo.lock'))
427         self.assertFalse(check_ref_format('heads/v@{ation'))
428         self.assertFalse(check_ref_format('heads/foo\bar'))
429
430
431 ONES = "1" * 40
432 TWOS = "2" * 40
433 THREES = "3" * 40
434 FOURS = "4" * 40
435
436 class PackedRefsFileTests(unittest.TestCase):
437
438     def test_split_ref_line_errors(self):
439         self.assertRaises(errors.PackedRefsException, _split_ref_line,
440                           'singlefield')
441         self.assertRaises(errors.PackedRefsException, _split_ref_line,
442                           'badsha name')
443         self.assertRaises(errors.PackedRefsException, _split_ref_line,
444                           '%s bad/../refname' % ONES)
445
446     def test_read_without_peeled(self):
447         f = StringIO('# comment\n%s ref/1\n%s ref/2' % (ONES, TWOS))
448         self.assertEqual([(ONES, 'ref/1'), (TWOS, 'ref/2')],
449                          list(read_packed_refs(f)))
450
451     def test_read_without_peeled_errors(self):
452         f = StringIO('%s ref/1\n^%s' % (ONES, TWOS))
453         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
454
455     def test_read_with_peeled(self):
456         f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
457           ONES, TWOS, THREES, FOURS))
458         self.assertEqual([
459           (ONES, 'ref/1', None),
460           (TWOS, 'ref/2', THREES),
461           (FOURS, 'ref/4', None),
462           ], list(read_packed_refs_with_peeled(f)))
463
464     def test_read_with_peeled_errors(self):
465         f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
466         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
467
468         f = StringIO('%s ref/1\n^%s\n^%s' % (ONES, TWOS, THREES))
469         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
470
471     def test_write_with_peeled(self):
472         f = StringIO()
473         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
474                           {'ref/1': THREES})
475         self.assertEqual(
476           "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
477           ONES, THREES, TWOS), f.getvalue())
478
479     def test_write_without_peeled(self):
480         f = StringIO()
481         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS})
482         self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
483
484
485 # Dict of refs that we expect all RefsContainerTests subclasses to define.
486 _TEST_REFS = {
487   'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
488   'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
489   'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
490   'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
491   'refs/tags/refs-0.2': '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
492   }
493
494
495 class RefsContainerTests(object):
496
497     def test_keys(self):
498         actual_keys = set(self._refs.keys())
499         self.assertEqual(set(self._refs.allkeys()), actual_keys)
500         # ignore the symref loop if it exists
501         actual_keys.discard('refs/heads/loop')
502         self.assertEqual(set(_TEST_REFS.iterkeys()), actual_keys)
503
504         actual_keys = self._refs.keys('refs/heads')
505         actual_keys.discard('loop')
506         self.assertEqual(['master', 'packed'], sorted(actual_keys))
507         self.assertEqual(['refs-0.1', 'refs-0.2'],
508                          sorted(self._refs.keys('refs/tags')))
509
510     def test_as_dict(self):
511         # refs/heads/loop does not show up even if it exists
512         self.assertEqual(_TEST_REFS, self._refs.as_dict())
513
514     def test_setitem(self):
515         self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
516         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
517                          self._refs['refs/some/ref'])
518
519     def test_set_if_equals(self):
520         nines = '9' * 40
521         self.assertFalse(self._refs.set_if_equals('HEAD', 'c0ffee', nines))
522         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
523                          self._refs['HEAD'])
524
525         self.assertTrue(self._refs.set_if_equals(
526           'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
527         self.assertEqual(nines, self._refs['HEAD'])
528
529         self.assertTrue(self._refs.set_if_equals('refs/heads/master', None,
530                                                  nines))
531         self.assertEqual(nines, self._refs['refs/heads/master'])
532
533     def test_add_if_new(self):
534         nines = '9' * 40
535         self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
536         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
537                          self._refs['refs/heads/master'])
538
539         self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
540         self.assertEqual(nines, self._refs['refs/some/ref'])
541
542     def test_check_refname(self):
543         try:
544             self._refs._check_refname('HEAD')
545         except KeyError:
546             self.fail()
547
548         try:
549             self._refs._check_refname('refs/heads/foo')
550         except KeyError:
551             self.fail()
552
553         self.assertRaises(KeyError, self._refs._check_refname, 'refs')
554         self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
555
556     def test_contains(self):
557         self.assertTrue('refs/heads/master' in self._refs)
558         self.assertFalse('refs/heads/bar' in self._refs)
559
560     def test_delitem(self):
561         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
562                           self._refs['refs/heads/master'])
563         del self._refs['refs/heads/master']
564         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
565
566     def test_remove_if_equals(self):
567         self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
568         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
569                          self._refs['HEAD'])
570         self.assertTrue(self._refs.remove_if_equals(
571           'refs/tags/refs-0.2', '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8'))
572         self.assertFalse('refs/tags/refs-0.2' in self._refs)
573
574
575 class DictRefsContainerTests(RefsContainerTests, unittest.TestCase):
576
577     def setUp(self):
578         self._refs = DictRefsContainer(dict(_TEST_REFS))
579
580
581 class DiskRefsContainerTests(RefsContainerTests, unittest.TestCase):
582
583     def setUp(self):
584         self._repo = open_repo('refs.git')
585         self._refs = self._repo.refs
586
587     def tearDown(self):
588         tear_down_repo(self._repo)
589
590     def test_get_packed_refs(self):
591         self.assertEqual({
592           'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
593           'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
594           }, self._refs.get_packed_refs())
595
596     def test_get_peeled_not_packed(self):
597         # not packed
598         self.assertEqual(None, self._refs.get_peeled('refs/tags/refs-0.2'))
599         self.assertEqual('3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
600                          self._refs['refs/tags/refs-0.2'])
601
602         # packed, known not peelable
603         self.assertEqual(self._refs['refs/heads/packed'],
604                          self._refs.get_peeled('refs/heads/packed'))
605
606         # packed, peeled
607         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
608                          self._refs.get_peeled('refs/tags/refs-0.1'))
609
610     def test_setitem(self):
611         RefsContainerTests.test_setitem(self)
612         f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
613         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
614                           f.read()[:40])
615         f.close()
616
617     def test_setitem_symbolic(self):
618         ones = '1' * 40
619         self._refs['HEAD'] = ones
620         self.assertEqual(ones, self._refs['HEAD'])
621
622         # ensure HEAD was not modified
623         f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
624         self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
625         f.close()
626
627         # ensure the symbolic link was written through
628         f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
629         self.assertEqual(ones, f.read()[:40])
630         f.close()
631
632     def test_set_if_equals(self):
633         RefsContainerTests.test_set_if_equals(self)
634
635         # ensure symref was followed
636         self.assertEqual('9' * 40, self._refs['refs/heads/master'])
637
638         # ensure lockfile was deleted
639         self.assertFalse(os.path.exists(
640           os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
641         self.assertFalse(os.path.exists(
642           os.path.join(self._refs.path, 'HEAD.lock')))
643
644     def test_add_if_new_packed(self):
645         # don't overwrite packed ref
646         self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', '9' * 40))
647         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
648                          self._refs['refs/tags/refs-0.1'])
649
650     def test_add_if_new_symbolic(self):
651         # Use an empty repo instead of the default.
652         tear_down_repo(self._repo)
653         repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
654         os.makedirs(repo_dir)
655         self._repo = Repo.init(repo_dir)
656         refs = self._repo.refs
657
658         nines = '9' * 40
659         self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
660         self.assertFalse('refs/heads/master' in refs)
661         self.assertTrue(refs.add_if_new('HEAD', nines))
662         self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
663         self.assertEqual(nines, refs['HEAD'])
664         self.assertEqual(nines, refs['refs/heads/master'])
665         self.assertFalse(refs.add_if_new('HEAD', '1' * 40))
666         self.assertEqual(nines, refs['HEAD'])
667         self.assertEqual(nines, refs['refs/heads/master'])
668
669     def test_follow(self):
670         self.assertEquals(
671           ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
672           self._refs._follow('HEAD'))
673         self.assertEquals(
674           ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
675           self._refs._follow('refs/heads/master'))
676         self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
677         self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
678
679     def test_delitem(self):
680         RefsContainerTests.test_delitem(self)
681         ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
682         self.assertFalse(os.path.exists(ref_file))
683         self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
684
685     def test_delitem_symbolic(self):
686         self.assertEqual('ref: refs/heads/master',
687                           self._refs.read_loose_ref('HEAD'))
688         del self._refs['HEAD']
689         self.assertRaises(KeyError, lambda: self._refs['HEAD'])
690         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
691                          self._refs['refs/heads/master'])
692         self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
693
694     def test_remove_if_equals_symref(self):
695         # HEAD is a symref, so shouldn't equal its dereferenced value
696         self.assertFalse(self._refs.remove_if_equals(
697           'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
698         self.assertTrue(self._refs.remove_if_equals(
699           'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
700         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
701
702         # HEAD is now a broken symref
703         self.assertRaises(KeyError, lambda: self._refs['HEAD'])
704         self.assertEqual('ref: refs/heads/master',
705                           self._refs.read_loose_ref('HEAD'))
706
707         self.assertFalse(os.path.exists(
708             os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
709         self.assertFalse(os.path.exists(
710             os.path.join(self._refs.path, 'HEAD.lock')))
711
712
713     def test_remove_if_equals_packed(self):
714         # test removing ref that is only packed
715         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
716                          self._refs['refs/tags/refs-0.1'])
717         self.assertTrue(
718           self._refs.remove_if_equals('refs/tags/refs-0.1',
719           'df6800012397fb85c56e7418dd4eb9405dee075c'))
720         self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])
721
722     def test_read_ref(self):
723         self.assertEqual('ref: refs/heads/master', self._refs.read_ref("HEAD"))
724         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec', 
725             self._refs.read_ref("refs/heads/packed"))
726         self.assertEqual(None,
727             self._refs.read_ref("nonexistant"))