Fetch revisions in chunks of a 1000.
[jelmer/subvertpy.git] / logwalker.py
index 8858467e4821e220b555bc5bed08a8d6f95cf391..80cb025bf6bb934284bd260cffdad92b0351df6d 100644 (file)
 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 """Cache of the Subversion history log."""
 
+from bzrlib import urlutils
 from bzrlib.errors import NoSuchRevision
 import bzrlib.ui as ui
-
-import os
+from copy import copy
 
 from svn.core import SubversionException, Pool
 from transport import SvnRaTransport
@@ -28,6 +28,8 @@ import base64
 
 from cache import sqlite3
 
+LOG_CHUNK_LIMIT = 1000
+
 def _escape_commit_message(message):
     """Replace xml-incompatible control characters."""
     if message is None:
@@ -51,23 +53,22 @@ def _escape_commit_message(message):
 
 class LogWalker(object):
     """Easy way to access the history of a Subversion repository."""
-    def __init__(self, transport=None, cache_db=None, last_revnum=None):
+    def __init__(self, transport, cache_db=None, limit=None):
         """Create a new instance.
 
         :param transport:   SvnRaTransport to use to access the repository.
         :param cache_db:    Optional sql database connection to use. Doesn't 
                             cache if not set.
-        :param last_revnum: Last known revnum in the repository. Will be 
-                            determined if not specified.
         """
         assert isinstance(transport, SvnRaTransport)
 
-        if last_revnum is None:
-            last_revnum = transport.get_latest_revnum()
-
-        self.last_revnum = last_revnum
+        self.url = transport.base
+        self._transport = None
 
-        self.transport = SvnRaTransport(transport.base)
+        if limit is not None:
+            self._limit = limit
+        else:
+            self._limit = LOG_CHUNK_LIMIT
 
         if cache_db is None:
             self.db = sqlite3.connect(":memory:")
@@ -79,21 +80,27 @@ class LogWalker(object):
           create unique index if not exists revision_revno on revision (revno);
           create table if not exists changed_path(rev integer, action text, path text, copyfrom_path text, copyfrom_rev integer);
           create index if not exists path_rev on changed_path(rev);
-          create index if not exists path_rev_path on changed_path(rev, path);
-          create index if not exists path_rev_path_action on changed_path(rev, path, action);
+          create unique index if not exists path_rev_path on changed_path(rev, path);
+          create unique index if not exists path_rev_path_action on changed_path(rev, path, action);
         """)
         self.db.commit()
         self.saved_revnum = self.db.execute("SELECT MAX(revno) FROM revision").fetchone()[0]
         if self.saved_revnum is None:
             self.saved_revnum = 0
 
-    def fetch_revisions(self, to_revnum):
+    def _get_transport(self):
+        if self._transport is not None:
+            return self._transport
+        self._transport = SvnRaTransport(self.url)
+        return self._transport
+
+    def fetch_revisions(self, to_revnum=None):
         """Fetch information about all revisions in the remote repository
         until to_revnum.
 
         :param to_revnum: End of range to fetch information for
         """
-        to_revnum = max(self.last_revnum, to_revnum)
+        to_revnum = max(self._get_transport().get_latest_revnum(), to_revnum)
 
         pb = ui.ui_factory.nested_progress_bar()
 
@@ -103,11 +110,11 @@ class LogWalker(object):
                 orig_paths = {}
             for p in orig_paths:
                 copyfrom_path = orig_paths[p].copyfrom_path
-                if copyfrom_path:
+                if copyfrom_path is not None:
                     copyfrom_path = copyfrom_path.strip("/")
 
                 self.db.execute(
-                     "insert into changed_path (rev, path, action, copyfrom_path, copyfrom_rev) values (?, ?, ?, ?, ?)", 
+                     "replace into changed_path (rev, path, action, copyfrom_path, copyfrom_rev) values (?, ?, ?, ?, ?)", 
                      (rev, p.strip("/"), orig_paths[p].action, copyfrom_path, orig_paths[p].copyfrom_rev))
 
             if message is not None:
@@ -119,11 +126,14 @@ class LogWalker(object):
             if self.saved_revnum % 1000 == 0:
                 self.db.commit()
 
-        pool = Pool()
         try:
             try:
-                self.transport.get_log("/", self.saved_revnum, to_revnum, 
-                               0, True, True, rcvr, pool)
+                while self.saved_revnum < to_revnum:
+                    pool = Pool()
+                    self._get_transport().get_log("/", self.saved_revnum, 
+                                             to_revnum, self._limit, True, 
+                                             True, rcvr, pool)
+                    pool.destroy()
             finally:
                 pb.finished()
         except SubversionException, (_, num):
@@ -132,7 +142,6 @@ class LogWalker(object):
                     revision="Revision number %d" % to_revnum)
             raise
         self.db.commit()
-        pool.destroy()
 
     def follow_path(self, path, revnum):
         """Return iterator over all the revisions between revnum and 
@@ -150,13 +159,20 @@ class LogWalker(object):
         if revnum == 0 and path == "":
             return
 
+        recurse = (path != "")
+
         path = path.strip("/")
 
         while revnum >= 0:
-            revpaths = self.get_revision_paths(revnum, path)
+            assert revnum > 0 or path == ""
+            revpaths = self.get_revision_paths(revnum, path, recurse=recurse)
 
             if revpaths != {}:
-                yield (path, revpaths, revnum)
+                yield (path, copy(revpaths), revnum)
+
+            if path == "":
+                revnum -= 1
+                continue
 
             if revpaths.has_key(path):
                 if revpaths[path][1] is None:
