Merge tag 'kbuild-misc-v4.16' of git://git.kernel.org/pub/scm/linux/kernel/git/masahi...
[sfrench/cifs-2.6.git] / arch / arm64 / net / bpf_jit_comp.c
index bb32f7f6dd0f967fd1435ad539d922b69650fb2b..1d4f1da7c58f8d51371947523e91649915e4320d 100644 (file)
@@ -31,8 +31,6 @@
 
 #include "bpf_jit.h"
 
-int bpf_jit_enable __read_mostly;
-
 #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
 #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
 #define TCALL_CNT (MAX_BPF_JIT_REG + 2)
@@ -99,6 +97,20 @@ static inline void emit_a64_mov_i64(const int reg, const u64 val,
        }
 }
 
+static inline void emit_addr_mov_i64(const int reg, const u64 val,
+                                    struct jit_ctx *ctx)
+{
+       u64 tmp = val;
+       int shift = 0;
+
+       emit(A64_MOVZ(1, reg, tmp & 0xffff, shift), ctx);
+       for (;shift < 48;) {
+               tmp >>= 16;
+               shift += 16;
+               emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
+       }
+}
+
 static inline void emit_a64_mov_i(const int is64, const int reg,
                                  const s32 val, struct jit_ctx *ctx)
 {
@@ -378,18 +390,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx)
        case BPF_ALU64 | BPF_DIV | BPF_X:
        case BPF_ALU | BPF_MOD | BPF_X:
        case BPF_ALU64 | BPF_MOD | BPF_X:
-       {
-               const u8 r0 = bpf2a64[BPF_REG_0];
-
-               /* if (src == 0) return 0 */
-               jmp_offset = 3; /* skip ahead to else path */
-               check_imm19(jmp_offset);
-               emit(A64_CBNZ(is64, src, jmp_offset), ctx);
-               emit(A64_MOVZ(1, r0, 0, 0), ctx);
-               jmp_offset = epilogue_offset(ctx);
-               check_imm26(jmp_offset);
-               emit(A64_B(jmp_offset), ctx);
-               /* else */
                switch (BPF_OP(code)) {
                case BPF_DIV:
                        emit(A64_UDIV(is64, dst, dst, src), ctx);
@@ -401,7 +401,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx)
                        break;
                }
                break;
-       }
        case BPF_ALU | BPF_LSH | BPF_X:
        case BPF_ALU64 | BPF_LSH | BPF_X:
                emit(A64_LSLV(is64, dst, dst, src), ctx);
@@ -605,7 +604,10 @@ emit_cond_jmp:
                const u8 r0 = bpf2a64[BPF_REG_0];
                const u64 func = (u64)__bpf_call_base + imm;
 
-               emit_a64_mov_i64(tmp, func, ctx);
+               if (ctx->prog->is_func)
+                       emit_addr_mov_i64(tmp, func, ctx);
+               else
+                       emit_a64_mov_i64(tmp, func, ctx);
                emit(A64_BLR(tmp), ctx);
                emit(A64_MOV(1, r0, A64_R(0)), ctx);
                break;
@@ -837,16 +839,24 @@ static inline void bpf_flush_icache(void *start, void *end)
        flush_icache_range((unsigned long)start, (unsigned long)end);
 }
 
+struct arm64_jit_data {
+       struct bpf_binary_header *header;
+       u8 *image;
+       struct jit_ctx ctx;
+};
+
 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 {
        struct bpf_prog *tmp, *orig_prog = prog;
        struct bpf_binary_header *header;
+       struct arm64_jit_data *jit_data;
        bool tmp_blinded = false;
+       bool extra_pass = false;
        struct jit_ctx ctx;
        int image_size;
        u8 *image_ptr;
 
-       if (!bpf_jit_enable)
+       if (!prog->jit_requested)
                return orig_prog;
 
        tmp = bpf_jit_blind_constants(prog);
@@ -860,13 +870,30 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                prog = tmp;
        }
 
+       jit_data = prog->aux->jit_data;
+       if (!jit_data) {
+               jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
+               if (!jit_data) {
+                       prog = orig_prog;
+                       goto out;
+               }
+               prog->aux->jit_data = jit_data;
+       }
+       if (jit_data->ctx.offset) {
+               ctx = jit_data->ctx;
+               image_ptr = jit_data->image;
+               header = jit_data->header;
+               extra_pass = true;
+               image_size = sizeof(u32) * ctx.idx;
+               goto skip_init_ctx;
+       }
        memset(&ctx, 0, sizeof(ctx));
        ctx.prog = prog;
 
        ctx.offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
        if (ctx.offset == NULL) {
                prog = orig_prog;
-               goto out;
+               goto out_off;
        }
 
        /* 1. Initial fake pass to compute ctx->idx. */
@@ -897,6 +924,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
        /* 2. Now, the actual pass. */
 
        ctx.image = (__le32 *)image_ptr;
+skip_init_ctx:
        ctx.idx = 0;
 
        build_prologue(&ctx);
@@ -922,13 +950,31 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 
        bpf_flush_icache(header, ctx.image + ctx.idx);
 
-       bpf_jit_binary_lock_ro(header);
+       if (!prog->is_func || extra_pass) {
+               if (extra_pass && ctx.idx != jit_data->ctx.idx) {
+                       pr_err_once("multi-func JIT bug %d != %d\n",
+                                   ctx.idx, jit_data->ctx.idx);
+                       bpf_jit_binary_free(header);
+                       prog->bpf_func = NULL;
+                       prog->jited = 0;
+                       goto out_off;
+               }
+               bpf_jit_binary_lock_ro(header);
+       } else {
+               jit_data->ctx = ctx;
+               jit_data->image = image_ptr;
+               jit_data->header = header;
+       }
        prog->bpf_func = (void *)ctx.image;
        prog->jited = 1;
        prog->jited_len = image_size;
 
+       if (!prog->is_func || extra_pass) {
 out_off:
-       kfree(ctx.offset);
+               kfree(ctx.offset);
+               kfree(jit_data);
+               prog->aux->jit_data = NULL;
+       }
 out:
        if (tmp_blinded)
                bpf_jit_prog_release_other(prog, prog == orig_prog ?