Merge tag 'net-next-6.9' of git://git.kernel.org/pub/scm/linux/kernel/git/netdev...
[sfrench/cifs-2.6.git] / arch / x86 / net / bpf_jit_comp.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler
4  *
5  * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
6  * Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
7  */
8 #include <linux/netdevice.h>
9 #include <linux/filter.h>
10 #include <linux/if_vlan.h>
11 #include <linux/bpf.h>
12 #include <linux/memory.h>
13 #include <linux/sort.h>
14 #include <asm/extable.h>
15 #include <asm/ftrace.h>
16 #include <asm/set_memory.h>
17 #include <asm/nospec-branch.h>
18 #include <asm/text-patching.h>
19 #include <asm/unwind.h>
20 #include <asm/cfi.h>
21
22 static bool all_callee_regs_used[4] = {true, true, true, true};
23
24 static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
25 {
26         if (len == 1)
27                 *ptr = bytes;
28         else if (len == 2)
29                 *(u16 *)ptr = bytes;
30         else {
31                 *(u32 *)ptr = bytes;
32                 barrier();
33         }
34         return ptr + len;
35 }
36
37 #define EMIT(bytes, len) \
38         do { prog = emit_code(prog, bytes, len); } while (0)
39
40 #define EMIT1(b1)               EMIT(b1, 1)
41 #define EMIT2(b1, b2)           EMIT((b1) + ((b2) << 8), 2)
42 #define EMIT3(b1, b2, b3)       EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
43 #define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
44
45 #define EMIT1_off32(b1, off) \
46         do { EMIT1(b1); EMIT(off, 4); } while (0)
47 #define EMIT2_off32(b1, b2, off) \
48         do { EMIT2(b1, b2); EMIT(off, 4); } while (0)
49 #define EMIT3_off32(b1, b2, b3, off) \
50         do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
51 #define EMIT4_off32(b1, b2, b3, b4, off) \
52         do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
53
54 #ifdef CONFIG_X86_KERNEL_IBT
55 #define EMIT_ENDBR()            EMIT(gen_endbr(), 4)
56 #define EMIT_ENDBR_POISON()     EMIT(gen_endbr_poison(), 4)
57 #else
58 #define EMIT_ENDBR()
59 #define EMIT_ENDBR_POISON()
60 #endif
61
62 static bool is_imm8(int value)
63 {
64         return value <= 127 && value >= -128;
65 }
66
67 static bool is_simm32(s64 value)
68 {
69         return value == (s64)(s32)value;
70 }
71
72 static bool is_uimm32(u64 value)
73 {
74         return value == (u64)(u32)value;
75 }
76
77 /* mov dst, src */
78 #define EMIT_mov(DST, SRC)                                                               \
79         do {                                                                             \
80                 if (DST != SRC)                                                          \
81                         EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
82         } while (0)
83
84 static int bpf_size_to_x86_bytes(int bpf_size)
85 {
86         if (bpf_size == BPF_W)
87                 return 4;
88         else if (bpf_size == BPF_H)
89                 return 2;
90         else if (bpf_size == BPF_B)
91                 return 1;
92         else if (bpf_size == BPF_DW)
93                 return 4; /* imm32 */
94         else
95                 return 0;
96 }
97
98 /*
99  * List of x86 cond jumps opcodes (. + s8)
100  * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
101  */
102 #define X86_JB  0x72
103 #define X86_JAE 0x73
104 #define X86_JE  0x74
105 #define X86_JNE 0x75
106 #define X86_JBE 0x76
107 #define X86_JA  0x77
108 #define X86_JL  0x7C
109 #define X86_JGE 0x7D
110 #define X86_JLE 0x7E
111 #define X86_JG  0x7F
112
113 /* Pick a register outside of BPF range for JIT internal work */
114 #define AUX_REG (MAX_BPF_JIT_REG + 1)
115 #define X86_REG_R9 (MAX_BPF_JIT_REG + 2)
116 #define X86_REG_R12 (MAX_BPF_JIT_REG + 3)
117
118 /*
119  * The following table maps BPF registers to x86-64 registers.
120  *
121  * x86-64 register R12 is unused, since if used as base address
122  * register in load/store instructions, it always needs an
123  * extra byte of encoding and is callee saved.
124  *
125  * x86-64 register R9 is not used by BPF programs, but can be used by BPF
126  * trampoline. x86-64 register R10 is used for blinding (if enabled).
127  */
128 static const int reg2hex[] = {
129         [BPF_REG_0] = 0,  /* RAX */
130         [BPF_REG_1] = 7,  /* RDI */
131         [BPF_REG_2] = 6,  /* RSI */
132         [BPF_REG_3] = 2,  /* RDX */
133         [BPF_REG_4] = 1,  /* RCX */
134         [BPF_REG_5] = 0,  /* R8  */
135         [BPF_REG_6] = 3,  /* RBX callee saved */
136         [BPF_REG_7] = 5,  /* R13 callee saved */
137         [BPF_REG_8] = 6,  /* R14 callee saved */
138         [BPF_REG_9] = 7,  /* R15 callee saved */
139         [BPF_REG_FP] = 5, /* RBP readonly */
140         [BPF_REG_AX] = 2, /* R10 temp register */
141         [AUX_REG] = 3,    /* R11 temp register */
142         [X86_REG_R9] = 1, /* R9 register, 6th function argument */
143         [X86_REG_R12] = 4, /* R12 callee saved */
144 };
145
146 static const int reg2pt_regs[] = {
147         [BPF_REG_0] = offsetof(struct pt_regs, ax),
148         [BPF_REG_1] = offsetof(struct pt_regs, di),
149         [BPF_REG_2] = offsetof(struct pt_regs, si),
150         [BPF_REG_3] = offsetof(struct pt_regs, dx),
151         [BPF_REG_4] = offsetof(struct pt_regs, cx),
152         [BPF_REG_5] = offsetof(struct pt_regs, r8),
153         [BPF_REG_6] = offsetof(struct pt_regs, bx),
154         [BPF_REG_7] = offsetof(struct pt_regs, r13),
155         [BPF_REG_8] = offsetof(struct pt_regs, r14),
156         [BPF_REG_9] = offsetof(struct pt_regs, r15),
157 };
158
159 /*
160  * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15
161  * which need extra byte of encoding.
162  * rax,rcx,...,rbp have simpler encoding
163  */
164 static bool is_ereg(u32 reg)
165 {
166         return (1 << reg) & (BIT(BPF_REG_5) |
167                              BIT(AUX_REG) |
168                              BIT(BPF_REG_7) |
169                              BIT(BPF_REG_8) |
170                              BIT(BPF_REG_9) |
171                              BIT(X86_REG_R9) |
172                              BIT(X86_REG_R12) |
173                              BIT(BPF_REG_AX));
174 }
175
176 /*
177  * is_ereg_8l() == true if BPF register 'reg' is mapped to access x86-64
178  * lower 8-bit registers dil,sil,bpl,spl,r8b..r15b, which need extra byte
179  * of encoding. al,cl,dl,bl have simpler encoding.
180  */
181 static bool is_ereg_8l(u32 reg)
182 {
183         return is_ereg(reg) ||
184             (1 << reg) & (BIT(BPF_REG_1) |
185                           BIT(BPF_REG_2) |
186                           BIT(BPF_REG_FP));
187 }
188
189 static bool is_axreg(u32 reg)
190 {
191         return reg == BPF_REG_0;
192 }
193
194 /* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */
195 static u8 add_1mod(u8 byte, u32 reg)
196 {
197         if (is_ereg(reg))
198                 byte |= 1;
199         return byte;
200 }
201
202 static u8 add_2mod(u8 byte, u32 r1, u32 r2)
203 {
204         if (is_ereg(r1))
205                 byte |= 1;
206         if (is_ereg(r2))
207                 byte |= 4;
208         return byte;
209 }
210
211 static u8 add_3mod(u8 byte, u32 r1, u32 r2, u32 index)
212 {
213         if (is_ereg(r1))
214                 byte |= 1;
215         if (is_ereg(index))
216                 byte |= 2;
217         if (is_ereg(r2))
218                 byte |= 4;
219         return byte;
220 }
221
222 /* Encode 'dst_reg' register into x86-64 opcode 'byte' */
223 static u8 add_1reg(u8 byte, u32 dst_reg)
224 {
225         return byte + reg2hex[dst_reg];
226 }
227
228 /* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */
229 static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
230 {
231         return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
232 }
233
234 /* Some 1-byte opcodes for binary ALU operations */
235 static u8 simple_alu_opcodes[] = {
236         [BPF_ADD] = 0x01,
237         [BPF_SUB] = 0x29,
238         [BPF_AND] = 0x21,
239         [BPF_OR] = 0x09,
240         [BPF_XOR] = 0x31,
241         [BPF_LSH] = 0xE0,
242         [BPF_RSH] = 0xE8,
243         [BPF_ARSH] = 0xF8,
244 };
245
246 static void jit_fill_hole(void *area, unsigned int size)
247 {
248         /* Fill whole space with INT3 instructions */
249         memset(area, 0xcc, size);
250 }
251
252 int bpf_arch_text_invalidate(void *dst, size_t len)
253 {
254         return IS_ERR_OR_NULL(text_poke_set(dst, 0xcc, len));
255 }
256
257 struct jit_context {
258         int cleanup_addr; /* Epilogue code offset */
259
260         /*
261          * Program specific offsets of labels in the code; these rely on the
262          * JIT doing at least 2 passes, recording the position on the first
263          * pass, only to generate the correct offset on the second pass.
264          */
265         int tail_call_direct_label;
266         int tail_call_indirect_label;
267 };
268
269 /* Maximum number of bytes emitted while JITing one eBPF insn */
270 #define BPF_MAX_INSN_SIZE       128
271 #define BPF_INSN_SAFETY         64
272
273 /* Number of bytes emit_patch() needs to generate instructions */
274 #define X86_PATCH_SIZE          5
275 /* Number of bytes that will be skipped on tailcall */
276 #define X86_TAIL_CALL_OFFSET    (11 + ENDBR_INSN_SIZE)
277
278 static void push_r12(u8 **pprog)
279 {
280         u8 *prog = *pprog;
281
282         EMIT2(0x41, 0x54);   /* push r12 */
283         *pprog = prog;
284 }
285
286 static void push_callee_regs(u8 **pprog, bool *callee_regs_used)
287 {
288         u8 *prog = *pprog;
289
290         if (callee_regs_used[0])
291                 EMIT1(0x53);         /* push rbx */
292         if (callee_regs_used[1])
293                 EMIT2(0x41, 0x55);   /* push r13 */
294         if (callee_regs_used[2])
295                 EMIT2(0x41, 0x56);   /* push r14 */
296         if (callee_regs_used[3])
297                 EMIT2(0x41, 0x57);   /* push r15 */
298         *pprog = prog;
299 }
300
301 static void pop_r12(u8 **pprog)
302 {
303         u8 *prog = *pprog;
304
305         EMIT2(0x41, 0x5C);   /* pop r12 */
306         *pprog = prog;
307 }
308
309 static void pop_callee_regs(u8 **pprog, bool *callee_regs_used)
310 {
311         u8 *prog = *pprog;
312
313         if (callee_regs_used[3])
314                 EMIT2(0x41, 0x5F);   /* pop r15 */
315         if (callee_regs_used[2])
316                 EMIT2(0x41, 0x5E);   /* pop r14 */
317         if (callee_regs_used[1])
318                 EMIT2(0x41, 0x5D);   /* pop r13 */
319         if (callee_regs_used[0])
320                 EMIT1(0x5B);         /* pop rbx */
321         *pprog = prog;
322 }
323
324 static void emit_nops(u8 **pprog, int len)
325 {
326         u8 *prog = *pprog;
327         int i, noplen;
328
329         while (len > 0) {
330                 noplen = len;
331
332                 if (noplen > ASM_NOP_MAX)
333                         noplen = ASM_NOP_MAX;
334
335                 for (i = 0; i < noplen; i++)
336                         EMIT1(x86_nops[noplen][i]);
337                 len -= noplen;
338         }
339
340         *pprog = prog;
341 }
342
343 /*
344  * Emit the various CFI preambles, see asm/cfi.h and the comments about FineIBT
345  * in arch/x86/kernel/alternative.c
346  */
347
348 static void emit_fineibt(u8 **pprog, u32 hash)
349 {
350         u8 *prog = *pprog;
351
352         EMIT_ENDBR();
353         EMIT3_off32(0x41, 0x81, 0xea, hash);            /* subl $hash, %r10d    */
354         EMIT2(0x74, 0x07);                              /* jz.d8 +7             */
355         EMIT2(0x0f, 0x0b);                              /* ud2                  */
356         EMIT1(0x90);                                    /* nop                  */
357         EMIT_ENDBR_POISON();
358
359         *pprog = prog;
360 }
361
362 static void emit_kcfi(u8 **pprog, u32 hash)
363 {
364         u8 *prog = *pprog;
365
366         EMIT1_off32(0xb8, hash);                        /* movl $hash, %eax     */
367 #ifdef CONFIG_CALL_PADDING
368         EMIT1(0x90);
369         EMIT1(0x90);
370         EMIT1(0x90);
371         EMIT1(0x90);
372         EMIT1(0x90);
373         EMIT1(0x90);
374         EMIT1(0x90);
375         EMIT1(0x90);
376         EMIT1(0x90);
377         EMIT1(0x90);
378         EMIT1(0x90);
379 #endif
380         EMIT_ENDBR();
381
382         *pprog = prog;
383 }
384
385 static void emit_cfi(u8 **pprog, u32 hash)
386 {
387         u8 *prog = *pprog;
388
389         switch (cfi_mode) {
390         case CFI_FINEIBT:
391                 emit_fineibt(&prog, hash);
392                 break;
393
394         case CFI_KCFI:
395                 emit_kcfi(&prog, hash);
396                 break;
397
398         default:
399                 EMIT_ENDBR();
400                 break;
401         }
402
403         *pprog = prog;
404 }
405
406 /*
407  * Emit x86-64 prologue code for BPF program.
408  * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes
409  * while jumping to another program
410  */
411 static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
412                           bool tail_call_reachable, bool is_subprog,
413                           bool is_exception_cb)
414 {
415         u8 *prog = *pprog;
416
417         emit_cfi(&prog, is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash);
418         /* BPF trampoline can be made to work without these nops,
419          * but let's waste 5 bytes for now and optimize later
420          */
421         emit_nops(&prog, X86_PATCH_SIZE);
422         if (!ebpf_from_cbpf) {
423                 if (tail_call_reachable && !is_subprog)
424                         /* When it's the entry of the whole tailcall context,
425                          * zeroing rax means initialising tail_call_cnt.
426                          */
427                         EMIT2(0x31, 0xC0); /* xor eax, eax */
428                 else
429                         /* Keep the same instruction layout. */
430                         EMIT2(0x66, 0x90); /* nop2 */
431         }
432         /* Exception callback receives FP as third parameter */
433         if (is_exception_cb) {
434                 EMIT3(0x48, 0x89, 0xF4); /* mov rsp, rsi */
435                 EMIT3(0x48, 0x89, 0xD5); /* mov rbp, rdx */
436                 /* The main frame must have exception_boundary as true, so we
437                  * first restore those callee-saved regs from stack, before
438                  * reusing the stack frame.
439                  */
440                 pop_callee_regs(&prog, all_callee_regs_used);
441                 pop_r12(&prog);
442                 /* Reset the stack frame. */
443                 EMIT3(0x48, 0x89, 0xEC); /* mov rsp, rbp */
444         } else {
445                 EMIT1(0x55);             /* push rbp */
446                 EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
447         }
448
449         /* X86_TAIL_CALL_OFFSET is here */
450         EMIT_ENDBR();
451
452         /* sub rsp, rounded_stack_depth */
453         if (stack_depth)
454                 EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
455         if (tail_call_reachable)
456                 EMIT1(0x50);         /* push rax */
457         *pprog = prog;
458 }
459
460 static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode)
461 {
462         u8 *prog = *pprog;
463         s64 offset;
464
465         offset = func - (ip + X86_PATCH_SIZE);
466         if (!is_simm32(offset)) {
467                 pr_err("Target call %p is out of range\n", func);
468                 return -ERANGE;
469         }
470         EMIT1_off32(opcode, offset);
471         *pprog = prog;
472         return 0;
473 }
474
475 static int emit_call(u8 **pprog, void *func, void *ip)
476 {
477         return emit_patch(pprog, func, ip, 0xE8);
478 }
479
480 static int emit_rsb_call(u8 **pprog, void *func, void *ip)
481 {
482         OPTIMIZER_HIDE_VAR(func);
483         x86_call_depth_emit_accounting(pprog, func);
484         return emit_patch(pprog, func, ip, 0xE8);
485 }
486
487 static int emit_jump(u8 **pprog, void *func, void *ip)
488 {
489         return emit_patch(pprog, func, ip, 0xE9);
490 }
491
492 static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
493                                 void *old_addr, void *new_addr)
494 {
495         const u8 *nop_insn = x86_nops[5];
496         u8 old_insn[X86_PATCH_SIZE];
497         u8 new_insn[X86_PATCH_SIZE];
498         u8 *prog;
499         int ret;
500
501         memcpy(old_insn, nop_insn, X86_PATCH_SIZE);
502         if (old_addr) {
503                 prog = old_insn;
504                 ret = t == BPF_MOD_CALL ?
505                       emit_call(&prog, old_addr, ip) :
506                       emit_jump(&prog, old_addr, ip);
507                 if (ret)
508                         return ret;
509         }
510
511         memcpy(new_insn, nop_insn, X86_PATCH_SIZE);
512         if (new_addr) {
513                 prog = new_insn;
514                 ret = t == BPF_MOD_CALL ?
515                       emit_call(&prog, new_addr, ip) :
516                       emit_jump(&prog, new_addr, ip);
517                 if (ret)
518                         return ret;
519         }
520
521         ret = -EBUSY;
522         mutex_lock(&text_mutex);
523         if (memcmp(ip, old_insn, X86_PATCH_SIZE))
524                 goto out;
525         ret = 1;
526         if (memcmp(ip, new_insn, X86_PATCH_SIZE)) {
527                 text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
528                 ret = 0;
529         }
530 out:
531         mutex_unlock(&text_mutex);
532         return ret;
533 }
534
535 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
536                        void *old_addr, void *new_addr)
537 {
538         if (!is_kernel_text((long)ip) &&
539             !is_bpf_text_address((long)ip))
540                 /* BPF poking in modules is not supported */
541                 return -EINVAL;
542
543         /*
544          * See emit_prologue(), for IBT builds the trampoline hook is preceded
545          * with an ENDBR instruction.
546          */
547         if (is_endbr(*(u32 *)ip))
548                 ip += ENDBR_INSN_SIZE;
549
550         return __bpf_arch_text_poke(ip, t, old_addr, new_addr);
551 }
552
553 #define EMIT_LFENCE()   EMIT3(0x0F, 0xAE, 0xE8)
554
555 static void emit_indirect_jump(u8 **pprog, int reg, u8 *ip)
556 {
557         u8 *prog = *pprog;
558
559         if (cpu_feature_enabled(X86_FEATURE_RETPOLINE_LFENCE)) {
560                 EMIT_LFENCE();
561                 EMIT2(0xFF, 0xE0 + reg);
562         } else if (cpu_feature_enabled(X86_FEATURE_RETPOLINE)) {
563                 OPTIMIZER_HIDE_VAR(reg);
564                 if (cpu_feature_enabled(X86_FEATURE_CALL_DEPTH))
565                         emit_jump(&prog, &__x86_indirect_jump_thunk_array[reg], ip);
566                 else
567                         emit_jump(&prog, &__x86_indirect_thunk_array[reg], ip);
568         } else {
569                 EMIT2(0xFF, 0xE0 + reg);        /* jmp *%\reg */
570                 if (IS_ENABLED(CONFIG_MITIGATION_RETPOLINE) || IS_ENABLED(CONFIG_MITIGATION_SLS))
571                         EMIT1(0xCC);            /* int3 */
572         }
573
574         *pprog = prog;
575 }
576
577 static void emit_return(u8 **pprog, u8 *ip)
578 {
579         u8 *prog = *pprog;
580
581         if (cpu_feature_enabled(X86_FEATURE_RETHUNK)) {
582                 emit_jump(&prog, x86_return_thunk, ip);
583         } else {
584                 EMIT1(0xC3);            /* ret */
585                 if (IS_ENABLED(CONFIG_MITIGATION_SLS))
586                         EMIT1(0xCC);    /* int3 */
587         }
588
589         *pprog = prog;
590 }
591
592 /*
593  * Generate the following code:
594  *
595  * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
596  *   if (index >= array->map.max_entries)
597  *     goto out;
598  *   if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
599  *     goto out;
600  *   prog = array->ptrs[index];
601  *   if (prog == NULL)
602  *     goto out;
603  *   goto *(prog->bpf_func + prologue_size);
604  * out:
605  */
606 static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog,
607                                         u8 **pprog, bool *callee_regs_used,
608                                         u32 stack_depth, u8 *ip,
609                                         struct jit_context *ctx)
610 {
611         int tcc_off = -4 - round_up(stack_depth, 8);
612         u8 *prog = *pprog, *start = *pprog;
613         int offset;
614
615         /*
616          * rdi - pointer to ctx
617          * rsi - pointer to bpf_array
618          * rdx - index in bpf_array
619          */
620
621         /*
622          * if (index >= array->map.max_entries)
623          *      goto out;
624          */
625         EMIT2(0x89, 0xD2);                        /* mov edx, edx */
626         EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
627               offsetof(struct bpf_array, map.max_entries));
628
629         offset = ctx->tail_call_indirect_label - (prog + 2 - start);
630         EMIT2(X86_JBE, offset);                   /* jbe out */
631
632         /*
633          * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
634          *      goto out;
635          */
636         EMIT2_off32(0x8B, 0x85, tcc_off);         /* mov eax, dword ptr [rbp - tcc_off] */
637         EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
638
639         offset = ctx->tail_call_indirect_label - (prog + 2 - start);
640         EMIT2(X86_JAE, offset);                   /* jae out */
641         EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
642         EMIT2_off32(0x89, 0x85, tcc_off);         /* mov dword ptr [rbp - tcc_off], eax */
643
644         /* prog = array->ptrs[index]; */
645         EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6,       /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */
646                     offsetof(struct bpf_array, ptrs));
647
648         /*
649          * if (prog == NULL)
650          *      goto out;
651          */
652         EMIT3(0x48, 0x85, 0xC9);                  /* test rcx,rcx */
653
654         offset = ctx->tail_call_indirect_label - (prog + 2 - start);
655         EMIT2(X86_JE, offset);                    /* je out */
656
657         if (bpf_prog->aux->exception_boundary) {
658                 pop_callee_regs(&prog, all_callee_regs_used);
659                 pop_r12(&prog);
660         } else {
661                 pop_callee_regs(&prog, callee_regs_used);
662                 if (bpf_arena_get_kern_vm_start(bpf_prog->aux->arena))
663                         pop_r12(&prog);
664         }
665
666         EMIT1(0x58);                              /* pop rax */
667         if (stack_depth)
668                 EMIT3_off32(0x48, 0x81, 0xC4,     /* add rsp, sd */
669                             round_up(stack_depth, 8));
670
671         /* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */
672         EMIT4(0x48, 0x8B, 0x49,                   /* mov rcx, qword ptr [rcx + 32] */
673               offsetof(struct bpf_prog, bpf_func));
674         EMIT4(0x48, 0x83, 0xC1,                   /* add rcx, X86_TAIL_CALL_OFFSET */
675               X86_TAIL_CALL_OFFSET);
676         /*
677          * Now we're ready to jump into next BPF program
678          * rdi == ctx (1st arg)
679          * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET
680          */
681         emit_indirect_jump(&prog, 1 /* rcx */, ip + (prog - start));
682
683         /* out: */
684         ctx->tail_call_indirect_label = prog - start;
685         *pprog = prog;
686 }
687
688 static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog,
689                                       struct bpf_jit_poke_descriptor *poke,
690                                       u8 **pprog, u8 *ip,
691                                       bool *callee_regs_used, u32 stack_depth,
692                                       struct jit_context *ctx)
693 {
694         int tcc_off = -4 - round_up(stack_depth, 8);
695         u8 *prog = *pprog, *start = *pprog;
696         int offset;
697
698         /*
699          * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT)
700          *      goto out;
701          */
702         EMIT2_off32(0x8B, 0x85, tcc_off);             /* mov eax, dword ptr [rbp - tcc_off] */
703         EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);         /* cmp eax, MAX_TAIL_CALL_CNT */
704
705         offset = ctx->tail_call_direct_label - (prog + 2 - start);
706         EMIT2(X86_JAE, offset);                       /* jae out */
707         EMIT3(0x83, 0xC0, 0x01);                      /* add eax, 1 */
708         EMIT2_off32(0x89, 0x85, tcc_off);             /* mov dword ptr [rbp - tcc_off], eax */
709
710         poke->tailcall_bypass = ip + (prog - start);
711         poke->adj_off = X86_TAIL_CALL_OFFSET;
712         poke->tailcall_target = ip + ctx->tail_call_direct_label - X86_PATCH_SIZE;
713         poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
714
715         emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
716                   poke->tailcall_bypass);
717
718         if (bpf_prog->aux->exception_boundary) {
719                 pop_callee_regs(&prog, all_callee_regs_used);
720                 pop_r12(&prog);
721         } else {
722                 pop_callee_regs(&prog, callee_regs_used);
723                 if (bpf_arena_get_kern_vm_start(bpf_prog->aux->arena))
724                         pop_r12(&prog);
725         }
726
727         EMIT1(0x58);                                  /* pop rax */
728         if (stack_depth)
729                 EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
730
731         emit_nops(&prog, X86_PATCH_SIZE);
732
733         /* out: */
734         ctx->tail_call_direct_label = prog - start;
735
736         *pprog = prog;
737 }
738
739 static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
740 {
741         struct bpf_jit_poke_descriptor *poke;
742         struct bpf_array *array;
743         struct bpf_prog *target;
744         int i, ret;
745
746         for (i = 0; i < prog->aux->size_poke_tab; i++) {
747                 poke = &prog->aux->poke_tab[i];
748                 if (poke->aux && poke->aux != prog->aux)
749                         continue;
750
751                 WARN_ON_ONCE(READ_ONCE(poke->tailcall_target_stable));
752
753                 if (poke->reason != BPF_POKE_REASON_TAIL_CALL)
754                         continue;
755
756                 array = container_of(poke->tail_call.map, struct bpf_array, map);
757                 mutex_lock(&array->aux->poke_mutex);
758                 target = array->ptrs[poke->tail_call.key];
759                 if (target) {
760                         ret = __bpf_arch_text_poke(poke->tailcall_target,
761                                                    BPF_MOD_JUMP, NULL,
762                                                    (u8 *)target->bpf_func +
763                                                    poke->adj_off);
764                         BUG_ON(ret < 0);
765                         ret = __bpf_arch_text_poke(poke->tailcall_bypass,
766                                                    BPF_MOD_JUMP,
767                                                    (u8 *)poke->tailcall_target +
768                                                    X86_PATCH_SIZE, NULL);
769                         BUG_ON(ret < 0);
770                 }
771                 WRITE_ONCE(poke->tailcall_target_stable, true);
772                 mutex_unlock(&array->aux->poke_mutex);
773         }
774 }
775
776 static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
777                            u32 dst_reg, const u32 imm32)
778 {
779         u8 *prog = *pprog;
780         u8 b1, b2, b3;
781
782         /*
783          * Optimization: if imm32 is positive, use 'mov %eax, imm32'
784          * (which zero-extends imm32) to save 2 bytes.
785          */
786         if (sign_propagate && (s32)imm32 < 0) {
787                 /* 'mov %rax, imm32' sign extends imm32 */
788                 b1 = add_1mod(0x48, dst_reg);
789                 b2 = 0xC7;
790                 b3 = 0xC0;
791                 EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
792                 goto done;
793         }
794
795         /*
796          * Optimization: if imm32 is zero, use 'xor %eax, %eax'
797          * to save 3 bytes.
798          */
799         if (imm32 == 0) {
800                 if (is_ereg(dst_reg))
801                         EMIT1(add_2mod(0x40, dst_reg, dst_reg));
802                 b2 = 0x31; /* xor */
803                 b3 = 0xC0;
804                 EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
805                 goto done;
806         }
807
808         /* mov %eax, imm32 */
809         if (is_ereg(dst_reg))
810                 EMIT1(add_1mod(0x40, dst_reg));
811         EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
812 done:
813         *pprog = prog;
814 }
815
816 static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
817                            const u32 imm32_hi, const u32 imm32_lo)
818 {
819         u8 *prog = *pprog;
820
821         if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
822                 /*
823                  * For emitting plain u32, where sign bit must not be
824                  * propagated LLVM tends to load imm64 over mov32
825                  * directly, so save couple of bytes by just doing
826                  * 'mov %eax, imm32' instead.
827                  */
828                 emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
829         } else {
830                 /* movabsq rax, imm64 */
831                 EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
832                 EMIT(imm32_lo, 4);
833                 EMIT(imm32_hi, 4);
834         }
835
836         *pprog = prog;
837 }
838
839 static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
840 {
841         u8 *prog = *pprog;
842
843         if (is64) {
844                 /* mov dst, src */
845                 EMIT_mov(dst_reg, src_reg);
846         } else {
847                 /* mov32 dst, src */
848                 if (is_ereg(dst_reg) || is_ereg(src_reg))
849                         EMIT1(add_2mod(0x40, dst_reg, src_reg));
850                 EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
851         }
852
853         *pprog = prog;
854 }
855
856 static void emit_movsx_reg(u8 **pprog, int num_bits, bool is64, u32 dst_reg,
857                            u32 src_reg)
858 {
859         u8 *prog = *pprog;
860
861         if (is64) {
862                 /* movs[b,w,l]q dst, src */
863                 if (num_bits == 8)
864                         EMIT4(add_2mod(0x48, src_reg, dst_reg), 0x0f, 0xbe,
865                               add_2reg(0xC0, src_reg, dst_reg));
866                 else if (num_bits == 16)
867                         EMIT4(add_2mod(0x48, src_reg, dst_reg), 0x0f, 0xbf,
868                               add_2reg(0xC0, src_reg, dst_reg));
869                 else if (num_bits == 32)
870                         EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x63,
871                               add_2reg(0xC0, src_reg, dst_reg));
872         } else {
873                 /* movs[b,w]l dst, src */
874                 if (num_bits == 8) {
875                         EMIT4(add_2mod(0x40, src_reg, dst_reg), 0x0f, 0xbe,
876                               add_2reg(0xC0, src_reg, dst_reg));
877                 } else if (num_bits == 16) {
878                         if (is_ereg(dst_reg) || is_ereg(src_reg))
879                                 EMIT1(add_2mod(0x40, src_reg, dst_reg));
880                         EMIT3(add_2mod(0x0f, src_reg, dst_reg), 0xbf,
881                               add_2reg(0xC0, src_reg, dst_reg));
882                 }
883         }
884
885         *pprog = prog;
886 }
887
888 /* Emit the suffix (ModR/M etc) for addressing *(ptr_reg + off) and val_reg */
889 static void emit_insn_suffix(u8 **pprog, u32 ptr_reg, u32 val_reg, int off)
890 {
891         u8 *prog = *pprog;
892
893         if (is_imm8(off)) {
894                 /* 1-byte signed displacement.
895                  *
896                  * If off == 0 we could skip this and save one extra byte, but
897                  * special case of x86 R13 which always needs an offset is not
898                  * worth the hassle
899                  */
900                 EMIT2(add_2reg(0x40, ptr_reg, val_reg), off);
901         } else {
902                 /* 4-byte signed displacement */
903                 EMIT1_off32(add_2reg(0x80, ptr_reg, val_reg), off);
904         }
905         *pprog = prog;
906 }
907
908 static void emit_insn_suffix_SIB(u8 **pprog, u32 ptr_reg, u32 val_reg, u32 index_reg, int off)
909 {
910         u8 *prog = *pprog;
911
912         if (is_imm8(off)) {
913                 EMIT3(add_2reg(0x44, BPF_REG_0, val_reg), add_2reg(0, ptr_reg, index_reg) /* SIB */, off);
914         } else {
915                 EMIT2_off32(add_2reg(0x84, BPF_REG_0, val_reg), add_2reg(0, ptr_reg, index_reg) /* SIB */, off);
916         }
917         *pprog = prog;
918 }
919
920 /*
921  * Emit a REX byte if it will be necessary to address these registers
922  */
923 static void maybe_emit_mod(u8 **pprog, u32 dst_reg, u32 src_reg, bool is64)
924 {
925         u8 *prog = *pprog;
926
927         if (is64)
928                 EMIT1(add_2mod(0x48, dst_reg, src_reg));
929         else if (is_ereg(dst_reg) || is_ereg(src_reg))
930                 EMIT1(add_2mod(0x40, dst_reg, src_reg));
931         *pprog = prog;
932 }
933
934 /*
935  * Similar version of maybe_emit_mod() for a single register
936  */
937 static void maybe_emit_1mod(u8 **pprog, u32 reg, bool is64)
938 {
939         u8 *prog = *pprog;
940
941         if (is64)
942                 EMIT1(add_1mod(0x48, reg));
943         else if (is_ereg(reg))
944                 EMIT1(add_1mod(0x40, reg));
945         *pprog = prog;
946 }
947
948 /* LDX: dst_reg = *(u8*)(src_reg + off) */
949 static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
950 {
951         u8 *prog = *pprog;
952
953         switch (size) {
954         case BPF_B:
955                 /* Emit 'movzx rax, byte ptr [rax + off]' */
956                 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
957                 break;
958         case BPF_H:
959                 /* Emit 'movzx rax, word ptr [rax + off]' */
960                 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
961                 break;
962         case BPF_W:
963                 /* Emit 'mov eax, dword ptr [rax+0x14]' */
964                 if (is_ereg(dst_reg) || is_ereg(src_reg))
965                         EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
966                 else
967                         EMIT1(0x8B);
968                 break;
969         case BPF_DW:
970                 /* Emit 'mov rax, qword ptr [rax+0x14]' */
971                 EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
972                 break;
973         }
974         emit_insn_suffix(&prog, src_reg, dst_reg, off);
975         *pprog = prog;
976 }
977
978 /* LDSX: dst_reg = *(s8*)(src_reg + off) */
979 static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
980 {
981         u8 *prog = *pprog;
982
983         switch (size) {
984         case BPF_B:
985                 /* Emit 'movsx rax, byte ptr [rax + off]' */
986                 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE);
987                 break;
988         case BPF_H:
989                 /* Emit 'movsx rax, word ptr [rax + off]' */
990                 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF);
991                 break;
992         case BPF_W:
993                 /* Emit 'movsx rax, dword ptr [rax+0x14]' */
994                 EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63);
995                 break;
996         }
997         emit_insn_suffix(&prog, src_reg, dst_reg, off);
998         *pprog = prog;
999 }
1000
1001 static void emit_ldx_index(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, u32 index_reg, int off)
1002 {
1003         u8 *prog = *pprog;
1004
1005         switch (size) {
1006         case BPF_B:
1007                 /* movzx rax, byte ptr [rax + r12 + off] */
1008                 EMIT3(add_3mod(0x40, src_reg, dst_reg, index_reg), 0x0F, 0xB6);
1009                 break;
1010         case BPF_H:
1011                 /* movzx rax, word ptr [rax + r12 + off] */
1012                 EMIT3(add_3mod(0x40, src_reg, dst_reg, index_reg), 0x0F, 0xB7);
1013                 break;
1014         case BPF_W:
1015                 /* mov eax, dword ptr [rax + r12 + off] */
1016                 EMIT2(add_3mod(0x40, src_reg, dst_reg, index_reg), 0x8B);
1017                 break;
1018         case BPF_DW:
1019                 /* mov rax, qword ptr [rax + r12 + off] */
1020                 EMIT2(add_3mod(0x48, src_reg, dst_reg, index_reg), 0x8B);
1021                 break;
1022         }
1023         emit_insn_suffix_SIB(&prog, src_reg, dst_reg, index_reg, off);
1024         *pprog = prog;
1025 }
1026
1027 static void emit_ldx_r12(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
1028 {
1029         emit_ldx_index(pprog, size, dst_reg, src_reg, X86_REG_R12, off);
1030 }
1031
1032 /* STX: *(u8*)(dst_reg + off) = src_reg */
1033 static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
1034 {
1035         u8 *prog = *pprog;
1036
1037         switch (size) {
1038         case BPF_B:
1039                 /* Emit 'mov byte ptr [rax + off], al' */
1040                 if (is_ereg(dst_reg) || is_ereg_8l(src_reg))
1041                         /* Add extra byte for eregs or SIL,DIL,BPL in src_reg */
1042                         EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
1043                 else
1044                         EMIT1(0x88);
1045                 break;
1046         case BPF_H:
1047                 if (is_ereg(dst_reg) || is_ereg(src_reg))
1048                         EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
1049                 else
1050                         EMIT2(0x66, 0x89);
1051                 break;
1052         case BPF_W:
1053                 if (is_ereg(dst_reg) || is_ereg(src_reg))
1054                         EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
1055                 else
1056                         EMIT1(0x89);
1057                 break;
1058         case BPF_DW:
1059                 EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
1060                 break;
1061         }
1062         emit_insn_suffix(&prog, dst_reg, src_reg, off);
1063         *pprog = prog;
1064 }
1065
1066 /* STX: *(u8*)(dst_reg + index_reg + off) = src_reg */
1067 static void emit_stx_index(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, u32 index_reg, int off)
1068 {
1069         u8 *prog = *pprog;
1070
1071         switch (size) {
1072         case BPF_B:
1073                 /* mov byte ptr [rax + r12 + off], al */
1074                 EMIT2(add_3mod(0x40, dst_reg, src_reg, index_reg), 0x88);
1075                 break;
1076         case BPF_H:
1077                 /* mov word ptr [rax + r12 + off], ax */
1078                 EMIT3(0x66, add_3mod(0x40, dst_reg, src_reg, index_reg), 0x89);
1079                 break;
1080         case BPF_W:
1081                 /* mov dword ptr [rax + r12 + 1], eax */
1082                 EMIT2(add_3mod(0x40, dst_reg, src_reg, index_reg), 0x89);
1083                 break;
1084         case BPF_DW:
1085                 /* mov qword ptr [rax + r12 + 1], rax */
1086                 EMIT2(add_3mod(0x48, dst_reg, src_reg, index_reg), 0x89);
1087                 break;
1088         }
1089         emit_insn_suffix_SIB(&prog, dst_reg, src_reg, index_reg, off);
1090         *pprog = prog;
1091 }
1092
1093 static void emit_stx_r12(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
1094 {
1095         emit_stx_index(pprog, size, dst_reg, src_reg, X86_REG_R12, off);
1096 }
1097
1098 /* ST: *(u8*)(dst_reg + index_reg + off) = imm32 */
1099 static void emit_st_index(u8 **pprog, u32 size, u32 dst_reg, u32 index_reg, int off, int imm)
1100 {
1101         u8 *prog = *pprog;
1102
1103         switch (size) {
1104         case BPF_B:
1105                 /* mov byte ptr [rax + r12 + off], imm8 */
1106                 EMIT2(add_3mod(0x40, dst_reg, 0, index_reg), 0xC6);
1107                 break;
1108         case BPF_H:
1109                 /* mov word ptr [rax + r12 + off], imm16 */
1110                 EMIT3(0x66, add_3mod(0x40, dst_reg, 0, index_reg), 0xC7);
1111                 break;
1112         case BPF_W:
1113                 /* mov dword ptr [rax + r12 + 1], imm32 */
1114                 EMIT2(add_3mod(0x40, dst_reg, 0, index_reg), 0xC7);
1115                 break;
1116         case BPF_DW:
1117                 /* mov qword ptr [rax + r12 + 1], imm32 */
1118                 EMIT2(add_3mod(0x48, dst_reg, 0, index_reg), 0xC7);
1119                 break;
1120         }
1121         emit_insn_suffix_SIB(&prog, dst_reg, 0, index_reg, off);
1122         EMIT(imm, bpf_size_to_x86_bytes(size));
1123         *pprog = prog;
1124 }
1125
1126 static void emit_st_r12(u8 **pprog, u32 size, u32 dst_reg, int off, int imm)
1127 {
1128         emit_st_index(pprog, size, dst_reg, X86_REG_R12, off, imm);
1129 }
1130
1131 static int emit_atomic(u8 **pprog, u8 atomic_op,
1132                        u32 dst_reg, u32 src_reg, s16 off, u8 bpf_size)
1133 {
1134         u8 *prog = *pprog;
1135
1136         EMIT1(0xF0); /* lock prefix */
1137
1138         maybe_emit_mod(&prog, dst_reg, src_reg, bpf_size == BPF_DW);
1139
1140         /* emit opcode */
1141         switch (atomic_op) {
1142         case BPF_ADD:
1143         case BPF_AND:
1144         case BPF_OR:
1145         case BPF_XOR:
1146                 /* lock *(u32/u64*)(dst_reg + off) <op>= src_reg */
1147                 EMIT1(simple_alu_opcodes[atomic_op]);
1148                 break;
1149         case BPF_ADD | BPF_FETCH:
1150                 /* src_reg = atomic_fetch_add(dst_reg + off, src_reg); */
1151                 EMIT2(0x0F, 0xC1);
1152                 break;
1153         case BPF_XCHG:
1154                 /* src_reg = atomic_xchg(dst_reg + off, src_reg); */
1155                 EMIT1(0x87);
1156                 break;
1157         case BPF_CMPXCHG:
1158                 /* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
1159                 EMIT2(0x0F, 0xB1);
1160                 break;
1161         default:
1162                 pr_err("bpf_jit: unknown atomic opcode %02x\n", atomic_op);
1163                 return -EFAULT;
1164         }
1165
1166         emit_insn_suffix(&prog, dst_reg, src_reg, off);
1167
1168         *pprog = prog;
1169         return 0;
1170 }
1171
1172 #define DONT_CLEAR 1
1173
1174 bool ex_handler_bpf(const struct exception_table_entry *x, struct pt_regs *regs)
1175 {
1176         u32 reg = x->fixup >> 8;
1177
1178         /* jump over faulting load and clear dest register */
1179         if (reg != DONT_CLEAR)
1180                 *(unsigned long *)((void *)regs + reg) = 0;
1181         regs->ip += x->fixup & 0xff;
1182         return true;
1183 }
1184
1185 static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
1186                              bool *regs_used, bool *tail_call_seen)
1187 {
1188         int i;
1189
1190         for (i = 1; i <= insn_cnt; i++, insn++) {
1191                 if (insn->code == (BPF_JMP | BPF_TAIL_CALL))
1192                         *tail_call_seen = true;
1193                 if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6)
1194                         regs_used[0] = true;
1195                 if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7)
1196                         regs_used[1] = true;
1197                 if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8)
1198                         regs_used[2] = true;
1199                 if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9)
1200                         regs_used[3] = true;
1201         }
1202 }
1203
1204 /* emit the 3-byte VEX prefix
1205  *
1206  * r: same as rex.r, extra bit for ModRM reg field
1207  * x: same as rex.x, extra bit for SIB index field
1208  * b: same as rex.b, extra bit for ModRM r/m, or SIB base
1209  * m: opcode map select, encoding escape bytes e.g. 0x0f38
1210  * w: same as rex.w (32 bit or 64 bit) or opcode specific
1211  * src_reg2: additional source reg (encoded as BPF reg)
1212  * l: vector length (128 bit or 256 bit) or reserved
1213  * pp: opcode prefix (none, 0x66, 0xf2 or 0xf3)
1214  */
1215 static void emit_3vex(u8 **pprog, bool r, bool x, bool b, u8 m,
1216                       bool w, u8 src_reg2, bool l, u8 pp)
1217 {
1218         u8 *prog = *pprog;
1219         const u8 b0 = 0xc4; /* first byte of 3-byte VEX prefix */
1220         u8 b1, b2;
1221         u8 vvvv = reg2hex[src_reg2];
1222
1223         /* reg2hex gives only the lower 3 bit of vvvv */
1224         if (is_ereg(src_reg2))
1225                 vvvv |= 1 << 3;
1226
1227         /*
1228          * 2nd byte of 3-byte VEX prefix
1229          * ~ means bit inverted encoding
1230          *
1231          *    7                           0
1232          *  +---+---+---+---+---+---+---+---+
1233          *  |~R |~X |~B |         m         |
1234          *  +---+---+---+---+---+---+---+---+
1235          */
1236         b1 = (!r << 7) | (!x << 6) | (!b << 5) | (m & 0x1f);
1237         /*
1238          * 3rd byte of 3-byte VEX prefix
1239          *
1240          *    7                           0
1241          *  +---+---+---+---+---+---+---+---+
1242          *  | W |     ~vvvv     | L |   pp  |
1243          *  +---+---+---+---+---+---+---+---+
1244          */
1245         b2 = (w << 7) | ((~vvvv & 0xf) << 3) | (l << 2) | (pp & 3);
1246
1247         EMIT3(b0, b1, b2);
1248         *pprog = prog;
1249 }
1250
1251 /* emit BMI2 shift instruction */
1252 static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op)
1253 {
1254         u8 *prog = *pprog;
1255         bool r = is_ereg(dst_reg);
1256         u8 m = 2; /* escape code 0f38 */
1257
1258         emit_3vex(&prog, r, false, r, m, is64, src_reg, false, op);
1259         EMIT2(0xf7, add_2reg(0xC0, dst_reg, dst_reg));
1260         *pprog = prog;
1261 }
1262
1263 #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
1264
1265 /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
1266 #define RESTORE_TAIL_CALL_CNT(stack)                            \
1267         EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8)
1268
1269 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image,
1270                   int oldproglen, struct jit_context *ctx, bool jmp_padding)
1271 {
1272         bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
1273         struct bpf_insn *insn = bpf_prog->insnsi;
1274         bool callee_regs_used[4] = {};
1275         int insn_cnt = bpf_prog->len;
1276         bool tail_call_seen = false;
1277         bool seen_exit = false;
1278         u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
1279         u64 arena_vm_start, user_vm_start;
1280         int i, excnt = 0;
1281         int ilen, proglen = 0;
1282         u8 *prog = temp;
1283         int err;
1284
1285         arena_vm_start = bpf_arena_get_kern_vm_start(bpf_prog->aux->arena);
1286         user_vm_start = bpf_arena_get_user_vm_start(bpf_prog->aux->arena);
1287
1288         detect_reg_usage(insn, insn_cnt, callee_regs_used,
1289                          &tail_call_seen);
1290
1291         /* tail call's presence in current prog implies it is reachable */
1292         tail_call_reachable |= tail_call_seen;
1293
1294         emit_prologue(&prog, bpf_prog->aux->stack_depth,
1295                       bpf_prog_was_classic(bpf_prog), tail_call_reachable,
1296                       bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb);
1297         /* Exception callback will clobber callee regs for its own use, and
1298          * restore the original callee regs from main prog's stack frame.
1299          */
1300         if (bpf_prog->aux->exception_boundary) {
1301                 /* We also need to save r12, which is not mapped to any BPF
1302                  * register, as we throw after entry into the kernel, which may
1303                  * overwrite r12.
1304                  */
1305                 push_r12(&prog);
1306                 push_callee_regs(&prog, all_callee_regs_used);
1307         } else {
1308                 if (arena_vm_start)
1309                         push_r12(&prog);
1310                 push_callee_regs(&prog, callee_regs_used);
1311         }
1312         if (arena_vm_start)
1313                 emit_mov_imm64(&prog, X86_REG_R12,
1314                                arena_vm_start >> 32, (u32) arena_vm_start);
1315
1316         ilen = prog - temp;
1317         if (rw_image)
1318                 memcpy(rw_image + proglen, temp, ilen);
1319         proglen += ilen;
1320         addrs[0] = proglen;
1321         prog = temp;
1322
1323         for (i = 1; i <= insn_cnt; i++, insn++) {
1324                 const s32 imm32 = insn->imm;
1325                 u32 dst_reg = insn->dst_reg;
1326                 u32 src_reg = insn->src_reg;
1327                 u8 b2 = 0, b3 = 0;
1328                 u8 *start_of_ldx;
1329                 s64 jmp_offset;
1330                 s16 insn_off;
1331                 u8 jmp_cond;
1332                 u8 *func;
1333                 int nops;
1334
1335                 switch (insn->code) {
1336                         /* ALU */
1337                 case BPF_ALU | BPF_ADD | BPF_X:
1338                 case BPF_ALU | BPF_SUB | BPF_X:
1339                 case BPF_ALU | BPF_AND | BPF_X:
1340                 case BPF_ALU | BPF_OR | BPF_X:
1341                 case BPF_ALU | BPF_XOR | BPF_X:
1342                 case BPF_ALU64 | BPF_ADD | BPF_X:
1343                 case BPF_ALU64 | BPF_SUB | BPF_X:
1344                 case BPF_ALU64 | BPF_AND | BPF_X:
1345                 case BPF_ALU64 | BPF_OR | BPF_X:
1346                 case BPF_ALU64 | BPF_XOR | BPF_X:
1347                         maybe_emit_mod(&prog, dst_reg, src_reg,
1348                                        BPF_CLASS(insn->code) == BPF_ALU64);
1349                         b2 = simple_alu_opcodes[BPF_OP(insn->code)];
1350                         EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
1351                         break;
1352
1353                 case BPF_ALU64 | BPF_MOV | BPF_X:
1354                         if (insn->off == BPF_ADDR_SPACE_CAST &&
1355                             insn->imm == 1U << 16) {
1356                                 if (dst_reg != src_reg)
1357                                         /* 32-bit mov */
1358                                         emit_mov_reg(&prog, false, dst_reg, src_reg);
1359                                 /* shl dst_reg, 32 */
1360                                 maybe_emit_1mod(&prog, dst_reg, true);
1361                                 EMIT3(0xC1, add_1reg(0xE0, dst_reg), 32);
1362
1363                                 /* or dst_reg, user_vm_start */
1364                                 maybe_emit_1mod(&prog, dst_reg, true);
1365                                 if (is_axreg(dst_reg))
1366                                         EMIT1_off32(0x0D,  user_vm_start >> 32);
1367                                 else
1368                                         EMIT2_off32(0x81, add_1reg(0xC8, dst_reg),  user_vm_start >> 32);
1369
1370                                 /* rol dst_reg, 32 */
1371                                 maybe_emit_1mod(&prog, dst_reg, true);
1372                                 EMIT3(0xC1, add_1reg(0xC0, dst_reg), 32);
1373
1374                                 /* xor r11, r11 */
1375                                 EMIT3(0x4D, 0x31, 0xDB);
1376
1377                                 /* test dst_reg32, dst_reg32; check if lower 32-bit are zero */
1378                                 maybe_emit_mod(&prog, dst_reg, dst_reg, false);
1379                                 EMIT2(0x85, add_2reg(0xC0, dst_reg, dst_reg));
1380
1381                                 /* cmove r11, dst_reg; if so, set dst_reg to zero */
1382                                 /* WARNING: Intel swapped src/dst register encoding in CMOVcc !!! */
1383                                 maybe_emit_mod(&prog, AUX_REG, dst_reg, true);
1384                                 EMIT3(0x0F, 0x44, add_2reg(0xC0, AUX_REG, dst_reg));
1385                                 break;
1386                         }
1387                         fallthrough;
1388                 case BPF_ALU | BPF_MOV | BPF_X:
1389                         if (insn->off == 0)
1390                                 emit_mov_reg(&prog,
1391                                              BPF_CLASS(insn->code) == BPF_ALU64,
1392                                              dst_reg, src_reg);
1393                         else
1394                                 emit_movsx_reg(&prog, insn->off,
1395                                                BPF_CLASS(insn->code) == BPF_ALU64,
1396                                                dst_reg, src_reg);
1397                         break;
1398
1399                         /* neg dst */
1400                 case BPF_ALU | BPF_NEG:
1401                 case BPF_ALU64 | BPF_NEG:
1402                         maybe_emit_1mod(&prog, dst_reg,
1403                                         BPF_CLASS(insn->code) == BPF_ALU64);
1404                         EMIT2(0xF7, add_1reg(0xD8, dst_reg));
1405                         break;
1406
1407                 case BPF_ALU | BPF_ADD | BPF_K:
1408                 case BPF_ALU | BPF_SUB | BPF_K:
1409                 case BPF_ALU | BPF_AND | BPF_K:
1410                 case BPF_ALU | BPF_OR | BPF_K:
1411                 case BPF_ALU | BPF_XOR | BPF_K:
1412                 case BPF_ALU64 | BPF_ADD | BPF_K:
1413                 case BPF_ALU64 | BPF_SUB | BPF_K:
1414                 case BPF_ALU64 | BPF_AND | BPF_K:
1415                 case BPF_ALU64 | BPF_OR | BPF_K:
1416                 case BPF_ALU64 | BPF_XOR | BPF_K:
1417                         maybe_emit_1mod(&prog, dst_reg,
1418                                         BPF_CLASS(insn->code) == BPF_ALU64);
1419
1420                         /*
1421                          * b3 holds 'normal' opcode, b2 short form only valid
1422                          * in case dst is eax/rax.
1423                          */
1424                         switch (BPF_OP(insn->code)) {
1425                         case BPF_ADD:
1426                                 b3 = 0xC0;
1427                                 b2 = 0x05;
1428                                 break;
1429                         case BPF_SUB:
1430                                 b3 = 0xE8;
1431                                 b2 = 0x2D;
1432                                 break;
1433                         case BPF_AND:
1434                                 b3 = 0xE0;
1435                                 b2 = 0x25;
1436                                 break;
1437                         case BPF_OR:
1438                                 b3 = 0xC8;
1439                                 b2 = 0x0D;
1440                                 break;
1441                         case BPF_XOR:
1442                                 b3 = 0xF0;
1443                                 b2 = 0x35;
1444                                 break;
1445                         }
1446
1447                         if (is_imm8(imm32))
1448                                 EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
1449                         else if (is_axreg(dst_reg))
1450                                 EMIT1_off32(b2, imm32);
1451                         else
1452                                 EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
1453                         break;
1454
1455                 case BPF_ALU64 | BPF_MOV | BPF_K:
1456                 case BPF_ALU | BPF_MOV | BPF_K:
1457                         emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
1458                                        dst_reg, imm32);
1459                         break;
1460
1461                 case BPF_LD | BPF_IMM | BPF_DW:
1462                         emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
1463                         insn++;
1464                         i++;
1465                         break;
1466
1467                         /* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
1468                 case BPF_ALU | BPF_MOD | BPF_X:
1469                 case BPF_ALU | BPF_DIV | BPF_X:
1470                 case BPF_ALU | BPF_MOD | BPF_K:
1471                 case BPF_ALU | BPF_DIV | BPF_K:
1472                 case BPF_ALU64 | BPF_MOD | BPF_X:
1473                 case BPF_ALU64 | BPF_DIV | BPF_X:
1474                 case BPF_ALU64 | BPF_MOD | BPF_K:
1475                 case BPF_ALU64 | BPF_DIV | BPF_K: {
1476                         bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
1477
1478                         if (dst_reg != BPF_REG_0)
1479                                 EMIT1(0x50); /* push rax */
1480                         if (dst_reg != BPF_REG_3)
1481                                 EMIT1(0x52); /* push rdx */
1482
1483                         if (BPF_SRC(insn->code) == BPF_X) {
1484                                 if (src_reg == BPF_REG_0 ||
1485                                     src_reg == BPF_REG_3) {
1486                                         /* mov r11, src_reg */
1487                                         EMIT_mov(AUX_REG, src_reg);
1488                                         src_reg = AUX_REG;
1489                                 }
1490                         } else {
1491                                 /* mov r11, imm32 */
1492                                 EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
1493                                 src_reg = AUX_REG;
1494                         }
1495
1496                         if (dst_reg != BPF_REG_0)
1497                                 /* mov rax, dst_reg */
1498                                 emit_mov_reg(&prog, is64, BPF_REG_0, dst_reg);
1499
1500                         if (insn->off == 0) {
1501                                 /*
1502                                  * xor edx, edx
1503                                  * equivalent to 'xor rdx, rdx', but one byte less
1504                                  */
1505                                 EMIT2(0x31, 0xd2);
1506
1507                                 /* div src_reg */
1508                                 maybe_emit_1mod(&prog, src_reg, is64);
1509                                 EMIT2(0xF7, add_1reg(0xF0, src_reg));
1510                         } else {
1511                                 if (BPF_CLASS(insn->code) == BPF_ALU)
1512                                         EMIT1(0x99); /* cdq */
1513                                 else
1514                                         EMIT2(0x48, 0x99); /* cqo */
1515
1516                                 /* idiv src_reg */
1517                                 maybe_emit_1mod(&prog, src_reg, is64);
1518                                 EMIT2(0xF7, add_1reg(0xF8, src_reg));
1519                         }
1520
1521                         if (BPF_OP(insn->code) == BPF_MOD &&
1522                             dst_reg != BPF_REG_3)
1523                                 /* mov dst_reg, rdx */
1524                                 emit_mov_reg(&prog, is64, dst_reg, BPF_REG_3);
1525                         else if (BPF_OP(insn->code) == BPF_DIV &&
1526                                  dst_reg != BPF_REG_0)
1527                                 /* mov dst_reg, rax */
1528                                 emit_mov_reg(&prog, is64, dst_reg, BPF_REG_0);
1529
1530                         if (dst_reg != BPF_REG_3)
1531                                 EMIT1(0x5A); /* pop rdx */
1532                         if (dst_reg != BPF_REG_0)
1533                                 EMIT1(0x58); /* pop rax */
1534                         break;
1535                 }
1536
1537                 case BPF_ALU | BPF_MUL | BPF_K:
1538                 case BPF_ALU64 | BPF_MUL | BPF_K:
1539                         maybe_emit_mod(&prog, dst_reg, dst_reg,
1540                                        BPF_CLASS(insn->code) == BPF_ALU64);
1541
1542                         if (is_imm8(imm32))
1543                                 /* imul dst_reg, dst_reg, imm8 */
1544                                 EMIT3(0x6B, add_2reg(0xC0, dst_reg, dst_reg),
1545                                       imm32);
1546                         else
1547                                 /* imul dst_reg, dst_reg, imm32 */
1548                                 EMIT2_off32(0x69,
1549                                             add_2reg(0xC0, dst_reg, dst_reg),
1550                                             imm32);
1551                         break;
1552
1553                 case BPF_ALU | BPF_MUL | BPF_X:
1554                 case BPF_ALU64 | BPF_MUL | BPF_X:
1555                         maybe_emit_mod(&prog, src_reg, dst_reg,
1556                                        BPF_CLASS(insn->code) == BPF_ALU64);
1557
1558                         /* imul dst_reg, src_reg */
1559                         EMIT3(0x0F, 0xAF, add_2reg(0xC0, src_reg, dst_reg));
1560                         break;
1561
1562                         /* Shifts */
1563                 case BPF_ALU | BPF_LSH | BPF_K:
1564                 case BPF_ALU | BPF_RSH | BPF_K:
1565                 case BPF_ALU | BPF_ARSH | BPF_K:
1566                 case BPF_ALU64 | BPF_LSH | BPF_K:
1567                 case BPF_ALU64 | BPF_RSH | BPF_K:
1568                 case BPF_ALU64 | BPF_ARSH | BPF_K:
1569                         maybe_emit_1mod(&prog, dst_reg,
1570                                         BPF_CLASS(insn->code) == BPF_ALU64);
1571
1572                         b3 = simple_alu_opcodes[BPF_OP(insn->code)];
1573                         if (imm32 == 1)
1574                                 EMIT2(0xD1, add_1reg(b3, dst_reg));
1575                         else
1576                                 EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
1577                         break;
1578
1579                 case BPF_ALU | BPF_LSH | BPF_X:
1580                 case BPF_ALU | BPF_RSH | BPF_X:
1581                 case BPF_ALU | BPF_ARSH | BPF_X:
1582                 case BPF_ALU64 | BPF_LSH | BPF_X:
1583                 case BPF_ALU64 | BPF_RSH | BPF_X:
1584                 case BPF_ALU64 | BPF_ARSH | BPF_X:
1585                         /* BMI2 shifts aren't better when shift count is already in rcx */
1586                         if (boot_cpu_has(X86_FEATURE_BMI2) && src_reg != BPF_REG_4) {
1587                                 /* shrx/sarx/shlx dst_reg, dst_reg, src_reg */
1588                                 bool w = (BPF_CLASS(insn->code) == BPF_ALU64);
1589                                 u8 op;
1590
1591                                 switch (BPF_OP(insn->code)) {
1592                                 case BPF_LSH:
1593                                         op = 1; /* prefix 0x66 */
1594                                         break;
1595                                 case BPF_RSH:
1596                                         op = 3; /* prefix 0xf2 */
1597                                         break;
1598                                 case BPF_ARSH:
1599                                         op = 2; /* prefix 0xf3 */
1600                                         break;
1601                                 }
1602
1603                                 emit_shiftx(&prog, dst_reg, src_reg, w, op);
1604
1605                                 break;
1606                         }
1607
1608                         if (src_reg != BPF_REG_4) { /* common case */
1609                                 /* Check for bad case when dst_reg == rcx */
1610                                 if (dst_reg == BPF_REG_4) {
1611                                         /* mov r11, dst_reg */
1612                                         EMIT_mov(AUX_REG, dst_reg);
1613                                         dst_reg = AUX_REG;
1614                                 } else {
1615                                         EMIT1(0x51); /* push rcx */
1616                                 }
1617                                 /* mov rcx, src_reg */
1618                                 EMIT_mov(BPF_REG_4, src_reg);
1619                         }
1620
1621                         /* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
1622                         maybe_emit_1mod(&prog, dst_reg,
1623                                         BPF_CLASS(insn->code) == BPF_ALU64);
1624
1625                         b3 = simple_alu_opcodes[BPF_OP(insn->code)];
1626                         EMIT2(0xD3, add_1reg(b3, dst_reg));
1627
1628                         if (src_reg != BPF_REG_4) {
1629                                 if (insn->dst_reg == BPF_REG_4)
1630                                         /* mov dst_reg, r11 */
1631                                         EMIT_mov(insn->dst_reg, AUX_REG);
1632                                 else
1633                                         EMIT1(0x59); /* pop rcx */
1634                         }
1635
1636                         break;
1637
1638                 case BPF_ALU | BPF_END | BPF_FROM_BE:
1639                 case BPF_ALU64 | BPF_END | BPF_FROM_LE:
1640                         switch (imm32) {
1641                         case 16:
1642                                 /* Emit 'ror %ax, 8' to swap lower 2 bytes */
1643                                 EMIT1(0x66);
1644                                 if (is_ereg(dst_reg))
1645                                         EMIT1(0x41);
1646                                 EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
1647
1648                                 /* Emit 'movzwl eax, ax' */
1649                                 if (is_ereg(dst_reg))
1650                                         EMIT3(0x45, 0x0F, 0xB7);
1651                                 else
1652                                         EMIT2(0x0F, 0xB7);
1653                                 EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
1654                                 break;
1655                         case 32:
1656                                 /* Emit 'bswap eax' to swap lower 4 bytes */
1657                                 if (is_ereg(dst_reg))
1658                                         EMIT2(0x41, 0x0F);
1659                                 else
1660                                         EMIT1(0x0F);
1661                                 EMIT1(add_1reg(0xC8, dst_reg));
1662                                 break;
1663                         case 64:
1664                                 /* Emit 'bswap rax' to swap 8 bytes */
1665                                 EMIT3(add_1mod(0x48, dst_reg), 0x0F,
1666                                       add_1reg(0xC8, dst_reg));
1667                                 break;
1668                         }
1669                         break;
1670
1671                 case BPF_ALU | BPF_END | BPF_FROM_LE:
1672                         switch (imm32) {
1673                         case 16:
1674                                 /*
1675                                  * Emit 'movzwl eax, ax' to zero extend 16-bit
1676                                  * into 64 bit
1677                                  */
1678                                 if (is_ereg(dst_reg))
1679                                         EMIT3(0x45, 0x0F, 0xB7);
1680                                 else
1681                                         EMIT2(0x0F, 0xB7);
1682                                 EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
1683                                 break;
1684                         case 32:
1685                                 /* Emit 'mov eax, eax' to clear upper 32-bits */
1686                                 if (is_ereg(dst_reg))
1687                                         EMIT1(0x45);
1688                                 EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
1689                                 break;
1690                         case 64:
1691                                 /* nop */
1692                                 break;
1693                         }
1694                         break;
1695
1696                         /* speculation barrier */
1697                 case BPF_ST | BPF_NOSPEC:
1698                         EMIT_LFENCE();
1699                         break;
1700
1701                         /* ST: *(u8*)(dst_reg + off) = imm */
1702                 case BPF_ST | BPF_MEM | BPF_B:
1703                         if (is_ereg(dst_reg))
1704                                 EMIT2(0x41, 0xC6);
1705                         else
1706                                 EMIT1(0xC6);
1707                         goto st;
1708                 case BPF_ST | BPF_MEM | BPF_H:
1709                         if (is_ereg(dst_reg))
1710                                 EMIT3(0x66, 0x41, 0xC7);
1711                         else
1712                                 EMIT2(0x66, 0xC7);
1713                         goto st;
1714                 case BPF_ST | BPF_MEM | BPF_W:
1715                         if (is_ereg(dst_reg))
1716                                 EMIT2(0x41, 0xC7);
1717                         else
1718                                 EMIT1(0xC7);
1719                         goto st;
1720                 case BPF_ST | BPF_MEM | BPF_DW:
1721                         EMIT2(add_1mod(0x48, dst_reg), 0xC7);
1722
1723 st:                     if (is_imm8(insn->off))
1724                                 EMIT2(add_1reg(0x40, dst_reg), insn->off);
1725                         else
1726                                 EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
1727
1728                         EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
1729                         break;
1730
1731                         /* STX: *(u8*)(dst_reg + off) = src_reg */
1732                 case BPF_STX | BPF_MEM | BPF_B:
1733                 case BPF_STX | BPF_MEM | BPF_H:
1734                 case BPF_STX | BPF_MEM | BPF_W:
1735                 case BPF_STX | BPF_MEM | BPF_DW:
1736                         emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
1737                         break;
1738
1739                 case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
1740                 case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
1741                 case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
1742                 case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
1743                         start_of_ldx = prog;
1744                         emit_st_r12(&prog, BPF_SIZE(insn->code), dst_reg, insn->off, insn->imm);
1745                         goto populate_extable;
1746
1747                         /* LDX: dst_reg = *(u8*)(src_reg + r12 + off) */
1748                 case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
1749                 case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
1750                 case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
1751                 case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
1752                 case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
1753                 case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
1754                 case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
1755                 case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
1756                         start_of_ldx = prog;
1757                         if (BPF_CLASS(insn->code) == BPF_LDX)
1758                                 emit_ldx_r12(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
1759                         else
1760                                 emit_stx_r12(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
1761 populate_extable:
1762                         {
1763                                 struct exception_table_entry *ex;
1764                                 u8 *_insn = image + proglen + (start_of_ldx - temp);
1765                                 s64 delta;
1766
1767                                 if (!bpf_prog->aux->extable)
1768                                         break;
1769
1770                                 if (excnt >= bpf_prog->aux->num_exentries) {
1771                                         pr_err("mem32 extable bug\n");
1772                                         return -EFAULT;
1773                                 }
1774                                 ex = &bpf_prog->aux->extable[excnt++];
1775
1776                                 delta = _insn - (u8 *)&ex->insn;
1777                                 /* switch ex to rw buffer for writes */
1778                                 ex = (void *)rw_image + ((void *)ex - (void *)image);
1779
1780                                 ex->insn = delta;
1781
1782                                 ex->data = EX_TYPE_BPF;
1783
1784                                 ex->fixup = (prog - start_of_ldx) |
1785                                         ((BPF_CLASS(insn->code) == BPF_LDX ? reg2pt_regs[dst_reg] : DONT_CLEAR) << 8);
1786                         }
1787                         break;
1788
1789                         /* LDX: dst_reg = *(u8*)(src_reg + off) */
1790                 case BPF_LDX | BPF_MEM | BPF_B:
1791                 case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1792                 case BPF_LDX | BPF_MEM | BPF_H:
1793                 case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1794                 case BPF_LDX | BPF_MEM | BPF_W:
1795                 case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1796                 case BPF_LDX | BPF_MEM | BPF_DW:
1797                 case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1798                         /* LDXS: dst_reg = *(s8*)(src_reg + off) */
1799                 case BPF_LDX | BPF_MEMSX | BPF_B:
1800                 case BPF_LDX | BPF_MEMSX | BPF_H:
1801                 case BPF_LDX | BPF_MEMSX | BPF_W:
1802                 case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1803                 case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1804                 case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1805                         insn_off = insn->off;
1806
1807                         if (BPF_MODE(insn->code) == BPF_PROBE_MEM ||
1808                             BPF_MODE(insn->code) == BPF_PROBE_MEMSX) {
1809                                 /* Conservatively check that src_reg + insn->off is a kernel address:
1810                                  *   src_reg + insn->off >= TASK_SIZE_MAX + PAGE_SIZE
1811                                  * src_reg is used as scratch for src_reg += insn->off and restored
1812                                  * after emit_ldx if necessary
1813                                  */
1814
1815                                 u64 limit = TASK_SIZE_MAX + PAGE_SIZE;
1816                                 u8 *end_of_jmp;
1817
1818                                 /* At end of these emitted checks, insn->off will have been added
1819                                  * to src_reg, so no need to do relative load with insn->off offset
1820                                  */
1821                                 insn_off = 0;
1822
1823                                 /* movabsq r11, limit */
1824                                 EMIT2(add_1mod(0x48, AUX_REG), add_1reg(0xB8, AUX_REG));
1825                                 EMIT((u32)limit, 4);
1826                                 EMIT(limit >> 32, 4);
1827
1828                                 if (insn->off) {
1829                                         /* add src_reg, insn->off */
1830                                         maybe_emit_1mod(&prog, src_reg, true);
1831                                         EMIT2_off32(0x81, add_1reg(0xC0, src_reg), insn->off);
1832                                 }
1833
1834                                 /* cmp src_reg, r11 */
1835                                 maybe_emit_mod(&prog, src_reg, AUX_REG, true);
1836                                 EMIT2(0x39, add_2reg(0xC0, src_reg, AUX_REG));
1837
1838                                 /* if unsigned '>=', goto load */
1839                                 EMIT2(X86_JAE, 0);
1840                                 end_of_jmp = prog;
1841
1842                                 /* xor dst_reg, dst_reg */
1843                                 emit_mov_imm32(&prog, false, dst_reg, 0);
1844                                 /* jmp byte_after_ldx */
1845                                 EMIT2(0xEB, 0);
1846
1847                                 /* populate jmp_offset for JAE above to jump to start_of_ldx */
1848                                 start_of_ldx = prog;
1849                                 end_of_jmp[-1] = start_of_ldx - end_of_jmp;
1850                         }
1851                         if (BPF_MODE(insn->code) == BPF_PROBE_MEMSX ||
1852                             BPF_MODE(insn->code) == BPF_MEMSX)
1853                                 emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
1854                         else
1855                                 emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
1856                         if (BPF_MODE(insn->code) == BPF_PROBE_MEM ||
1857                             BPF_MODE(insn->code) == BPF_PROBE_MEMSX) {
1858                                 struct exception_table_entry *ex;
1859                                 u8 *_insn = image + proglen + (start_of_ldx - temp);
1860                                 s64 delta;
1861
1862                                 /* populate jmp_offset for JMP above */
1863                                 start_of_ldx[-1] = prog - start_of_ldx;
1864
1865                                 if (insn->off && src_reg != dst_reg) {
1866                                         /* sub src_reg, insn->off
1867                                          * Restore src_reg after "add src_reg, insn->off" in prev
1868                                          * if statement. But if src_reg == dst_reg, emit_ldx
1869                                          * above already clobbered src_reg, so no need to restore.
1870                                          * If add src_reg, insn->off was unnecessary, no need to
1871                                          * restore either.
1872                                          */
1873                                         maybe_emit_1mod(&prog, src_reg, true);
1874                                         EMIT2_off32(0x81, add_1reg(0xE8, src_reg), insn->off);
1875                                 }
1876
1877                                 if (!bpf_prog->aux->extable)
1878                                         break;
1879
1880                                 if (excnt >= bpf_prog->aux->num_exentries) {
1881                                         pr_err("ex gen bug\n");
1882                                         return -EFAULT;
1883                                 }
1884                                 ex = &bpf_prog->aux->extable[excnt++];
1885
1886                                 delta = _insn - (u8 *)&ex->insn;
1887                                 if (!is_simm32(delta)) {
1888                                         pr_err("extable->insn doesn't fit into 32-bit\n");
1889                                         return -EFAULT;
1890                                 }
1891                                 /* switch ex to rw buffer for writes */
1892                                 ex = (void *)rw_image + ((void *)ex - (void *)image);
1893
1894                                 ex->insn = delta;
1895
1896                                 ex->data = EX_TYPE_BPF;
1897
1898                                 if (dst_reg > BPF_REG_9) {
1899                                         pr_err("verifier error\n");
1900                                         return -EFAULT;
1901                                 }
1902                                 /*
1903                                  * Compute size of x86 insn and its target dest x86 register.
1904                                  * ex_handler_bpf() will use lower 8 bits to adjust
1905                                  * pt_regs->ip to jump over this x86 instruction
1906                                  * and upper bits to figure out which pt_regs to zero out.
1907                                  * End result: x86 insn "mov rbx, qword ptr [rax+0x14]"
1908                                  * of 4 bytes will be ignored and rbx will be zero inited.
1909                                  */
1910                                 ex->fixup = (prog - start_of_ldx) | (reg2pt_regs[dst_reg] << 8);
1911                         }
1912                         break;
1913
1914                 case BPF_STX | BPF_ATOMIC | BPF_W:
1915                 case BPF_STX | BPF_ATOMIC | BPF_DW:
1916                         if (insn->imm == (BPF_AND | BPF_FETCH) ||
1917                             insn->imm == (BPF_OR | BPF_FETCH) ||
1918                             insn->imm == (BPF_XOR | BPF_FETCH)) {
1919                                 bool is64 = BPF_SIZE(insn->code) == BPF_DW;
1920                                 u32 real_src_reg = src_reg;
1921                                 u32 real_dst_reg = dst_reg;
1922                                 u8 *branch_target;
1923
1924                                 /*
1925                                  * Can't be implemented with a single x86 insn.
1926                                  * Need to do a CMPXCHG loop.
1927                                  */
1928
1929                                 /* Will need RAX as a CMPXCHG operand so save R0 */
1930                                 emit_mov_reg(&prog, true, BPF_REG_AX, BPF_REG_0);
1931                                 if (src_reg == BPF_REG_0)
1932                                         real_src_reg = BPF_REG_AX;
1933                                 if (dst_reg == BPF_REG_0)
1934                                         real_dst_reg = BPF_REG_AX;
1935
1936                                 branch_target = prog;
1937                                 /* Load old value */
1938                                 emit_ldx(&prog, BPF_SIZE(insn->code),
1939                                          BPF_REG_0, real_dst_reg, insn->off);
1940                                 /*
1941                                  * Perform the (commutative) operation locally,
1942                                  * put the result in the AUX_REG.
1943                                  */
1944                                 emit_mov_reg(&prog, is64, AUX_REG, BPF_REG_0);
1945                                 maybe_emit_mod(&prog, AUX_REG, real_src_reg, is64);
1946                                 EMIT2(simple_alu_opcodes[BPF_OP(insn->imm)],
1947                                       add_2reg(0xC0, AUX_REG, real_src_reg));
1948                                 /* Attempt to swap in new value */
1949                                 err = emit_atomic(&prog, BPF_CMPXCHG,
1950                                                   real_dst_reg, AUX_REG,
1951                                                   insn->off,
1952                                                   BPF_SIZE(insn->code));
1953                                 if (WARN_ON(err))
1954                                         return err;
1955                                 /*
1956                                  * ZF tells us whether we won the race. If it's
1957                                  * cleared we need to try again.
1958                                  */
1959                                 EMIT2(X86_JNE, -(prog - branch_target) - 2);
1960                                 /* Return the pre-modification value */
1961                                 emit_mov_reg(&prog, is64, real_src_reg, BPF_REG_0);
1962                                 /* Restore R0 after clobbering RAX */
1963                                 emit_mov_reg(&prog, true, BPF_REG_0, BPF_REG_AX);
1964                                 break;
1965                         }
1966
1967                         err = emit_atomic(&prog, insn->imm, dst_reg, src_reg,
1968                                           insn->off, BPF_SIZE(insn->code));
1969                         if (err)
1970                                 return err;
1971                         break;
1972
1973                         /* call */
1974                 case BPF_JMP | BPF_CALL: {
1975                         int offs;
1976
1977                         func = (u8 *) __bpf_call_base + imm32;
1978                         if (tail_call_reachable) {
1979                                 RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth);
1980                                 if (!imm32)
1981                                         return -EINVAL;
1982                                 offs = 7 + x86_call_depth_emit_accounting(&prog, func);
1983                         } else {
1984                                 if (!imm32)
1985                                         return -EINVAL;
1986                                 offs = x86_call_depth_emit_accounting(&prog, func);
1987                         }
1988                         if (emit_call(&prog, func, image + addrs[i - 1] + offs))
1989                                 return -EINVAL;
1990                         break;
1991                 }
1992
1993                 case BPF_JMP | BPF_TAIL_CALL:
1994                         if (imm32)
1995                                 emit_bpf_tail_call_direct(bpf_prog,
1996                                                           &bpf_prog->aux->poke_tab[imm32 - 1],
1997                                                           &prog, image + addrs[i - 1],
1998                                                           callee_regs_used,
1999                                                           bpf_prog->aux->stack_depth,
2000                                                           ctx);
2001                         else
2002                                 emit_bpf_tail_call_indirect(bpf_prog,
2003                                                             &prog,
2004                                                             callee_regs_used,
2005                                                             bpf_prog->aux->stack_depth,
2006                                                             image + addrs[i - 1],
2007                                                             ctx);
2008                         break;
2009
2010                         /* cond jump */
2011                 case BPF_JMP | BPF_JEQ | BPF_X:
2012                 case BPF_JMP | BPF_JNE | BPF_X:
2013                 case BPF_JMP | BPF_JGT | BPF_X:
2014                 case BPF_JMP | BPF_JLT | BPF_X:
2015                 case BPF_JMP | BPF_JGE | BPF_X:
2016                 case BPF_JMP | BPF_JLE | BPF_X:
2017                 case BPF_JMP | BPF_JSGT | BPF_X:
2018                 case BPF_JMP | BPF_JSLT | BPF_X:
2019                 case BPF_JMP | BPF_JSGE | BPF_X:
2020                 case BPF_JMP | BPF_JSLE | BPF_X:
2021                 case BPF_JMP32 | BPF_JEQ | BPF_X:
2022                 case BPF_JMP32 | BPF_JNE | BPF_X:
2023                 case BPF_JMP32 | BPF_JGT | BPF_X:
2024                 case BPF_JMP32 | BPF_JLT | BPF_X:
2025                 case BPF_JMP32 | BPF_JGE | BPF_X:
2026                 case BPF_JMP32 | BPF_JLE | BPF_X:
2027                 case BPF_JMP32 | BPF_JSGT | BPF_X:
2028                 case BPF_JMP32 | BPF_JSLT | BPF_X:
2029                 case BPF_JMP32 | BPF_JSGE | BPF_X:
2030                 case BPF_JMP32 | BPF_JSLE | BPF_X:
2031                         /* cmp dst_reg, src_reg */
2032                         maybe_emit_mod(&prog, dst_reg, src_reg,
2033                                        BPF_CLASS(insn->code) == BPF_JMP);
2034                         EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg));
2035                         goto emit_cond_jmp;
2036
2037                 case BPF_JMP | BPF_JSET | BPF_X:
2038                 case BPF_JMP32 | BPF_JSET | BPF_X:
2039                         /* test dst_reg, src_reg */
2040                         maybe_emit_mod(&prog, dst_reg, src_reg,
2041                                        BPF_CLASS(insn->code) == BPF_JMP);
2042                         EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg));
2043                         goto emit_cond_jmp;
2044
2045                 case BPF_JMP | BPF_JSET | BPF_K:
2046                 case BPF_JMP32 | BPF_JSET | BPF_K:
2047                         /* test dst_reg, imm32 */
2048                         maybe_emit_1mod(&prog, dst_reg,
2049                                         BPF_CLASS(insn->code) == BPF_JMP);
2050                         EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
2051                         goto emit_cond_jmp;
2052
2053                 case BPF_JMP | BPF_JEQ | BPF_K:
2054                 case BPF_JMP | BPF_JNE | BPF_K:
2055                 case BPF_JMP | BPF_JGT | BPF_K:
2056                 case BPF_JMP | BPF_JLT | BPF_K:
2057                 case BPF_JMP | BPF_JGE | BPF_K:
2058                 case BPF_JMP | BPF_JLE | BPF_K:
2059                 case BPF_JMP | BPF_JSGT | BPF_K:
2060                 case BPF_JMP | BPF_JSLT | BPF_K:
2061                 case BPF_JMP | BPF_JSGE | BPF_K:
2062                 case BPF_JMP | BPF_JSLE | BPF_K:
2063                 case BPF_JMP32 | BPF_JEQ | BPF_K:
2064                 case BPF_JMP32 | BPF_JNE | BPF_K:
2065                 case BPF_JMP32 | BPF_JGT | BPF_K:
2066                 case BPF_JMP32 | BPF_JLT | BPF_K:
2067                 case BPF_JMP32 | BPF_JGE | BPF_K:
2068                 case BPF_JMP32 | BPF_JLE | BPF_K:
2069                 case BPF_JMP32 | BPF_JSGT | BPF_K:
2070                 case BPF_JMP32 | BPF_JSLT | BPF_K:
2071                 case BPF_JMP32 | BPF_JSGE | BPF_K:
2072                 case BPF_JMP32 | BPF_JSLE | BPF_K:
2073                         /* test dst_reg, dst_reg to save one extra byte */
2074                         if (imm32 == 0) {
2075                                 maybe_emit_mod(&prog, dst_reg, dst_reg,
2076                                                BPF_CLASS(insn->code) == BPF_JMP);
2077                                 EMIT2(0x85, add_2reg(0xC0, dst_reg, dst_reg));
2078                                 goto emit_cond_jmp;
2079                         }
2080
2081                         /* cmp dst_reg, imm8/32 */
2082                         maybe_emit_1mod(&prog, dst_reg,
2083                                         BPF_CLASS(insn->code) == BPF_JMP);
2084
2085                         if (is_imm8(imm32))
2086                                 EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
2087                         else
2088                                 EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
2089
2090 emit_cond_jmp:          /* Convert BPF opcode to x86 */
2091                         switch (BPF_OP(insn->code)) {
2092                         case BPF_JEQ:
2093                                 jmp_cond = X86_JE;
2094                                 break;
2095                         case BPF_JSET:
2096                         case BPF_JNE:
2097                                 jmp_cond = X86_JNE;
2098                                 break;
2099                         case BPF_JGT:
2100                                 /* GT is unsigned '>', JA in x86 */
2101                                 jmp_cond = X86_JA;
2102                                 break;
2103                         case BPF_JLT:
2104                                 /* LT is unsigned '<', JB in x86 */
2105                                 jmp_cond = X86_JB;
2106                                 break;
2107                         case BPF_JGE:
2108                                 /* GE is unsigned '>=', JAE in x86 */
2109                                 jmp_cond = X86_JAE;
2110                                 break;
2111                         case BPF_JLE:
2112                                 /* LE is unsigned '<=', JBE in x86 */
2113                                 jmp_cond = X86_JBE;
2114                                 break;
2115                         case BPF_JSGT:
2116                                 /* Signed '>', GT in x86 */
2117                                 jmp_cond = X86_JG;
2118                                 break;
2119                         case BPF_JSLT:
2120                                 /* Signed '<', LT in x86 */
2121                                 jmp_cond = X86_JL;
2122                                 break;
2123                         case BPF_JSGE:
2124                                 /* Signed '>=', GE in x86 */
2125                                 jmp_cond = X86_JGE;
2126                                 break;
2127                         case BPF_JSLE:
2128                                 /* Signed '<=', LE in x86 */
2129                                 jmp_cond = X86_JLE;
2130                                 break;
2131                         default: /* to silence GCC warning */
2132                                 return -EFAULT;
2133                         }
2134                         jmp_offset = addrs[i + insn->off] - addrs[i];
2135                         if (is_imm8(jmp_offset)) {
2136                                 if (jmp_padding) {
2137                                         /* To keep the jmp_offset valid, the extra bytes are
2138                                          * padded before the jump insn, so we subtract the
2139                                          * 2 bytes of jmp_cond insn from INSN_SZ_DIFF.
2140                                          *
2141                                          * If the previous pass already emits an imm8
2142                                          * jmp_cond, then this BPF insn won't shrink, so
2143                                          * "nops" is 0.
2144                                          *
2145                                          * On the other hand, if the previous pass emits an
2146                                          * imm32 jmp_cond, the extra 4 bytes(*) is padded to
2147                                          * keep the image from shrinking further.
2148                                          *
2149                                          * (*) imm32 jmp_cond is 6 bytes, and imm8 jmp_cond
2150                                          *     is 2 bytes, so the size difference is 4 bytes.
2151                                          */
2152                                         nops = INSN_SZ_DIFF - 2;
2153                                         if (nops != 0 && nops != 4) {
2154                                                 pr_err("unexpected jmp_cond padding: %d bytes\n",
2155                                                        nops);
2156                                                 return -EFAULT;
2157                                         }
2158                                         emit_nops(&prog, nops);
2159                                 }
2160                                 EMIT2(jmp_cond, jmp_offset);
2161                         } else if (is_simm32(jmp_offset)) {
2162                                 EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
2163                         } else {
2164                                 pr_err("cond_jmp gen bug %llx\n", jmp_offset);
2165                                 return -EFAULT;
2166                         }
2167
2168                         break;
2169
2170                 case BPF_JMP | BPF_JA:
2171                 case BPF_JMP32 | BPF_JA:
2172                         if (BPF_CLASS(insn->code) == BPF_JMP) {
2173                                 if (insn->off == -1)
2174                                         /* -1 jmp instructions will always jump
2175                                          * backwards two bytes. Explicitly handling
2176                                          * this case avoids wasting too many passes
2177                                          * when there are long sequences of replaced
2178                                          * dead code.
2179                                          */
2180                                         jmp_offset = -2;
2181                                 else
2182                                         jmp_offset = addrs[i + insn->off] - addrs[i];
2183                         } else {
2184                                 if (insn->imm == -1)
2185                                         jmp_offset = -2;
2186                                 else
2187                                         jmp_offset = addrs[i + insn->imm] - addrs[i];
2188                         }
2189
2190                         if (!jmp_offset) {
2191                                 /*
2192                                  * If jmp_padding is enabled, the extra nops will
2193                                  * be inserted. Otherwise, optimize out nop jumps.
2194                                  */
2195                                 if (jmp_padding) {
2196                                         /* There are 3 possible conditions.
2197                                          * (1) This BPF_JA is already optimized out in
2198                                          *     the previous run, so there is no need
2199                                          *     to pad any extra byte (0 byte).
2200                                          * (2) The previous pass emits an imm8 jmp,
2201                                          *     so we pad 2 bytes to match the previous
2202                                          *     insn size.
2203                                          * (3) Similarly, the previous pass emits an
2204                                          *     imm32 jmp, and 5 bytes is padded.
2205                                          */
2206                                         nops = INSN_SZ_DIFF;
2207                                         if (nops != 0 && nops != 2 && nops != 5) {
2208                                                 pr_err("unexpected nop jump padding: %d bytes\n",
2209                                                        nops);
2210                                                 return -EFAULT;
2211                                         }
2212                                         emit_nops(&prog, nops);
2213                                 }
2214                                 break;
2215                         }
2216 emit_jmp:
2217                         if (is_imm8(jmp_offset)) {
2218                                 if (jmp_padding) {
2219                                         /* To avoid breaking jmp_offset, the extra bytes
2220                                          * are padded before the actual jmp insn, so
2221                                          * 2 bytes is subtracted from INSN_SZ_DIFF.
2222                                          *
2223                                          * If the previous pass already emits an imm8
2224                                          * jmp, there is nothing to pad (0 byte).
2225                                          *
2226                                          * If it emits an imm32 jmp (5 bytes) previously
2227                                          * and now an imm8 jmp (2 bytes), then we pad
2228                                          * (5 - 2 = 3) bytes to stop the image from
2229                                          * shrinking further.
2230                                          */
2231                                         nops = INSN_SZ_DIFF - 2;
2232                                         if (nops != 0 && nops != 3) {
2233                                                 pr_err("unexpected jump padding: %d bytes\n",
2234                                                        nops);
2235                                                 return -EFAULT;
2236                                         }
2237                                         emit_nops(&prog, INSN_SZ_DIFF - 2);
2238                                 }
2239                                 EMIT2(0xEB, jmp_offset);
2240                         } else if (is_simm32(jmp_offset)) {
2241                                 EMIT1_off32(0xE9, jmp_offset);
2242                         } else {
2243                                 pr_err("jmp gen bug %llx\n", jmp_offset);
2244                                 return -EFAULT;
2245                         }
2246                         break;
2247
2248                 case BPF_JMP | BPF_EXIT:
2249                         if (seen_exit) {
2250                                 jmp_offset = ctx->cleanup_addr - addrs[i];
2251                                 goto emit_jmp;
2252                         }
2253                         seen_exit = true;
2254                         /* Update cleanup_addr */
2255                         ctx->cleanup_addr = proglen;
2256                         if (bpf_prog->aux->exception_boundary) {
2257                                 pop_callee_regs(&prog, all_callee_regs_used);
2258                                 pop_r12(&prog);
2259                         } else {
2260                                 pop_callee_regs(&prog, callee_regs_used);
2261                                 if (arena_vm_start)
2262                                         pop_r12(&prog);
2263                         }
2264                         EMIT1(0xC9);         /* leave */
2265                         emit_return(&prog, image + addrs[i - 1] + (prog - temp));
2266                         break;
2267
2268                 default:
2269                         /*
2270                          * By design x86-64 JIT should support all BPF instructions.
2271                          * This error will be seen if new instruction was added
2272                          * to the interpreter, but not to the JIT, or if there is
2273                          * junk in bpf_prog.
2274                          */
2275                         pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
2276                         return -EINVAL;
2277                 }
2278
2279                 ilen = prog - temp;
2280                 if (ilen > BPF_MAX_INSN_SIZE) {
2281                         pr_err("bpf_jit: fatal insn size error\n");
2282                         return -EFAULT;
2283                 }
2284
2285                 if (image) {
2286                         /*
2287                          * When populating the image, assert that:
2288                          *
2289                          *  i) We do not write beyond the allocated space, and
2290                          * ii) addrs[i] did not change from the prior run, in order
2291                          *     to validate assumptions made for computing branch
2292                          *     displacements.
2293                          */
2294                         if (unlikely(proglen + ilen > oldproglen ||
2295                                      proglen + ilen != addrs[i])) {
2296                                 pr_err("bpf_jit: fatal error\n");
2297                                 return -EFAULT;
2298                         }
2299                         memcpy(rw_image + proglen, temp, ilen);
2300                 }
2301                 proglen += ilen;
2302                 addrs[i] = proglen;
2303                 prog = temp;
2304         }
2305
2306         if (image && excnt != bpf_prog->aux->num_exentries) {
2307                 pr_err("extable is not populated\n");
2308                 return -EFAULT;
2309         }
2310         return proglen;
2311 }
2312
2313 static void clean_stack_garbage(const struct btf_func_model *m,
2314                                 u8 **pprog, int nr_stack_slots,
2315                                 int stack_size)
2316 {
2317         int arg_size, off;
2318         u8 *prog;
2319
2320         /* Generally speaking, the compiler will pass the arguments
2321          * on-stack with "push" instruction, which will take 8-byte
2322          * on the stack. In this case, there won't be garbage values
2323          * while we copy the arguments from origin stack frame to current
2324          * in BPF_DW.
2325          *
2326          * However, sometimes the compiler will only allocate 4-byte on
2327          * the stack for the arguments. For now, this case will only
2328          * happen if there is only one argument on-stack and its size
2329          * not more than 4 byte. In this case, there will be garbage
2330          * values on the upper 4-byte where we store the argument on
2331          * current stack frame.
2332          *
2333          * arguments on origin stack:
2334          *
2335          * stack_arg_1(4-byte) xxx(4-byte)
2336          *
2337          * what we copy:
2338          *
2339          * stack_arg_1(8-byte): stack_arg_1(origin) xxx
2340          *
2341          * and the xxx is the garbage values which we should clean here.
2342          */
2343         if (nr_stack_slots != 1)
2344                 return;
2345
2346         /* the size of the last argument */
2347         arg_size = m->arg_size[m->nr_args - 1];
2348         if (arg_size <= 4) {
2349                 off = -(stack_size - 4);
2350                 prog = *pprog;
2351                 /* mov DWORD PTR [rbp + off], 0 */
2352                 if (!is_imm8(off))
2353                         EMIT2_off32(0xC7, 0x85, off);
2354                 else
2355                         EMIT3(0xC7, 0x45, off);
2356                 EMIT(0, 4);
2357                 *pprog = prog;
2358         }
2359 }
2360
2361 /* get the count of the regs that are used to pass arguments */
2362 static int get_nr_used_regs(const struct btf_func_model *m)
2363 {
2364         int i, arg_regs, nr_used_regs = 0;
2365
2366         for (i = 0; i < min_t(int, m->nr_args, MAX_BPF_FUNC_ARGS); i++) {
2367                 arg_regs = (m->arg_size[i] + 7) / 8;
2368                 if (nr_used_regs + arg_regs <= 6)
2369                         nr_used_regs += arg_regs;
2370
2371                 if (nr_used_regs >= 6)
2372                         break;
2373         }
2374
2375         return nr_used_regs;
2376 }
2377
2378 static void save_args(const struct btf_func_model *m, u8 **prog,
2379                       int stack_size, bool for_call_origin)
2380 {
2381         int arg_regs, first_off = 0, nr_regs = 0, nr_stack_slots = 0;
2382         int i, j;
2383
2384         /* Store function arguments to stack.
2385          * For a function that accepts two pointers the sequence will be:
2386          * mov QWORD PTR [rbp-0x10],rdi
2387          * mov QWORD PTR [rbp-0x8],rsi
2388          */
2389         for (i = 0; i < min_t(int, m->nr_args, MAX_BPF_FUNC_ARGS); i++) {
2390                 arg_regs = (m->arg_size[i] + 7) / 8;
2391
2392                 /* According to the research of Yonghong, struct members
2393                  * should be all in register or all on the stack.
2394                  * Meanwhile, the compiler will pass the argument on regs
2395                  * if the remaining regs can hold the argument.
2396                  *
2397                  * Disorder of the args can happen. For example:
2398                  *
2399                  * struct foo_struct {
2400                  *     long a;
2401                  *     int b;
2402                  * };
2403                  * int foo(char, char, char, char, char, struct foo_struct,
2404                  *         char);
2405                  *
2406                  * the arg1-5,arg7 will be passed by regs, and arg6 will
2407                  * by stack.
2408                  */
2409                 if (nr_regs + arg_regs > 6) {
2410                         /* copy function arguments from origin stack frame
2411                          * into current stack frame.
2412                          *
2413                          * The starting address of the arguments on-stack
2414                          * is:
2415                          *   rbp + 8(push rbp) +
2416                          *   8(return addr of origin call) +
2417                          *   8(return addr of the caller)
2418                          * which means: rbp + 24
2419                          */
2420                         for (j = 0; j < arg_regs; j++) {
2421                                 emit_ldx(prog, BPF_DW, BPF_REG_0, BPF_REG_FP,
2422                                          nr_stack_slots * 8 + 0x18);
2423                                 emit_stx(prog, BPF_DW, BPF_REG_FP, BPF_REG_0,
2424                                          -stack_size);
2425
2426                                 if (!nr_stack_slots)
2427                                         first_off = stack_size;
2428                                 stack_size -= 8;
2429                                 nr_stack_slots++;
2430                         }
2431                 } else {
2432                         /* Only copy the arguments on-stack to current
2433                          * 'stack_size' and ignore the regs, used to
2434                          * prepare the arguments on-stack for origin call.
2435                          */
2436                         if (for_call_origin) {
2437                                 nr_regs += arg_regs;
2438                                 continue;
2439                         }
2440
2441                         /* copy the arguments from regs into stack */
2442                         for (j = 0; j < arg_regs; j++) {
2443                                 emit_stx(prog, BPF_DW, BPF_REG_FP,
2444                                          nr_regs == 5 ? X86_REG_R9 : BPF_REG_1 + nr_regs,
2445                                          -stack_size);
2446                                 stack_size -= 8;
2447                                 nr_regs++;
2448                         }
2449                 }
2450         }
2451
2452         clean_stack_garbage(m, prog, nr_stack_slots, first_off);
2453 }
2454
2455 static void restore_regs(const struct btf_func_model *m, u8 **prog,
2456                          int stack_size)
2457 {
2458         int i, j, arg_regs, nr_regs = 0;
2459
2460         /* Restore function arguments from stack.
2461          * For a function that accepts two pointers the sequence will be:
2462          * EMIT4(0x48, 0x8B, 0x7D, 0xF0); mov rdi,QWORD PTR [rbp-0x10]
2463          * EMIT4(0x48, 0x8B, 0x75, 0xF8); mov rsi,QWORD PTR [rbp-0x8]
2464          *
2465          * The logic here is similar to what we do in save_args()
2466          */
2467         for (i = 0; i < min_t(int, m->nr_args, MAX_BPF_FUNC_ARGS); i++) {
2468                 arg_regs = (m->arg_size[i] + 7) / 8;
2469                 if (nr_regs + arg_regs <= 6) {
2470                         for (j = 0; j < arg_regs; j++) {
2471                                 emit_ldx(prog, BPF_DW,
2472                                          nr_regs == 5 ? X86_REG_R9 : BPF_REG_1 + nr_regs,
2473                                          BPF_REG_FP,
2474                                          -stack_size);
2475                                 stack_size -= 8;
2476                                 nr_regs++;
2477                         }
2478                 } else {
2479                         stack_size -= 8 * arg_regs;
2480                 }
2481
2482                 if (nr_regs >= 6)
2483                         break;
2484         }
2485 }
2486
2487 static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog,
2488                            struct bpf_tramp_link *l, int stack_size,
2489                            int run_ctx_off, bool save_ret,
2490                            void *image, void *rw_image)
2491 {
2492         u8 *prog = *pprog;
2493         u8 *jmp_insn;
2494         int ctx_cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
2495         struct bpf_prog *p = l->link.prog;
2496         u64 cookie = l->cookie;
2497
2498         /* mov rdi, cookie */
2499         emit_mov_imm64(&prog, BPF_REG_1, (long) cookie >> 32, (u32) (long) cookie);
2500
2501         /* Prepare struct bpf_tramp_run_ctx.
2502          *
2503          * bpf_tramp_run_ctx is already preserved by
2504          * arch_prepare_bpf_trampoline().
2505          *
2506          * mov QWORD PTR [rbp - run_ctx_off + ctx_cookie_off], rdi
2507          */
2508         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_1, -run_ctx_off + ctx_cookie_off);
2509
2510         /* arg1: mov rdi, progs[i] */
2511         emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32, (u32) (long) p);
2512         /* arg2: lea rsi, [rbp - ctx_cookie_off] */
2513         if (!is_imm8(-run_ctx_off))
2514                 EMIT3_off32(0x48, 0x8D, 0xB5, -run_ctx_off);
2515         else
2516                 EMIT4(0x48, 0x8D, 0x75, -run_ctx_off);
2517
2518         if (emit_rsb_call(&prog, bpf_trampoline_enter(p), image + (prog - (u8 *)rw_image)))
2519                 return -EINVAL;
2520         /* remember prog start time returned by __bpf_prog_enter */
2521         emit_mov_reg(&prog, true, BPF_REG_6, BPF_REG_0);
2522
2523         /* if (__bpf_prog_enter*(prog) == 0)
2524          *      goto skip_exec_of_prog;
2525          */
2526         EMIT3(0x48, 0x85, 0xC0);  /* test rax,rax */
2527         /* emit 2 nops that will be replaced with JE insn */
2528         jmp_insn = prog;
2529         emit_nops(&prog, 2);
2530
2531         /* arg1: lea rdi, [rbp - stack_size] */
2532         if (!is_imm8(-stack_size))
2533                 EMIT3_off32(0x48, 0x8D, 0xBD, -stack_size);
2534         else
2535                 EMIT4(0x48, 0x8D, 0x7D, -stack_size);
2536         /* arg2: progs[i]->insnsi for interpreter */
2537         if (!p->jited)
2538                 emit_mov_imm64(&prog, BPF_REG_2,
2539                                (long) p->insnsi >> 32,
2540                                (u32) (long) p->insnsi);
2541         /* call JITed bpf program or interpreter */
2542         if (emit_rsb_call(&prog, p->bpf_func, image + (prog - (u8 *)rw_image)))
2543                 return -EINVAL;
2544
2545         /*
2546          * BPF_TRAMP_MODIFY_RETURN trampolines can modify the return
2547          * of the previous call which is then passed on the stack to
2548          * the next BPF program.
2549          *
2550          * BPF_TRAMP_FENTRY trampoline may need to return the return
2551          * value of BPF_PROG_TYPE_STRUCT_OPS prog.
2552          */
2553         if (save_ret)
2554                 emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
2555
2556         /* replace 2 nops with JE insn, since jmp target is known */
2557         jmp_insn[0] = X86_JE;
2558         jmp_insn[1] = prog - jmp_insn - 2;
2559
2560         /* arg1: mov rdi, progs[i] */
2561         emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32, (u32) (long) p);
2562         /* arg2: mov rsi, rbx <- start time in nsec */
2563         emit_mov_reg(&prog, true, BPF_REG_2, BPF_REG_6);
2564         /* arg3: lea rdx, [rbp - run_ctx_off] */
2565         if (!is_imm8(-run_ctx_off))
2566                 EMIT3_off32(0x48, 0x8D, 0x95, -run_ctx_off);
2567         else
2568                 EMIT4(0x48, 0x8D, 0x55, -run_ctx_off);
2569         if (emit_rsb_call(&prog, bpf_trampoline_exit(p), image + (prog - (u8 *)rw_image)))
2570                 return -EINVAL;
2571
2572         *pprog = prog;
2573         return 0;
2574 }
2575
2576 static void emit_align(u8 **pprog, u32 align)
2577 {
2578         u8 *target, *prog = *pprog;
2579
2580         target = PTR_ALIGN(prog, align);
2581         if (target != prog)
2582                 emit_nops(&prog, target - prog);
2583
2584         *pprog = prog;
2585 }
2586
2587 static int emit_cond_near_jump(u8 **pprog, void *func, void *ip, u8 jmp_cond)
2588 {
2589         u8 *prog = *pprog;
2590         s64 offset;
2591
2592         offset = func - (ip + 2 + 4);
2593         if (!is_simm32(offset)) {
2594                 pr_err("Target %p is out of range\n", func);
2595                 return -EINVAL;
2596         }
2597         EMIT2_off32(0x0F, jmp_cond + 0x10, offset);
2598         *pprog = prog;
2599         return 0;
2600 }
2601
2602 static int invoke_bpf(const struct btf_func_model *m, u8 **pprog,
2603                       struct bpf_tramp_links *tl, int stack_size,
2604                       int run_ctx_off, bool save_ret,
2605                       void *image, void *rw_image)
2606 {
2607         int i;
2608         u8 *prog = *pprog;
2609
2610         for (i = 0; i < tl->nr_links; i++) {
2611                 if (invoke_bpf_prog(m, &prog, tl->links[i], stack_size,
2612                                     run_ctx_off, save_ret, image, rw_image))
2613                         return -EINVAL;
2614         }
2615         *pprog = prog;
2616         return 0;
2617 }
2618
2619 static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
2620                               struct bpf_tramp_links *tl, int stack_size,
2621                               int run_ctx_off, u8 **branches,
2622                               void *image, void *rw_image)
2623 {
2624         u8 *prog = *pprog;
2625         int i;
2626
2627         /* The first fmod_ret program will receive a garbage return value.
2628          * Set this to 0 to avoid confusing the program.
2629          */
2630         emit_mov_imm32(&prog, false, BPF_REG_0, 0);
2631         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
2632         for (i = 0; i < tl->nr_links; i++) {
2633                 if (invoke_bpf_prog(m, &prog, tl->links[i], stack_size, run_ctx_off, true,
2634                                     image, rw_image))
2635                         return -EINVAL;
2636
2637                 /* mod_ret prog stored return value into [rbp - 8]. Emit:
2638                  * if (*(u64 *)(rbp - 8) !=  0)
2639                  *      goto do_fexit;
2640                  */
2641                 /* cmp QWORD PTR [rbp - 0x8], 0x0 */
2642                 EMIT4(0x48, 0x83, 0x7d, 0xf8); EMIT1(0x00);
2643
2644                 /* Save the location of the branch and Generate 6 nops
2645                  * (4 bytes for an offset and 2 bytes for the jump) These nops
2646                  * are replaced with a conditional jump once do_fexit (i.e. the
2647                  * start of the fexit invocation) is finalized.
2648                  */
2649                 branches[i] = prog;
2650                 emit_nops(&prog, 4 + 2);
2651         }
2652
2653         *pprog = prog;
2654         return 0;
2655 }
2656
2657 /* Example:
2658  * __be16 eth_type_trans(struct sk_buff *skb, struct net_device *dev);
2659  * its 'struct btf_func_model' will be nr_args=2
2660  * The assembly code when eth_type_trans is executing after trampoline:
2661  *
2662  * push rbp
2663  * mov rbp, rsp
2664  * sub rsp, 16                     // space for skb and dev
2665  * push rbx                        // temp regs to pass start time
2666  * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
2667  * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
2668  * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
2669  * mov rbx, rax                    // remember start time in bpf stats are enabled
2670  * lea rdi, [rbp - 16]             // R1==ctx of bpf prog
2671  * call addr_of_jited_FENTRY_prog
2672  * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
2673  * mov rsi, rbx                    // prog start time
2674  * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
2675  * mov rdi, qword ptr [rbp - 16]   // restore skb pointer from stack
2676  * mov rsi, qword ptr [rbp - 8]    // restore dev pointer from stack
2677  * pop rbx
2678  * leave
2679  * ret
2680  *
2681  * eth_type_trans has 5 byte nop at the beginning. These 5 bytes will be
2682  * replaced with 'call generated_bpf_trampoline'. When it returns
2683  * eth_type_trans will continue executing with original skb and dev pointers.
2684  *
2685  * The assembly code when eth_type_trans is called from trampoline:
2686  *
2687  * push rbp
2688  * mov rbp, rsp
2689  * sub rsp, 24                     // space for skb, dev, return value
2690  * push rbx                        // temp regs to pass start time
2691  * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
2692  * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
2693  * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
2694  * mov rbx, rax                    // remember start time if bpf stats are enabled
2695  * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
2696  * call addr_of_jited_FENTRY_prog  // bpf prog can access skb and dev
2697  * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
2698  * mov rsi, rbx                    // prog start time
2699  * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
2700  * mov rdi, qword ptr [rbp - 24]   // restore skb pointer from stack
2701  * mov rsi, qword ptr [rbp - 16]   // restore dev pointer from stack
2702  * call eth_type_trans+5           // execute body of eth_type_trans
2703  * mov qword ptr [rbp - 8], rax    // save return value
2704  * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
2705  * mov rbx, rax                    // remember start time in bpf stats are enabled
2706  * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
2707  * call addr_of_jited_FEXIT_prog   // bpf prog can access skb, dev, return value
2708  * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
2709  * mov rsi, rbx                    // prog start time
2710  * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
2711  * mov rax, qword ptr [rbp - 8]    // restore eth_type_trans's return value
2712  * pop rbx
2713  * leave
2714  * add rsp, 8                      // skip eth_type_trans's frame
2715  * ret                             // return to its caller
2716  */
2717 static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_image,
2718                                          void *rw_image_end, void *image,
2719                                          const struct btf_func_model *m, u32 flags,
2720                                          struct bpf_tramp_links *tlinks,
2721                                          void *func_addr)
2722 {
2723         int i, ret, nr_regs = m->nr_args, stack_size = 0;
2724         int regs_off, nregs_off, ip_off, run_ctx_off, arg_stack_off, rbx_off;
2725         struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
2726         struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
2727         struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
2728         void *orig_call = func_addr;
2729         u8 **branches = NULL;
2730         u8 *prog;
2731         bool save_ret;
2732
2733         /*
2734          * F_INDIRECT is only compatible with F_RET_FENTRY_RET, it is
2735          * explicitly incompatible with F_CALL_ORIG | F_SKIP_FRAME | F_IP_ARG
2736          * because @func_addr.
2737          */
2738         WARN_ON_ONCE((flags & BPF_TRAMP_F_INDIRECT) &&
2739                      (flags & ~(BPF_TRAMP_F_INDIRECT | BPF_TRAMP_F_RET_FENTRY_RET)));
2740
2741         /* extra registers for struct arguments */
2742         for (i = 0; i < m->nr_args; i++) {
2743                 if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
2744                         nr_regs += (m->arg_size[i] + 7) / 8 - 1;
2745         }
2746
2747         /* x86-64 supports up to MAX_BPF_FUNC_ARGS arguments. 1-6
2748          * are passed through regs, the remains are through stack.
2749          */
2750         if (nr_regs > MAX_BPF_FUNC_ARGS)
2751                 return -ENOTSUPP;
2752
2753         /* Generated trampoline stack layout:
2754          *
2755          * RBP + 8         [ return address  ]
2756          * RBP + 0         [ RBP             ]
2757          *
2758          * RBP - 8         [ return value    ]  BPF_TRAMP_F_CALL_ORIG or
2759          *                                      BPF_TRAMP_F_RET_FENTRY_RET flags
2760          *
2761          *                 [ reg_argN        ]  always
2762          *                 [ ...             ]
2763          * RBP - regs_off  [ reg_arg1        ]  program's ctx pointer
2764          *
2765          * RBP - nregs_off [ regs count      ]  always
2766          *
2767          * RBP - ip_off    [ traced function ]  BPF_TRAMP_F_IP_ARG flag
2768          *
2769          * RBP - rbx_off   [ rbx value       ]  always
2770          *
2771          * RBP - run_ctx_off [ bpf_tramp_run_ctx ]
2772          *
2773          *                     [ stack_argN ]  BPF_TRAMP_F_CALL_ORIG
2774          *                     [ ...        ]
2775          *                     [ stack_arg2 ]
2776          * RBP - arg_stack_off [ stack_arg1 ]
2777          * RSP                 [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX
2778          */
2779
2780         /* room for return value of orig_call or fentry prog */
2781         save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
2782         if (save_ret)
2783                 stack_size += 8;
2784
2785         stack_size += nr_regs * 8;
2786         regs_off = stack_size;
2787
2788         /* regs count  */
2789         stack_size += 8;
2790         nregs_off = stack_size;
2791
2792         if (flags & BPF_TRAMP_F_IP_ARG)
2793                 stack_size += 8; /* room for IP address argument */
2794
2795         ip_off = stack_size;
2796
2797         stack_size += 8;
2798         rbx_off = stack_size;
2799
2800         stack_size += (sizeof(struct bpf_tramp_run_ctx) + 7) & ~0x7;
2801         run_ctx_off = stack_size;
2802
2803         if (nr_regs > 6 && (flags & BPF_TRAMP_F_CALL_ORIG)) {
2804                 /* the space that used to pass arguments on-stack */
2805                 stack_size += (nr_regs - get_nr_used_regs(m)) * 8;
2806                 /* make sure the stack pointer is 16-byte aligned if we
2807                  * need pass arguments on stack, which means
2808                  *  [stack_size + 8(rbp) + 8(rip) + 8(origin rip)]
2809                  * should be 16-byte aligned. Following code depend on
2810                  * that stack_size is already 8-byte aligned.
2811                  */
2812                 stack_size += (stack_size % 16) ? 0 : 8;
2813         }
2814
2815         arg_stack_off = stack_size;
2816
2817         if (flags & BPF_TRAMP_F_SKIP_FRAME) {
2818                 /* skip patched call instruction and point orig_call to actual
2819                  * body of the kernel function.
2820                  */
2821                 if (is_endbr(*(u32 *)orig_call))
2822                         orig_call += ENDBR_INSN_SIZE;
2823                 orig_call += X86_PATCH_SIZE;
2824         }
2825
2826         prog = rw_image;
2827
2828         if (flags & BPF_TRAMP_F_INDIRECT) {
2829                 /*
2830                  * Indirect call for bpf_struct_ops
2831                  */
2832                 emit_cfi(&prog, cfi_get_func_hash(func_addr));
2833         } else {
2834                 /*
2835                  * Direct-call fentry stub, as such it needs accounting for the
2836                  * __fentry__ call.
2837                  */
2838                 x86_call_depth_emit_accounting(&prog, NULL);
2839         }
2840         EMIT1(0x55);             /* push rbp */
2841         EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
2842         if (!is_imm8(stack_size)) {
2843                 /* sub rsp, stack_size */
2844                 EMIT3_off32(0x48, 0x81, 0xEC, stack_size);
2845         } else {
2846                 /* sub rsp, stack_size */
2847                 EMIT4(0x48, 0x83, 0xEC, stack_size);
2848         }
2849         if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
2850                 EMIT1(0x50);            /* push rax */
2851         /* mov QWORD PTR [rbp - rbx_off], rbx */
2852         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off);
2853
2854         /* Store number of argument registers of the traced function:
2855          *   mov rax, nr_regs
2856          *   mov QWORD PTR [rbp - nregs_off], rax
2857          */
2858         emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_regs);
2859         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -nregs_off);
2860
2861         if (flags & BPF_TRAMP_F_IP_ARG) {
2862                 /* Store IP address of the traced function:
2863                  * movabsq rax, func_addr
2864                  * mov QWORD PTR [rbp - ip_off], rax
2865                  */
2866                 emit_mov_imm64(&prog, BPF_REG_0, (long) func_addr >> 32, (u32) (long) func_addr);
2867                 emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -ip_off);
2868         }
2869
2870         save_args(m, &prog, regs_off, false);
2871
2872         if (flags & BPF_TRAMP_F_CALL_ORIG) {
2873                 /* arg1: mov rdi, im */
2874                 emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
2875                 if (emit_rsb_call(&prog, __bpf_tramp_enter,
2876                                   image + (prog - (u8 *)rw_image))) {
2877                         ret = -EINVAL;
2878                         goto cleanup;
2879                 }
2880         }
2881
2882         if (fentry->nr_links) {
2883                 if (invoke_bpf(m, &prog, fentry, regs_off, run_ctx_off,
2884                                flags & BPF_TRAMP_F_RET_FENTRY_RET, image, rw_image))
2885                         return -EINVAL;
2886         }
2887
2888         if (fmod_ret->nr_links) {
2889                 branches = kcalloc(fmod_ret->nr_links, sizeof(u8 *),
2890                                    GFP_KERNEL);
2891                 if (!branches)
2892                         return -ENOMEM;
2893
2894                 if (invoke_bpf_mod_ret(m, &prog, fmod_ret, regs_off,
2895                                        run_ctx_off, branches, image, rw_image)) {
2896                         ret = -EINVAL;
2897                         goto cleanup;
2898                 }
2899         }
2900
2901         if (flags & BPF_TRAMP_F_CALL_ORIG) {
2902                 restore_regs(m, &prog, regs_off);
2903                 save_args(m, &prog, arg_stack_off, true);
2904
2905                 if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
2906                         /* Before calling the original function, restore the
2907                          * tail_call_cnt from stack to rax.
2908                          */
2909                         RESTORE_TAIL_CALL_CNT(stack_size);
2910                 }
2911
2912                 if (flags & BPF_TRAMP_F_ORIG_STACK) {
2913                         emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8);
2914                         EMIT2(0xff, 0xd3); /* call *rbx */
2915                 } else {
2916                         /* call original function */
2917                         if (emit_rsb_call(&prog, orig_call, image + (prog - (u8 *)rw_image))) {
2918                                 ret = -EINVAL;
2919                                 goto cleanup;
2920                         }
2921                 }
2922                 /* remember return value in a stack for bpf prog to access */
2923                 emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
2924                 im->ip_after_call = image + (prog - (u8 *)rw_image);
2925                 emit_nops(&prog, X86_PATCH_SIZE);
2926         }
2927
2928         if (fmod_ret->nr_links) {
2929                 /* From Intel 64 and IA-32 Architectures Optimization
2930                  * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
2931                  * Coding Rule 11: All branch targets should be 16-byte
2932                  * aligned.
2933                  */
2934                 emit_align(&prog, 16);
2935                 /* Update the branches saved in invoke_bpf_mod_ret with the
2936                  * aligned address of do_fexit.
2937                  */
2938                 for (i = 0; i < fmod_ret->nr_links; i++) {
2939                         emit_cond_near_jump(&branches[i], image + (prog - (u8 *)rw_image),
2940                                             image + (branches[i] - (u8 *)rw_image), X86_JNE);
2941                 }
2942         }
2943
2944         if (fexit->nr_links) {
2945                 if (invoke_bpf(m, &prog, fexit, regs_off, run_ctx_off,
2946                                false, image, rw_image)) {
2947                         ret = -EINVAL;
2948                         goto cleanup;
2949                 }
2950         }
2951
2952         if (flags & BPF_TRAMP_F_RESTORE_REGS)
2953                 restore_regs(m, &prog, regs_off);
2954
2955         /* This needs to be done regardless. If there were fmod_ret programs,
2956          * the return value is only updated on the stack and still needs to be
2957          * restored to R0.
2958          */
2959         if (flags & BPF_TRAMP_F_CALL_ORIG) {
2960                 im->ip_epilogue = image + (prog - (u8 *)rw_image);
2961                 /* arg1: mov rdi, im */
2962                 emit_mov_imm64(&prog, BPF_REG_1, (long) im >> 32, (u32) (long) im);
2963                 if (emit_rsb_call(&prog, __bpf_tramp_exit, image + (prog - (u8 *)rw_image))) {
2964                         ret = -EINVAL;
2965                         goto cleanup;
2966                 }
2967         } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
2968                 /* Before running the original function, restore the
2969                  * tail_call_cnt from stack to rax.
2970                  */
2971                 RESTORE_TAIL_CALL_CNT(stack_size);
2972         }
2973
2974         /* restore return value of orig_call or fentry prog back into RAX */
2975         if (save_ret)
2976                 emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
2977
2978         emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, -rbx_off);
2979         EMIT1(0xC9); /* leave */
2980         if (flags & BPF_TRAMP_F_SKIP_FRAME) {
2981                 /* skip our return address and return to parent */
2982                 EMIT4(0x48, 0x83, 0xC4, 8); /* add rsp, 8 */
2983         }
2984         emit_return(&prog, image + (prog - (u8 *)rw_image));
2985         /* Make sure the trampoline generation logic doesn't overflow */
2986         if (WARN_ON_ONCE(prog > (u8 *)rw_image_end - BPF_INSN_SAFETY)) {
2987                 ret = -EFAULT;
2988                 goto cleanup;
2989         }
2990         ret = prog - (u8 *)rw_image + BPF_INSN_SAFETY;
2991
2992 cleanup:
2993         kfree(branches);
2994         return ret;
2995 }
2996
2997 void *arch_alloc_bpf_trampoline(unsigned int size)
2998 {
2999         return bpf_prog_pack_alloc(size, jit_fill_hole);
3000 }
3001
3002 void arch_free_bpf_trampoline(void *image, unsigned int size)
3003 {
3004         bpf_prog_pack_free(image, size);
3005 }
3006
3007 void arch_protect_bpf_trampoline(void *image, unsigned int size)
3008 {
3009 }
3010
3011 void arch_unprotect_bpf_trampoline(void *image, unsigned int size)
3012 {
3013 }
3014
3015 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *image_end,
3016                                 const struct btf_func_model *m, u32 flags,
3017                                 struct bpf_tramp_links *tlinks,
3018                                 void *func_addr)
3019 {
3020         void *rw_image, *tmp;
3021         int ret;
3022         u32 size = image_end - image;
3023
3024         /* rw_image doesn't need to be in module memory range, so we can
3025          * use kvmalloc.
3026          */
3027         rw_image = kvmalloc(size, GFP_KERNEL);
3028         if (!rw_image)
3029                 return -ENOMEM;
3030
3031         ret = __arch_prepare_bpf_trampoline(im, rw_image, rw_image + size, image, m,
3032                                             flags, tlinks, func_addr);
3033         if (ret < 0)
3034                 goto out;
3035
3036         tmp = bpf_arch_text_copy(image, rw_image, size);
3037         if (IS_ERR(tmp))
3038                 ret = PTR_ERR(tmp);
3039 out:
3040         kvfree(rw_image);
3041         return ret;
3042 }
3043
3044 int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
3045                              struct bpf_tramp_links *tlinks, void *func_addr)
3046 {
3047         struct bpf_tramp_image im;
3048         void *image;
3049         int ret;
3050
3051         /* Allocate a temporary buffer for __arch_prepare_bpf_trampoline().
3052          * This will NOT cause fragmentation in direct map, as we do not
3053          * call set_memory_*() on this buffer.
3054          *
3055          * We cannot use kvmalloc here, because we need image to be in
3056          * module memory range.
3057          */
3058         image = bpf_jit_alloc_exec(PAGE_SIZE);
3059         if (!image)
3060                 return -ENOMEM;
3061
3062         ret = __arch_prepare_bpf_trampoline(&im, image, image + PAGE_SIZE, image,
3063                                             m, flags, tlinks, func_addr);
3064         bpf_jit_free_exec(image);
3065         return ret;
3066 }
3067
3068 static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs, u8 *image, u8 *buf)
3069 {
3070         u8 *jg_reloc, *prog = *pprog;
3071         int pivot, err, jg_bytes = 1;
3072         s64 jg_offset;
3073
3074         if (a == b) {
3075                 /* Leaf node of recursion, i.e. not a range of indices
3076                  * anymore.
3077                  */
3078                 EMIT1(add_1mod(0x48, BPF_REG_3));       /* cmp rdx,func */
3079                 if (!is_simm32(progs[a]))
3080                         return -1;
3081                 EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3),
3082                             progs[a]);
3083                 err = emit_cond_near_jump(&prog,        /* je func */
3084                                           (void *)progs[a], image + (prog - buf),
3085                                           X86_JE);
3086                 if (err)
3087                         return err;
3088
3089                 emit_indirect_jump(&prog, 2 /* rdx */, image + (prog - buf));
3090
3091                 *pprog = prog;
3092                 return 0;
3093         }
3094
3095         /* Not a leaf node, so we pivot, and recursively descend into
3096          * the lower and upper ranges.
3097          */
3098         pivot = (b - a) / 2;
3099         EMIT1(add_1mod(0x48, BPF_REG_3));               /* cmp rdx,func */
3100         if (!is_simm32(progs[a + pivot]))
3101                 return -1;
3102         EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3), progs[a + pivot]);
3103
3104         if (pivot > 2) {                                /* jg upper_part */
3105                 /* Require near jump. */
3106                 jg_bytes = 4;
3107                 EMIT2_off32(0x0F, X86_JG + 0x10, 0);
3108         } else {
3109                 EMIT2(X86_JG, 0);
3110         }
3111         jg_reloc = prog;
3112
3113         err = emit_bpf_dispatcher(&prog, a, a + pivot,  /* emit lower_part */
3114                                   progs, image, buf);
3115         if (err)
3116                 return err;
3117
3118         /* From Intel 64 and IA-32 Architectures Optimization
3119          * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
3120          * Coding Rule 11: All branch targets should be 16-byte
3121          * aligned.
3122          */
3123         emit_align(&prog, 16);
3124         jg_offset = prog - jg_reloc;
3125         emit_code(jg_reloc - jg_bytes, jg_offset, jg_bytes);
3126
3127         err = emit_bpf_dispatcher(&prog, a + pivot + 1, /* emit upper_part */
3128                                   b, progs, image, buf);
3129         if (err)
3130                 return err;
3131
3132         *pprog = prog;
3133         return 0;
3134 }
3135
3136 static int cmp_ips(const void *a, const void *b)
3137 {
3138         const s64 *ipa = a;
3139         const s64 *ipb = b;
3140
3141         if (*ipa > *ipb)
3142                 return 1;
3143         if (*ipa < *ipb)
3144                 return -1;
3145         return 0;
3146 }
3147
3148 int arch_prepare_bpf_dispatcher(void *image, void *buf, s64 *funcs, int num_funcs)
3149 {
3150         u8 *prog = buf;
3151
3152         sort(funcs, num_funcs, sizeof(funcs[0]), cmp_ips, NULL);
3153         return emit_bpf_dispatcher(&prog, 0, num_funcs - 1, funcs, image, buf);
3154 }
3155
3156 struct x64_jit_data {
3157         struct bpf_binary_header *rw_header;
3158         struct bpf_binary_header *header;
3159         int *addrs;
3160         u8 *image;
3161         int proglen;
3162         struct jit_context ctx;
3163 };
3164
3165 #define MAX_PASSES 20
3166 #define PADDING_PASSES (MAX_PASSES - 5)
3167
3168 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
3169 {
3170         struct bpf_binary_header *rw_header = NULL;
3171         struct bpf_binary_header *header = NULL;
3172         struct bpf_prog *tmp, *orig_prog = prog;
3173         struct x64_jit_data *jit_data;
3174         int proglen, oldproglen = 0;
3175         struct jit_context ctx = {};
3176         bool tmp_blinded = false;
3177         bool extra_pass = false;
3178         bool padding = false;
3179         u8 *rw_image = NULL;
3180         u8 *image = NULL;
3181         int *addrs;
3182         int pass;
3183         int i;
3184
3185         if (!prog->jit_requested)
3186                 return orig_prog;
3187
3188         tmp = bpf_jit_blind_constants(prog);
3189         /*
3190          * If blinding was requested and we failed during blinding,
3191          * we must fall back to the interpreter.
3192          */
3193         if (IS_ERR(tmp))
3194                 return orig_prog;
3195         if (tmp != prog) {
3196                 tmp_blinded = true;
3197                 prog = tmp;
3198         }
3199
3200         jit_data = prog->aux->jit_data;
3201         if (!jit_data) {
3202                 jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
3203                 if (!jit_data) {
3204                         prog = orig_prog;
3205                         goto out;
3206                 }
3207                 prog->aux->jit_data = jit_data;
3208         }
3209         addrs = jit_data->addrs;
3210         if (addrs) {
3211                 ctx = jit_data->ctx;
3212                 oldproglen = jit_data->proglen;
3213                 image = jit_data->image;
3214                 header = jit_data->header;
3215                 rw_header = jit_data->rw_header;
3216                 rw_image = (void *)rw_header + ((void *)image - (void *)header);
3217                 extra_pass = true;
3218                 padding = true;
3219                 goto skip_init_addrs;
3220         }
3221         addrs = kvmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL);
3222         if (!addrs) {
3223                 prog = orig_prog;
3224                 goto out_addrs;
3225         }
3226
3227         /*
3228          * Before first pass, make a rough estimation of addrs[]
3229          * each BPF instruction is translated to less than 64 bytes
3230          */
3231         for (proglen = 0, i = 0; i <= prog->len; i++) {
3232                 proglen += 64;
3233                 addrs[i] = proglen;
3234         }
3235         ctx.cleanup_addr = proglen;
3236 skip_init_addrs:
3237
3238         /*
3239          * JITed image shrinks with every pass and the loop iterates
3240          * until the image stops shrinking. Very large BPF programs
3241          * may converge on the last pass. In such case do one more
3242          * pass to emit the final image.
3243          */
3244         for (pass = 0; pass < MAX_PASSES || image; pass++) {
3245                 if (!padding && pass >= PADDING_PASSES)
3246                         padding = true;
3247                 proglen = do_jit(prog, addrs, image, rw_image, oldproglen, &ctx, padding);
3248                 if (proglen <= 0) {
3249 out_image:
3250                         image = NULL;
3251                         if (header) {
3252                                 bpf_arch_text_copy(&header->size, &rw_header->size,
3253                                                    sizeof(rw_header->size));
3254                                 bpf_jit_binary_pack_free(header, rw_header);
3255                         }
3256                         /* Fall back to interpreter mode */
3257                         prog = orig_prog;
3258                         if (extra_pass) {
3259                                 prog->bpf_func = NULL;
3260                                 prog->jited = 0;
3261                                 prog->jited_len = 0;
3262                         }
3263                         goto out_addrs;
3264                 }
3265                 if (image) {
3266                         if (proglen != oldproglen) {
3267                                 pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
3268                                        proglen, oldproglen);
3269                                 goto out_image;
3270                         }
3271                         break;
3272                 }
3273                 if (proglen == oldproglen) {
3274                         /*
3275                          * The number of entries in extable is the number of BPF_LDX
3276                          * insns that access kernel memory via "pointer to BTF type".
3277                          * The verifier changed their opcode from LDX|MEM|size
3278                          * to LDX|PROBE_MEM|size to make JITing easier.
3279                          */
3280                         u32 align = __alignof__(struct exception_table_entry);
3281                         u32 extable_size = prog->aux->num_exentries *
3282                                 sizeof(struct exception_table_entry);
3283
3284                         /* allocate module memory for x86 insns and extable */
3285                         header = bpf_jit_binary_pack_alloc(roundup(proglen, align) + extable_size,
3286                                                            &image, align, &rw_header, &rw_image,
3287                                                            jit_fill_hole);
3288                         if (!header) {
3289                                 prog = orig_prog;
3290                                 goto out_addrs;
3291                         }
3292                         prog->aux->extable = (void *) image + roundup(proglen, align);
3293                 }
3294                 oldproglen = proglen;
3295                 cond_resched();
3296         }
3297
3298         if (bpf_jit_enable > 1)
3299                 bpf_jit_dump(prog->len, proglen, pass + 1, rw_image);
3300
3301         if (image) {
3302                 if (!prog->is_func || extra_pass) {
3303                         /*
3304                          * bpf_jit_binary_pack_finalize fails in two scenarios:
3305                          *   1) header is not pointing to proper module memory;
3306                          *   2) the arch doesn't support bpf_arch_text_copy().
3307                          *
3308                          * Both cases are serious bugs and justify WARN_ON.
3309                          */
3310                         if (WARN_ON(bpf_jit_binary_pack_finalize(prog, header, rw_header))) {
3311                                 /* header has been freed */
3312                                 header = NULL;
3313                                 goto out_image;
3314                         }
3315
3316                         bpf_tail_call_direct_fixup(prog);
3317                 } else {
3318                         jit_data->addrs = addrs;
3319                         jit_data->ctx = ctx;
3320                         jit_data->proglen = proglen;
3321                         jit_data->image = image;
3322                         jit_data->header = header;
3323                         jit_data->rw_header = rw_header;
3324                 }
3325                 /*
3326                  * ctx.prog_offset is used when CFI preambles put code *before*
3327                  * the function. See emit_cfi(). For FineIBT specifically this code
3328                  * can also be executed and bpf_prog_kallsyms_add() will
3329                  * generate an additional symbol to cover this, hence also
3330                  * decrement proglen.
3331                  */
3332                 prog->bpf_func = (void *)image + cfi_get_offset();
3333                 prog->jited = 1;
3334                 prog->jited_len = proglen - cfi_get_offset();
3335         } else {
3336                 prog = orig_prog;
3337         }
3338
3339         if (!image || !prog->is_func || extra_pass) {
3340                 if (image)
3341                         bpf_prog_fill_jited_linfo(prog, addrs + 1);
3342 out_addrs:
3343                 kvfree(addrs);
3344                 kfree(jit_data);
3345                 prog->aux->jit_data = NULL;
3346         }
3347 out:
3348         if (tmp_blinded)
3349                 bpf_jit_prog_release_other(prog, prog == orig_prog ?
3350                                            tmp : orig_prog);
3351         return prog;
3352 }
3353
3354 bool bpf_jit_supports_kfunc_call(void)
3355 {
3356         return true;
3357 }
3358
3359 void *bpf_arch_text_copy(void *dst, void *src, size_t len)
3360 {
3361         if (text_poke_copy(dst, src, len) == NULL)
3362                 return ERR_PTR(-EINVAL);
3363         return dst;
3364 }
3365
3366 /* Indicate the JIT backend supports mixing bpf2bpf and tailcalls. */
3367 bool bpf_jit_supports_subprog_tailcalls(void)
3368 {
3369         return true;
3370 }
3371
3372 void bpf_jit_free(struct bpf_prog *prog)
3373 {
3374         if (prog->jited) {
3375                 struct x64_jit_data *jit_data = prog->aux->jit_data;
3376                 struct bpf_binary_header *hdr;
3377
3378                 /*
3379                  * If we fail the final pass of JIT (from jit_subprogs),
3380                  * the program may not be finalized yet. Call finalize here
3381                  * before freeing it.
3382                  */
3383                 if (jit_data) {
3384                         bpf_jit_binary_pack_finalize(prog, jit_data->header,
3385                                                      jit_data->rw_header);
3386                         kvfree(jit_data->addrs);
3387                         kfree(jit_data);
3388                 }
3389                 prog->bpf_func = (void *)prog->bpf_func - cfi_get_offset();
3390                 hdr = bpf_jit_binary_pack_hdr(prog);
3391                 bpf_jit_binary_pack_free(hdr, NULL);
3392                 WARN_ON_ONCE(!bpf_prog_kallsyms_verify_off(prog));
3393         }
3394
3395         bpf_prog_unlock_free(prog);
3396 }
3397
3398 bool bpf_jit_supports_exceptions(void)
3399 {
3400         /* We unwind through both kernel frames (starting from within bpf_throw
3401          * call) and BPF frames. Therefore we require ORC unwinder to be enabled
3402          * to walk kernel frames and reach BPF frames in the stack trace.
3403          */
3404         return IS_ENABLED(CONFIG_UNWINDER_ORC);
3405 }
3406
3407 void arch_bpf_stack_walk(bool (*consume_fn)(void *cookie, u64 ip, u64 sp, u64 bp), void *cookie)
3408 {
3409 #if defined(CONFIG_UNWINDER_ORC)
3410         struct unwind_state state;
3411         unsigned long addr;
3412
3413         for (unwind_start(&state, current, NULL, NULL); !unwind_done(&state);
3414              unwind_next_frame(&state)) {
3415                 addr = unwind_get_return_address(&state);
3416                 if (!addr || !consume_fn(cookie, (u64)addr, (u64)state.sp, (u64)state.bp))
3417                         break;
3418         }
3419         return;
3420 #endif
3421         WARN(1, "verification of programs using bpf_throw should have failed\n");
3422 }
3423
3424 void bpf_arch_poke_desc_update(struct bpf_jit_poke_descriptor *poke,
3425                                struct bpf_prog *new, struct bpf_prog *old)
3426 {
3427         u8 *old_addr, *new_addr, *old_bypass_addr;
3428         int ret;
3429
3430         old_bypass_addr = old ? NULL : poke->bypass_addr;
3431         old_addr = old ? (u8 *)old->bpf_func + poke->adj_off : NULL;
3432         new_addr = new ? (u8 *)new->bpf_func + poke->adj_off : NULL;
3433
3434         /*
3435          * On program loading or teardown, the program's kallsym entry
3436          * might not be in place, so we use __bpf_arch_text_poke to skip
3437          * the kallsyms check.
3438          */
3439         if (new) {
3440                 ret = __bpf_arch_text_poke(poke->tailcall_target,
3441                                            BPF_MOD_JUMP,
3442                                            old_addr, new_addr);
3443                 BUG_ON(ret < 0);
3444                 if (!old) {
3445                         ret = __bpf_arch_text_poke(poke->tailcall_bypass,
3446                                                    BPF_MOD_JUMP,
3447                                                    poke->bypass_addr,
3448                                                    NULL);
3449                         BUG_ON(ret < 0);
3450                 }
3451         } else {
3452                 ret = __bpf_arch_text_poke(poke->tailcall_bypass,
3453                                            BPF_MOD_JUMP,
3454                                            old_bypass_addr,
3455                                            poke->bypass_addr);
3456                 BUG_ON(ret < 0);
3457                 /* let other CPUs finish the execution of program
3458                  * so that it will not possible to expose them
3459                  * to invalid nop, stack unwind, nop state
3460                  */
3461                 if (!ret)
3462                         synchronize_rcu();
3463                 ret = __bpf_arch_text_poke(poke->tailcall_target,
3464                                            BPF_MOD_JUMP,
3465                                            old_addr, NULL);
3466                 BUG_ON(ret < 0);
3467         }
3468 }
3469
3470 bool bpf_jit_supports_arena(void)
3471 {
3472         return true;
3473 }
3474
3475 bool bpf_jit_supports_ptr_xchg(void)
3476 {
3477         return true;
3478 }