PEP8: fix E225: missing whitespace around operator
[nivanova/samba-autobuild/.git] / python / samba / gpclass.py
index a4ff22b5e1359e712bd2fd3da5b156bdeeb83d01..c419c363611263af187800ef5758374e1a7cea4e 100644 (file)
@@ -17,6 +17,7 @@
 
 import sys
 import os
+import errno
 import tdb
 sys.path.insert(0, "bin/python")
 from samba import NTSTATUSError
@@ -25,6 +26,13 @@ from StringIO import StringIO
 from abc import ABCMeta, abstractmethod
 import xml.etree.ElementTree as etree
 import re
+from samba.net import Net
+from samba.dcerpc import nbt
+from samba import smb
+import samba.gpo as gpo
+from samba.param import LoadParm
+from uuid import UUID
+from tempfile import NamedTemporaryFile
 
 try:
     from enum import Enum
@@ -131,9 +139,11 @@ class gp_log:
             apply_log = user_obj.find('applylog')
             if apply_log is None:
                 apply_log = etree.SubElement(user_obj, 'applylog')
-            item = etree.SubElement(apply_log, 'guid')
-            item.attrib['count'] = '%d' % (len(apply_log)-1)
-            item.attrib['value'] = guid
+            prev = apply_log.find('guid[@value="%s"]' % guid)
+            if prev is None:
+                item = etree.SubElement(apply_log, 'guid')
+                item.attrib['count'] = '%d' % (len(apply_log)-1)
+                item.attrib['value'] = guid
 
     def apply_log_pop(self):
         ''' Pop a GPO guid from the applylog
@@ -214,8 +224,7 @@ class gp_log:
                 for attr in attrs:
                     func = None
                     if attr.attrib['name'] in data_maps[ext.attrib['name']]:
-                        func = data_maps[ext.attrib['name']]\
-                               [attr.attrib['name']][-1]
+                        func = data_maps[ext.attrib['name']][attr.attrib['name']][-1]
                     else:
                         for dmap in data_maps[ext.attrib['name']].keys():
                             if data_maps[ext.attrib['name']][dmap][0] == \
@@ -301,32 +310,16 @@ class gp_ext(object):
     def read(self, policy):
         pass
 
-    def parse(self, afile, ldb, conn, gp_db, lp):
+    def parse(self, afile, ldb, gp_db, lp):
         self.ldb = ldb
         self.gp_db = gp_db
         self.lp = lp
 
-        # Fixing the bug where only some Linux Boxes capitalize MACHINE
-        try:
-            blist = afile.split('/')
-            idx = afile.lower().split('/').index('machine')
-            for case in [
-                            blist[idx].upper(),
-                            blist[idx].capitalize(),
-                            blist[idx].lower()
-                        ]:
-                bfile = '/'.join(blist[:idx]) + '/' + case + '/' + \
-                    '/'.join(blist[idx+1:])
-                try:
-                    return self.read(conn.loadfile(bfile.replace('/', '\\')))
-                except NTSTATUSError:
-                    continue
-        except ValueError:
-            try:
-                return self.read(conn.loadfile(afile.replace('/', '\\')))
-            except Exception as e:
-                self.logger.error(str(e))
-                return None
+        local_path = self.lp.cache_path('gpo_cache')
+        data_file = os.path.join(local_path, check_safe_path(afile).upper())
+        if os.path.exists(data_file):
+            return self.read(open(data_file, 'r').read())
+        return None
 
     @abstractmethod
     def __str__(self):
@@ -358,94 +351,6 @@ class gp_ext_setter():
     def __str__(self):
         pass
 
-class inf_to_kdc_tdb(gp_ext_setter):
-    def mins_to_hours(self):
-        return '%d' % (int(self.val)/60)
-
-    def days_to_hours(self):
-        return '%d' % (int(self.val)*24)
-
-    def set_kdc_tdb(self, val):
-        old_val = self.gp_db.gpostore.get(self.attribute)
-        self.logger.info('%s was changed from %s to %s' % (self.attribute,
-                                                           old_val, val))
-        if val is not None:
-            self.gp_db.gpostore.store(self.attribute, val)
-            self.gp_db.store(str(self), self.attribute, old_val)
-        else:
-            self.gp_db.gpostore.delete(self.attribute)
-            self.gp_db.delete(str(self), self.attribute)
-
-    def mapper(self):
-        return { 'kdc:user_ticket_lifetime': (self.set_kdc_tdb, self.explicit),
-                 'kdc:service_ticket_lifetime': (self.set_kdc_tdb,
-                                                 self.mins_to_hours),
-                 'kdc:renewal_lifetime': (self.set_kdc_tdb,
-                                          self.days_to_hours),
-               }
-
-    def __str__(self):
-        return 'Kerberos Policy'
-
-class inf_to_ldb(gp_ext_setter):
-    '''This class takes the .inf file parameter (essentially a GPO file mapped
-    to a GUID), hashmaps it to the Samba parameter, which then uses an ldb
-    object to update the parameter to Samba4. Not registry oriented whatsoever.
-    '''
-
-    def ch_minPwdAge(self, val):
-        old_val = self.ldb.get_minPwdAge()
-        self.logger.info('KDC Minimum Password age was changed from %s to %s' \
-                         % (old_val, val))
-        self.gp_db.store(str(self), self.attribute, old_val)
-        self.ldb.set_minPwdAge(val)
-
-    def ch_maxPwdAge(self, val):
-        old_val = self.ldb.get_maxPwdAge()
-        self.logger.info('KDC Maximum Password age was changed from %s to %s' \
-                         % (old_val, val))
-        self.gp_db.store(str(self), self.attribute, old_val)
-        self.ldb.set_maxPwdAge(val)
-
-    def ch_minPwdLength(self, val):
-        old_val = self.ldb.get_minPwdLength()
-        self.logger.info(
-            'KDC Minimum Password length was changed from %s to %s' \
-             % (old_val, val))
-        self.gp_db.store(str(self), self.attribute, old_val)
-        self.ldb.set_minPwdLength(val)
-
-    def ch_pwdProperties(self, val):
-        old_val = self.ldb.get_pwdProperties()
-        self.logger.info('KDC Password Properties were changed from %s to %s' \
-                         % (old_val, val))
-        self.gp_db.store(str(self), self.attribute, old_val)
-        self.ldb.set_pwdProperties(val)
-
-    def days2rel_nttime(self):
-        seconds = 60
-        minutes = 60
-        hours = 24
-        sam_add = 10000000
-        val = (self.val)
-        val = int(val)
-        return  str(-(val * seconds * minutes * hours * sam_add))
-
-    def mapper(self):
-        '''ldap value : samba setter'''
-        return { "minPwdAge" : (self.ch_minPwdAge, self.days2rel_nttime),
-                 "maxPwdAge" : (self.ch_maxPwdAge, self.days2rel_nttime),
-                 # Could be none, but I like the method assignment in
-                 # update_samba
-                 "minPwdLength" : (self.ch_minPwdLength, self.explicit),
-                 "pwdProperties" : (self.ch_pwdProperties, self.explicit),
-
-               }
-
-    def __str__(self):
-        return 'System Access'
-
-
 class gp_inf_ext(gp_ext):
     @abstractmethod
     def list(self, rootpath):
@@ -468,7 +373,7 @@ class gp_inf_ext(gp_ext):
         # then we return that boolean at the end.
 
         inf_conf = ConfigParser()
-        inf_conf.optionxform=str
+        inf_conf.optionxform = str
         try:
             inf_conf.readfp(StringIO(policy))
         except:
@@ -492,49 +397,188 @@ class gp_inf_ext(gp_ext):
     def __str__(self):
         pass
 
-class gp_sec_ext(gp_inf_ext):
-    '''This class does the following two things:
-        1) Identifies the GPO if it has a certain kind of filepath,
-        2) Finally parses it.
-    '''
-
-    count = 0
-
-    def __str__(self):
-        return "Security GPO extension"
-
-    def list(self, rootpath):
-        return os.path.join(rootpath,
-                            "MACHINE/Microsoft/Windows NT/SecEdit/GptTmpl.inf")
-
-    def listmachpol(self, rootpath):
-        return os.path.join(rootpath, "Machine/Registry.pol")
-
-    def listuserpol(self, rootpath):
-        return os.path.join(rootpath, "User/Registry.pol")
-
-    def apply_map(self):
-        return {"System Access": {"MinimumPasswordAge": ("minPwdAge",
-                                                         inf_to_ldb),
-                                  "MaximumPasswordAge": ("maxPwdAge",
-                                                         inf_to_ldb),
-                                  "MinimumPasswordLength": ("minPwdLength",
-                                                            inf_to_ldb),
-                                  "PasswordComplexity": ("pwdProperties",
-                                                         inf_to_ldb),
-                                 },
-                "Kerberos Policy": {"MaxTicketAge": (
-                                        "kdc:user_ticket_lifetime",
-                                        inf_to_kdc_tdb
-                                    ),
-                                    "MaxServiceAge": (
-                                        "kdc:service_ticket_lifetime",
-                                        inf_to_kdc_tdb
-                                    ),
-                                    "MaxRenewAge": (
-                                        "kdc:renewal_lifetime",
-                                        inf_to_kdc_tdb
-                                    ),
-                                   }
-               }
-
+''' Fetch the hostname of a writable DC '''
+def get_dc_hostname(creds, lp):
+    net = Net(creds=creds, lp=lp)
+    cldap_ret = net.finddc(domain=lp.get('realm'), flags=(nbt.NBT_SERVER_LDAP |
+                                                          nbt.NBT_SERVER_DS))
+    return cldap_ret.pdc_dns_name
+
+''' Fetch a list of GUIDs for applicable GPOs '''
+def get_gpo_list(dc_hostname, creds, lp):
+    gpos = []
+    ads = gpo.ADS_STRUCT(dc_hostname, lp, creds)
+    if ads.connect():
+        gpos = ads.get_gpo_list(creds.get_username())
+    return gpos
+
+
+def cache_gpo_dir(conn, cache, sub_dir):
+    loc_sub_dir = sub_dir.upper()
+    local_dir = os.path.join(cache, loc_sub_dir)
+    try:
+        os.makedirs(local_dir, mode=0o755)
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise
+    for fdata in conn.list(sub_dir):
+        if fdata['attrib'] & smb.FILE_ATTRIBUTE_DIRECTORY:
+            cache_gpo_dir(conn, cache, os.path.join(sub_dir, fdata['name']))
+        else:
+            local_name = fdata['name'].upper()
+            f = NamedTemporaryFile(delete=False, dir=local_dir)
+            fname = os.path.join(sub_dir, fdata['name']).replace('/', '\\')
+            f.write(conn.loadfile(fname))
+            f.close()
+            os.rename(f.name, os.path.join(local_dir, local_name))
+
+
+def check_safe_path(path):
+    dirs = re.split('/|\\\\', path)
+    if 'sysvol' in path:
+        dirs = dirs[dirs.index('sysvol')+1:]
+    if not '..' in dirs:
+        return os.path.join(*dirs)
+    raise OSError(path)
+
+def check_refresh_gpo_list(dc_hostname, lp, creds, gpos):
+    conn = smb.SMB(dc_hostname, 'sysvol', lp=lp, creds=creds, sign=True)
+    cache_path = lp.cache_path('gpo_cache')
+    for gpo in gpos:
+        if not gpo.file_sys_path:
+            continue
+        cache_gpo_dir(conn, cache_path, check_safe_path(gpo.file_sys_path))
+
+def gpo_version(lp, path):
+    # gpo.gpo_get_sysvol_gpt_version() reads the GPT.INI from a local file,
+    # read from the gpo client cache.
+    gpt_path = lp.cache_path(os.path.join('gpo_cache', path))
+    return int(gpo.gpo_get_sysvol_gpt_version(gpt_path)[1])
+
+def apply_gp(lp, creds, test_ldb, logger, store, gp_extensions):
+    gp_db = store.get_gplog(creds.get_username())
+    dc_hostname = get_dc_hostname(creds, lp)
+    gpos = get_gpo_list(dc_hostname, creds, lp)
+    try:
+        check_refresh_gpo_list(dc_hostname, lp, creds, gpos)
+    except:
+        logger.error('Failed downloading gpt cache from \'%s\' using SMB' \
+                     % dc_hostname)
+        return
+
+    for gpo_obj in gpos:
+        guid = gpo_obj.name
+        if guid == 'Local Policy':
+            continue
+        path = os.path.join(lp.get('realm'), 'Policies', guid).upper()
+        version = gpo_version(lp, path)
+        if version != store.get_int(guid):
+            logger.info('GPO %s has changed' % guid)
+            gp_db.state(GPOSTATE.APPLY)
+        else:
+            gp_db.state(GPOSTATE.ENFORCE)
+        gp_db.set_guid(guid)
+        store.start()
+        for ext in gp_extensions:
+            try:
+                ext.parse(ext.list(path), test_ldb, gp_db, lp)
+            except Exception as e:
+                logger.error('Failed to parse gpo %s for extension %s' % \
+                             (guid, str(ext)))
+                logger.error('Message was: ' + str(e))
+                store.cancel()
+                continue
+        store.store(guid, '%i' % version)
+        store.commit()
+
+def unapply_log(gp_db):
+    while True:
+        item = gp_db.apply_log_pop()
+        if item:
+            yield item
+        else:
+            break
+
+def unapply_gp(lp, creds, test_ldb, logger, store, gp_extensions):
+    gp_db = store.get_gplog(creds.get_username())
+    gp_db.state(GPOSTATE.UNAPPLY)
+    for gpo_guid in unapply_log(gp_db):
+        gp_db.set_guid(gpo_guid)
+        unapply_attributes = gp_db.list(gp_extensions)
+        for attr in unapply_attributes:
+            attr_obj = attr[-1](logger, test_ldb, gp_db, lp, attr[0], attr[1])
+            attr_obj.mapper()[attr[0]][0](attr[1]) # Set the old value
+            gp_db.delete(str(attr_obj), attr[0])
+        gp_db.commit()
+
+def parse_gpext_conf(smb_conf):
+    lp = LoadParm()
+    if smb_conf is not None:
+        lp.load(smb_conf)
+    else:
+        lp.load_default()
+    ext_conf = lp.state_path('gpext.conf')
+    parser = ConfigParser()
+    parser.read(ext_conf)
+    return lp, parser
+
+def atomic_write_conf(lp, parser):
+    ext_conf = lp.state_path('gpext.conf')
+    with NamedTemporaryFile(delete=False, dir=os.path.dirname(ext_conf)) as f:
+        parser.write(f)
+        os.rename(f.name, ext_conf)
+
+def check_guid(guid):
+    # Check for valid guid with curly braces
+    if guid[0] != '{' or guid[-1] != '}' or len(guid) != 38:
+        return False
+    try:
+        UUID(guid, version=4)
+    except ValueError:
+        return False
+    return True
+
+def register_gp_extension(guid, name, path,
+                          smb_conf=None, machine=True, user=True):
+    # Check that the module exists
+    if not os.path.exists(path):
+        return False
+    if not check_guid(guid):
+        return False
+
+    lp, parser = parse_gpext_conf(smb_conf)
+    if not guid in parser.sections():
+        parser.add_section(guid)
+    parser.set(guid, 'DllName', path)
+    parser.set(guid, 'ProcessGroupPolicy', name)
+    parser.set(guid, 'NoMachinePolicy', 0 if machine else 1)
+    parser.set(guid, 'NoUserPolicy', 0 if user else 1)
+
+    atomic_write_conf(lp, parser)
+
+    return True
+
+def list_gp_extensions(smb_conf=None):
+    _, parser = parse_gpext_conf(smb_conf)
+    results = {}
+    for guid in parser.sections():
+        results[guid] = {}
+        results[guid]['DllName'] = parser.get(guid, 'DllName')
+        results[guid]['ProcessGroupPolicy'] = \
+            parser.get(guid, 'ProcessGroupPolicy')
+        results[guid]['MachinePolicy'] = \
+            not int(parser.get(guid, 'NoMachinePolicy'))
+        results[guid]['UserPolicy'] = not int(parser.get(guid, 'NoUserPolicy'))
+    return results
+
+def unregister_gp_extension(guid, smb_conf=None):
+    if not check_guid(guid):
+        return False
+
+    lp, parser = parse_gpext_conf(smb_conf)
+    if guid in parser.sections():
+        parser.remove_section(guid)
+
+    atomic_write_conf(lp, parser)
+
+    return True