Merge branch 'for-4.15/callbacks' into for-linus
[sfrench/cifs-2.6.git] / kernel / livepatch / core.c
index bf8c8fd72589ddeeae34662a8d76352afa890678..de9e45dca70f6887c50f08ad3d747540f96f50f1 100644 (file)
@@ -54,11 +54,6 @@ static bool klp_is_module(struct klp_object *obj)
        return obj->name;
 }
 
-static bool klp_is_object_loaded(struct klp_object *obj)
-{
-       return !obj->name || obj->mod;
-}
-
 /* sets obj->mod if object is not vmlinux and module is found */
 static void klp_find_object_module(struct klp_object *obj)
 {
@@ -285,6 +280,11 @@ static int klp_write_object_relocations(struct module *pmod,
 
 static int __klp_disable_patch(struct klp_patch *patch)
 {
+       struct klp_object *obj;
+
+       if (WARN_ON(!patch->enabled))
+               return -EINVAL;
+
        if (klp_transition_patch)
                return -EBUSY;
 
@@ -295,6 +295,10 @@ static int __klp_disable_patch(struct klp_patch *patch)
 
        klp_init_transition(patch, KLP_UNPATCHED);
 
+       klp_for_each_object(patch, obj)
+               if (obj->patched)
+                       klp_pre_unpatch_callback(obj);
+
        /*
         * Enforce the order of the func->transition writes in
         * klp_init_transition() and the TIF_PATCH_PENDING writes in
@@ -388,13 +392,18 @@ static int __klp_enable_patch(struct klp_patch *patch)
                if (!klp_is_object_loaded(obj))
                        continue;
 
-               ret = klp_patch_object(obj);
+               ret = klp_pre_patch_callback(obj);
                if (ret) {
-                       pr_warn("failed to enable patch '%s'\n",
-                               patch->mod->name);
+                       pr_warn("pre-patch callback failed for object '%s'\n",
+                               klp_is_module(obj) ? obj->name : "vmlinux");
+                       goto err;
+               }
 
-                       klp_cancel_transition();
-                       return ret;
+               ret = klp_patch_object(obj);
+               if (ret) {
+                       pr_warn("failed to patch object '%s'\n",
+                               klp_is_module(obj) ? obj->name : "vmlinux");
+                       goto err;
                }
        }
 
@@ -403,6 +412,11 @@ static int __klp_enable_patch(struct klp_patch *patch)
        patch->enabled = true;
 
        return 0;
+err:
+       pr_warn("failed to enable patch '%s'\n", patch->mod->name);
+
+       klp_cancel_transition();
+       return ret;
 }
 
 /**
@@ -854,9 +868,15 @@ static void klp_cleanup_module_patches_limited(struct module *mod,
                         * is in transition.
                         */
                        if (patch->enabled || patch == klp_transition_patch) {
+
+                               if (patch != klp_transition_patch)
+                                       klp_pre_unpatch_callback(obj);
+
                                pr_notice("reverting patch '%s' on unloading module '%s'\n",
                                          patch->mod->name, obj->mod->name);
                                klp_unpatch_object(obj);
+
+                               klp_post_unpatch_callback(obj);
                        }
 
                        klp_free_object_loaded(obj);
@@ -906,13 +926,25 @@ int klp_module_coming(struct module *mod)
                        pr_notice("applying patch '%s' to loading module '%s'\n",
                                  patch->mod->name, obj->mod->name);
 
+                       ret = klp_pre_patch_callback(obj);
+                       if (ret) {
+                               pr_warn("pre-patch callback failed for object '%s'\n",
+                                       obj->name);
+                               goto err;
+                       }
+
                        ret = klp_patch_object(obj);
                        if (ret) {
                                pr_warn("failed to apply patch '%s' to module '%s' (%d)\n",
                                        patch->mod->name, obj->mod->name, ret);
+
+                               klp_post_unpatch_callback(obj);
                                goto err;
                        }
 
+                       if (patch != klp_transition_patch)
+                               klp_post_patch_callback(obj);
+
                        break;
                }
        }