Fix bug in new follow_branch functions.
[jelmer/subvertpy.git] / logwalker.py
index cf4fe7cedcf3cf840b327e5089a721696bde133e..acd981ca7697e817736ff2780f332eee3b7c8d3e 100644 (file)
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
-from bzrlib.config import config_dir
 from bzrlib.errors import NoSuchRevision, BzrError, NotBranchError
 from bzrlib.progress import ProgressBar, DummyProgress
 from bzrlib.trace import mutter
 
 import os
-import shelve
-from cStringIO import StringIO
 
-from svn.core import SubversionException
-import svn.ra
-
-cache_dir = os.path.join(config_dir(), 'svn-cache')
-
-def create_cache_dir(uuid):
-    if not os.path.exists(cache_dir):
-        os.mkdir(cache_dir)
-
-        open(os.path.join(cache_dir, "README"), 'w').write(
-"""This directory contains information cached by the bzr-svn plugin.
-
-It is used for performance reasons only and can be removed 
-without losing data.
-""")
-
-    dir = os.path.join(cache_dir, uuid)
-    if not os.path.exists(dir):
-        os.mkdir(dir)
-    return dir
-
-
-class NotSvnBranchPath(BzrError):
-    def __init__(self, branch_path):
-        BzrError.__init__(self, 
-                "%r is not a valid Svn branch path", 
-                branch_path)
-        self.branch_path = branch_path
+from svn.core import SubversionException, Pool
+from transport import SvnRaTransport
+import svn.core
+
+import base64
+
+try:
+    import sqlite3
+except ImportError:
+    from pysqlite2 import dbapi2 as sqlite3
+
+shelves = {}
+
+def _escape_commit_message(message):
+    """Replace xml-incompatible control characters."""
+    if message is None:
+        return None
+    import re
+    # FIXME: RBC 20060419 this should be done by the revision
+    # serialiser not by commit. Then we can also add an unescaper
+    # in the deserializer and start roundtripping revision messages
+    # precisely. See repository_implementations/test_repository.py
+    
+    # Python strings can include characters that can't be
+    # represented in well-formed XML; escape characters that
+    # aren't listed in the XML specification
+    # (http://www.w3.org/TR/REC-xml/#NT-Char).
+    message, _ = re.subn(
+        u'[^\x09\x0A\x0D\u0020-\uD7FF\uE000-\uFFFD]+',
+        lambda match: match.group(0).encode('unicode_escape'),
+        message)
+    return message
 
 
 class LogWalker(object):
-    def __init__(self, scheme, ra=None, uuid=None, last_revnum=None, repos_url=None, pb=None):
-        if ra is None:
-            callbacks = svn.ra.callbacks2_t()
-            ra = svn.ra.open2(repos_url.encode('utf8'), callbacks, None, None)
-            root = svn.ra.get_repos_root(ra)
-            if root != repos_url:
-                svn.ra.reparent(ra, root.encode('utf8'))
-
-        if not uuid:
-            uuid = svn.ra.get_uuid(ra)
-
-        self.uuid = uuid
+    """Easy way to access the history of a Subversion repository."""
+    def __init__(self, transport=None, cache_db=None, last_revnum=None):
+        """Create a new instance.
+
+        :param transport:   SvnRaTransport to use to access the repository.
+        :param cache_db:    Optional sql database connection to use. Doesn't 
+                            cache if not set.
+        :param last_revnum: Last known revnum in the repository. Will be 
+                            determined if not specified.
+        """
+        assert isinstance(transport, SvnRaTransport)
 
         if last_revnum is None:
-            last_revnum = svn.ra.get_latest_revnum(ra)
+            last_revnum = transport.get_latest_revnum()
 
-        self.cache_file = os.path.join(create_cache_dir(uuid), 'log')
-        self.ra = ra
-        self.scheme = scheme
+        self.last_revnum = last_revnum
 
-        # Try to load cache from file
-        self.revisions = shelve.open(self.cache_file)
-        self.saved_revnum = max(len(self.revisions)-1, 0)
+        self.transport = SvnRaTransport(transport.get_repos_root())
 
-        if self.saved_revnum < last_revnum:
-            self.fetch_revisions(self.saved_revnum, last_revnum, pb)
+        if cache_db is None:
+            self.db = sqlite3.connect(":memory:")
         else:
