Merge tag 'pwm/for-4.11-rc1' of git://git.kernel.org/pub/scm/linux/kernel/git/thierry...
[sfrench/cifs-2.6.git] / kernel / jump_label.c
index a9b8cf50059151c17f63d35cf4c622ae8b72f131..6c9cb208ac4827ea6141d3d04864fbd4560ec8b9 100644 (file)
@@ -236,12 +236,28 @@ void __weak __init_or_module arch_jump_label_transform_static(struct jump_entry
 
 static inline struct jump_entry *static_key_entries(struct static_key *key)
 {
-       return (struct jump_entry *)((unsigned long)key->entries & ~JUMP_TYPE_MASK);
+       WARN_ON_ONCE(key->type & JUMP_TYPE_LINKED);
+       return (struct jump_entry *)(key->type & ~JUMP_TYPE_MASK);
 }
 
 static inline bool static_key_type(struct static_key *key)
 {
-       return (unsigned long)key->entries & JUMP_TYPE_MASK;
+       return key->type & JUMP_TYPE_TRUE;
+}
+
+static inline bool static_key_linked(struct static_key *key)
+{
+       return key->type & JUMP_TYPE_LINKED;
+}
+
+static inline void static_key_clear_linked(struct static_key *key)
+{
+       key->type &= ~JUMP_TYPE_LINKED;
+}
+
+static inline void static_key_set_linked(struct static_key *key)
+{
+       key->type |= JUMP_TYPE_LINKED;
 }
 
 static inline struct static_key *jump_entry_key(struct jump_entry *entry)
@@ -254,6 +270,26 @@ static bool jump_entry_branch(struct jump_entry *entry)
        return (unsigned long)entry->key & 1UL;
 }
 
+/***
+ * A 'struct static_key' uses a union such that it either points directly
+ * to a table of 'struct jump_entry' or to a linked list of modules which in
+ * turn point to 'struct jump_entry' tables.
+ *
+ * The two lower bits of the pointer are used to keep track of which pointer
+ * type is in use and to store the initial branch direction, we use an access
+ * function which preserves these bits.
+ */
+static void static_key_set_entries(struct static_key *key,
+                                  struct jump_entry *entries)
+{
+       unsigned long type;
+
+       WARN_ON_ONCE((unsigned long)entries & JUMP_TYPE_MASK);
+       type = key->type & JUMP_TYPE_MASK;
+       key->entries = entries;
+       key->type |= type;
+}
+
 static enum jump_label_type jump_label_type(struct jump_entry *entry)
 {
        struct static_key *key = jump_entry_key(entry);
@@ -313,13 +349,7 @@ void __init jump_label_init(void)
                        continue;
 
                key = iterk;
-               /*
-                * Set key->entries to iter, but preserve JUMP_LABEL_TRUE_BRANCH.
-                */
-               *((unsigned long *)&key->entries) += (unsigned long)iter;
-#ifdef CONFIG_MODULES
-               key->next = NULL;
-#endif
+               static_key_set_entries(key, iter);
        }
        static_key_initialized = true;
        jump_label_unlock();
@@ -343,6 +373,29 @@ struct static_key_mod {
        struct module *mod;
 };
 
