Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/dtor/input
[sfrench/cifs-2.6.git] / arch / x86 / kernel / static_call.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/static_call.h>
3 #include <linux/memory.h>
4 #include <linux/bug.h>
5 #include <asm/text-patching.h>
6
7 enum insn_type {
8         CALL = 0, /* site call */
9         NOP = 1,  /* site cond-call */
10         JMP = 2,  /* tramp / site tail-call */
11         RET = 3,  /* tramp / site cond-tail-call */
12 };
13
14 static void __ref __static_call_transform(void *insn, enum insn_type type, void *func)
15 {
16         int size = CALL_INSN_SIZE;
17         const void *code;
18
19         switch (type) {
20         case CALL:
21                 code = text_gen_insn(CALL_INSN_OPCODE, insn, func);
22                 break;
23
24         case NOP:
25                 code = ideal_nops[NOP_ATOMIC5];
26                 break;
27
28         case JMP:
29                 code = text_gen_insn(JMP32_INSN_OPCODE, insn, func);
30                 break;
31
32         case RET:
33                 code = text_gen_insn(RET_INSN_OPCODE, insn, func);
34                 size = RET_INSN_SIZE;
35                 break;
36         }
37
38         if (memcmp(insn, code, size) == 0)
39                 return;
40
41         if (unlikely(system_state == SYSTEM_BOOTING))
42                 return text_poke_early(insn, code, size);
43
44         text_poke_bp(insn, code, size, NULL);
45 }
46
47 static void __static_call_validate(void *insn, bool tail)
48 {
49         u8 opcode = *(u8 *)insn;
50
51         if (tail) {
52                 if (opcode == JMP32_INSN_OPCODE ||
53                     opcode == RET_INSN_OPCODE)
54                         return;
55         } else {
56                 if (opcode == CALL_INSN_OPCODE ||
57                     !memcmp(insn, ideal_nops[NOP_ATOMIC5], 5))
58                         return;
59         }
60
61         /*
62          * If we ever trigger this, our text is corrupt, we'll probably not live long.
63          */
64         WARN_ONCE(1, "unexpected static_call insn opcode 0x%x at %pS\n", opcode, insn);
65 }
66
67 static inline enum insn_type __sc_insn(bool null, bool tail)
68 {
69         /*
70          * Encode the following table without branches:
71          *
72          *      tail    null    insn
73          *      -----+-------+------
74          *        0  |   0   |  CALL
75          *        0  |   1   |  NOP
76          *        1  |   0   |  JMP
77          *        1  |   1   |  RET
78          */
79         return 2*tail + null;
80 }
81
82 void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
83 {
84         mutex_lock(&text_mutex);
85
86         if (tramp) {
87                 __static_call_validate(tramp, true);
88                 __static_call_transform(tramp, __sc_insn(!func, true), func);
89         }
90
91         if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site) {
92                 __static_call_validate(site, tail);
93                 __static_call_transform(site, __sc_insn(!func, tail), func);
94         }
95
96         mutex_unlock(&text_mutex);
97 }
98 EXPORT_SYMBOL_GPL(arch_static_call_transform);