Add basic Index object based on libgit2.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_repository.py
1 # test_repository.py -- tests for repository.py
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # of the License or (at your option) any later version of
8 # the License.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 """Tests for the repository."""
21
22 from cStringIO import StringIO
23 import os
24 import shutil
25 import tempfile
26 import warnings
27
28 from dulwich import errors
29 from dulwich.file import (
30     GitFile,
31     )
32 from dulwich.object_store import (
33     tree_lookup_path,
34     )
35 from dulwich import objects
36 from dulwich.repo import (
37     check_ref_format,
38     DictRefsContainer,
39     Repo,
40     MemoryRepo,
41     read_packed_refs,
42     read_packed_refs_with_peeled,
43     write_packed_refs,
44     _split_ref_line,
45     )
46 from dulwich.tests import (
47     TestCase,
48     )
49 from dulwich.tests.utils import (
50     open_repo,
51     tear_down_repo,
52     )
53
54 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
55
56
57 class CreateRepositoryTests(TestCase):
58
59     def assertFileContentsEqual(self, expected, repo, path):
60         f = repo.get_named_file(path)
61         if not f:
62             self.assertEqual(expected, None)
63         else:
64             try:
65                 self.assertEqual(expected, f.read())
66             finally:
67                 f.close()
68
69     def _check_repo_contents(self, repo):
70         self.assertTrue(repo.bare)
71         self.assertFileContentsEqual('Unnamed repository', repo, 'description')
72         self.assertFileContentsEqual('', repo, os.path.join('info', 'exclude'))
73         self.assertFileContentsEqual(None, repo, 'nonexistent file')
74
75     def test_create_disk(self):
76         tmp_dir = tempfile.mkdtemp()
77         try:
78             repo = Repo.init_bare(tmp_dir)
79             self.assertEquals(tmp_dir, repo._controldir)
80             self._check_repo_contents(repo)
81         finally:
82             shutil.rmtree(tmp_dir)
83
84     def test_create_memory(self):
85         repo = MemoryRepo.init_bare([], {})
86         self._check_repo_contents(repo)
87
88
89 class RepositoryTests(TestCase):
90
91     def setUp(self):
92         super(RepositoryTests, self).setUp()
93         self._repo = None
94
95     def tearDown(self):
96         if self._repo is not None:
97             tear_down_repo(self._repo)
98         super(RepositoryTests, self).tearDown()
99
100     def test_simple_props(self):
101         r = self._repo = open_repo('a.git')
102         self.assertEqual(r.controldir(), r.path)
103
104     def test_ref(self):
105         r = self._repo = open_repo('a.git')
106         self.assertEqual(r.ref('refs/heads/master'),
107                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
108
109     def test_setitem(self):
110         r = self._repo = open_repo('a.git')
111         r["refs/tags/foo"] = 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
112         self.assertEquals('a90fa2d900a17e99b433217e988c4eb4a2e9a097',
113                           r["refs/tags/foo"].id)
114
115     def test_get_refs(self):
116         r = self._repo = open_repo('a.git')
117         self.assertEqual({
118             'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
119             'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
120             'refs/tags/mytag': '28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
121             'refs/tags/mytag-packed': 'b0931cadc54336e78a1d980420e3268903b57a50',
122             }, r.get_refs())
123
124     def test_head(self):
125         r = self._repo = open_repo('a.git')
126         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
127
128     def test_get_object(self):
129         r = self._repo = open_repo('a.git')
130         obj = r.get_object(r.head())
131         self.assertEqual(obj.type_name, 'commit')
132
133     def test_get_object_non_existant(self):
134         r = self._repo = open_repo('a.git')
135         self.assertRaises(KeyError, r.get_object, missing_sha)
136
137     def test_contains_object(self):
138         r = self._repo = open_repo('a.git')
139         self.assertTrue(r.head() in r)
140
141     def test_contains_ref(self):
142         r = self._repo = open_repo('a.git')
143         self.assertTrue("HEAD" in r)
144
145     def test_contains_missing(self):
146         r = self._repo = open_repo('a.git')
147         self.assertFalse("bar" in r)
148
149     def test_commit(self):
150         r = self._repo = open_repo('a.git')
151         warnings.simplefilter("ignore", DeprecationWarning)
152         try:
153             obj = r.commit(r.head())
154         finally:
155             warnings.resetwarnings()
156         self.assertEqual(obj.type_name, 'commit')
157
158     def test_commit_not_commit(self):
159         r = self._repo = open_repo('a.git')
160         warnings.simplefilter("ignore", DeprecationWarning)
161         try:
162             self.assertRaises(errors.NotCommitError,
163                 r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
164         finally:
165             warnings.resetwarnings()
166
167     def test_tree(self):
168         r = self._repo = open_repo('a.git')
169         commit = r[r.head()]
170         warnings.simplefilter("ignore", DeprecationWarning)
171         try:
172             tree = r.tree(commit.tree)
173         finally:
174             warnings.resetwarnings()
175         self.assertEqual(tree.type_name, 'tree')
176         self.assertEqual(tree.sha().hexdigest(), commit.tree)
177
178     def test_tree_not_tree(self):
179         r = self._repo = open_repo('a.git')
180         warnings.simplefilter("ignore", DeprecationWarning)
181         try:
182             self.assertRaises(errors.NotTreeError, r.tree, r.head())
183         finally:
184             warnings.resetwarnings()
185
186     def test_tag(self):
187         r = self._repo = open_repo('a.git')
188         tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
189         warnings.simplefilter("ignore", DeprecationWarning)
190         try:
191             tag = r.tag(tag_sha)
192         finally:
193             warnings.resetwarnings()
194         self.assertEqual(tag.type_name, 'tag')
195         self.assertEqual(tag.sha().hexdigest(), tag_sha)
196         obj_class, obj_sha = tag.object
197         self.assertEqual(obj_class, objects.Commit)
198         self.assertEqual(obj_sha, r.head())
199
200     def test_tag_not_tag(self):
201         r = self._repo = open_repo('a.git')
202         warnings.simplefilter("ignore", DeprecationWarning)
203         try:
204             self.assertRaises(errors.NotTagError, r.tag, r.head())
205         finally:
206             warnings.resetwarnings()
207
208     def test_get_peeled(self):
209         # unpacked ref
210         r = self._repo = open_repo('a.git')
211         tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
212         self.assertNotEqual(r[tag_sha].sha().hexdigest(), r.head())
213         self.assertEqual(r.get_peeled('refs/tags/mytag'), r.head())
214
215         # packed ref with cached peeled value
216         packed_tag_sha = 'b0931cadc54336e78a1d980420e3268903b57a50'
217         parent_sha = r[r.head()].parents[0]
218         self.assertNotEqual(r[packed_tag_sha].sha().hexdigest(), parent_sha)
219         self.assertEqual(r.get_peeled('refs/tags/mytag-packed'), parent_sha)
220
221         # TODO: add more corner cases to test repo
222
223     def test_get_peeled_not_tag(self):
224         r = self._repo = open_repo('a.git')
225         self.assertEqual(r.get_peeled('HEAD'), r.head())
226
227     def test_get_blob(self):
228         r = self._repo = open_repo('a.git')
229         commit = r[r.head()]
230         tree = r[commit.tree]
231         blob_sha = tree.entries()[0][2]
232         warnings.simplefilter("ignore", DeprecationWarning)
233         try:
234             blob = r.get_blob(blob_sha)
235         finally:
236             warnings.resetwarnings()
237         self.assertEqual(blob.type_name, 'blob')
238         self.assertEqual(blob.sha().hexdigest(), blob_sha)
239
240     def test_get_blob_notblob(self):
241         r = self._repo = open_repo('a.git')
242         warnings.simplefilter("ignore", DeprecationWarning)
243         try:
244             self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
245         finally:
246             warnings.resetwarnings()
247
248     def test_linear_history(self):
249         r = self._repo = open_repo('a.git')
250         history = r.revision_history(r.head())
251         shas = [c.sha().hexdigest() for c in history]
252         self.assertEqual(shas, [r.head(),
253                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
254
255     def test_merge_history(self):
256         r = self._repo = open_repo('simple_merge.git')
257         history = r.revision_history(r.head())
258         shas = [c.sha().hexdigest() for c in history]
259         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
260                                 'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
261                                 '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
262                                 '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
263                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
264
265     def test_revision_history_missing_commit(self):
266         r = self._repo = open_repo('simple_merge.git')
267         self.assertRaises(errors.MissingCommitError, r.revision_history,
268                           missing_sha)
269
270     def test_out_of_order_merge(self):
271         """Test that revision history is ordered by date, not parent order."""
272         r = self._repo = open_repo('ooo_merge.git')
273         history = r.revision_history(r.head())
274         shas = [c.sha().hexdigest() for c in history]
275         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
276                                 'f507291b64138b875c28e03469025b1ea20bc614',
277                                 'fb5b0425c7ce46959bec94d54b9a157645e114f5',
278                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
279
280     def test_get_tags_empty(self):
281         r = self._repo = open_repo('ooo_merge.git')
282         self.assertEqual({}, r.refs.as_dict('refs/tags'))
283
284     def test_get_config(self):
285         r = self._repo = open_repo('ooo_merge.git')
286         self.assertEquals({}, r.get_config())
287
288     def test_common_revisions(self):
289         """
290         This test demonstrates that ``find_common_revisions()`` actually returns
291         common heads, not revisions; dulwich already uses
292         ``find_common_revisions()`` in such a manner (see
293         ``Repo.fetch_objects()``).
294         """
295
296         expected_shas = set(['60dacdc733de308bb77bb76ce0fb0f9b44c9769e'])
297
298         # Source for objects.
299         r_base = open_repo('simple_merge.git')
300
301         # Re-create each-side of the merge in simple_merge.git.
302         #
303         # Since the trees and blobs are missing, the repository created is
304         # corrupted, but we're only checking for commits for the purpose of this
305         # test, so it's immaterial.
306         r1_dir = tempfile.mkdtemp()
307         r1_commits = ['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', # HEAD
308                       '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
309                       '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
310
311         r2_dir = tempfile.mkdtemp()
312         r2_commits = ['4cffe90e0a41ad3f5190079d7c8f036bde29cbe6', # HEAD
313                       '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
314                       '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
315
316         try:
317             r1 = Repo.init_bare(r1_dir)
318             map(lambda c: r1.object_store.add_object(r_base.get_object(c)), \
319                 r1_commits)
320             r1.refs['HEAD'] = r1_commits[0]
321
322             r2 = Repo.init_bare(r2_dir)
323             map(lambda c: r2.object_store.add_object(r_base.get_object(c)), \
324                 r2_commits)
325             r2.refs['HEAD'] = r2_commits[0]
326
327             # Finally, the 'real' testing!
328             shas = r2.object_store.find_common_revisions(r1.get_graph_walker())
329             self.assertEqual(set(shas), expected_shas)
330
331             shas = r1.object_store.find_common_revisions(r2.get_graph_walker())
332             self.assertEqual(set(shas), expected_shas)
333         finally:
334             shutil.rmtree(r1_dir)
335             shutil.rmtree(r2_dir)
336
337
338 class BuildRepoTests(TestCase):
339     """Tests that build on-disk repos from scratch.
340
341     Repos live in a temp dir and are torn down after each test. They start with
342     a single commit in master having single file named 'a'.
343     """
344
345     def setUp(self):
346         super(BuildRepoTests, self).setUp()
347         repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
348         os.makedirs(repo_dir)
349         r = self._repo = Repo.init(repo_dir)
350         self.assertFalse(r.bare)
351         self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
352         self.assertRaises(KeyError, lambda: r.refs['refs/heads/master'])
353
354         f = open(os.path.join(r.path, 'a'), 'wb')
355         try:
356             f.write('file contents')
357         finally:
358             f.close()
359         r.stage(['a'])
360         commit_sha = r.do_commit('msg',
361                                  committer='Test Committer <test@nodomain.com>',
362                                  author='Test Author <test@nodomain.com>',
363                                  commit_timestamp=12345, commit_timezone=0,
364                                  author_timestamp=12345, author_timezone=0)
365         self.assertEqual([], r[commit_sha].parents)
366         self._root_commit = commit_sha
367
368     def tearDown(self):
369         tear_down_repo(self._repo)
370         super(BuildRepoTests, self).tearDown()
371
372     def test_build_repo(self):
373         r = self._repo
374         self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
375         self.assertEqual(self._root_commit, r.refs['refs/heads/master'])
376         expected_blob = objects.Blob.from_string('file contents')
377         self.assertEqual(expected_blob.data, r[expected_blob.id].data)
378         actual_commit = r[self._root_commit]
379         self.assertEqual('msg', actual_commit.message)
380
381     def test_commit_modified(self):
382         r = self._repo
383         f = open(os.path.join(r.path, 'a'), 'wb')
384         try:
385             f.write('new contents')
386         finally:
387             f.close()
388         r.stage(['a'])
389         commit_sha = r.do_commit('modified a',
390                                  committer='Test Committer <test@nodomain.com>',
391                                  author='Test Author <test@nodomain.com>',
392                                  commit_timestamp=12395, commit_timezone=0,
393                                  author_timestamp=12395, author_timezone=0)
394         self.assertEqual([self._root_commit], r[commit_sha].parents)
395         _, blob_id = tree_lookup_path(r.get_object, r[commit_sha].tree, 'a')
396         self.assertEqual('new contents', r[blob_id].data)
397
398     def test_commit_deleted(self):
399         r = self._repo
400         os.remove(os.path.join(r.path, 'a'))
401         r.stage(['a'])
402         commit_sha = r.do_commit('deleted a',
403                                  committer='Test Committer <test@nodomain.com>',
404                                  author='Test Author <test@nodomain.com>',
405                                  commit_timestamp=12395, commit_timezone=0,
406                                  author_timestamp=12395, author_timezone=0)
407         self.assertEqual([self._root_commit], r[commit_sha].parents)
408         self.assertEqual([], list(r.open_index()))
409         tree = r[r[commit_sha].tree]
410         self.assertEqual([], list(tree.iteritems()))
411
412     def test_commit_fail_ref(self):
413         r = self._repo
414
415         def set_if_equals(name, old_ref, new_ref):
416             return False
417         r.refs.set_if_equals = set_if_equals
418
419         def add_if_new(name, new_ref):
420             self.fail('Unexpected call to add_if_new')
421         r.refs.add_if_new = add_if_new
422
423         old_shas = set(r.object_store)
424         self.assertRaises(errors.CommitError, r.do_commit, 'failed commit',
425                           committer='Test Committer <test@nodomain.com>',
426                           author='Test Author <test@nodomain.com>',
427                           commit_timestamp=12345, commit_timezone=0,
428                           author_timestamp=12345, author_timezone=0)
429         new_shas = set(r.object_store) - old_shas
430         self.assertEqual(1, len(new_shas))
431         # Check that the new commit (now garbage) was added.
432         new_commit = r[new_shas.pop()]
433         self.assertEqual(r[self._root_commit].tree, new_commit.tree)
434         self.assertEqual('failed commit', new_commit.message)
435
436     def test_stage_deleted(self):
437         r = self._repo
438         os.remove(os.path.join(r.path, 'a'))
439         r.stage(['a'])
440         r.stage(['a'])  # double-stage a deleted path
441
442
443 class CheckRefFormatTests(TestCase):
444     """Tests for the check_ref_format function.
445
446     These are the same tests as in the git test suite.
447     """
448
449     def test_valid(self):
450         self.assertTrue(check_ref_format('heads/foo'))
451         self.assertTrue(check_ref_format('foo/bar/baz'))
452         self.assertTrue(check_ref_format('refs///heads/foo'))
453         self.assertTrue(check_ref_format('foo./bar'))
454         self.assertTrue(check_ref_format('heads/foo@bar'))
455         self.assertTrue(check_ref_format('heads/fix.lock.error'))
456
457     def test_invalid(self):
458         self.assertFalse(check_ref_format('foo'))
459         self.assertFalse(check_ref_format('heads/foo/'))
460         self.assertFalse(check_ref_format('./foo'))
461         self.assertFalse(check_ref_format('.refs/foo'))
462         self.assertFalse(check_ref_format('heads/foo..bar'))
463         self.assertFalse(check_ref_format('heads/foo?bar'))
464         self.assertFalse(check_ref_format('heads/foo.lock'))
465         self.assertFalse(check_ref_format('heads/v@{ation'))
466         self.assertFalse(check_ref_format('heads/foo\bar'))
467
468
469 ONES = "1" * 40
470 TWOS = "2" * 40
471 THREES = "3" * 40
472 FOURS = "4" * 40
473
474 class PackedRefsFileTests(TestCase):
475
476     def test_split_ref_line_errors(self):
477         self.assertRaises(errors.PackedRefsException, _split_ref_line,
478                           'singlefield')
479         self.assertRaises(errors.PackedRefsException, _split_ref_line,
480                           'badsha name')
481         self.assertRaises(errors.PackedRefsException, _split_ref_line,
482                           '%s bad/../refname' % ONES)
483
484     def test_read_without_peeled(self):
485         f = StringIO('# comment\n%s ref/1\n%s ref/2' % (ONES, TWOS))
486         self.assertEqual([(ONES, 'ref/1'), (TWOS, 'ref/2')],
487                          list(read_packed_refs(f)))
488
489     def test_read_without_peeled_errors(self):
490         f = StringIO('%s ref/1\n^%s' % (ONES, TWOS))
491         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
492
493     def test_read_with_peeled(self):
494         f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
495           ONES, TWOS, THREES, FOURS))
496         self.assertEqual([
497           (ONES, 'ref/1', None),
498           (TWOS, 'ref/2', THREES),
499           (FOURS, 'ref/4', None),
500           ], list(read_packed_refs_with_peeled(f)))
501
502     def test_read_with_peeled_errors(self):
503         f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
504         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
505
506         f = StringIO('%s ref/1\n^%s\n^%s' % (ONES, TWOS, THREES))
507         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
508
509     def test_write_with_peeled(self):
510         f = StringIO()
511         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
512                           {'ref/1': THREES})
513         self.assertEqual(
514           "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
515           ONES, THREES, TWOS), f.getvalue())
516
517     def test_write_without_peeled(self):
518         f = StringIO()
519         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS})
520         self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
521
522
523 # Dict of refs that we expect all RefsContainerTests subclasses to define.
524 _TEST_REFS = {
525   'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
526   'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
527   'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
528   'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
529   'refs/tags/refs-0.2': '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
530   }
531
532
533 class RefsContainerTests(object):
534
535     def test_keys(self):
536         actual_keys = set(self._refs.keys())
537         self.assertEqual(set(self._refs.allkeys()), actual_keys)
538         # ignore the symref loop if it exists
539         actual_keys.discard('refs/heads/loop')
540         self.assertEqual(set(_TEST_REFS.iterkeys()), actual_keys)
541
542         actual_keys = self._refs.keys('refs/heads')
543         actual_keys.discard('loop')
544         self.assertEqual(['master', 'packed'], sorted(actual_keys))
545         self.assertEqual(['refs-0.1', 'refs-0.2'],
546                          sorted(self._refs.keys('refs/tags')))
547
548     def test_as_dict(self):
549         # refs/heads/loop does not show up even if it exists
550         self.assertEqual(_TEST_REFS, self._refs.as_dict())
551
552     def test_setitem(self):
553         self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
554         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
555                          self._refs['refs/some/ref'])
556
557     def test_set_if_equals(self):
558         nines = '9' * 40
559         self.assertFalse(self._refs.set_if_equals('HEAD', 'c0ffee', nines))
560         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
561                          self._refs['HEAD'])
562
563         self.assertTrue(self._refs.set_if_equals(
564           'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
565         self.assertEqual(nines, self._refs['HEAD'])
566
567         self.assertTrue(self._refs.set_if_equals('refs/heads/master', None,
568                                                  nines))
569         self.assertEqual(nines, self._refs['refs/heads/master'])
570
571     def test_add_if_new(self):
572         nines = '9' * 40
573         self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
574         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
575                          self._refs['refs/heads/master'])
576
577         self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
578         self.assertEqual(nines, self._refs['refs/some/ref'])
579
580     def test_set_symbolic_ref(self):
581         self._refs.set_symbolic_ref('refs/heads/symbolic', 'refs/heads/master')
582         self.assertEqual('ref: refs/heads/master',
583                          self._refs.read_loose_ref('refs/heads/symbolic'))
584         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
585                          self._refs['refs/heads/symbolic'])
586
587     def test_set_symbolic_ref_overwrite(self):
588         nines = '9' * 40
589         self.assertFalse('refs/heads/symbolic' in self._refs)
590         self._refs['refs/heads/symbolic'] = nines
591         self.assertEqual(nines, self._refs.read_loose_ref('refs/heads/symbolic'))
592         self._refs.set_symbolic_ref('refs/heads/symbolic', 'refs/heads/master')
593         self.assertEqual('ref: refs/heads/master',
594                          self._refs.read_loose_ref('refs/heads/symbolic'))
595         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
596                          self._refs['refs/heads/symbolic'])
597
598     def test_check_refname(self):
599         try:
600             self._refs._check_refname('HEAD')
601         except KeyError:
602             self.fail()
603
604         try:
605             self._refs._check_refname('refs/heads/foo')
606         except KeyError:
607             self.fail()
608
609         self.assertRaises(KeyError, self._refs._check_refname, 'refs')
610         self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
611
612     def test_contains(self):
613         self.assertTrue('refs/heads/master' in self._refs)
614         self.assertFalse('refs/heads/bar' in self._refs)
615
616     def test_delitem(self):
617         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
618                           self._refs['refs/heads/master'])
619         del self._refs['refs/heads/master']
620         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
621
622     def test_remove_if_equals(self):
623         self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
624         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
625                          self._refs['HEAD'])
626         self.assertTrue(self._refs.remove_if_equals(
627           'refs/tags/refs-0.2', '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8'))
628         self.assertFalse('refs/tags/refs-0.2' in self._refs)
629
630
631 class DictRefsContainerTests(RefsContainerTests, TestCase):
632
633     def setUp(self):
634         TestCase.setUp(self)
635         self._refs = DictRefsContainer(dict(_TEST_REFS))
636
637
638 class DiskRefsContainerTests(RefsContainerTests, TestCase):
639
640     def setUp(self):
641         TestCase.setUp(self)
642         self._repo = open_repo('refs.git')
643         self._refs = self._repo.refs
644
645     def tearDown(self):
646         tear_down_repo(self._repo)
647         TestCase.tearDown(self)
648
649     def test_get_packed_refs(self):
650         self.assertEqual({
651           'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
652           'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
653           }, self._refs.get_packed_refs())
654
655     def test_get_peeled_not_packed(self):
656         # not packed
657         self.assertEqual(None, self._refs.get_peeled('refs/tags/refs-0.2'))
658         self.assertEqual('3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
659                          self._refs['refs/tags/refs-0.2'])
660
661         # packed, known not peelable
662         self.assertEqual(self._refs['refs/heads/packed'],
663                          self._refs.get_peeled('refs/heads/packed'))
664
665         # packed, peeled
666         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
667                          self._refs.get_peeled('refs/tags/refs-0.1'))
668
669     def test_setitem(self):
670         RefsContainerTests.test_setitem(self)
671         f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
672         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
673                           f.read()[:40])
674         f.close()
675
676     def test_setitem_symbolic(self):
677         ones = '1' * 40
678         self._refs['HEAD'] = ones
679         self.assertEqual(ones, self._refs['HEAD'])
680
681         # ensure HEAD was not modified
682         f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
683         self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
684         f.close()
685
686         # ensure the symbolic link was written through
687         f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
688         self.assertEqual(ones, f.read()[:40])
689         f.close()
690
691     def test_set_if_equals(self):
692         RefsContainerTests.test_set_if_equals(self)
693
694         # ensure symref was followed
695         self.assertEqual('9' * 40, self._refs['refs/heads/master'])
696
697         # ensure lockfile was deleted
698         self.assertFalse(os.path.exists(
699           os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
700         self.assertFalse(os.path.exists(
701           os.path.join(self._refs.path, 'HEAD.lock')))
702
703     def test_add_if_new_packed(self):
704         # don't overwrite packed ref
705         self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', '9' * 40))
706         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
707                          self._refs['refs/tags/refs-0.1'])
708
709     def test_add_if_new_symbolic(self):
710         # Use an empty repo instead of the default.
711         tear_down_repo(self._repo)
712         repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
713         os.makedirs(repo_dir)
714         self._repo = Repo.init(repo_dir)
715         refs = self._repo.refs
716
717         nines = '9' * 40
718         self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
719         self.assertFalse('refs/heads/master' in refs)
720         self.assertTrue(refs.add_if_new('HEAD', nines))
721         self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
722         self.assertEqual(nines, refs['HEAD'])
723         self.assertEqual(nines, refs['refs/heads/master'])
724         self.assertFalse(refs.add_if_new('HEAD', '1' * 40))
725         self.assertEqual(nines, refs['HEAD'])
726         self.assertEqual(nines, refs['refs/heads/master'])
727
728     def test_follow(self):
729         self.assertEquals(
730           ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
731           self._refs._follow('HEAD'))
732         self.assertEquals(
733           ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
734           self._refs._follow('refs/heads/master'))
735         self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
736         self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
737
738     def test_delitem(self):
739         RefsContainerTests.test_delitem(self)
740         ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
741         self.assertFalse(os.path.exists(ref_file))
742         self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
743
744     def test_delitem_symbolic(self):
745         self.assertEqual('ref: refs/heads/master',
746                           self._refs.read_loose_ref('HEAD'))
747         del self._refs['HEAD']
748         self.assertRaises(KeyError, lambda: self._refs['HEAD'])
749         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
750                          self._refs['refs/heads/master'])
751         self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
752
753     def test_remove_if_equals_symref(self):
754         # HEAD is a symref, so shouldn't equal its dereferenced value
755         self.assertFalse(self._refs.remove_if_equals(
756           'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
757         self.assertTrue(self._refs.remove_if_equals(
758           'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
759         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
760
761         # HEAD is now a broken symref
762         self.assertRaises(KeyError, lambda: self._refs['HEAD'])
763         self.assertEqual('ref: refs/heads/master',
764                           self._refs.read_loose_ref('HEAD'))
765
766         self.assertFalse(os.path.exists(
767             os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
768         self.assertFalse(os.path.exists(
769             os.path.join(self._refs.path, 'HEAD.lock')))
770
771     def test_remove_packed_without_peeled(self):
772         refs_file = os.path.join(self._repo.path, 'packed-refs')
773         f = GitFile(refs_file)
774         refs_data = f.read()
775         f.close()
776         f = GitFile(refs_file, 'wb')
777         f.write('\n'.join(l for l in refs_data.split('\n')
778                           if not l or l[0] not in '#^'))
779         f.close()
780         self._repo = Repo(self._repo.path)
781         refs = self._repo.refs
782         self.assertTrue(refs.remove_if_equals(
783           'refs/heads/packed', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
784
785     def test_remove_if_equals_packed(self):
786         # test removing ref that is only packed
787         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
788                          self._refs['refs/tags/refs-0.1'])
789         self.assertTrue(
790           self._refs.remove_if_equals('refs/tags/refs-0.1',
791           'df6800012397fb85c56e7418dd4eb9405dee075c'))
792         self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])
793
794     def test_read_ref(self):
795         self.assertEqual('ref: refs/heads/master', self._refs.read_ref("HEAD"))
796         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
797             self._refs.read_ref("refs/heads/packed"))
798         self.assertEqual(None,
799             self._refs.read_ref("nonexistant"))