static_call: Handle tail-calls
authorPeter Zijlstra <peterz@infradead.org>
Tue, 18 Aug 2020 13:57:49 +0000 (15:57 +0200)
committerIngo Molnar <mingo@kernel.org>
Tue, 1 Sep 2020 07:58:06 +0000 (09:58 +0200)
GCC can turn our static_call(name)(args...) into a tail call, in which
case we get a JMP.d32 into the trampoline (which then does a further
tail-call).

Teach objtool to recognise and mark these in .static_call_sites and
adjust the code patching to deal with this.

Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Signed-off-by: Ingo Molnar <mingo@kernel.org>
Cc: Linus Torvalds <torvalds@linux-foundation.org>
Link: https://lore.kernel.org/r/20200818135805.101186767@infradead.org
arch/x86/kernel/static_call.c
include/linux/static_call.h
include/linux/static_call_types.h
kernel/static_call.c
tools/include/linux/static_call_types.h
tools/objtool/check.c

index ead6726fb06d511250456dcee18fb8433dc7a95e..60a325c731df13582c5ac11df45907eb72909afd 100644 (file)
@@ -41,15 +41,30 @@ static void __static_call_transform(void *insn, enum insn_type type, void *func)
        text_poke_bp(insn, code, size, NULL);
 }
 
-void arch_static_call_transform(void *site, void *tramp, void *func)
+static inline enum insn_type __sc_insn(bool null, bool tail)
+{
+       /*
+        * Encode the following table without branches:
+        *
+        *      tail    null    insn
+        *      -----+-------+------
+        *        0  |   0   |  CALL
+        *        0  |   1   |  NOP
+        *        1  |   0   |  JMP
+        *        1  |   1   |  RET
+        */
+       return 2*tail + null;
+}
+
+void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
 {
        mutex_lock(&text_mutex);
 
        if (tramp)
-               __static_call_transform(tramp, func ? JMP : RET, func);
+               __static_call_transform(tramp, __sc_insn(!func, true), func);
 
        if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site)
-               __static_call_transform(site, func ? CALL : NOP, func);
+               __static_call_transform(site, __sc_insn(!func, tail), func);
 
        mutex_unlock(&text_mutex);
 }
index 0f74581e0e2fab318d6716983b782b265a387cc4..519bd666e096bf63efafb08cf050c070f3442ac3 100644 (file)
 /*
  * Either @site or @tramp can be NULL.
  */
-extern void arch_static_call_transform(void *site, void *tramp, void *func);
+extern void arch_static_call_transform(void *site, void *tramp, void *func, bool tail);
 
 #define STATIC_CALL_TRAMP_ADDR(name) &STATIC_CALL_TRAMP(name)
 
@@ -206,7 +206,7 @@ void __static_call_update(struct static_call_key *key, void *tramp, void *func)
 {
        cpus_read_lock();
        WRITE_ONCE(key->func, func);
-       arch_static_call_transform(NULL, tramp, func);
+       arch_static_call_transform(NULL, tramp, func, false);
        cpus_read_unlock();
 }
 
index 408d345d83e1c80c80f385cd38ab4ec039a21774..89135bb35bf7619bd3227612e16b17d7a7a88a43 100644 (file)
 #define STATIC_CALL_TRAMP(name)                __PASTE(STATIC_CALL_TRAMP_PREFIX, name)
 #define STATIC_CALL_TRAMP_STR(name)    __stringify(STATIC_CALL_TRAMP(name))
 
+/*
+ * Flags in the low bits of static_call_site::key.
+ */
+#define STATIC_CALL_SITE_TAIL 1UL      /* tail call */
+#define STATIC_CALL_SITE_INIT 2UL      /* init section */
+#define STATIC_CALL_SITE_FLAGS 3UL
+
 /*
  * The static call site table needs to be created by external tooling (objtool
  * or a compiler plugin).
index 97142cb6bfa660ff6ab607dbafd9017b70708d71..d98e0e4272c147c32bc91749cd2914d37137a28d 100644 (file)
@@ -15,8 +15,6 @@ extern struct static_call_site __start_static_call_sites[],
 
 static bool static_call_initialized;
 
-#define STATIC_CALL_INIT 1UL
-
 /* mutex to protect key modules/sites */
 static DEFINE_MUTEX(static_call_mutex);
 
