gpo: avoid quadratic behaviour in guid retrieval
[nivanova/samba-autobuild/.git] / python / samba / gpclass.py
index c419c363611263af187800ef5758374e1a7cea4e..748411f7aba639075fce92fef0ca1e6bdc0a9b4f 100644 (file)
@@ -22,7 +22,7 @@ import tdb
 sys.path.insert(0, "bin/python")
 from samba import NTSTATUSError
 from ConfigParser import ConfigParser
-from StringIO import StringIO
+from samba.compat import StringIO
 from abc import ABCMeta, abstractmethod
 import xml.etree.ElementTree as etree
 import re
@@ -43,6 +43,7 @@ except ImportError:
         ENFORCE = 2
         UNAPPLY = 3
 
+
 class gp_log:
     ''' Log settings overwritten by gpo apply
     The gp_log is an xml file that stores a history of gpo changes (and the
@@ -142,7 +143,7 @@ class gp_log:
             prev = apply_log.find('guid[@value="%s"]' % guid)
             if prev is None:
                 item = etree.SubElement(apply_log, 'guid')
-                item.attrib['count'] = '%d' % (len(apply_log)-1)
+                item.attrib['count'] = '%d' % (len(apply_log) - 1)
                 item.attrib['value'] = guid
 
     def apply_log_pop(self):
@@ -155,7 +156,7 @@ class gp_log:
         user_obj = self.gpdb.find('user[@name="%s"]' % self.user)
         apply_log = user_obj.find('applylog')
         if apply_log is not None:
-            ret = apply_log.find('guid[@count="%d"]' % (len(apply_log)-1))
+            ret = apply_log.find('guid[@count="%d"]' % (len(apply_log) - 1))
             if ret is not None:
                 apply_log.remove(ret)
                 return ret.attrib['value']
@@ -234,6 +235,45 @@ class gp_log:
                     ret.append((attr.attrib['name'], attr.text, func))
         return ret
 
+    def get_applied_guids(self):
+        ''' Return a list of applied ext guids
+        return              - List of guids for gpos that have applied settings
+                              to the system.
+        '''
+        guids = []
+        user_obj = self.gpdb.find('user[@name="%s"]' % self.user)
+        if user_obj is not None:
+            apply_log = user_obj.find('applylog')
+            if apply_log is not None:
+                guid_objs = apply_log.findall('guid[@count]')
+                guids_by_count = [(g.get('count'), g.get('value'))
+                                  for g in guid_objs]
+                guids_by_count.sort(reverse=True)
+                guids.extend(guid for count, guid in guids_by_count)
+        return guids
+
+    def get_applied_settings(self, guids):
+        ''' Return a list of applied ext guids
+        return              - List of tuples containing the guid of a gpo, then
+                              a dictionary of policies and their values prior
+                              policy application. These are sorted so that the
+                              most recently applied settings are removed first.
+        '''
+        ret = []
+        user_obj = self.gpdb.find('user[@name="%s"]' % self.user)
+        for guid in guids:
+            guid_settings = user_obj.find('guid[@value="%s"]' % guid)
+            exts = guid_settings.findall('gp_ext')
+            settings = {}
+            for ext in exts:
+                attr_dict = {}
+                attrs = ext.findall('attribute')
+                for attr in attrs:
+                    attr_dict[attr.attrib['name']] = attr.text
+                settings[ext.attrib['name']] = attr_dict
+            ret.append((guid, settings))
+        return ret
+
     def delete(self, gp_ext_name, attribute):
         ''' Remove an attribute from the gp_log
         param gp_ext_name   - name of extension from which to remove the
@@ -255,12 +295,13 @@ class gp_log:
         ''' Write gp_log changes to disk '''
         self.gpostore.store(self.username, etree.tostring(self.gpdb, 'utf-8'))
 
+
 class GPOStorage:
     def __init__(self, log_file):
         if os.path.isfile(log_file):
             self.log = tdb.open(log_file)
         else:
-            self.log = tdb.Tdb(log_file, 0, tdb.DEFAULT, os.O_CREAT|os.O_RDWR)
+            self.log = tdb.Tdb(log_file, 0, tdb.DEFAULT, os.O_CREAT |os.O_RDWR)
 
     def start(self):
         self.log.transaction_start()
@@ -292,29 +333,25 @@ class GPOStorage:
     def __del__(self):
         self.log.close()
 
+
 class gp_ext(object):
     __metaclass__ = ABCMeta
 
-    def __init__(self, logger):
+    def __init__(self, logger, lp, creds, store):
         self.logger = logger
+        self.lp = lp
+        self.creds = creds
+        self.gp_db = store.get_gplog(creds.get_username())
 
     @abstractmethod
