BaseObjectStore.determine_wants_all no longer breaks on zero SHAs.
[jelmer/dulwich.git] / dulwich / object_store.py
1 # object_store.py -- Object store for git objects
2 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; either version 2
7 # or (at your option) a later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19
20 """Git object store interfaces and implementation."""
21
22
23 import errno
24 import itertools
25 import os
26 import stat
27 import tempfile
28 import urllib2
29
30 from dulwich.diff_tree import (
31     tree_changes,
32     walk_trees,
33     )
34 from dulwich.errors import (
35     NotTreeError,
36     )
37 from dulwich.file import GitFile
38 from dulwich.objects import (
39     Commit,
40     ShaFile,
41     Tag,
42     Tree,
43     ZERO_SHA,
44     hex_to_sha,
45     sha_to_hex,
46     hex_to_filename,
47     S_ISGITLINK,
48     object_class,
49     )
50 from dulwich.pack import (
51     Pack,
52     PackData,
53     ThinPackData,
54     iter_sha1,
55     load_pack_index,
56     write_pack,
57     write_pack_data,
58     write_pack_index_v2,
59     )
60
61 INFODIR = 'info'
62 PACKDIR = 'pack'
63
64
65 class BaseObjectStore(object):
66     """Object store interface."""
67
68     def determine_wants_all(self, refs):
69         return [sha for (ref, sha) in refs.iteritems()
70                 if not sha in self and not ref.endswith("^{}") and
71                    not sha == ZERO_SHA]
72
73     def iter_shas(self, shas):
74         """Iterate over the objects for the specified shas.
75
76         :param shas: Iterable object with SHAs
77         :return: Object iterator
78         """
79         return ObjectStoreIterator(self, shas)
80
81     def contains_loose(self, sha):
82         """Check if a particular object is present by SHA1 and is loose."""
83         raise NotImplementedError(self.contains_loose)
84
85     def contains_packed(self, sha):
86         """Check if a particular object is present by SHA1 and is packed."""
87         raise NotImplementedError(self.contains_packed)
88
89     def __contains__(self, sha):
90         """Check if a particular object is present by SHA1.
91
92         This method makes no distinction between loose and packed objects.
93         """
94         return self.contains_packed(sha) or self.contains_loose(sha)
95
96     @property
97     def packs(self):
98         """Iterable of pack objects."""
99         raise NotImplementedError
100
101     def get_raw(self, name):
102         """Obtain the raw text for an object.
103
104         :param name: sha for the object.
105         :return: tuple with numeric type and object contents.
106         """
107         raise NotImplementedError(self.get_raw)
108
109     def __getitem__(self, sha):
110         """Obtain an object by SHA1."""
111         type_num, uncomp = self.get_raw(sha)
112         return ShaFile.from_raw_string(type_num, uncomp)
113
114     def __iter__(self):
115         """Iterate over the SHAs that are present in this store."""
116         raise NotImplementedError(self.__iter__)
117
118     def add_object(self, obj):
119         """Add a single object to this object store.
120
121         """
122         raise NotImplementedError(self.add_object)
123
124     def add_objects(self, objects):
125         """Add a set of objects to this object store.
126
127         :param objects: Iterable over a list of objects.
128         """
129         raise NotImplementedError(self.add_objects)
130
131     def tree_changes(self, source, target, want_unchanged=False):
132         """Find the differences between the contents of two trees
133
134         :param object_store: Object store to use for retrieving tree contents
135         :param tree: SHA1 of the root tree
136         :param want_unchanged: Whether unchanged files should be reported
137         :return: Iterator over tuples with
138             (oldpath, newpath), (oldmode, newmode), (oldsha, newsha)
139         """
140         for change in tree_changes(self, source, target,
141                                    want_unchanged=want_unchanged):
142             yield ((change.old.path, change.new.path),
143                    (change.old.mode, change.new.mode),
144                    (change.old.sha, change.new.sha))
145
146     def iter_tree_contents(self, tree_id, include_trees=False):
147         """Iterate the contents of a tree and all subtrees.
148
149         Iteration is depth-first pre-order, as in e.g. os.walk.
150
151         :param tree_id: SHA1 of the tree.
152         :param include_trees: If True, include tree objects in the iteration.
153         :return: Iterator over TreeEntry namedtuples for all the objects in a
154             tree.
155         """
156         for entry, _ in walk_trees(self, tree_id, None):
157             if not stat.S_ISDIR(entry.mode) or include_trees:
158                 yield entry
159
160     def find_missing_objects(self, haves, wants, progress=None,
161                              get_tagged=None):
162         """Find the missing objects required for a set of revisions.
163
164         :param haves: Iterable over SHAs already in common.
165         :param wants: Iterable over SHAs of objects to fetch.
166         :param progress: Simple progress function that will be called with
167             updated progress strings.
168         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
169             sha for including tags.
170         :return: Iterator over (sha, path) pairs.
171         """
172         finder = MissingObjectFinder(self, haves, wants, progress, get_tagged)
173         return iter(finder.next, None)
174
175     def find_common_revisions(self, graphwalker):
176         """Find which revisions this store has in common using graphwalker.
177
178         :param graphwalker: A graphwalker object.
179         :return: List of SHAs that are in common
180         """
181         haves = []
182         sha = graphwalker.next()
183         while sha:
184             if sha in self:
185                 haves.append(sha)
186                 graphwalker.ack(sha)
187             sha = graphwalker.next()
188         return haves
189
190     def get_graph_walker(self, heads):
191         """Obtain a graph walker for this object store.
192
193         :param heads: Local heads to start search with
194         :return: GraphWalker object
195         """
196         return ObjectStoreGraphWalker(heads, lambda sha: self[sha].parents)
197
198     def generate_pack_contents(self, have, want, progress=None):
199         """Iterate over the contents of a pack file.
200
201         :param have: List of SHA1s of objects that should not be sent
202         :param want: List of SHA1s of objects that should be sent
203         :param progress: Optional progress reporting method
204         """
205         return self.iter_shas(self.find_missing_objects(have, want, progress))
206
207     def peel_sha(self, sha):
208         """Peel all tags from a SHA.
209
210         :param sha: The object SHA to peel.
211         :return: The fully-peeled SHA1 of a tag object, after peeling all
212             intermediate tags; if the original ref does not point to a tag, this
213             will equal the original SHA1.
214         """
215         obj = self[sha]
216         obj_class = object_class(obj.type_name)
217         while obj_class is Tag:
218             obj_class, sha = obj.object
219             obj = self[sha]
220         return obj
221
222
223 class PackBasedObjectStore(BaseObjectStore):
224
225     def __init__(self):
226         self._pack_cache = None
227
228     def contains_packed(self, sha):
229         """Check if a particular object is present by SHA1 and is packed."""
230         for pack in self.packs:
231             if sha in pack:
232                 return True
233         return False
234
235     def _load_packs(self):
236         raise NotImplementedError(self._load_packs)
237
238     def _pack_cache_stale(self):
239         """Check whether the pack cache is stale."""
240         raise NotImplementedError(self._pack_cache_stale)
241
242     def _add_known_pack(self, pack):
243         """Add a newly appeared pack to the cache by path.
244
245         """
246         if self._pack_cache is not None:
247             self._pack_cache.append(pack)
248
249     @property
250     def packs(self):
251         """List with pack objects."""
252         if self._pack_cache is None or self._pack_cache_stale():
253             self._pack_cache = self._load_packs()
254         return self._pack_cache
255
256     def _iter_loose_objects(self):
257         """Iterate over the SHAs of all loose objects."""
258         raise NotImplementedError(self._iter_loose_objects)
259
260     def _get_loose_object(self, sha):
261         raise NotImplementedError(self._get_loose_object)
262
263     def _remove_loose_object(self, sha):
264         raise NotImplementedError(self._remove_loose_object)
265
266     def pack_loose_objects(self):
267         """Pack loose objects.
268         
269         :return: Number of objects packed
270         """
271         objects = set()
272         for sha in self._iter_loose_objects():
273             objects.add((self._get_loose_object(sha), None))
274         self.add_objects(objects)
275         for obj, path in objects:
276             self._remove_loose_object(obj.id)
277         return len(objects)
278
279     def __iter__(self):
280         """Iterate over the SHAs that are present in this store."""
281         iterables = self.packs + [self._iter_loose_objects()]
282         return itertools.chain(*iterables)
283
284     def contains_loose(self, sha):
285         """Check if a particular object is present by SHA1 and is loose."""
286         return self._get_loose_object(sha) is not None
287
288     def get_raw(self, name):
289         """Obtain the raw text for an object.
290
291         :param name: sha for the object.
292         :return: tuple with numeric type and object contents.
293         """
294         if len(name) == 40:
295             sha = hex_to_sha(name)
296             hexsha = name
297         elif len(name) == 20:
298             sha = name
299             hexsha = None
300         else:
301             raise AssertionError("Invalid object name %r" % name)
302         for pack in self.packs:
303             try:
304                 return pack.get_raw(sha)
305             except KeyError:
306                 pass
307         if hexsha is None:
308             hexsha = sha_to_hex(name)
309         ret = self._get_loose_object(hexsha)
310         if ret is not None:
311             return ret.type_num, ret.as_raw_string()
312         raise KeyError(hexsha)
313
314     def add_objects(self, objects):
315         """Add a set of objects to this object store.
316
317         :param objects: Iterable over objects, should support __len__.
318         :return: Pack object of the objects written.
319         """
320         if len(objects) == 0:
321             # Don't bother writing an empty pack file
322             return
323         f, commit = self.add_pack()
324         write_pack_data(f, objects, len(objects))
325         return commit()
326
327
328 class DiskObjectStore(PackBasedObjectStore):
329     """Git-style object store that exists on disk."""
330
331     def __init__(self, path):
332         """Open an object store.
333
334         :param path: Path of the object store.
335         """
336         super(DiskObjectStore, self).__init__()
337         self.path = path
338         self.pack_dir = os.path.join(self.path, PACKDIR)
339         self._pack_cache_time = 0
340
341     def _load_packs(self):
342         pack_files = []
343         try:
344             self._pack_cache_time = os.stat(self.pack_dir).st_mtime
345             pack_dir_contents = os.listdir(self.pack_dir)
346             for name in pack_dir_contents:
347                 # TODO: verify that idx exists first
348                 if name.startswith("pack-") and name.endswith(".pack"):
349                     filename = os.path.join(self.pack_dir, name)
350                     pack_files.append((os.stat(filename).st_mtime, filename))
351         except OSError, e:
352             if e.errno == errno.ENOENT:
353                 return []
354             raise
355         pack_files.sort(reverse=True)
356         suffix_len = len(".pack")
357         return [Pack(f[:-suffix_len]) for _, f in pack_files]
358
359     def _pack_cache_stale(self):
360         try:
361             return os.stat(self.pack_dir).st_mtime > self._pack_cache_time
362         except OSError, e:
363             if e.errno == errno.ENOENT:
364                 return True
365             raise
366
367     def _get_shafile_path(self, sha):
368         # Check from object dir
369         return hex_to_filename(self.path, sha)
370
371     def _iter_loose_objects(self):
372         for base in os.listdir(self.path):
373             if len(base) != 2:
374                 continue
375             for rest in os.listdir(os.path.join(self.path, base)):
376                 yield base+rest
377
378     def _get_loose_object(self, sha):
379         path = self._get_shafile_path(sha)
380         try:
381             return ShaFile.from_path(path)
382         except (OSError, IOError), e:
383             if e.errno == errno.ENOENT:
384                 return None
385             raise
386
387     def _remove_loose_object(self, sha):
388         os.remove(self._get_shafile_path(sha))
389
390     def move_in_thin_pack(self, path):
391         """Move a specific file containing a pack into the pack directory.
392
393         :note: The file should be on the same file system as the
394             packs directory.
395
396         :param path: Path to the pack file.
397         """
398         data = ThinPackData(self.get_raw, path)
399
400         # Write index for the thin pack (do we really need this?)
401         temppath = os.path.join(self.pack_dir,
402             sha_to_hex(urllib2.randombytes(20))+".tempidx")
403         data.create_index_v2(temppath)
404         p = Pack.from_objects(data, load_pack_index(temppath))
405
406         try:
407             # Write a full pack version
408             temppath = os.path.join(self.pack_dir,
409                 sha_to_hex(urllib2.randombytes(20))+".temppack")
410             write_pack(temppath, ((o, None) for o in p.iterobjects()), len(p))
411         finally:
412             p.close()
413
414         pack_sha = load_pack_index(temppath+".idx").objects_sha1()
415         newbasename = os.path.join(self.pack_dir, "pack-%s" % pack_sha)
416         os.rename(temppath+".pack", newbasename+".pack")
417         os.rename(temppath+".idx", newbasename+".idx")
418         final_pack = Pack(newbasename)
419         self._add_known_pack(final_pack)
420         return final_pack
421
422     def move_in_pack(self, path):
423         """Move a specific file containing a pack into the pack directory.
424
425         :note: The file should be on the same file system as the
426             packs directory.
427
428         :param path: Path to the pack file.
429         """
430         p = PackData(path)
431         entries = p.sorted_entries()
432         basename = os.path.join(self.pack_dir,
433             "pack-%s" % iter_sha1(entry[0] for entry in entries))
434         f = GitFile(basename+".idx", "wb")
435         try:
436             write_pack_index_v2(f, entries, p.get_stored_checksum())
437         finally:
438             f.close()
439         p.close()
440         os.rename(path, basename + ".pack")
441         final_pack = Pack(basename)
442         self._add_known_pack(final_pack)
443         return final_pack
444
445     def add_thin_pack(self):
446         """Add a new thin pack to this object store.
447
448         Thin packs are packs that contain deltas with parents that exist
449         in a different pack.
450         """
451         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
452         f = os.fdopen(fd, 'wb')
453         def commit():
454             os.fsync(fd)
455             f.close()
456             if os.path.getsize(path) > 0:
457                 return self.move_in_thin_pack(path)
458             else:
459                 os.remove(path)
460                 return None
461         return f, commit
462
463     def add_pack(self):
464         """Add a new pack to this object store.
465
466         :return: Fileobject to write to and a commit function to
467             call when the pack is finished.
468         """
469         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
470         f = os.fdopen(fd, 'wb')
471         def commit():
472             os.fsync(fd)
473             f.close()
474             if os.path.getsize(path) > 0:
475                 return self.move_in_pack(path)
476             else:
477                 os.remove(path)
478                 return None
479         return f, commit
480
481     def add_object(self, obj):
482         """Add a single object to this object store.
483
484         :param obj: Object to add
485         """
486         dir = os.path.join(self.path, obj.id[:2])
487         try:
488             os.mkdir(dir)
489         except OSError, e:
490             if e.errno != errno.EEXIST:
491                 raise
492         path = os.path.join(dir, obj.id[2:])
493         if os.path.exists(path):
494             return # Already there, no need to write again
495         f = GitFile(path, 'wb')
496         try:
497             f.write(obj.as_legacy_object())
498         finally:
499             f.close()
500
501     @classmethod
502     def init(cls, path):
503         try:
504             os.mkdir(path)
505         except OSError, e:
506             if e.errno != errno.EEXIST:
507                 raise
508         os.mkdir(os.path.join(path, "info"))
509         os.mkdir(os.path.join(path, PACKDIR))
510         return cls(path)
511
512
513 class MemoryObjectStore(BaseObjectStore):
514     """Object store that keeps all objects in memory."""
515
516     def __init__(self):
517         super(MemoryObjectStore, self).__init__()
518         self._data = {}
519
520     def contains_loose(self, sha):
521         """Check if a particular object is present by SHA1 and is loose."""
522         return sha in self._data
523
524     def contains_packed(self, sha):
525         """Check if a particular object is present by SHA1 and is packed."""
526         return False
527
528     def __iter__(self):
529         """Iterate over the SHAs that are present in this store."""
530         return self._data.iterkeys()
531
532     @property
533     def packs(self):
534         """List with pack objects."""
535         return []
536
537     def get_raw(self, name):
538         """Obtain the raw text for an object.
539
540         :param name: sha for the object.
541         :return: tuple with numeric type and object contents.
542         """
543         return self[name].as_raw_string()
544
545     def __getitem__(self, name):
546         return self._data[name]
547
548     def __delitem__(self, name):
549         """Delete an object from this store, for testing only."""
550         del self._data[name]
551
552     def add_object(self, obj):
553         """Add a single object to this object store.
554
555         """
556         self._data[obj.id] = obj
557
558     def add_objects(self, objects):
559         """Add a set of objects to this object store.
560
561         :param objects: Iterable over a list of objects.
562         """
563         for obj, path in objects:
564             self._data[obj.id] = obj
565
566
567 class ObjectImporter(object):
568     """Interface for importing objects."""
569
570     def __init__(self, count):
571         """Create a new ObjectImporter.
572
573         :param count: Number of objects that's going to be imported.
574         """
575         self.count = count
576
577     def add_object(self, object):
578         """Add an object."""
579         raise NotImplementedError(self.add_object)
580
581     def finish(self, object):
582         """Finish the import and write objects to disk."""
583         raise NotImplementedError(self.finish)
584
585
586 class ObjectIterator(object):
587     """Interface for iterating over objects."""
588
589     def iterobjects(self):
590         raise NotImplementedError(self.iterobjects)
591
592
593 class ObjectStoreIterator(ObjectIterator):
594     """ObjectIterator that works on top of an ObjectStore."""
595
596     def __init__(self, store, sha_iter):
597         """Create a new ObjectIterator.
598
599         :param store: Object store to retrieve from
600         :param sha_iter: Iterator over (sha, path) tuples
601         """
602         self.store = store
603         self.sha_iter = sha_iter
604         self._shas = []
605
606     def __iter__(self):
607         """Yield tuple with next object and path."""
608         for sha, path in self.itershas():
609             yield self.store[sha], path
610
611     def iterobjects(self):
612         """Iterate over just the objects."""
613         for o, path in self:
614             yield o
615
616     def itershas(self):
617         """Iterate over the SHAs."""
618         for sha in self._shas:
619             yield sha
620         for sha in self.sha_iter:
621             self._shas.append(sha)
622             yield sha
623
624     def __contains__(self, needle):
625         """Check if an object is present.
626
627         :note: This checks if the object is present in
628             the underlying object store, not if it would
629             be yielded by the iterator.
630
631         :param needle: SHA1 of the object to check for
632         """
633         return needle in self.store
634
635     def __getitem__(self, key):
636         """Find an object by SHA1.
637
638         :note: This retrieves the object from the underlying
639             object store. It will also succeed if the object would
640             not be returned by the iterator.
641         """
642         return self.store[key]
643
644     def __len__(self):
645         """Return the number of objects."""
646         return len(list(self.itershas()))
647
648
649 def tree_lookup_path(lookup_obj, root_sha, path):
650     """Lookup an object in a Git tree.
651
652     :param lookup_obj: Callback for retrieving object by SHA1
653     :param root_sha: SHA1 of the root tree
654     :param path: Path to lookup
655     """
656     parts = path.split("/")
657     sha = root_sha
658     mode = None
659     for p in parts:
660         obj = lookup_obj(sha)
661         if not isinstance(obj, Tree):
662             raise NotTreeError(sha)
663         if p == '':
664             continue
665         mode, sha = obj[p]
666     return mode, sha
667
668
669 class MissingObjectFinder(object):
670     """Find the objects missing from another object store.
671
672     :param object_store: Object store containing at least all objects to be
673         sent
674     :param haves: SHA1s of commits not to send (already present in target)
675     :param wants: SHA1s of commits to send
676     :param progress: Optional function to report progress to.
677     :param get_tagged: Function that returns a dict of pointed-to sha -> tag
678         sha for including tags.
679     :param tagged: dict of pointed-to sha -> tag sha for including tags
680     """
681
682     def __init__(self, object_store, haves, wants, progress=None,
683                  get_tagged=None):
684         haves = set(haves)
685         self.sha_done = haves
686         self.objects_to_send = set([(w, None, False) for w in wants
687                                     if w not in haves])
688         self.object_store = object_store
689         if progress is None:
690             self.progress = lambda x: None
691         else:
692             self.progress = progress
693         self._tagged = get_tagged and get_tagged() or {}
694
695     def add_todo(self, entries):
696         self.objects_to_send.update([e for e in entries
697                                      if not e[0] in self.sha_done])
698
699     def parse_tree(self, tree):
700         self.add_todo([(sha, name, not stat.S_ISDIR(mode))
701                        for mode, name, sha in tree.entries()
702                        if not S_ISGITLINK(mode)])
703
704     def parse_commit(self, commit):
705         self.add_todo([(commit.tree, "", False)])
706         self.add_todo([(p, None, False) for p in commit.parents])
707
708     def parse_tag(self, tag):
709         self.add_todo([(tag.object[1], None, False)])
710
711     def next(self):
712         if not self.objects_to_send:
713             return None
714         (sha, name, leaf) = self.objects_to_send.pop()
715         if not leaf:
716             o = self.object_store[sha]
717             if isinstance(o, Commit):
718                 self.parse_commit(o)
719             elif isinstance(o, Tree):
720                 self.parse_tree(o)
721             elif isinstance(o, Tag):
722                 self.parse_tag(o)
723         if sha in self._tagged:
724             self.add_todo([(self._tagged[sha], None, True)])
725         self.sha_done.add(sha)
726         self.progress("counting objects: %d\r" % len(self.sha_done))
727         return (sha, name)
728
729
730 class ObjectStoreGraphWalker(object):
731     """Graph walker that finds what commits are missing from an object store.
732
733     :ivar heads: Revisions without descendants in the local repo
734     :ivar get_parents: Function to retrieve parents in the local repo
735     """
736
737     def __init__(self, local_heads, get_parents):
738         """Create a new instance.
739
740         :param local_heads: Heads to start search with
741         :param get_parents: Function for finding the parents of a SHA1.
742         """
743         self.heads = set(local_heads)
744         self.get_parents = get_parents
745         self.parents = {}
746
747     def ack(self, sha):
748         """Ack that a revision and its ancestors are present in the source."""
749         ancestors = set([sha])
750
751         # stop if we run out of heads to remove
752         while self.heads:
753             for a in ancestors:
754                 if a in self.heads:
755                     self.heads.remove(a)
756
757             # collect all ancestors
758             new_ancestors = set()
759             for a in ancestors:
760                 if a in self.parents:
761                     new_ancestors.update(self.parents[a])
762
763             # no more ancestors; stop
764             if not new_ancestors:
765                 break
766
767             ancestors = new_ancestors
768
769     def next(self):
770         """Iterate over ancestors of heads in the target."""
771         if self.heads:
772             ret = self.heads.pop()
773             ps = self.get_parents(ret)
774             self.parents[ret] = ps
775             self.heads.update(ps)
776             return ret
777         return None