@@ -39,18 +37,23 @@ static inline void *static_call_addr(struct static_call_site *site)
 static inline struct static_call_key *static_call_key(const struct static_call_site *site)
 {
        return (struct static_call_key *)
-               (((long)site->key + (long)&site->key) & ~STATIC_CALL_INIT);
+               (((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
 }
 
 /* These assume the key is word-aligned. */
 static inline bool static_call_is_init(struct static_call_site *site)
 {
-       return ((long)site->key + (long)&site->key) & STATIC_CALL_INIT;
+       return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
+}
+
+static inline bool static_call_is_tail(struct static_call_site *site)
+{
+       return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
 }
 
 static inline void static_call_set_init(struct static_call_site *site)
 {
-       site->key = ((long)static_call_key(site) | STATIC_CALL_INIT) -
+       site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
                    (long)&site->key;
 }
 
@@ -104,7 +107,7 @@ void __static_call_update(struct static_call_key *key, void *tramp, void *func)
 
        key->func = func;
 
-       arch_static_call_transform(NULL, tramp, func);
+       arch_static_call_transform(NULL, tramp, func, false);
 
        /*
         * If uninitialized, we'll not update the callsites, but they still
@@ -154,7 +157,8 @@ void __static_call_update(struct static_call_key *key, void *tramp, void *func)
                                continue;
                        }
 
-                       arch_static_call_transform(site_addr, NULL, func);
+                       arch_static_call_transform(site_addr, NULL, func,
+                               static_call_is_tail(site));
                }
        }
 
@@ -198,7 +202,8 @@ static int __static_call_init(struct module *mod,
                        key->mods = site_mod;
                }
 
-               arch_static_call_transform(site_addr, NULL, key->func);
+               arch_static_call_transform(site_addr, NULL, key->func,
+                               static_call_is_tail(site));
        }
 
        return 0;
index 408d345d83e1c80c80f385cd38ab4ec039a21774..89135bb35bf7619bd3227612e16b17d7a7a88a43 100644 (file)
 #define STATIC_CALL_TRAMP(name)                __PASTE(STATIC_CALL_TRAMP_PREFIX, name)
 #define STATIC_CALL_TRAMP_STR(name)    __stringify(STATIC_CALL_TRAMP(name))
 
+/*
+ * Flags in the low bits of static_call_site::key.
+ */
+#define STATIC_CALL_SITE_TAIL 1UL      /* tail call */
+#define STATIC_CALL_SITE_INIT 2UL      /* init section */
+#define STATIC_CALL_SITE_FLAGS 3UL
+
 /*
  * The static call site table needs to be created by external tooling (objtool
  * or a compiler plugin).
index f8f7a40c6ef359b6c24c55d17ac01f5f73167000..75d0cd2f904443f45fa4e415b779117d6d138040 100644 (file)
@@ -516,7 +516,7 @@ static int create_static_call_sections(struct objtool_file *file)
                }
                memset(reloc, 0, sizeof(*reloc));
                reloc->sym = key_sym;
-               reloc->addend = 0;
+               reloc->addend = is_sibling_call(insn) ? STATIC_CALL_SITE_TAIL : 0;
                reloc->type = R_X86_64_PC32;
                reloc->offset = idx * sizeof(struct static_call_site) + 4;
                reloc->sec = reloc_sec;
@@ -747,6 +747,10 @@ static int add_jump_destinations(struct objtool_file *file)
                } else {
                        /* external sibling call */
                        insn->call_dest = reloc->sym;
+                       if (insn->call_dest->static_call_tramp) {
+                               list_add_tail(&insn->static_call_node,
+                                             &file->static_call_list);
+                       }
                        continue;
                }
 
@@ -798,6 +802,10 @@ static int add_jump_destinations(struct objtool_file *file)
 
                                /* internal sibling call */
                                insn->call_dest = insn->jump_dest->func;
+                               if (insn->call_dest->static_call_tramp) {
+                                       list_add_tail(&insn->static_call_node,
+                                                     &file->static_call_list);
+                               }
                        }
                }
        }
@@ -1684,6 +1692,10 @@ static int decode_sections(struct objtool_file *file)
        if (ret)
                return ret;
 
+       ret = read_static_call_tramps(file);
+       if (ret)
+               return ret;
+
        ret = add_jump_destinations(file);
        if (ret)
                return ret;
@@ -1716,10 +1728,6 @@ static int decode_sections(struct objtool_file *file)
        if (ret)
                return ret;
 
-       ret = read_static_call_tramps(file);
-       if (ret)
-               return ret;
-
        return 0;
 }