dsdb: Log the transaction duraton.
[nivanova/samba-autobuild/.git] / python / samba / tests / audit_log_base.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Andrew Bartlett <abartlet@samba.org> 2017
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 from __future__ import print_function
19 """Tests for DSDB audit logging.
20 """
21
22 import samba.tests
23 from samba.messaging import Messaging
24 from samba.dcerpc.messaging import MSG_AUTH_LOG, AUTH_EVENT_NAME
25 import time
26 import json
27 import os
28 import re
29
30
31 def getAudit(message):
32     if "type" not in message:
33         return None
34
35     type = message["type"]
36     audit = message[type]
37     return audit
38
39
40 class AuditLogTestBase(samba.tests.TestCase):
41
42     def setUp(self):
43         super(AuditLogTestBase, self).setUp()
44         lp_ctx = self.get_loadparm()
45         self.msg_ctx = Messaging((1,), lp_ctx=lp_ctx)
46         self.msg_ctx.irpc_add_name(self.event_type)
47
48         #
49         # Check the remote address of a message against the one beimg used
50         # for the tests.
51         #
52         def isRemote(message):
53             audit = getAudit(message)
54             if audit is None:
55                 return False
56
57             remote = audit["remoteAddress"]
58             if remote is None:
59                 return False
60
61             try:
62                 addr = remote.split(":")
63                 return addr[1] == self.remoteAddress
64             except IndexError:
65                 return False
66
67         def messageHandler(context, msgType, src, message):
68             # This does not look like sub unit output and it
69             # makes these tests much easier to debug.
70             print(message)
71             jsonMsg = json.loads(message)
72             if ((jsonMsg["type"] == "passwordChange" or
73                 jsonMsg["type"] == "dsdbChange" or
74                 jsonMsg["type"] == "groupChange") and
75                     isRemote(jsonMsg)):
76                 context["messages"].append(jsonMsg)
77             elif jsonMsg["type"] == "dsdbTransaction":
78                 context["txnMessage"] = jsonMsg
79
80         self.context = {"messages": [], "txnMessage": None}
81         self.msg_handler_and_context = (messageHandler, self.context)
82         self.msg_ctx.register(self.msg_handler_and_context,
83                               msg_type=self.message_type)
84
85         self.msg_ctx.irpc_add_name(AUTH_EVENT_NAME)
86
87         def authHandler(context, msgType, src, message):
88             jsonMsg = json.loads(message)
89             if jsonMsg["type"] == "Authorization" and isRemote(jsonMsg):
90                 # This does not look like sub unit output and it
91                 # makes these tests much easier to debug.
92                 print(message)
93                 context["sessionId"] = jsonMsg["Authorization"]["sessionId"]
94                 context["serviceDescription"] =\
95                     jsonMsg["Authorization"]["serviceDescription"]
96
97         self.auth_context = {"sessionId": "", "serviceDescription": ""}
98         self.auth_handler_and_context = (authHandler, self.auth_context)
99         self.msg_ctx.register(self.auth_handler_and_context,
100                               msg_type=MSG_AUTH_LOG)
101
102         self.discardMessages()
103
104         self.server = os.environ["SERVER"]
105         self.connection = None
106
107     def tearDown(self):
108         self.discardMessages()
109         self.msg_ctx.irpc_remove_name(self.event_type)
110         self.msg_ctx.irpc_remove_name(AUTH_EVENT_NAME)
111         if self.msg_handler_and_context:
112             self.msg_ctx.deregister(self.msg_handler_and_context,
113                                     msg_type=self.message_type)
114         if self.auth_handler_and_context:
115             self.msg_ctx.deregister(self.auth_handler_and_context,
116                                     msg_type=MSG_AUTH_LOG)
117
118     def haveExpected(self, expected, dn):
119         if dn is None:
120             return len(self.context["messages"]) >= expected
121         else:
122             received = 0
123             for msg in self.context["messages"]:
124                 audit = getAudit(msg)
125                 if audit["dn"].lower() == dn.lower():
126                     received += 1
127                     if received >= expected:
128                         return True
129             return False
130
131     def waitForMessages(self, number, connection=None, dn=None):
132         """Wait for all the expected messages to arrive
133         The connection is passed through to keep the connection alive
134         until all the logging messages have been received.
135         """
136
137         self.connection = connection
138
139         start_time = time.time()
140         while not self.haveExpected(number, dn):
141             self.msg_ctx.loop_once(0.1)
142             if time.time() - start_time > 1:
143                 self.connection = None
144                 print("Timed out")
145                 return []
146
147         self.connection = None
148         if dn is None:
149             return self.context["messages"]
150
151         messages = []
152         for msg in self.context["messages"]:
153             audit = getAudit(msg)
154             if audit["dn"].lower() == dn.lower():
155                 messages.append(msg)
156         return messages
157
158     # Discard any previously queued messages.
159     def discardMessages(self):
160         self.msg_ctx.loop_once(0.001)
161         while (len(self.context["messages"]) or
162                self.context["txnMessage"] is not None):
163
164             self.context["messages"] = []
165             self.context["txnMessage"] = None
166             self.msg_ctx.loop_once(0.001)
167
168     GUID_RE = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
169
170     #
171     # Is the supplied GUID string correctly formatted
172     #
173     def is_guid(self, guid):
174         return re.match(self.GUID_RE, guid)
175
176     def get_session(self):
177         return self.auth_context["sessionId"]
178
179     def get_service_description(self):
180         return self.auth_context["serviceDescription"]