dsdb: audit samdb and password changes
[nivanova/samba-autobuild/.git] / python / samba / tests / audit_log_base.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Andrew Bartlett <abartlet@samba.org> 2017
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 from __future__ import print_function
19 """Tests for DSDB audit logging.
20 """
21
22 import samba.tests
23 from samba.messaging import Messaging
24 from samba.dcerpc.messaging import MSG_AUTH_LOG, AUTH_EVENT_NAME
25 import time
26 import json
27 import os
28 import re
29
30 def getAudit(message):
31     if "type" not in message:
32         return None
33
34     type = message["type"]
35     audit = message[type]
36     return audit
37
38 class AuditLogTestBase(samba.tests.TestCase):
39
40
41     def setUp(self):
42         super(AuditLogTestBase, self).setUp()
43         lp_ctx = self.get_loadparm()
44         self.msg_ctx = Messaging((1,), lp_ctx=lp_ctx)
45         self.msg_ctx.irpc_add_name(self.event_type)
46
47         #
48         # Check the remote address of a message against the one beimg used
49         # for the tests.
50         #
51         def isRemote(message):
52             audit = getAudit(message)
53             if audit is None:
54                 return false
55
56             remote = audit["remoteAddress"]
57             if remote is None:
58                 return False
59
60             try:
61                 addr = remote.split(":")
62                 return addr[1] == self.remoteAddress
63             except IndexError:
64                 return False
65
66         def messageHandler(context, msgType, src, message):
67             # This does not look like sub unit output and it
68             # makes these tests much easier to debug.
69             print(message)
70             jsonMsg = json.loads(message)
71             if ((jsonMsg["type"] == "passwordChange" or
72                 jsonMsg["type"] == "dsdbChange" or
73                 jsonMsg["type"] == "groupChange") and
74                     isRemote(jsonMsg)):
75                 context["messages"].append(jsonMsg)
76             elif jsonMsg["type"] == "dsdbTransaction":
77                 context["txnMessage"] = jsonMsg
78
79         self.context = {"messages": [], "txnMessage": ""}
80         self.msg_handler_and_context = (messageHandler, self.context)
81         self.msg_ctx.register(self.msg_handler_and_context,
82                               msg_type=self.message_type)
83
84         self.msg_ctx.irpc_add_name(AUTH_EVENT_NAME)
85
86         def authHandler(context, msgType, src, message):
87             jsonMsg = json.loads(message)
88             if jsonMsg["type"] == "Authorization" and isRemote(jsonMsg):
89                 # This does not look like sub unit output and it
90                 # makes these tests much easier to debug.
91                 print(message)
92                 context["sessionId"] = jsonMsg["Authorization"]["sessionId"]
93                 context["serviceDescription"] =\
94                     jsonMsg["Authorization"]["serviceDescription"]
95
96         self.auth_context = {"sessionId": "", "serviceDescription": ""}
97         self.auth_handler_and_context = (authHandler, self.auth_context)
98         self.msg_ctx.register(self.auth_handler_and_context,
99                               msg_type=MSG_AUTH_LOG)
100
101         self.discardMessages()
102
103         self.server = os.environ["SERVER"]
104         self.connection = None
105
106     def tearDown(self):
107         self.discardMessages()
108         self.msg_ctx.irpc_remove_name(self.event_type)
109         self.msg_ctx.irpc_remove_name(AUTH_EVENT_NAME)
110         if self.msg_handler_and_context:
111             self.msg_ctx.deregister(self.msg_handler_and_context,
112                                     msg_type=self.message_type)
113         if self.auth_handler_and_context:
114             self.msg_ctx.deregister(self.auth_handler_and_context,
115                                     msg_type=MSG_AUTH_LOG)
116
117     def haveExpected(self, expected, dn):
118         if dn is None:
119             return len(self.context["messages"]) >= expected
120         else:
121             received = 0
122             for msg in self.context["messages"]:
123                 audit = getAudit(msg)
124                 if audit["dn"].lower() == dn.lower():
125                     received += 1
126                     if received >= expected:
127                         return True
128             return False
129
130
131     def waitForMessages(self, number, connection=None, dn=None):
132         """Wait for all the expected messages to arrive
133         The connection is passed through to keep the connection alive
134         until all the logging messages have been received.
135         """
136
137         self.connection = connection
138
139         start_time = time.time()
140         while not self.haveExpected(number, dn):
141             self.msg_ctx.loop_once(0.1)
142             if time.time() - start_time > 1:
143                 self.connection = None
144                 print("Timed out")
145                 return []
146
147         self.connection = None
148         if dn is None:
149             return self.context["messages"]
150
151         messages = []
152         for msg in self.context["messages"]:
153             audit = getAudit(msg)
154             if audit["dn"].lower() == dn.lower():
155                 messages.append(msg)
156         return messages
157
158     # Discard any previously queued messages.
159     def discardMessages(self):
160         self.msg_ctx.loop_once(0.001)
161         while len(self.context["messages"]):
162             self.context["messages"] = []
163             self.msg_ctx.loop_once(0.001)
164
165     GUID_RE = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
166
167     #
168     # Is the supplied GUID string correctly formatted
169     #
170     def is_guid(self, guid):
171         return re.match(self.GUID_RE, guid)
172
173     def get_session(self):
174         return self.auth_context["sessionId"]
175
176     def get_service_description(self):
177         return self.auth_context["serviceDescription"]