ldb/python: Allow comparing a MessageElement to a list or a singleton.
authorJelmer Vernooij <jelmer@samba.org>
Fri, 11 Jan 2008 00:55:56 +0000 (01:55 +0100)
committerJelmer Vernooij <jelmer@samba.org>
Fri, 11 Jan 2008 00:55:56 +0000 (01:55 +0100)
source/lib/ldb/ldb.i
source/lib/ldb/ldb.py
source/lib/ldb/tests/python/api.py

index 57fa36584ec0ad6c7f9112eba6ff97694da18048..b6718351f8d6bd95bb236f292fd76b902e13b39a 100644 (file)
@@ -306,12 +306,15 @@ typedef struct ldb_message_element {
             return ret
 
         def __eq__(self, other):
-            if (isinstance(other, str) and 
-                len(set(self)) == 1 and 
-                set(self).pop() == other):
+            if (len(self) == 1 and self.get(0) == other):
                 return True
-            return self.__cmp__(other) == 0
-                
+            if isinstance(other, self.__class__):
+                return self.__cmp__(other) == 0
+            o = iter(other)
+            for i in range(len(self)):
+                if self.get(i) != o.next():
+                    return False
+            return True
     }
 } ldb_msg_element;
 
index 2d037f080c007a5b8545e692df198dee751cfe03..6aacc8c09c0300744902683089ec8cb005a2a318 100644 (file)
@@ -101,12 +101,15 @@ class ldb_msg_element(object):
         return ret
 
     def __eq__(self, other):
-        if (isinstance(other, str) and 
-            len(set(self)) == 1 and 
-            set(self).pop() == other):
+        if (len(self) == 1 and self.get(0) == other):
             return True
-        return self.__cmp__(other) == 0
-            
+        if isinstance(other, self.__class__):
+            return self.__cmp__(other) == 0
+        o = iter(other)
+        for i in range(len(self)):
+            if self.get(i) != o.next():
+                return False
+        return True
 
 ldb_msg_element.__iter__ = new_instancemethod(_ldb.ldb_msg_element___iter__,None,ldb_msg_element)
 ldb_msg_element.__set__ = new_instancemethod(_ldb.ldb_msg_element___set__,None,ldb_msg_element)
index 5ab40106a8da377439fce5b7befe046ba0c6d212..8469e8f3cd3e87fd60a4c68b02c4a357a93318a5 100755 (executable)
@@ -392,6 +392,12 @@ class MessageElementTests(unittest.TestCase):
         x = ldb.MessageElement(["foo", "bar"])
         self.assertEquals(2, len(x))
 
+    def test_eq(self):
+        x = ldb.MessageElement(["foo", "bar"])
+        self.assertEquals(["foo", "bar"], x)
+        x = ldb.MessageElement(["foo"])
+        self.assertEquals("foo", x)
+
 class ExampleModule:
     name = "example"