+static inline struct static_key_mod *static_key_mod(struct static_key *key)
+{
+       WARN_ON_ONCE(!(key->type & JUMP_TYPE_LINKED));
+       return (struct static_key_mod *)(key->type & ~JUMP_TYPE_MASK);
+}
+
+/***
+ * key->type and key->next are the same via union.
+ * This sets key->next and preserves the type bits.
+ *
+ * See additional comments above static_key_set_entries().
+ */
+static void static_key_set_mod(struct static_key *key,
+                              struct static_key_mod *mod)
+{
+       unsigned long type;
+
+       WARN_ON_ONCE((unsigned long)mod & JUMP_TYPE_MASK);
+       type = key->type & JUMP_TYPE_MASK;
+       key->next = mod;
+       key->type |= type;
+}
+
 static int __jump_label_mod_text_reserved(void *start, void *end)
 {
        struct module *mod;
@@ -365,11 +418,23 @@ static void __jump_label_mod_update(struct static_key *key)
 {
        struct static_key_mod *mod;
 
-       for (mod = key->next; mod; mod = mod->next) {
-               struct module *m = mod->mod;
+       for (mod = static_key_mod(key); mod; mod = mod->next) {
+               struct jump_entry *stop;
+               struct module *m;
+
+               /*
+                * NULL if the static_key is defined in a module
+                * that does not use it
+                */
+               if (!mod->entries)
+                       continue;
 
-               __jump_label_update(key, mod->entries,
-                                   m->jump_entries + m->num_jump_entries);
+               m = mod->mod;
+               if (!m)
+                       stop = __stop___jump_table;
+               else
+                       stop = m->jump_entries + m->num_jump_entries;
+               __jump_label_update(key, mod->entries, stop);
        }
 }
 
@@ -404,7 +469,7 @@ static int jump_label_add_module(struct module *mod)
        struct jump_entry *iter_stop = iter_start + mod->num_jump_entries;
        struct jump_entry *iter;
        struct static_key *key = NULL;
-       struct static_key_mod *jlm;
+       struct static_key_mod *jlm, *jlm2;
 
        /* if the module doesn't have jump label entries, just return */
        if (iter_start == iter_stop)
@@ -421,20 +486,32 @@ static int jump_label_add_module(struct module *mod)
 
                key = iterk;
                if (within_module(iter->key, mod)) {
-                       /*
-                        * Set key->entries to iter, but preserve JUMP_LABEL_TRUE_BRANCH.
-                        */
-                       *((unsigned long *)&key->entries) += (unsigned long)iter;
-                       key->next = NULL;
+                       static_key_set_entries(key, iter);
                        continue;
                }
                jlm = kzalloc(sizeof(struct static_key_mod), GFP_KERNEL);
                if (!jlm)
                        return -ENOMEM;
+               if (!static_key_linked(key)) {
+                       jlm2 = kzalloc(sizeof(struct static_key_mod),
+                                      GFP_KERNEL);
+                       if (!jlm2) {
+                               kfree(jlm);
+                               return -ENOMEM;
+                       }
+                       preempt_disable();
+                       jlm2->mod = __module_address((unsigned long)key);
+                       preempt_enable();
+                       jlm2->entries = static_key_entries(key);
+                       jlm2->next = NULL;
+                       static_key_set_mod(key, jlm2);
+                       static_key_set_linked(key);
+               }
                jlm->mod = mod;
                jlm->entries = iter;
-               jlm->next = key->next;
-               key->next = jlm;
+               jlm->next = static_key_mod(key);
+               static_key_set_mod(key, jlm);
+               static_key_set_linked(key);
 
                /* Only update if we've changed from our initial state */
                if (jump_label_type(iter) != jump_label_init_type(iter))
@@ -461,16 +538,34 @@ static void jump_label_del_module(struct module *mod)
                if (within_module(iter->key, mod))
                        continue;
 
+               /* No memory during module load */
+               if (WARN_ON(!static_key_linked(key)))
+                       continue;
+
                prev = &key->next;
-               jlm = key->next;
+               jlm = static_key_mod(key);
 
                while (jlm && jlm->mod != mod) {
                        prev = &jlm->next;
                        jlm = jlm->next;
                }
 
-               if (jlm) {
+               /* No memory during module load */
+               if (WARN_ON(!jlm))
+                       continue;
+
+               if (prev == &key->next)
+                       static_key_set_mod(key, jlm->next);
+               else
                        *prev = jlm->next;
+
+               kfree(jlm);
+
+               jlm = static_key_mod(key);
+               /* if only one etry is left, fold it back into the static_key */
+               if (jlm->next == NULL) {
+                       static_key_set_entries(key, jlm->entries);
+                       static_key_clear_linked(key);
                        kfree(jlm);
                }
        }
@@ -499,8 +594,10 @@ jump_label_module_notify(struct notifier_block *self, unsigned long val,
        case MODULE_STATE_COMING:
                jump_label_lock();
                ret = jump_label_add_module(mod);
-               if (ret)
+               if (ret) {
+                       WARN(1, "Failed to allocatote memory: jump_label may not work properly.\n");
                        jump_label_del_module(mod);
+               }
                jump_label_unlock();
                break;
        case MODULE_STATE_GOING:
@@ -561,11 +658,14 @@ int jump_label_text_reserved(void *start, void *end)
 static void jump_label_update(struct static_key *key)
 {
        struct jump_entry *stop = __stop___jump_table;
-       struct jump_entry *entry = static_key_entries(key);
+       struct jump_entry *entry;
 #ifdef CONFIG_MODULES
        struct module *mod;
 
-       __jump_label_mod_update(key);
+       if (static_key_linked(key)) {
+               __jump_label_mod_update(key);
+               return;
+       }
 
        preempt_disable();
        mod = __module_address((unsigned long)key);
@@ -573,6 +673,7 @@ static void jump_label_update(struct static_key *key)
                stop = mod->jump_entries + mod->num_jump_entries;
        preempt_enable();
 #endif
+       entry = static_key_entries(key);
        /* if there are no users, entry can be NULL */
        if (entry)
                __jump_label_update(key, entry, stop);