def do_commit(self, message, committer=None,
author=None, commit_timestamp=None,
commit_timezone=None, author_timestamp=None,
- author_timezone=None, tree=None, encoding=None):
+ author_timezone=None, tree=None, encoding=None, branch='HEAD'):
"""Create a new commit.
:param message: Commit message
c.encoding = encoding
c.message = message
try:
- old_head = self.refs["HEAD"]
+ old_head = self.refs[branch]
c.parents = [old_head]
self.object_store.add_object(c)
- ok = self.refs.set_if_equals("HEAD", old_head, c.id)
+ ok = self.refs.set_if_equals(branch, old_head, c.id)
except KeyError:
c.parents = []
self.object_store.add_object(c)
- ok = self.refs.add_if_new("HEAD", c.id)
+ ok = self.refs.add_if_new(branch, c.id)
if not ok:
# Fail if the atomic compare-and-swap failed, leaving the commit and
# all its objects as garbage.
- raise CommitError("HEAD changed during commit")
+ raise CommitError("%s changed during commit" % (branch))
return c.id
self.assertEqual(r[self._root_commit].tree, new_commit.tree)
self.assertEqual('failed commit', new_commit.message)
+ def test_commit_branch(self):
+ r = self._repo
+
+ commit_sha = r.do_commit('commit to branch',
+ committer='Test Committer <test@nodomain.com>',
+ author='Test Author <test@nodomain.com>',
+ commit_timestamp=12395, commit_timezone=0,
+ author_timestamp=12395, author_timezone=0,
+ branch="refs/heads/new_branch")
+ self.assertEqual(self._root_commit, r["HEAD"].id)
+ self.assertEqual(commit_sha, r["refs/heads/new_branch"].id)
+ self.assertEqual([], r[commit_sha].parents)
+ self.assertTrue("refs/heads/new_branch" in r)
+
+ new_branch_head = commit_sha
+
+ commit_sha = r.do_commit('commit to branch 2',
+ committer='Test Committer <test@nodomain.com>',
+ author='Test Author <test@nodomain.com>',
+ commit_timestamp=12395, commit_timezone=0,
+ author_timestamp=12395, author_timezone=0,
+ branch="refs/heads/new_branch")
+ self.assertEqual(self._root_commit, r["HEAD"].id)
+ self.assertEqual(commit_sha, r["refs/heads/new_branch"].id)
+ self.assertEqual([new_branch_head], r[commit_sha].parents)
+
def test_stage_deleted(self):
r = self._repo
os.remove(os.path.join(r.path, 'a'))