Add ObjectStore.close.
[jelmer/dulwich.git] / dulwich / object_store.py
1 # object_store.py -- Object store for git objects
2 # Copyright (C) 2008-2012 Jelmer Vernooij <jelmer@samba.org>
3 #                         and others
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; either version 2
8 # or (at your option) a later version of the License.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20
21 """Git object store interfaces and implementation."""
22
23
24 from cStringIO import StringIO
25 import errno
26 import itertools
27 import os
28 import stat
29 import tempfile
30
31 from dulwich.diff_tree import (
32     tree_changes,
33     walk_trees,
34     )
35 from dulwich.errors import (
36     NotTreeError,
37     )
38 from dulwich.file import GitFile
39 from dulwich.objects import (
40     Commit,
41     ShaFile,
42     Tag,
43     Tree,
44     ZERO_SHA,
45     hex_to_sha,
46     sha_to_hex,
47     hex_to_filename,
48     S_ISGITLINK,
49     object_class,
50     )
51 from dulwich.pack import (
52     Pack,
53     PackData,
54     PackInflater,
55     iter_sha1,
56     write_pack_header,
57     write_pack_index_v2,
58     write_pack_object,
59     write_pack_objects,
60     compute_file_sha,
61     PackIndexer,
62     PackStreamCopier,
63     )
64
65 INFODIR = 'info'
66 PACKDIR = 'pack'
67
68
69 class BaseObjectStore(object):
70     """Object store interface."""
71
72     def determine_wants_all(self, refs):
73         return [sha for (ref, sha) in refs.iteritems()
74                 if not sha in self and not ref.endswith("^{}") and
75                    not sha == ZERO_SHA]
76
77     def iter_shas(self, shas):
78         """Iterate over the objects for the specified shas.
79
80         :param shas: Iterable object with SHAs
81         :return: Object iterator
82         """
83         return ObjectStoreIterator(self, shas)
84
85     def contains_loose(self, sha):
86         """Check if a particular object is present by SHA1 and is loose."""
87         raise NotImplementedError(self.contains_loose)
88
89     def contains_packed(self, sha):
90         """Check if a particular object is present by SHA1 and is packed."""
91         raise NotImplementedError(self.contains_packed)
92
93     def __contains__(self, sha):
94         """Check if a particular object is present by SHA1.
95
96         This method makes no distinction between loose and packed objects.
97         """
98         return self.contains_packed(sha) or self.contains_loose(sha)
99
100     @property
101     def packs(self):
102         """Iterable of pack objects."""
103         raise NotImplementedError
104
105     def get_raw(self, name):
106         """Obtain the raw text for an object.
107
108         :param name: sha for the object.
109         :return: tuple with numeric type and object contents.
110         """
111         raise NotImplementedError(self.get_raw)
112
113     def __getitem__(self, sha):
114         """Obtain an object by SHA1."""
115         type_num, uncomp = self.get_raw(sha)
116         return ShaFile.from_raw_string(type_num, uncomp)
117
118     def __iter__(self):
119         """Iterate over the SHAs that are present in this store."""
120         raise NotImplementedError(self.__iter__)
121
122     def add_object(self, obj):
123         """Add a single object to this object store.
124
125         """
126         raise NotImplementedError(self.add_object)
127
128     def add_objects(self, objects):
129         """Add a set of objects to this object store.
130
131         :param objects: Iterable over a list of objects.
132         """
133         raise NotImplementedError(self.add_objects)
134
135     def tree_changes(self, source, target, want_unchanged=False):
136         """Find the differences between the contents of two trees
137
138         :param source: SHA1 of the source tree
139         :param target: SHA1 of the target tree
140         :param want_unchanged: Whether unchanged files should be reported
141         :return: Iterator over tuples with
142             (oldpath, newpath), (oldmode, newmode), (oldsha, newsha)
143         """
144         for change in tree_changes(self, source, target,
145                                    want_unchanged=want_unchanged):
146             yield ((change.old.path, change.new.path),
147                    (change.old.mode, change.new.mode),
148                    (change.old.sha, change.new.sha))
149
150     def iter_tree_contents(self, tree_id, include_trees=False):
151         """Iterate the contents of a tree and all subtrees.
152
153         Iteration is depth-first pre-order, as in e.g. os.walk.
154
155         :param tree_id: SHA1 of the tree.
156         :param include_trees: If True, include tree objects in the iteration.
157         :return: Iterator over TreeEntry namedtuples for all the objects in a
158             tree.
159         """
160         for entry, _ in walk_trees(self, tree_id, None):
161             if not stat.S_ISDIR(entry.mode) or include_trees:
162                 yield entry
163
164     def find_missing_objects(self, haves, wants, progress=None,
165                              get_tagged=None):
166         """Find the missing objects required for a set of revisions.
167
168         :param haves: Iterable over SHAs already in common.
169         :param wants: Iterable over SHAs of objects to fetch.
170         :param progress: Simple progress function that will be called with
171             updated progress strings.
172         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
173             sha for including tags.
174         :return: Iterator over (sha, path) pairs.
175         """
176         finder = MissingObjectFinder(self, haves, wants, progress, get_tagged)
177         return iter(finder.next, None)
178
179     def find_common_revisions(self, graphwalker):
180         """Find which revisions this store has in common using graphwalker.
181
182         :param graphwalker: A graphwalker object.
183         :return: List of SHAs that are in common
184         """
185         haves = []
186         sha = graphwalker.next()
187         while sha:
188             if sha in self:
189                 haves.append(sha)
190                 graphwalker.ack(sha)
191             sha = graphwalker.next()
192         return haves
193
194     def get_graph_walker(self, heads):
195         """Obtain a graph walker for this object store.
196
197         :param heads: Local heads to start search with
198         :return: GraphWalker object
199         """
200         return ObjectStoreGraphWalker(heads, lambda sha: self[sha].parents)
201
202     def generate_pack_contents(self, have, want, progress=None):
203         """Iterate over the contents of a pack file.
204
205         :param have: List of SHA1s of objects that should not be sent
206         :param want: List of SHA1s of objects that should be sent
207         :param progress: Optional progress reporting method
208         """
209         return self.iter_shas(self.find_missing_objects(have, want, progress))
210
211     def peel_sha(self, sha):
212         """Peel all tags from a SHA.
213
214         :param sha: The object SHA to peel.
215         :return: The fully-peeled SHA1 of a tag object, after peeling all
216             intermediate tags; if the original ref does not point to a tag, this
217             will equal the original SHA1.
218         """
219         obj = self[sha]
220         obj_class = object_class(obj.type_name)
221         while obj_class is Tag:
222             obj_class, sha = obj.object
223             obj = self[sha]
224         return obj
225
226     def _collect_ancestors(self, heads, common=set()):
227         """Collect all ancestors of heads up to (excluding) those in common.
228
229         :param heads: commits to start from
230         :param common: commits to end at, or empty set to walk repository
231             completely
232         :return: a tuple (A, B) where A - all commits reachable
233             from heads but not present in common, B - common (shared) elements
234             that are directly reachable from heads
235         """
236         bases = set()
237         commits = set()
238         queue = []
239         queue.extend(heads)
240         while queue:
241             e = queue.pop(0)
242             if e in common:
243                 bases.add(e)
244             elif e not in commits:
245                 commits.add(e)
246                 cmt = self[e]
247                 queue.extend(cmt.parents)
248         return (commits, bases)
249
250     def close(self):
251         """Close any files opened by this object store."""
252         # Default implementation is a NO-OP
253
254
255 class PackBasedObjectStore(BaseObjectStore):
256
257     def __init__(self):
258         self._pack_cache = None
259
260     @property
261     def alternates(self):
262         return []
263
264     def contains_packed(self, sha):
265         """Check if a particular object is present by SHA1 and is packed.
266
267         This does not check alternates.
268         """
269         for pack in self.packs:
270             if sha in pack:
271                 return True
272         return False
273
274     def __contains__(self, sha):
275         """Check if a particular object is present by SHA1.
276
277         This method makes no distinction between loose and packed objects.
278         """
279         if self.contains_packed(sha) or self.contains_loose(sha):
280             return True
281         for alternate in self.alternates:
282             if sha in alternate:
283                 return True
284         return False
285
286     def _load_packs(self):
287         raise NotImplementedError(self._load_packs)
288
289     def _pack_cache_stale(self):
290         """Check whether the pack cache is stale."""
291         raise NotImplementedError(self._pack_cache_stale)
292
293     def _add_known_pack(self, pack):
294         """Add a newly appeared pack to the cache by path.
295
296         """
297         if self._pack_cache is not None:
298             self._pack_cache.append(pack)
299
300     def close(self):
301         pack_cache = self._pack_cache
302         self._pack_cache = None
303         while pack_cache:
304             pack = pack_cache.pop()
305             pack.close()
306
307     @property
308     def packs(self):
309         """List with pack objects."""
310         if self._pack_cache is None or self._pack_cache_stale():
311             self._pack_cache = self._load_packs()
312         return self._pack_cache
313
314     def _iter_alternate_objects(self):
315         """Iterate over the SHAs of all the objects in alternate stores."""
316         for alternate in self.alternates:
317             for alternate_object in alternate:
318                 yield alternate_object
319
320     def _iter_loose_objects(self):
321         """Iterate over the SHAs of all loose objects."""
322         raise NotImplementedError(self._iter_loose_objects)
323
324     def _get_loose_object(self, sha):
325         raise NotImplementedError(self._get_loose_object)
326
327     def _remove_loose_object(self, sha):
328         raise NotImplementedError(self._remove_loose_object)
329
330     def pack_loose_objects(self):
331         """Pack loose objects.
332
333         :return: Number of objects packed
334         """
335         objects = set()
336         for sha in self._iter_loose_objects():
337             objects.add((self._get_loose_object(sha), None))
338         self.add_objects(list(objects))
339         for obj, path in objects:
340             self._remove_loose_object(obj.id)
341         return len(objects)
342
343     def __iter__(self):
344         """Iterate over the SHAs that are present in this store."""
345         iterables = self.packs + [self._iter_loose_objects()] + [self._iter_alternate_objects()]
346         return itertools.chain(*iterables)
347
348     def contains_loose(self, sha):
349         """Check if a particular object is present by SHA1 and is loose.
350
351         This does not check alternates.
352         """
353         return self._get_loose_object(sha) is not None
354
355     def get_raw(self, name):
356         """Obtain the raw text for an object.
357
358         :param name: sha for the object.
359         :return: tuple with numeric type and object contents.
360         """
361         if len(name) == 40:
362             sha = hex_to_sha(name)
363             hexsha = name
364         elif len(name) == 20:
365             sha = name
366             hexsha = None
367         else:
368             raise AssertionError("Invalid object name %r" % name)
369         for pack in self.packs:
370             try:
371                 return pack.get_raw(sha)
372             except KeyError:
373                 pass
374         if hexsha is None:
375             hexsha = sha_to_hex(name)
376         ret = self._get_loose_object(hexsha)
377         if ret is not None:
378             return ret.type_num, ret.as_raw_string()
379         for alternate in self.alternates:
380             try:
381                 return alternate.get_raw(hexsha)
382             except KeyError:
383                 pass
384         raise KeyError(hexsha)
385
386     def add_objects(self, objects):
387         """Add a set of objects to this object store.
388
389         :param objects: Iterable over objects, should support __len__.
390         :return: Pack object of the objects written.
391         """
392         if len(objects) == 0:
393             # Don't bother writing an empty pack file
394             return
395         f, commit, abort = self.add_pack()
396         try:
397             write_pack_objects(f, objects)
398         except:
399             abort()
400             raise
401         else:
402             return commit()
403
404
405 class DiskObjectStore(PackBasedObjectStore):
406     """Git-style object store that exists on disk."""
407
408     def __init__(self, path):
409         """Open an object store.
410
411         :param path: Path of the object store.
412         """
413         super(DiskObjectStore, self).__init__()
414         self.path = path
415         self.pack_dir = os.path.join(self.path, PACKDIR)
416         self._pack_cache_time = 0
417         self._alternates = None
418
419     @property
420     def alternates(self):
421         if self._alternates is not None:
422             return self._alternates
423         self._alternates = []
424         for path in self._read_alternate_paths():
425             self._alternates.append(DiskObjectStore(path))
426         return self._alternates
427
428     def _read_alternate_paths(self):
429         try:
430             f = GitFile(os.path.join(self.path, "info", "alternates"),
431                     'rb')
432         except (OSError, IOError), e:
433             if e.errno == errno.ENOENT:
434                 return []
435             raise
436         ret = []
437         try:
438             for l in f.readlines():
439                 l = l.rstrip("\n")
440                 if l[0] == "#":
441                     continue
442                 if os.path.isabs(l):
443                     ret.append(l)
444                 else:
445                     ret.append(os.path.join(self.path, l))
446             return ret
447         finally:
448             f.close()
449
450     def add_alternate_path(self, path):
451         """Add an alternate path to this object store.
452         """
453         try:
454             os.mkdir(os.path.join(self.path, "info"))
455         except OSError, e:
456             if e.errno != errno.EEXIST:
457                 raise
458         alternates_path = os.path.join(self.path, "info/alternates")
459         f = GitFile(alternates_path, 'wb')
460         try:
461             try:
462                 orig_f = open(alternates_path, 'rb')
463             except (OSError, IOError), e:
464                 if e.errno != errno.ENOENT:
465                     raise
466             else:
467                 try:
468                     f.write(orig_f.read())
469                 finally:
470                     orig_f.close()
471             f.write("%s\n" % path)
472         finally:
473             f.close()
474
475         if not os.path.isabs(path):
476             path = os.path.join(self.path, path)
477         self.alternates.append(DiskObjectStore(path))
478
479     def _load_packs(self):
480         pack_files = []
481         try:
482             self._pack_cache_time = os.stat(self.pack_dir).st_mtime
483             pack_dir_contents = os.listdir(self.pack_dir)
484             for name in pack_dir_contents:
485                 # TODO: verify that idx exists first
486                 if name.startswith("pack-") and name.endswith(".pack"):
487                     filename = os.path.join(self.pack_dir, name)
488                     pack_files.append((os.stat(filename).st_mtime, filename))
489         except OSError, e:
490             if e.errno == errno.ENOENT:
491                 return []
492             raise
493         pack_files.sort(reverse=True)
494         suffix_len = len(".pack")
495         return [Pack(f[:-suffix_len]) for _, f in pack_files]
496
497     def _pack_cache_stale(self):
498         try:
499             return os.stat(self.pack_dir).st_mtime > self._pack_cache_time
500         except OSError, e:
501             if e.errno == errno.ENOENT:
502                 return True
503             raise
504
505     def _get_shafile_path(self, sha):
506         # Check from object dir
507         return hex_to_filename(self.path, sha)
508
509     def _iter_loose_objects(self):
510         for base in os.listdir(self.path):
511             if len(base) != 2:
512                 continue
513             for rest in os.listdir(os.path.join(self.path, base)):
514                 yield base+rest
515
516     def _get_loose_object(self, sha):
517         path = self._get_shafile_path(sha)
518         try:
519             return ShaFile.from_path(path)
520         except (OSError, IOError), e:
521             if e.errno == errno.ENOENT:
522                 return None
523             raise
524
525     def _remove_loose_object(self, sha):
526         os.remove(self._get_shafile_path(sha))
527
528     def _complete_thin_pack(self, f, path, copier, indexer):
529         """Move a specific file containing a pack into the pack directory.
530
531         :note: The file should be on the same file system as the
532             packs directory.
533
534         :param f: Open file object for the pack.
535         :param path: Path to the pack file.
536         :param copier: A PackStreamCopier to use for writing pack data.
537         :param indexer: A PackIndexer for indexing the pack.
538         """
539         entries = list(indexer)
540
541         # Update the header with the new number of objects.
542         f.seek(0)
543         write_pack_header(f, len(entries) + len(indexer.ext_refs()))
544
545         # Must flush before reading (http://bugs.python.org/issue3207)
546         f.flush()
547
548         # Rescan the rest of the pack, computing the SHA with the new header.
549         new_sha = compute_file_sha(f, end_ofs=-20)
550
551         # Must reposition before writing (http://bugs.python.org/issue3207)
552         f.seek(0, os.SEEK_CUR)
553
554         # Complete the pack.
555         for ext_sha in indexer.ext_refs():
556             assert len(ext_sha) == 20
557             type_num, data = self.get_raw(ext_sha)
558             offset = f.tell()
559             crc32 = write_pack_object(f, type_num, data, sha=new_sha)
560             entries.append((ext_sha, offset, crc32))
561         pack_sha = new_sha.digest()
562         f.write(pack_sha)
563         f.close()
564
565         # Move the pack in.
566         entries.sort()
567         pack_base_name = os.path.join(
568           self.pack_dir, 'pack-' + iter_sha1(e[0] for e in entries))
569         os.rename(path, pack_base_name + '.pack')
570
571         # Write the index.
572         index_file = GitFile(pack_base_name + '.idx', 'wb')
573         try:
574             write_pack_index_v2(index_file, entries, pack_sha)
575             index_file.close()
576         finally:
577             index_file.abort()
578
579         # Add the pack to the store and return it.
580         final_pack = Pack(pack_base_name)
581         final_pack.check_length_and_checksum()
582         self._add_known_pack(final_pack)
583         return final_pack
584
585     def add_thin_pack(self, read_all, read_some):
586         """Add a new thin pack to this object store.
587
588         Thin packs are packs that contain deltas with parents that exist outside
589         the pack. They should never be placed in the object store directly, and
590         always indexed and completed as they are copied.
591
592         :param read_all: Read function that blocks until the number of requested
593             bytes are read.
594         :param read_some: Read function that returns at least one byte, but may
595             not return the number of bytes requested.
596         :return: A Pack object pointing at the now-completed thin pack in the
597             objects/pack directory.
598         """
599         fd, path = tempfile.mkstemp(dir=self.path, prefix='tmp_pack_')
600         f = os.fdopen(fd, 'w+b')
601
602         try:
603             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
604             copier = PackStreamCopier(read_all, read_some, f,
605                                       delta_iter=indexer)
606             copier.verify()
607             return self._complete_thin_pack(f, path, copier, indexer)
608         finally:
609             f.close()
610
611     def move_in_pack(self, path):
612         """Move a specific file containing a pack into the pack directory.
613
614         :note: The file should be on the same file system as the
615             packs directory.
616
617         :param path: Path to the pack file.
618         """
619         p = PackData(path)
620         entries = p.sorted_entries()
621         basename = os.path.join(self.pack_dir,
622             "pack-%s" % iter_sha1(entry[0] for entry in entries))
623         f = GitFile(basename+".idx", "wb")
624         try:
625             write_pack_index_v2(f, entries, p.get_stored_checksum())
626         finally:
627             f.close()
628         p.close()
629         os.rename(path, basename + ".pack")
630         final_pack = Pack(basename)
631         self._add_known_pack(final_pack)
632         return final_pack
633
634     def add_pack(self):
635         """Add a new pack to this object store.
636
637         :return: Fileobject to write to, a commit function to
638             call when the pack is finished and an abort
639             function.
640         """
641         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
642         f = os.fdopen(fd, 'wb')
643         def commit():
644             os.fsync(fd)
645             f.close()
646             if os.path.getsize(path) > 0:
647                 return self.move_in_pack(path)
648             else:
649                 os.remove(path)
650                 return None
651         def abort():
652             f.close()
653             os.remove(path)
654         return f, commit, abort
655
656     def add_object(self, obj):
657         """Add a single object to this object store.
658
659         :param obj: Object to add
660         """
661         dir = os.path.join(self.path, obj.id[:2])
662         try:
663             os.mkdir(dir)
664         except OSError, e:
665             if e.errno != errno.EEXIST:
666                 raise
667         path = os.path.join(dir, obj.id[2:])
668         if os.path.exists(path):
669             return # Already there, no need to write again
670         f = GitFile(path, 'wb')
671         try:
672             f.write(obj.as_legacy_object())
673         finally:
674             f.close()
675
676     @classmethod
677     def init(cls, path):
678         try:
679             os.mkdir(path)
680         except OSError, e:
681             if e.errno != errno.EEXIST:
682                 raise
683         os.mkdir(os.path.join(path, "info"))
684         os.mkdir(os.path.join(path, PACKDIR))
685         return cls(path)
686
687
688 class MemoryObjectStore(BaseObjectStore):
689     """Object store that keeps all objects in memory."""
690
691     def __init__(self):
692         super(MemoryObjectStore, self).__init__()
693         self._data = {}
694
695     def _to_hexsha(self, sha):
696         if len(sha) == 40:
697             return sha
698         elif len(sha) == 20:
699             return sha_to_hex(sha)
700         else:
701             raise ValueError("Invalid sha %r" % (sha,))
702
703     def contains_loose(self, sha):
704         """Check if a particular object is present by SHA1 and is loose."""
705         return self._to_hexsha(sha) in self._data
706
707     def contains_packed(self, sha):
708         """Check if a particular object is present by SHA1 and is packed."""
709         return False
710
711     def __iter__(self):
712         """Iterate over the SHAs that are present in this store."""
713         return self._data.iterkeys()
714
715     @property
716     def packs(self):
717         """List with pack objects."""
718         return []
719
720     def get_raw(self, name):
721         """Obtain the raw text for an object.
722
723         :param name: sha for the object.
724         :return: tuple with numeric type and object contents.
725         """
726         obj = self[self._to_hexsha(name)]
727         return obj.type_num, obj.as_raw_string()
728
729     def __getitem__(self, name):
730         return self._data[self._to_hexsha(name)]
731
732     def __delitem__(self, name):
733         """Delete an object from this store, for testing only."""
734         del self._data[self._to_hexsha(name)]
735
736     def add_object(self, obj):
737         """Add a single object to this object store.
738
739         """
740         self._data[obj.id] = obj
741
742     def add_objects(self, objects):
743         """Add a set of objects to this object store.
744
745         :param objects: Iterable over a list of objects.
746         """
747         for obj, path in objects:
748             self._data[obj.id] = obj
749
750     def add_pack(self):
751         """Add a new pack to this object store.
752
753         Because this object store doesn't support packs, we extract and add the
754         individual objects.
755
756         :return: Fileobject to write to and a commit function to
757             call when the pack is finished.
758         """
759         f = StringIO()
760         def commit():
761             p = PackData.from_file(StringIO(f.getvalue()), f.tell())
762             f.close()
763             for obj in PackInflater.for_pack_data(p):
764                 self._data[obj.id] = obj
765         def abort():
766             pass
767         return f, commit, abort
768
769     def _complete_thin_pack(self, f, indexer):
770         """Complete a thin pack by adding external references.
771
772         :param f: Open file object for the pack.
773         :param indexer: A PackIndexer for indexing the pack.
774         """
775         entries = list(indexer)
776
777         # Update the header with the new number of objects.
778         f.seek(0)
779         write_pack_header(f, len(entries) + len(indexer.ext_refs()))
780
781         # Rescan the rest of the pack, computing the SHA with the new header.
782         new_sha = compute_file_sha(f, end_ofs=-20)
783
784         # Complete the pack.
785         for ext_sha in indexer.ext_refs():
786             assert len(ext_sha) == 20
787             type_num, data = self.get_raw(ext_sha)
788             write_pack_object(f, type_num, data, sha=new_sha)
789         pack_sha = new_sha.digest()
790         f.write(pack_sha)
791
792     def add_thin_pack(self, read_all, read_some):
793         """Add a new thin pack to this object store.
794
795         Thin packs are packs that contain deltas with parents that exist outside
796         the pack. Because this object store doesn't support packs, we extract
797         and add the individual objects.
798
799         :param read_all: Read function that blocks until the number of requested
800             bytes are read.
801         :param read_some: Read function that returns at least one byte, but may
802             not return the number of bytes requested.
803         """
804         f, commit, abort = self.add_pack()
805         try:
806             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
807             copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
808             copier.verify()
809             self._complete_thin_pack(f, indexer)
810         except:
811             abort()
812             raise
813         else:
814             commit()
815
816
817 class ObjectImporter(object):
818     """Interface for importing objects."""
819
820     def __init__(self, count):
821         """Create a new ObjectImporter.
822
823         :param count: Number of objects that's going to be imported.
824         """
825         self.count = count
826
827     def add_object(self, object):
828         """Add an object."""
829         raise NotImplementedError(self.add_object)
830
831     def finish(self, object):
832         """Finish the import and write objects to disk."""
833         raise NotImplementedError(self.finish)
834
835
836 class ObjectIterator(object):
837     """Interface for iterating over objects."""
838
839     def iterobjects(self):
840         raise NotImplementedError(self.iterobjects)
841
842
843 class ObjectStoreIterator(ObjectIterator):
844     """ObjectIterator that works on top of an ObjectStore."""
845
846     def __init__(self, store, sha_iter):
847         """Create a new ObjectIterator.
848
849         :param store: Object store to retrieve from
850         :param sha_iter: Iterator over (sha, path) tuples
851         """
852         self.store = store
853         self.sha_iter = sha_iter
854         self._shas = []
855
856     def __iter__(self):
857         """Yield tuple with next object and path."""
858         for sha, path in self.itershas():
859             yield self.store[sha], path
860
861     def iterobjects(self):
862         """Iterate over just the objects."""
863         for o, path in self:
864             yield o
865
866     def itershas(self):
867         """Iterate over the SHAs."""
868         for sha in self._shas:
869             yield sha
870         for sha in self.sha_iter:
871             self._shas.append(sha)
872             yield sha
873
874     def __contains__(self, needle):
875         """Check if an object is present.
876
877         :note: This checks if the object is present in
878             the underlying object store, not if it would
879             be yielded by the iterator.
880
881         :param needle: SHA1 of the object to check for
882         """
883         return needle in self.store
884
885     def __getitem__(self, key):
886         """Find an object by SHA1.
887
888         :note: This retrieves the object from the underlying
889             object store. It will also succeed if the object would
890             not be returned by the iterator.
891         """
892         return self.store[key]
893
894     def __len__(self):
895         """Return the number of objects."""
896         return len(list(self.itershas()))
897
898
899 def tree_lookup_path(lookup_obj, root_sha, path):
900     """Look up an object in a Git tree.
901
902     :param lookup_obj: Callback for retrieving object by SHA1
903     :param root_sha: SHA1 of the root tree
904     :param path: Path to lookup
905     :return: A tuple of (mode, SHA) of the resulting path.
906     """
907     tree = lookup_obj(root_sha)
908     if not isinstance(tree, Tree):
909         raise NotTreeError(root_sha)
910     return tree.lookup_path(lookup_obj, path)
911
912
913 def _collect_filetree_revs(obj_store, tree_sha, kset):
914     """Collect SHA1s of files and directories for specified tree.
915
916     :param obj_store: Object store to get objects by SHA from
917     :param tree_sha: tree reference to walk
918     :param kset: set to fill with references to files and directories
919     """
920     filetree = obj_store[tree_sha]
921     for name, mode, sha in filetree.iteritems():
922        if not S_ISGITLINK(mode) and sha not in kset:
923            kset.add(sha)
924            if stat.S_ISDIR(mode):
925                _collect_filetree_revs(obj_store, sha, kset)
926
927
928 def _split_commits_and_tags(obj_store, lst, ignore_unknown=False):
929     """Split object id list into two list with commit SHA1s and tag SHA1s.
930
931     Commits referenced by tags are included into commits
932     list as well. Only SHA1s known in this repository will get
933     through, and unless ignore_unknown argument is True, KeyError
934     is thrown for SHA1 missing in the repository
935
936     :param obj_store: Object store to get objects by SHA1 from
937     :param lst: Collection of commit and tag SHAs
938     :param ignore_unknown: True to skip SHA1 missing in the repository
939         silently.
940     :return: A tuple of (commits, tags) SHA1s
941     """
942     commits = set()
943     tags = set()
944     for e in lst:
945         try:
946             o = obj_store[e]
947         except KeyError:
948             if not ignore_unknown:
949                 raise
950         else:
951             if isinstance(o, Commit):
952                 commits.add(e)
953             elif isinstance(o, Tag):
954                 tags.add(e)
955                 commits.add(o.object[1])
956             else:
957                 raise KeyError('Not a commit or a tag: %s' % e)
958     return (commits, tags)
959
960
961 class MissingObjectFinder(object):
962     """Find the objects missing from another object store.
963
964     :param object_store: Object store containing at least all objects to be
965         sent
966     :param haves: SHA1s of commits not to send (already present in target)
967     :param wants: SHA1s of commits to send
968     :param progress: Optional function to report progress to.
969     :param get_tagged: Function that returns a dict of pointed-to sha -> tag
970         sha for including tags.
971     :param tagged: dict of pointed-to sha -> tag sha for including tags
972     """
973
974     def __init__(self, object_store, haves, wants, progress=None,
975                  get_tagged=None):
976         self.object_store = object_store
977         # process Commits and Tags differently
978         # Note, while haves may list commits/tags not available locally,
979         # and such SHAs would get filtered out by _split_commits_and_tags,
980         # wants shall list only known SHAs, and otherwise
981         # _split_commits_and_tags fails with KeyError
982         have_commits, have_tags = \
983                 _split_commits_and_tags(object_store, haves, True)
984         want_commits, want_tags = \
985                 _split_commits_and_tags(object_store, wants, False)
986         # all_ancestors is a set of commits that shall not be sent
987         # (complete repository up to 'haves')
988         all_ancestors = object_store._collect_ancestors(have_commits)[0]
989         # all_missing - complete set of commits between haves and wants
990         # common - commits from all_ancestors we hit into while
991         # traversing parent hierarchy of wants
992         missing_commits, common_commits = \
993             object_store._collect_ancestors(want_commits, all_ancestors)
994         self.sha_done = set()
995         # Now, fill sha_done with commits and revisions of
996         # files and directories known to be both locally
997         # and on target. Thus these commits and files
998         # won't get selected for fetch
999         for h in common_commits:
1000             self.sha_done.add(h)
1001             cmt = object_store[h]
1002             _collect_filetree_revs(object_store, cmt.tree, self.sha_done)
1003         # record tags we have as visited, too
1004         for t in have_tags:
1005             self.sha_done.add(t)
1006
1007         missing_tags = want_tags.difference(have_tags)
1008         # in fact, what we 'want' is commits and tags
1009         # we've found missing
1010         wants = missing_commits.union(missing_tags)
1011
1012         self.objects_to_send = set([(w, None, False) for w in wants])
1013
1014         if progress is None:
1015             self.progress = lambda x: None
1016         else:
1017             self.progress = progress
1018         self._tagged = get_tagged and get_tagged() or {}
1019
1020     def add_todo(self, entries):
1021         self.objects_to_send.update([e for e in entries
1022                                      if not e[0] in self.sha_done])
1023
1024     def next(self):
1025         while True:
1026             if not self.objects_to_send:
1027                 return None
1028             (sha, name, leaf) = self.objects_to_send.pop()
1029             if sha not in self.sha_done:
1030                 break
1031         if not leaf:
1032             o = self.object_store[sha]
1033             if isinstance(o, Commit):
1034                 self.add_todo([(o.tree, "", False)])
1035             elif isinstance(o, Tree):
1036                 self.add_todo([(s, n, not stat.S_ISDIR(m))
1037                                for n, m, s in o.iteritems()
1038                                if not S_ISGITLINK(m)])
1039             elif isinstance(o, Tag):
1040                 self.add_todo([(o.object[1], None, False)])
1041         if sha in self._tagged:
1042             self.add_todo([(self._tagged[sha], None, True)])
1043         self.sha_done.add(sha)
1044         self.progress("counting objects: %d\r" % len(self.sha_done))
1045         return (sha, name)
1046
1047
1048 class ObjectStoreGraphWalker(object):
1049     """Graph walker that finds what commits are missing from an object store.
1050
1051     :ivar heads: Revisions without descendants in the local repo
1052     :ivar get_parents: Function to retrieve parents in the local repo
1053     """
1054
1055     def __init__(self, local_heads, get_parents):
1056         """Create a new instance.
1057
1058         :param local_heads: Heads to start search with
1059         :param get_parents: Function for finding the parents of a SHA1.
1060         """
1061         self.heads = set(local_heads)
1062         self.get_parents = get_parents
1063         self.parents = {}
1064
1065     def ack(self, sha):
1066         """Ack that a revision and its ancestors are present in the source."""
1067         ancestors = set([sha])
1068
1069         # stop if we run out of heads to remove
1070         while self.heads:
1071             for a in ancestors:
1072                 if a in self.heads:
1073                     self.heads.remove(a)
1074
1075             # collect all ancestors
1076             new_ancestors = set()
1077             for a in ancestors:
1078                 ps = self.parents.get(a)
1079                 if ps is not None:
1080                     new_ancestors.update(ps)
1081                 self.parents[a] = None
1082
1083             # no more ancestors; stop
1084             if not new_ancestors:
1085                 break
1086
1087             ancestors = new_ancestors
1088
1089     def next(self):
1090         """Iterate over ancestors of heads in the target."""
1091         if self.heads:
1092             ret = self.heads.pop()
1093             ps = self.get_parents(ret)
1094             self.parents[ret] = ps
1095             self.heads.update([p for p in ps if not p in self.parents])
1096             return ret
1097         return None