-            self.last_revnum = self.saved_revnum
-
-    def fetch_revisions(self, from_revnum, to_revnum, pb=None):
+            self.db = cache_db
+
+        self.db.executescript("""
+          create table if not exists revision(revno integer unique, author text, message text, date text);
+          create unique index if not exists revision_revno on revision (revno);
+          create table if not exists changed_path(rev integer, action text, path text, copyfrom_path text, copyfrom_rev integer);
+          create index if not exists path_rev on changed_path(rev);
+          create index if not exists path_rev_path on changed_path(rev, path);
+        """)
+        self.db.commit()
+        self.saved_revnum = self.db.execute("SELECT MAX(revno) FROM revision").fetchone()[0]
+        if self.saved_revnum is None:
+            self.saved_revnum = 0
+
+    def fetch_revisions(self, to_revnum, pb=None):
+        """Fetch information about all revisions in the remote repository
+        until to_revnum.
+
+        :param to_revnum: End of range to fetch information for
+        :param pb: Optional progress bar to use
+        """
         def rcvr(orig_paths, rev, author, date, message, pool):
             pb.update('fetching svn revision info', rev, to_revnum)
             paths = {}
@@ -93,15 +107,19 @@ class LogWalker(object):
                 copyfrom_path = orig_paths[p].copyfrom_path
                 if copyfrom_path:
                     copyfrom_path = copyfrom_path.strip("/")
-                paths[p.strip("/")] = (orig_paths[p].action,
-                            copyfrom_path, orig_paths[p].copyfrom_rev)
 
-            self.revisions[str(rev)] = {
-                    'paths': paths,
-                    'author': author,
-                    'date': date,
-                    'message': message
-                    }
+                self.db.execute(
+                     "insert into changed_path (rev, path, action, copyfrom_path, copyfrom_rev) values (?, ?, ?, ?, ?)", 
+                     (rev, p.strip("/"), orig_paths[p].action, copyfrom_path, orig_paths[p].copyfrom_rev))
+
+            if message is not None:
+                message = base64.b64encode(message)
+
+            self.db.execute("replace into revision (revno, author, date, message) values (?,?,?,?)", (rev, author, date, message))
+
+            self.saved_revnum = rev
+
+        to_revnum = max(self.last_revnum, to_revnum)
 
         # Don't bother for only a few revisions
         if abs(self.saved_revnum-to_revnum) < 10:
@@ -109,12 +127,11 @@ class LogWalker(object):
         else:
             pb = ProgressBar()
 
+        pool = Pool()
         try:
             try:
-                mutter('getting log %r:%r' % (self.saved_revnum, to_revnum))
-                svn.ra.get_log(self.ra, ["/"], self.saved_revnum, to_revnum, 
-                               0, True, True, rcvr)
-                self.last_revnum = to_revnum
+                self.transport.get_log("/", self.saved_revnum, to_revnum, 
+                               0, True, True, rcvr, pool)
             finally:
                 pb.clear()
         except SubversionException, (_, num):
@@ -122,140 +139,151 @@ class LogWalker(object):
                 raise NoSuchRevision(branch=self, 
                     revision="Revision number %d" % to_revnum)
             raise
+        self.db.commit()
+        pool.destroy()
 