-    def list(self, rootpath):
-        pass
-
-    @abstractmethod
-    def apply_map(self):
+    def process_group_policy(self, deleted_gpo_list, changed_gpo_list):
         pass
 
     @abstractmethod
     def read(self, policy):
         pass
 
-    def parse(self, afile, ldb, gp_db, lp):
-        self.ldb = ldb
-        self.gp_db = gp_db
-        self.lp = lp
-
+    def parse(self, afile):
         local_path = self.lp.cache_path('gpo_cache')
         data_file = os.path.join(local_path, check_safe_path(afile).upper())
         if os.path.exists(data_file):
@@ -325,15 +362,16 @@ class gp_ext(object):
     def __str__(self):
         pass
 
-class gp_ext_setter():
+
+class gp_ext_setter(object):
     __metaclass__ = ABCMeta
 
-    def __init__(self, logger, ldb, gp_db, lp, attribute, val):
+    def __init__(self, logger, gp_db, lp, creds, attribute, val):
         self.logger = logger
-        self.ldb = ldb
         self.attribute = attribute
         self.val = val
         self.lp = lp
+        self.creds = creds
         self.gp_db = gp_db
 
     def explicit(self):
@@ -351,60 +389,31 @@ class gp_ext_setter():
     def __str__(self):
         pass
 
-class gp_inf_ext(gp_ext):
-    @abstractmethod
-    def list(self, rootpath):
-        pass
-
-    @abstractmethod
-    def apply_map(self):
-        pass
 
+class gp_inf_ext(gp_ext):
     def read(self, policy):
-        ret = False
-        inftable = self.apply_map()
-
-        current_section = None
-
-        # So here we would declare a boolean,
-        # that would get changed to TRUE.
-        #
-        # If at any point in time a GPO was applied,
-        # then we return that boolean at the end.
-
         inf_conf = ConfigParser()
         inf_conf.optionxform = str
         try:
             inf_conf.readfp(StringIO(policy))
         except:
             inf_conf.readfp(StringIO(policy.decode('utf-16')))
+        return inf_conf
 
-        for section in inf_conf.sections():
-            current_section = inftable.get(section)
-            if not current_section:
-                continue
-            for key, value in inf_conf.items(section):
-                if current_section.get(key):
-                    (att, setter) = current_section.get(key)
-                    value = value.encode('ascii', 'ignore')
-                    ret = True
-                    setter(self.logger, self.ldb, self.gp_db, self.lp, att,
-                           value).update_samba()
-                    self.gp_db.commit()
-        return ret
-
-    @abstractmethod
-    def __str__(self):
-        pass
 
 ''' Fetch the hostname of a writable DC '''
+
+
 def get_dc_hostname(creds, lp):
     net = Net(creds=creds, lp=lp)
     cldap_ret = net.finddc(domain=lp.get('realm'), flags=(nbt.NBT_SERVER_LDAP |
                                                           nbt.NBT_SERVER_DS))
     return cldap_ret.pdc_dns_name
 
+
 ''' Fetch a list of GUIDs for applicable GPOs '''
+
+
 def get_gpo_list(dc_hostname, creds, lp):
     gpos = []
     ads = gpo.ADS_STRUCT(dc_hostname, lp, creds)
@@ -436,11 +445,12 @@ def cache_gpo_dir(conn, cache, sub_dir):
 def check_safe_path(path):
     dirs = re.split('/|\\\\', path)
     if 'sysvol' in path:
-        dirs = dirs[dirs.index('sysvol')+1:]
-    if not '..' in dirs:
+        dirs = dirs[dirs.index('sysvol') + 1:]
+    if '..' not in dirs:
         return os.path.join(*dirs)
     raise OSError(path)
 
+
 def check_refresh_gpo_list(dc_hostname, lp, creds, gpos):
     conn = smb.SMB(dc_hostname, 'sysvol', lp=lp, creds=creds, sign=True)
     cache_path = lp.cache_path('gpo_cache')
@@ -449,47 +459,53 @@ def check_refresh_gpo_list(dc_hostname, lp, creds, gpos):
             continue
         cache_gpo_dir(conn, cache_path, check_safe_path(gpo.file_sys_path))
 
+
 def gpo_version(lp, path):
     # gpo.gpo_get_sysvol_gpt_version() reads the GPT.INI from a local file,
     # read from the gpo client cache.
     gpt_path = lp.cache_path(os.path.join('gpo_cache', path))
     return int(gpo.gpo_get_sysvol_gpt_version(gpt_path)[1])
 
-def apply_gp(lp, creds, test_ldb, logger, store, gp_extensions):
+
+def apply_gp(lp, creds, logger, store, gp_extensions):
     gp_db = store.get_gplog(creds.get_username())
     dc_hostname = get_dc_hostname(creds, lp)
     gpos = get_gpo_list(dc_hostname, creds, lp)
     try:
         check_refresh_gpo_list(dc_hostname, lp, creds, gpos)
     except:
-        logger.error('Failed downloading gpt cache from \'%s\' using SMB' \
+        logger.error('Failed downloading gpt cache from \'%s\' using SMB'
                      % dc_hostname)
         return
 
+    changed_gpos = []
     for gpo_obj in gpos:
-        guid = gpo_obj.name
-        if guid == 'Local Policy':
+        if not gpo_obj.file_sys_path:
             continue
-        path = os.path.join(lp.get('realm'), 'Policies', guid).upper()
+        guid = gpo_obj.name
+        path = check_safe_path(gpo_obj.file_sys_path).upper()
         version = gpo_version(lp, path)
         if version != store.get_int(guid):
             logger.info('GPO %s has changed' % guid)
-            gp_db.state(GPOSTATE.APPLY)
-        else:
-            gp_db.state(GPOSTATE.ENFORCE)
-        gp_db.set_guid(guid)
-        store.start()
-        for ext in gp_extensions:
-            try:
-                ext.parse(ext.list(path), test_ldb, gp_db, lp)
-            except Exception as e:
-                logger.error('Failed to parse gpo %s for extension %s' % \
-                             (guid, str(ext)))
-                logger.error('Message was: ' + str(e))
-                store.cancel()
-                continue
+            changed_gpos.append(gpo_obj)
+
+    store.start()
+    for ext in gp_extensions:
+        try:
+            ext.process_group_policy([], changed_gpos)
+        except Exception as e:
+            logger.error('Failed to apply extension  %s' % str(ext))
+            logger.error('Message was: ' + str(e))
+            continue
+    for gpo_obj in gpos:
+        if not gpo_obj.file_sys_path:
+            continue
+        guid = gpo_obj.name
+        path = check_safe_path(gpo_obj.file_sys_path).upper()
+        version = gpo_version(lp, path)
         store.store(guid, '%i' % version)
-        store.commit()
+    store.commit()
+
 
 def unapply_log(gp_db):
     while True:
@@ -499,18 +515,20 @@ def unapply_log(gp_db):
         else:
             break
 
-def unapply_gp(lp, creds, test_ldb, logger, store, gp_extensions):
+
+def unapply_gp(lp, creds, logger, store, gp_extensions):
     gp_db = store.get_gplog(creds.get_username())
     gp_db.state(GPOSTATE.UNAPPLY)
     for gpo_guid in unapply_log(gp_db):
         gp_db.set_guid(gpo_guid)
         unapply_attributes = gp_db.list(gp_extensions)
         for attr in unapply_attributes:
-            attr_obj = attr[-1](logger, test_ldb, gp_db, lp, attr[0], attr[1])
-            attr_obj.mapper()[attr[0]][0](attr[1]) # Set the old value
+            attr_obj = attr[-1](logger, gp_db, lp, attr[0], attr[1])
+            attr_obj.mapper()[attr[0]][0](attr[1])  # Set the old value
             gp_db.delete(str(attr_obj), attr[0])
         gp_db.commit()
 
+
 def parse_gpext_conf(smb_conf):
     lp = LoadParm()
     if smb_conf is not None:
@@ -522,12 +540,14 @@ def parse_gpext_conf(smb_conf):
     parser.read(ext_conf)
     return lp, parser
 
+
 def atomic_write_conf(lp, parser):
     ext_conf = lp.state_path('gpext.conf')
     with NamedTemporaryFile(delete=False, dir=os.path.dirname(ext_conf)) as f:
         parser.write(f)
         os.rename(f.name, ext_conf)
 
+
 def check_guid(guid):
     # Check for valid guid with curly braces
     if guid[0] != '{' or guid[-1] != '}' or len(guid) != 38:
@@ -538,6 +558,7 @@ def check_guid(guid):
         return False
     return True
 
+
 def register_gp_extension(guid, name, path,
                           smb_conf=None, machine=True, user=True):
     # Check that the module exists
@@ -547,7 +568,7 @@ def register_gp_extension(guid, name, path,
         return False
 
     lp, parser = parse_gpext_conf(smb_conf)
-    if not guid in parser.sections():
+    if guid not in parser.sections():
         parser.add_section(guid)
     parser.set(guid, 'DllName', path)
     parser.set(guid, 'ProcessGroupPolicy', name)
@@ -558,6 +579,7 @@ def register_gp_extension(guid, name, path,
 
     return True
 
+
 def list_gp_extensions(smb_conf=None):
     _, parser = parse_gpext_conf(smb_conf)
     results = {}
@@ -571,6 +593,7 @@ def list_gp_extensions(smb_conf=None):
         results[guid]['UserPolicy'] = not int(parser.get(guid, 'NoUserPolicy'))
     return results
 
+
 def unregister_gp_extension(guid, smb_conf=None):
     if not check_guid(guid):
         return False