Fix all flake8 style issues.
[jelmer/dulwich.git] / dulwich / tests / utils.py
1 # utils.py -- Test utilities for Dulwich.
2 # Copyright (C) 2010 Google, Inc.
3 #
4 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
5 # General Public License as public by the Free Software Foundation; version 2.0
6 # or (at your option) any later version. You can redistribute it and/or
7 # modify it under the terms of either of these two licenses.
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 #
15 # You should have received a copy of the licenses; if not, see
16 # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
17 # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
18 # License, Version 2.0.
19 #
20
21 """Utility functions common to Dulwich tests."""
22
23
24 import datetime
25 import os
26 import shutil
27 import tempfile
28 import time
29 import types
30
31 import warnings
32
33 from dulwich.index import (
34     commit_tree,
35     )
36 from dulwich.objects import (
37     FixedSha,
38     Commit,
39     Tag,
40     object_class,
41     )
42 from dulwich.pack import (
43     OFS_DELTA,
44     REF_DELTA,
45     DELTA_TYPES,
46     obj_sha,
47     SHA1Writer,
48     write_pack_header,
49     write_pack_object,
50     create_delta,
51     )
52 from dulwich.repo import Repo
53 from dulwich.tests import (  # noqa: F401
54     skipIf,
55     SkipTest,
56     )
57
58
59 # Plain files are very frequently used in tests, so let the mode be very short.
60 F = 0o100644  # Shorthand mode for Files.
61
62
63 def open_repo(name, temp_dir=None):
64     """Open a copy of a repo in a temporary directory.
65
66     Use this function for accessing repos in dulwich/tests/data/repos to avoid
67     accidentally or intentionally modifying those repos in place. Use
68     tear_down_repo to delete any temp files created.
69
70     :param name: The name of the repository, relative to
71         dulwich/tests/data/repos
72     :param temp_dir: temporary directory to initialize to. If not provided, a
73         temporary directory will be created.
74     :returns: An initialized Repo object that lives in a temporary directory.
75     """
76     if temp_dir is None:
77         temp_dir = tempfile.mkdtemp()
78     repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
79     temp_repo_dir = os.path.join(temp_dir, name)
80     shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
81     return Repo(temp_repo_dir)
82
83
84 def tear_down_repo(repo):
85     """Tear down a test repository."""
86     repo.close()
87     temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
88     shutil.rmtree(temp_dir)
89
90
91 def make_object(cls, **attrs):
92     """Make an object for testing and assign some members.
93
94     This method creates a new subclass to allow arbitrary attribute
95     reassignment, which is not otherwise possible with objects having
96     __slots__.
97
98     :param attrs: dict of attributes to set on the new object.
99     :return: A newly initialized object of type cls.
100     """
101
102     class TestObject(cls):
103         """Class that inherits from the given class, but without __slots__.
104
105         Note that classes with __slots__ can't have arbitrary attributes
106         monkey-patched in, so this is a class that is exactly the same only
107         with a __dict__ instead of __slots__.
108         """
109         pass
110     TestObject.__name__ = 'TestObject_' + cls.__name__
111
112     obj = TestObject()
113     for name, value in attrs.items():
114         if name == 'id':
115             # id property is read-only, so we overwrite sha instead.
116             sha = FixedSha(value)
117             obj.sha = lambda: sha
118         else:
119             setattr(obj, name, value)
120     return obj
121
122
123 def make_commit(**attrs):
124     """Make a Commit object with a default set of members.
125
126     :param attrs: dict of attributes to overwrite from the default values.
127     :return: A newly initialized Commit object.
128     """
129     default_time = int(time.mktime(datetime.datetime(2010, 1, 1).timetuple()))
130     all_attrs = {'author': b'Test Author <test@nodomain.com>',
131                  'author_time': default_time,
132                  'author_timezone': 0,
133                  'committer': b'Test Committer <test@nodomain.com>',
134                  'commit_time': default_time,
135                  'commit_timezone': 0,
136                  'message': b'Test message.',
137                  'parents': [],
138                  'tree': b'0' * 40}
139     all_attrs.update(attrs)
140     return make_object(Commit, **all_attrs)
141
142
143 def make_tag(target, **attrs):
144     """Make a Tag object with a default set of values.
145
146     :param target: object to be tagged (Commit, Blob, Tree, etc)
147     :param attrs: dict of attributes to overwrite from the default values.
148     :return: A newly initialized Tag object.
149     """
150     target_id = target.id
151     target_type = object_class(target.type_name)
152     default_time = int(time.mktime(datetime.datetime(2010, 1, 1).timetuple()))
153     all_attrs = {'tagger': b'Test Author <test@nodomain.com>',
154                  'tag_time': default_time,
155                  'tag_timezone': 0,
156                  'message': b'Test message.',
157                  'object': (target_type, target_id),
158                  'name': b'Test Tag',
159                  }
160     all_attrs.update(attrs)
161     return make_object(Tag, **all_attrs)
162
163
164 def functest_builder(method, func):
165     """Generate a test method that tests the given function."""
166
167     def do_test(self):
168         method(self, func)
169
170     return do_test
171
172
173 def ext_functest_builder(method, func):
174     """Generate a test method that tests the given extension function.
175
176     This is intended to generate test methods that test both a pure-Python
177     version and an extension version using common test code. The extension test
178     will raise SkipTest if the extension is not found.
179
180     Sample usage:
181
182     class MyTest(TestCase);
183         def _do_some_test(self, func_impl):
184             self.assertEqual('foo', func_impl())
185
186         test_foo = functest_builder(_do_some_test, foo_py)
187         test_foo_extension = ext_functest_builder(_do_some_test, _foo_c)
188
189     :param method: The method to run. It must must two parameters, self and the
190         function implementation to test.
191     :param func: The function implementation to pass to method.
192     """
193
194     def do_test(self):
195         if not isinstance(func, types.BuiltinFunctionType):
196             raise SkipTest("%s extension not found" % func)
197         method(self, func)
198
199     return do_test
200
201
202 def build_pack(f, objects_spec, store=None):
203     """Write test pack data from a concise spec.
204
205     :param f: A file-like object to write the pack to.
206     :param objects_spec: A list of (type_num, obj). For non-delta types, obj
207         is the string of that object's data.
208         For delta types, obj is a tuple of (base, data), where:
209
210         * base can be either an index in objects_spec of the base for that
211         * delta; or for a ref delta, a SHA, in which case the resulting pack
212         * will be thin and the base will be an external ref.
213         * data is a string of the full, non-deltified data for that object.
214
215         Note that offsets/refs and deltas are computed within this function.
216     :param store: An optional ObjectStore for looking up external refs.
217     :return: A list of tuples in the order specified by objects_spec:
218         (offset, type num, data, sha, CRC32)
219     """
220     sf = SHA1Writer(f)
221     num_objects = len(objects_spec)
222     write_pack_header(sf, num_objects)
223
224     full_objects = {}
225     offsets = {}
226     crc32s = {}
227
228     while len(full_objects) < num_objects:
229         for i, (type_num, data) in enumerate(objects_spec):
230             if type_num not in DELTA_TYPES:
231                 full_objects[i] = (type_num, data,
232                                    obj_sha(type_num, [data]))
233                 continue
234             base, data = data
235             if isinstance(base, int):
236                 if base not in full_objects:
237                     continue
238                 base_type_num, _, _ = full_objects[base]
239             else:
240                 base_type_num, _ = store.get_raw(base)
241             full_objects[i] = (base_type_num, data,
242                                obj_sha(base_type_num, [data]))
243
244     for i, (type_num, obj) in enumerate(objects_spec):
245         offset = f.tell()
246         if type_num == OFS_DELTA:
247             base_index, data = obj
248             base = offset - offsets[base_index]
249             _, base_data, _ = full_objects[base_index]
250             obj = (base, create_delta(base_data, data))
251         elif type_num == REF_DELTA:
252             base_ref, data = obj
253             if isinstance(base_ref, int):
254                 _, base_data, base = full_objects[base_ref]
255             else:
256                 base_type_num, base_data = store.get_raw(base_ref)
257                 base = obj_sha(base_type_num, base_data)
258             obj = (base, create_delta(base_data, data))
259
260         crc32 = write_pack_object(sf, type_num, obj)
261         offsets[i] = offset
262         crc32s[i] = crc32
263
264     expected = []
265     for i in range(num_objects):
266         type_num, data, sha = full_objects[i]
267         assert len(sha) == 20
268         expected.append((offsets[i], type_num, data, sha, crc32s[i]))
269
270     sf.write_sha()
271     f.seek(0)
272     return expected
273
274
275 def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
276     """Build a commit graph from a concise specification.
277
278     Sample usage:
279     >>> c1, c2, c3 = build_commit_graph(store, [[1], [2, 1], [3, 1, 2]])
280     >>> store[store[c3].parents[0]] == c1
281     True
282     >>> store[store[c3].parents[1]] == c2
283     True
284
285     If not otherwise specified, commits will refer to the empty tree and have
286     commit times increasing in the same order as the commit spec.
287
288     :param object_store: An ObjectStore to commit objects to.
289     :param commit_spec: An iterable of iterables of ints defining the commit
290         graph. Each entry defines one commit, and entries must be in
291         topological order. The first element of each entry is a commit number,
292         and the remaining elements are its parents. The commit numbers are only
293         meaningful for the call to make_commits; since real commit objects are
294         created, they will get created with real, opaque SHAs.
295     :param trees: An optional dict of commit number -> tree spec for building
296         trees for commits. The tree spec is an iterable of (path, blob, mode)
297         or (path, blob) entries; if mode is omitted, it defaults to the normal
298         file mode (0100644).
299     :param attrs: A dict of commit number -> (dict of attribute -> value) for
300         assigning additional values to the commits.
301     :return: The list of commit objects created.
302     :raise ValueError: If an undefined commit identifier is listed as a parent.
303     """
304     if trees is None:
305         trees = {}
306     if attrs is None:
307         attrs = {}
308     commit_time = 0
309     nums = {}
310     commits = []
311
312     for commit in commit_spec:
313         commit_num = commit[0]
314         try:
315             parent_ids = [nums[pn] for pn in commit[1:]]
316         except KeyError as e:
317             missing_parent, = e.args
318             raise ValueError('Unknown parent %i' % missing_parent)
319
320         blobs = []
321         for entry in trees.get(commit_num, []):
322             if len(entry) == 2:
323                 path, blob = entry
324                 entry = (path, blob, F)
325             path, blob, mode = entry
326             blobs.append((path, blob.id, mode))
327             object_store.add_object(blob)
328         tree_id = commit_tree(object_store, blobs)
329
330         commit_attrs = {
331             'message': ('Commit %i' % commit_num).encode('ascii'),
332             'parents': parent_ids,
333             'tree': tree_id,
334             'commit_time': commit_time,
335             }
336         commit_attrs.update(attrs.get(commit_num, {}))
337         commit_obj = make_commit(**commit_attrs)
338
339         # By default, increment the time by a lot. Out-of-order commits should
340         # be closer together than this because their main cause is clock skew.
341         commit_time = commit_attrs['commit_time'] + 100
342         nums[commit_num] = commit_obj.id
343         object_store.add_object(commit_obj)
344         commits.append(commit_obj)
345
346     return commits
347
348
349 def setup_warning_catcher():
350     """Wrap warnings.showwarning with code that records warnings."""
351
352     caught_warnings = []
353     original_showwarning = warnings.showwarning
354
355     def custom_showwarning(*args,  **kwargs):
356         caught_warnings.append(args[0])
357
358     warnings.showwarning = custom_showwarning
359
360     def restore_showwarning():
361         warnings.showwarning = original_showwarning
362
363     return caught_warnings, restore_showwarning