-        self.save()
-
-    def save(self):
-        pickle.dump(self.revisions, open(self.cache_file, 'w'))
-
-    def follow_history(self, branch_path, revnum):
-        for (branch, paths, rev, _, _, _) in self.get_branch_log(branch_path, 
-                                                                 revnum):
-            yield (branch, paths, rev)
-
-    def get_branch_log(self, branch_path, from_revnum, to_revnum=0, limit=0):
-        """Return iterator over all the revisions between from_revnum and 
-        to_revnum that touch branch_path."""
-        assert from_revnum >= to_revnum
+    def follow_path(self, path, revnum):
+        """Return iterator over all the revisions between revnum and 
+        0 named path or inside path.
 
-        if not branch_path is None and not self.scheme.is_branch(branch_path):
-            raise NotSvnBranchPath(branch_path)
+        :param path:   Branch path to start reporting (in revnum)
+        :param revnum:        Start revision.
 
-        if branch_path:
-            branch_path = branch_path.strip("/")
+        :return: An iterators that yields tuples with (path, paths, revnum)
+        where paths is a dictionary with all changes that happened in path 
+        in revnum.
+        """
+        assert revnum >= 0
+
+        if revnum == 0 and path == "":
+            return
+
+        path = path.strip("/")
+
+        while revnum > 0:
+            revpaths = self.get_revision_paths(revnum, path)
+
+            if revpaths != {}:
+                yield (path, revpaths, revnum)
+
+            if revpaths.has_key(path):
+                if revpaths[path][1] is None:
+                    if revpaths[path][0] in ('A', 'R'):
+                        # this path didn't exist before this revision
+                        return
+                else:
+                    # In this revision, this path was copied from 
+                    # somewhere else
+                    revnum = revpaths[path][2]
+                    path = revpaths[path][1]
+                    continue
+            revnum-=1
+
+    def get_revision_paths(self, revnum, path=None):
+        """Obtain dictionary with all the changes in a particular revision.
+
+        :param revnum: Subversion revision number
+        :param path: optional path under which to return all entries
+        :returns: dictionary with paths as keys and 
+                  (action, copyfrom_path, copyfrom_rev) as values.
+        """
 
-        if max(from_revnum, to_revnum) > self.last_revnum:
-            self.fetch_revisions(self.last_revnum, max(from_revnum, to_revnum))
+        if revnum == 0:
+            return {'': ('A', None, -1)}
+                
+        if revnum > self.saved_revnum:
+            self.fetch_revisions(revnum)
 
+        query = "select path, action, copyfrom_path, copyfrom_rev from changed_path where rev="+str(revnum)
+        if path is not None and path != "":
+            query += " and (path='%s' or path like '%s/%%')" % (path, path)
 
-        continue_revnum = None
-        num = 0
-        for i in range(abs(from_revnum-to_revnum)+1):
-            if to_revnum < from_revnum:
-                i = from_revnum - i
-            else:
-                i = from_revnum + i
+        paths = {}
+        for p, act, cf, cr in self.db.execute(query):
+            paths[p] = (act, cf, cr)
+        return paths
 
-            if i == 0:
-                continue
+    def get_revision_info(self, revnum, pb=None):
+        """Obtain basic information for a specific revision.
 
-            if not (continue_revnum is None or continue_revnum == i):
-                continue
+        :param revnum: Revision number.
+        :returns: Tuple with author, log message and date of the revision.
+        """
+        assert revnum >= 1
+        if revnum > self.saved_revnum:
+            self.fetch_revisions(revnum, pb)
+        (author, message, date) = self.db.execute("select author, message, date from revision where revno="+ str(revnum)).fetchone()
+        if author is None:
+            author = None
+        return (author, _escape_commit_message(base64.b64decode(message)), date)
+
+    def find_latest_change(self, path, revnum):
+        """Find latest revision that touched path.
+
+        :param path: Path to check for changes
+        :param revnum: First revision to check
+        """
+        if revnum > self.saved_revnum:
+            self.fetch_revisions(revnum)
 
-            continue_revnum = None
+        row = self.db.execute(
+             "select rev from changed_path where path='%s' and rev <= %d order by rev desc limit 1" % (path.strip("/"), revnum)).fetchone()
+        if row is None and path == "":
+            return 0
 
-            rev = self.revisions[str(i)]
-            changed_paths = {}
-            for p in rev['paths']:
-                if (branch_path is None or 
-                    p == branch_path or
-                    branch_path == "" or
-                    p.startswith(branch_path+"/")):
+        assert row is not None, "no latest change for %r:%d" % (path, revnum)
 
-                    try:
-                        (bp, rp) = self.scheme.unprefix(p)
-                        if not changed_paths.has_key(bp):
-                            changed_paths[bp] = {}
-                        changed_paths[bp][p] = rev['paths'][p]
-                    except NotBranchError:
-                        pass
+        return row[0]
 
-            assert branch_path is None or len(changed_paths) <= 1
+    def touches_path(self, path, revnum):
+        """Check whether path was changed in specified revision.
 
-            for bp in changed_paths:
-                num = num + 1
-                yield (bp, changed_paths[bp], i, rev['author'], rev['date'], 
-                       rev['message'])
+        :param path:  Path to check
+        :param revnum:  Revision to check
+        """
+        if revnum > self.saved_revnum:
+            self.fetch_revisions(revnum)
+        if revnum == 0:
+            return (path == "")
+        return (self.db.execute("select 1 from changed_path where path='%s' and rev=%d" % (path, revnum)).fetchone() is not None)
 
