Merge tag 'zonefs-6.9-rc1' of git://git.kernel.org/pub/scm/linux/kernel/git/dlemoal...
[sfrench/cifs-2.6.git] / arch / x86 / kernel / static_call.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/static_call.h>
3 #include <linux/memory.h>
4 #include <linux/bug.h>
5 #include <asm/text-patching.h>
6
7 enum insn_type {
8         CALL = 0, /* site call */
9         NOP = 1,  /* site cond-call */
10         JMP = 2,  /* tramp / site tail-call */
11         RET = 3,  /* tramp / site cond-tail-call */
12         JCC = 4,
13 };
14
15 /*
16  * ud1 %esp, %ecx - a 3 byte #UD that is unique to trampolines, chosen such
17  * that there is no false-positive trampoline identification while also being a
18  * speculation stop.
19  */
20 static const u8 tramp_ud[] = { 0x0f, 0xb9, 0xcc };
21
22 /*
23  * cs cs cs xorl %eax, %eax - a single 5 byte instruction that clears %[er]ax
24  */
25 static const u8 xor5rax[] = { 0x2e, 0x2e, 0x2e, 0x31, 0xc0 };
26
27 static const u8 retinsn[] = { RET_INSN_OPCODE, 0xcc, 0xcc, 0xcc, 0xcc };
28
29 static u8 __is_Jcc(u8 *insn) /* Jcc.d32 */
30 {
31         u8 ret = 0;
32
33         if (insn[0] == 0x0f) {
34                 u8 tmp = insn[1];
35                 if ((tmp & 0xf0) == 0x80)
36                         ret = tmp;
37         }
38
39         return ret;
40 }
41
42 extern void __static_call_return(void);
43
44 asm (".global __static_call_return\n\t"
45      ".type __static_call_return, @function\n\t"
46      ASM_FUNC_ALIGN "\n\t"
47      "__static_call_return:\n\t"
48      ANNOTATE_NOENDBR
49      ANNOTATE_RETPOLINE_SAFE
50      "ret; int3\n\t"
51      ".size __static_call_return, . - __static_call_return \n\t");
52
53 static void __ref __static_call_transform(void *insn, enum insn_type type,
54                                           void *func, bool modinit)
55 {
56         const void *emulate = NULL;
57         int size = CALL_INSN_SIZE;
58         const void *code;
59         u8 op, buf[6];
60
61         if ((type == JMP || type == RET) && (op = __is_Jcc(insn)))
62                 type = JCC;
63
64         switch (type) {
65         case CALL:
66                 func = callthunks_translate_call_dest(func);
67                 code = text_gen_insn(CALL_INSN_OPCODE, insn, func);
68                 if (func == &__static_call_return0) {
69                         emulate = code;
70                         code = &xor5rax;
71                 }
72
73                 break;
74
75         case NOP:
76                 code = x86_nops[5];
77                 break;
78
79         case JMP:
80                 code = text_gen_insn(JMP32_INSN_OPCODE, insn, func);
81                 break;
82
83         case RET:
84                 if (cpu_feature_enabled(X86_FEATURE_RETHUNK))
85                         code = text_gen_insn(JMP32_INSN_OPCODE, insn, x86_return_thunk);
86                 else
87                         code = &retinsn;
88                 break;
89
90         case JCC:
91                 if (!func) {
92                         func = __static_call_return;
93                         if (cpu_feature_enabled(X86_FEATURE_RETHUNK))
94                                 func = x86_return_thunk;
95                 }
96
97                 buf[0] = 0x0f;
98                 __text_gen_insn(buf+1, op, insn+1, func, 5);
99                 code = buf;
100                 size = 6;
101
102                 break;
103         }
104
105         if (memcmp(insn, code, size) == 0)
106                 return;
107
108         if (system_state == SYSTEM_BOOTING || modinit)
109                 return text_poke_early(insn, code, size);
110
111         text_poke_bp(insn, code, size, emulate);
112 }
113
114 static void __static_call_validate(u8 *insn, bool tail, bool tramp)
115 {
116         u8 opcode = insn[0];
117
118         if (tramp && memcmp(insn+5, tramp_ud, 3)) {
119                 pr_err("trampoline signature fail");
120                 BUG();
121         }
122
123         if (tail) {
124                 if (opcode == JMP32_INSN_OPCODE ||
125                     opcode == RET_INSN_OPCODE ||
126                     __is_Jcc(insn))
127                         return;
128         } else {
129                 if (opcode == CALL_INSN_OPCODE ||
130                     !memcmp(insn, x86_nops[5], 5) ||
131                     !memcmp(insn, xor5rax, 5))
132                         return;
133         }
134
135         /*
136          * If we ever trigger this, our text is corrupt, we'll probably not live long.
137          */
138         pr_err("unexpected static_call insn opcode 0x%x at %pS\n", opcode, insn);
139         BUG();
140 }
141
142 static inline enum insn_type __sc_insn(bool null, bool tail)
143 {
144         /*
145          * Encode the following table without branches:
146          *
147          *      tail    null    insn
148          *      -----+-------+------
149          *        0  |   0   |  CALL
150          *        0  |   1   |  NOP
151          *        1  |   0   |  JMP
152          *        1  |   1   |  RET
153          */
154         return 2*tail + null;
155 }
156
157 void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
158 {
159         mutex_lock(&text_mutex);
160
161         if (tramp) {
162                 __static_call_validate(tramp, true, true);
163                 __static_call_transform(tramp, __sc_insn(!func, true), func, false);
164         }
165
166         if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site) {
167                 __static_call_validate(site, tail, false);
168                 __static_call_transform(site, __sc_insn(!func, tail), func, false);
169         }
170
171         mutex_unlock(&text_mutex);
172 }
173 EXPORT_SYMBOL_GPL(arch_static_call_transform);
174
175 #ifdef CONFIG_MITIGATION_RETHUNK
176 /*
177  * This is called by apply_returns() to fix up static call trampolines,
178  * specifically ARCH_DEFINE_STATIC_CALL_NULL_TRAMP which is recorded as
179  * having a return trampoline.
180  *
181  * The problem is that static_call() is available before determining
182  * X86_FEATURE_RETHUNK and, by implication, running alternatives.
183  *
184  * This means that __static_call_transform() above can have overwritten the
185  * return trampoline and we now need to fix things up to be consistent.
186  */
187 bool __static_call_fixup(void *tramp, u8 op, void *dest)
188 {
189         unsigned long addr = (unsigned long)tramp;
190         /*
191          * Not all .return_sites are a static_call trampoline (most are not).
192          * Check if the 3 bytes after the return are still kernel text, if not,
193          * then this definitely is not a trampoline and we need not worry
194          * further.
195          *
196          * This avoids the memcmp() below tripping over pagefaults etc..
197          */
198         if (((addr >> PAGE_SHIFT) != ((addr + 7) >> PAGE_SHIFT)) &&
199             !kernel_text_address(addr + 7))
200                 return false;
201
202         if (memcmp(tramp+5, tramp_ud, 3)) {
203                 /* Not a trampoline site, not our problem. */
204                 return false;
205         }
206
207         mutex_lock(&text_mutex);
208         if (op == RET_INSN_OPCODE || dest == &__x86_return_thunk)
209                 __static_call_transform(tramp, RET, NULL, true);
210         mutex_unlock(&text_mutex);
211
212         return true;
213 }
214 #endif