@@ -168,19 +184,28 @@ class LogWalker(object):
                     # somewhere else
                     revnum = revpaths[path][2]
                     path = revpaths[path][1]
+                    assert path == "" or revnum > 0
                     continue
             revnum -= 1
-
-    def get_revision_paths(self, revnum, path=None):
+            for p in sorted(revpaths.keys()):
+                if path.startswith(p+"/") and revpaths[p][0] in ('A', 'R'):
+                    assert revpaths[p][1]
+                    path = path.replace(p, revpaths[p][1])
+                    revnum = revpaths[p][2]
+                    break
+
+    def get_revision_paths(self, revnum, path=None, recurse=False):
         """Obtain dictionary with all the changes in a particular revision.
 
         :param revnum: Subversion revision number
         :param path: optional path under which to return all entries
+        :param recurse: Report changes to parents as well
         :returns: dictionary with paths as keys and 
                   (action, copyfrom_path, copyfrom_rev) as values.
         """
 
         if revnum == 0:
+            assert path is None or path == ""
             return {'': ('A', None, -1)}
                 
         if revnum > self.saved_revnum:
@@ -188,11 +213,14 @@ class LogWalker(object):
 
         query = "select path, action, copyfrom_path, copyfrom_rev from changed_path where rev="+str(revnum)
         if path is not None and path != "":
-            query += " and (path='%s' or path like '%s/%%')" % (path, path)
+            query += " and (path='%s' or path like '%s/%%'" % (path, path)
+            if recurse:
+                query += " or ('%s' LIKE path || '/%%')" % path
+            query += ")"
 
         paths = {}
         for p, act, cf, cr in self.db.execute(query):
-            paths[p] = (act, cf, cr)
+            paths[p.encode("utf-8")] = (act, cf, cr)
         return paths
 
     def get_revision_info(self, revnum):
@@ -211,7 +239,8 @@ class LogWalker(object):
             message = _escape_commit_message(base64.b64decode(message))
         return (author, message, date)
 
-    def find_latest_change(self, path, revnum, recurse=False):
+    def find_latest_change(self, path, revnum, include_parents=False,
+                           include_children=False):
         """Find latest revision that touched path.
 
         :param path: Path to check for changes
@@ -222,11 +251,12 @@ class LogWalker(object):
         if revnum > self.saved_revnum:
             self.fetch_revisions(revnum)
 
-        if recurse:
-            extra = " or path like '%s/%%'" % path.strip("/")
-        else:
-            extra = ""
-        query = "select rev from changed_path where (path='%s' or ('%s' like (path || '/%%') and (action = 'R' or action = 'A'))%s) and rev <= %d order by rev desc limit 1" % (path.strip("/"), path.strip("/"), extra, revnum)
+        extra = ""
+        if include_children:
+            extra += " or path like '%s/%%'" % path.strip("/")
+        if include_parents:
+            extra += " or ('%s' like (path || '/%%') and (action = 'R' or action = 'A'))" % path.strip("/")
+        query = "select rev from changed_path where (path='%s'%s) and rev <= %d order by rev desc limit 1" % (path.strip("/"), extra, revnum)
 
         row = self.db.execute(query).fetchone()
         if row is None and path == "":
@@ -252,21 +282,28 @@ class LogWalker(object):
     def find_children(self, path, revnum):
         """Find all children of path in revnum."""
         path = path.strip("/")
-        if self.transport.check_path(path, revnum) == svn.core.svn_node_file:
+        transport = self._get_transport()
+        ft = transport.check_path(path, revnum)
+        if ft == svn.core.svn_node_file:
             return []
+        assert ft == svn.core.svn_node_dir
+
         class TreeLister(svn.delta.Editor):
             def __init__(self, base):
                 self.files = []
                 self.base = base
 
             def set_target_revision(self, revnum):
+                """See Editor.set_target_revision()."""
                 pass
 
             def open_root(self, revnum, baton):
+                """See Editor.open_root()."""
                 return path
 
             def add_directory(self, path, parent_baton, copyfrom_path, copyfrom_revnum, pool):
-                self.files.append(os.path.join(self.base, path))
+                """See Editor.add_directory()."""
+                self.files.append(urlutils.join(self.base, path))
                 return path
 
             def change_dir_prop(self, id, name, value, pool):
@@ -276,7 +313,7 @@ class LogWalker(object):
                 pass
 
             def add_file(self, path, parent_id, copyfrom_path, copyfrom_revnum, baton):
-                self.files.append(os.path.join(self.base, path))
+                self.files.append(urlutils.join(self.base, path))
                 return path
 
             def close_dir(self, id):
@@ -296,12 +333,15 @@ class LogWalker(object):
         pool = Pool()
         editor = TreeLister(path)
         edit, baton = svn.delta.make_editor(editor, pool)
-        root_repos = self.transport.get_repos_root()
-        self.transport.reparent(os.path.join(root_repos, path))
-        reporter = self.transport.do_update(
-                        revnum, "", True, edit, baton, pool)
-        reporter.set_path("", revnum, True, None, pool)
-        reporter.finish_report(pool)
+        old_base = transport.base
+        try:
+            root_repos = transport.get_repos_root()
+            transport.reparent(urlutils.join(root_repos, path))
+            reporter = transport.do_update(revnum,  True, edit, baton, pool)
+            reporter.set_path("", revnum, True, None, pool)
+            reporter.finish_report(pool)
+        finally:
+            transport.reparent(old_base)
         return editor.files
 
     def get_previous(self, path, revnum):