-            if (not branch_path is None and 
-                branch_path in rev['paths'] and 
-                not rev['paths'][branch_path][1] is None):
-                # In this revision, this branch was copied from 
-                # somewhere else
-                # FIXME: What if copyfrom_path is not a branch path?
-                continue_revnum = rev['paths'][branch_path][2]
-                branch_path = rev['paths'][branch_path][1]
+    def find_children(self, path, revnum):
+        """Find all children of path in revnum."""
+        # TODO: Find children by walking history, or use 
+        # cache?
 
-            if limit and num == limit:
+        try:
+            (dirents, _, _) = self.transport.get_dir(
+                path.lstrip("/").encode('utf8'), revnum, kind=True)
+        except SubversionException, (_, num):
+            if num == svn.core.SVN_ERR_FS_NOT_DIRECTORY:
                 return
+            raise
 
-    def get_offspring(self, path, orig_revnum, revnum):
-        """Check which files in revnum directly descend from path in orig_revnum."""
-        assert orig_revnum <= revnum
-
-        ancestors = [path]
-        dest = (path, orig_revnum)
-
-        for i in range(revnum-orig_revnum):
-            paths = self.revisions[str(i+1+orig_revnum)]['paths']
-            for p in paths:
-                new_ancestors = list(ancestors)
-
-                if paths[p][0] in ('R', 'A') and paths[p][1]:
-                    if paths[p][1:3] == dest:
-                        new_ancestors.append(p)
-
-                    for s in ancestors:
-                        if s.startswith(paths[p][1]+"/"):
-                            new_ancestors.append(s.replace(paths[p][1], p, 1))
-
-                ancestors = new_ancestors
-
-                if paths[p][0] in ('R', 'D'):
-                    for s in ancestors:
-                        if s == p or s.startswith(p+"/"):
-                            new_ancestors.remove(s)
-
-                ancestors = new_ancestors
-
-        return ancestors
-
-    def find_branches(self, revnum):
-        created_branches = {}
-
-        for i in range(revnum):
-            if i == 0:
-                continue
-            rev = self.revisions[str(i)]
-            for p in rev['paths']:
-                if self.scheme.is_branch(p):
-                    if rev['paths'][p][0] in ('R', 'D'):
-                        del created_branches[p]
-                        yield (p, i, False)
-
-                    if rev['paths'][p][0] in ('A', 'R'): 
-                        created_branches[p] = i
-
-        for p in created_branches:
-            yield (p, i, True)
-
-    def get_revision_info(self, revnum, pb=None):
-        """Obtain basic information for a specific revision.
-
-        :param revnum: Revision number.
-        :returns: Tuple with author, log message and date of the revision.
+        for p in dirents:
+            yield os.path.join(path, p)
+            # This needs to be != svn.core.svn_node_file because 
+            # some ra backends seem to return negative values for .kind.
+            # This if statement is just an optimization to make use of this 
+            # property when possible.
+            if dirents[p].kind != svn.core.svn_node_file:
+                for c in self.find_children(os.path.join(path, p), revnum):
+                    yield c
+
+    def get_previous(self, path, revnum):
+        """Return path,revnum pair specified pair was derived from.
+
+        :param path:  Path to check
+        :param revnum:  Revision to check
         """
-        if revnum > self.last_revnum:
-            self.fetch_revisions(self.saved_revnum, revnum, pb)
-        rev = self.revisions[str(revnum)]
-        return (rev['author'], rev['message'], rev['date'])
+        assert revnum >= 0
+        if revnum > self.saved_revnum:
+            self.fetch_revisions(revnum)
+        if revnum == 0:
+            return (None, -1)
+        row = self.db.execute("select action, copyfrom_path, copyfrom_rev from changed_path where path='%s' and rev=%d" % (path, revnum)).fetchone()
+        if row[2] == -1:
+            if row[0] == 'A':
+                return (None, -1)
+            return (path, revnum-1)
+        return (row[1], row[2])