+++ /dev/null
-# Copyright (c) 2009-2012 testtools developers. See LICENSE for details.
-
-__all__ = [
- 'KeysEqual',
- ]
-
-from ..helpers import (
- dict_subtract,
- filter_values,
- map_values,
- )
-from ._higherorder import (
- AnnotatedMismatch,
- PrefixedMismatch,
- MismatchesAll,
- )
-from ._impl import Matcher, Mismatch
-
-
-def LabelledMismatches(mismatches, details=None):
- """A collection of mismatches, each labelled."""
- return MismatchesAll(
- (PrefixedMismatch(k, v) for (k, v) in sorted(mismatches.items())),
- wrap=False)
-
-
-class MatchesAllDict(Matcher):
- """Matches if all of the matchers it is created with match.
-
- A lot like ``MatchesAll``, but takes a dict of Matchers and labels any
- mismatches with the key of the dictionary.
- """
-
- def __init__(self, matchers):
- super(MatchesAllDict, self).__init__()
- self.matchers = matchers
-
- def __str__(self):
- return 'MatchesAllDict(%s)' % (_format_matcher_dict(self.matchers),)
-
- def match(self, observed):
- mismatches = {}
- for label in self.matchers:
- mismatches[label] = self.matchers[label].match(observed)
- return _dict_to_mismatch(
- mismatches, result_mismatch=LabelledMismatches)
-
-
-class DictMismatches(Mismatch):
- """A mismatch with a dict of child mismatches."""
-
- def __init__(self, mismatches, details=None):
- super(DictMismatches, self).__init__(None, details=details)
- self.mismatches = mismatches
-
- def describe(self):
- lines = ['{']
- lines.extend(
- [' %r: %s,' % (key, mismatch.describe())
- for (key, mismatch) in sorted(self.mismatches.items())])
- lines.append('}')
- return '\n'.join(lines)
-
-
-def _dict_to_mismatch(data, to_mismatch=None,
- result_mismatch=DictMismatches):
- if to_mismatch:
- data = map_values(to_mismatch, data)
- mismatches = filter_values(bool, data)
- if mismatches:
- return result_mismatch(mismatches)
-
-
-class _MatchCommonKeys(Matcher):
- """Match on keys in a dictionary.
-
- Given a dictionary where the values are matchers, this will look for
- common keys in the matched dictionary and match if and only if all common
- keys match the given matchers.
-
- Thus::
-
- >>> structure = {'a': Equals('x'), 'b': Equals('y')}
- >>> _MatchCommonKeys(structure).match({'a': 'x', 'c': 'z'})
- None
- """
-
- def __init__(self, dict_of_matchers):
- super(_MatchCommonKeys, self).__init__()
- self._matchers = dict_of_matchers
-
- def _compare_dicts(self, expected, observed):
- common_keys = set(expected.keys()) & set(observed.keys())
- mismatches = {}
- for key in common_keys:
- mismatch = expected[key].match(observed[key])
- if mismatch:
- mismatches[key] = mismatch
- return mismatches
-
- def match(self, observed):
- mismatches = self._compare_dicts(self._matchers, observed)
- if mismatches:
- return DictMismatches(mismatches)
-
-
-class _SubDictOf(Matcher):
- """Matches if the matched dict only has keys that are in given dict."""
-
- def __init__(self, super_dict, format_value=repr):
- super(_SubDictOf, self).__init__()
- self.super_dict = super_dict
- self.format_value = format_value
-
- def match(self, observed):
- excess = dict_subtract(observed, self.super_dict)
- return _dict_to_mismatch(
- excess, lambda v: Mismatch(self.format_value(v)))
-
-
-class _SuperDictOf(Matcher):
- """Matches if all of the keys in the given dict are in the matched dict.
- """
-
- def __init__(self, sub_dict, format_value=repr):
- super(_SuperDictOf, self).__init__()
- self.sub_dict = sub_dict
- self.format_value = format_value
-
- def match(self, super_dict):
- return _SubDictOf(super_dict, self.format_value).match(self.sub_dict)
-
-
-def _format_matcher_dict(matchers):
- return '{%s}' % (
- ', '.join(sorted('%r: %s' % (k, v) for k, v in matchers.items())))
-
-
-class _CombinedMatcher(Matcher):
- """Many matchers labelled and combined into one uber-matcher.
-
- Subclass this and then specify a dict of matcher factories that take a
- single 'expected' value and return a matcher. The subclass will match
- only if all of the matchers made from factories match.
-
- Not **entirely** dissimilar from ``MatchesAll``.
- """
-
- matcher_factories = {}
-
- def __init__(self, expected):
- super(_CombinedMatcher, self).__init__()
- self._expected = expected
-
- def format_expected(self, expected):
- return repr(expected)
-
- def __str__(self):
- return '%s(%s)' % (
- self.__class__.__name__, self.format_expected(self._expected))
-
- def match(self, observed):
- matchers = dict(
- (k, v(self._expected)) for k, v in self.matcher_factories.items())
- return MatchesAllDict(matchers).match(observed)
-
-
-class MatchesDict(_CombinedMatcher):
- """Match a dictionary exactly, by its keys.
-
- Specify a dictionary mapping keys (often strings) to matchers. This is
- the 'expected' dict. Any dictionary that matches this must have exactly
- the same keys, and the values must match the corresponding matchers in the
- expected dict.
- """
-
- matcher_factories = {
- 'Extra': _SubDictOf,
- 'Missing': lambda m: _SuperDictOf(m, format_value=str),
- 'Differences': _MatchCommonKeys,
- }
-
- format_expected = lambda self, expected: _format_matcher_dict(expected)
-
-
-class ContainsDict(_CombinedMatcher):
- """Match a dictionary for that contains a specified sub-dictionary.
-
- Specify a dictionary mapping keys (often strings) to matchers. This is
- the 'expected' dict. Any dictionary that matches this must have **at
- least** these keys, and the values must match the corresponding matchers
- in the expected dict. Dictionaries that have more keys will also match.
-
- In other words, any matching dictionary must contain the dictionary given
- to the constructor.
-
- Does not check for strict sub-dictionary. That is, equal dictionaries
- match.
- """
-
- matcher_factories = {
- 'Missing': lambda m: _SuperDictOf(m, format_value=str),
- 'Differences': _MatchCommonKeys,
- }
-
- format_expected = lambda self, expected: _format_matcher_dict(expected)
-
-
-class ContainedByDict(_CombinedMatcher):
- """Match a dictionary for which this is a super-dictionary.
-
- Specify a dictionary mapping keys (often strings) to matchers. This is
- the 'expected' dict. Any dictionary that matches this must have **only**
- these keys, and the values must match the corresponding matchers in the
- expected dict. Dictionaries that have fewer keys can also match.
-
- In other words, any matching dictionary must be contained by the
- dictionary given to the constructor.
-
- Does not check for strict super-dictionary. That is, equal dictionaries
- match.
- """
-
- matcher_factories = {
- 'Extra': _SubDictOf,
- 'Differences': _MatchCommonKeys,
- }
-
- format_expected = lambda self, expected: _format_matcher_dict(expected)
-
-
-class KeysEqual(Matcher):
- """Checks whether a dict has particular keys."""
-
- def __init__(self, *expected):
- """Create a `KeysEqual` Matcher.
-
- :param expected: The keys the dict is expected to have. If a dict,
- then we use the keys of that dict, if a collection, we assume it
- is a collection of expected keys.
- """
- super(KeysEqual, self).__init__()
- try:
- self.expected = expected.keys()
- except AttributeError:
- self.expected = list(expected)
-
- def __str__(self):
- return "KeysEqual(%s)" % ', '.join(map(repr, self.expected))
-
- def match(self, matchee):
- from ._basic import _BinaryMismatch, Equals
- expected = sorted(self.expected)
- matched = Equals(expected).match(sorted(matchee.keys()))
- if matched:
- return AnnotatedMismatch(
- 'Keys not equal',
- _BinaryMismatch(expected, 'does not match', matchee))
- return None