Import Samba Testing Framework code from private CVS module.
[samba.git] / source / stf / stf.py
1 #!/usr/bin/python
2 #
3 # Samba Testing Framework for Unit-testing
4 #
5
6 import os, string, re
7 import osver
8
9 def get_server_list_from_string(s):
10
11     server_list = []
12     
13     # Format is a list of server:domain\username%password separated
14     # by commas.
15
16     for entry in string.split(s, ","):
17
18         # Parse entry 
19
20         m = re.match("(.*):(.*)(\\\\|/)(.*)%(.*)", entry)
21         if not m:
22             raise "badly formed server list entry '%s'" % entry
23
24         server = m.group(1)
25         domain = m.group(2)
26         username = m.group(4)
27         password = m.group(5)
28
29         # Categorise servers
30
31         server_list.append({"platform": osver.os_version(server),
32                             "hostname": server,
33                             "administrator": {"username": username,
34                                               "domain": domain,
35                                               "password" : password}})
36
37     return server_list
38
39 def get_server_list():
40     """Iterate through all sources of server info and append them all
41     in one big list."""
42     
43     server_list = []
44
45     # The $STF_SERVERS environment variable
46
47     if os.environ.has_key("STF_SERVERS"):
48         server_list = server_list + \
49                       get_server_list_from_string(os.environ["STF_SERVERS"])
50
51     return server_list
52
53 def get_server(platform = None):
54     """Return configuration information for a server.  The platform
55     argument can be a string either 'nt4' or 'nt5' for Windows NT or
56     Windows 2000 servers, or just 'nt' for Windows NT and higher."""
57     
58     server_list = get_server_list()
59
60     for server in server_list:
61         if platform:
62             p = server["platform"]
63             if platform == "nt":
64                 if (p == osver.PLATFORM_NT4 or p == osver.PLATFORM_NT5):
65                     return server
66             if platform == "nt4" and p == osver.PLATFORM_NT4:
67                 return server
68             if platform == "nt5" and p == osver.PLATFORM_NT5:
69                 return server
70         else:
71             # No filter defined, return first in list
72             return server
73         
74     return None
75
76 def dict_check(sample_dict, real_dict):
77     """Check that real_dict contains all the keys present in sample_dict
78     and no extras.  Also check that common keys are of them same type."""
79     tmp = real_dict.copy()
80     for key in sample_dict.keys():
81         # Check existing key and type
82         if not real_dict.has_key(key):
83             raise ValueError, "dict does not contain key '%s'" % key
84         if type(sample_dict[key]) != type(real_dict[key]):
85             raise ValueError, "dict has differing types (%s vs %s) for key " \
86                   "'%s'" % (type(sample_dict[key]), type(real_dict[key]), key)
87         # Check dictionaries recursively
88         if type(sample_dict[key]) == dict:
89             dict_check(sample_dict[key], real_dict[key])
90         # Delete visited keys from copy
91         del(tmp[key])
92     # Any keys leftover are present in the real dict but not the sample
93     if len(tmp) == 0:
94         return
95     result = "dict has extra keys: "
96     for key in tmp.keys():
97         result = result + key + " "
98     raise ValueError, result
99
100 if __name__ == "__main__":
101     print get_server(platform = "nt")