bpf: arm64: add JIT support for multi-function programs
[sfrench/cifs-2.6.git] / arch / arm64 / net / bpf_jit_comp.c
index ba38d403abb2fc92d8ea6ae9a6c3c38e70979062..396490cf7316b6179a2ab3ccf83a2eb302841c9f 100644 (file)
@@ -99,6 +99,20 @@ static inline void emit_a64_mov_i64(const int reg, const u64 val,
        }
 }
 
+static inline void emit_addr_mov_i64(const int reg, const u64 val,
+                                    struct jit_ctx *ctx)
+{
+       u64 tmp = val;
+       int shift = 0;
+
+       emit(A64_MOVZ(1, reg, tmp & 0xffff, shift), ctx);
+       for (;shift < 48;) {
+               tmp >>= 16;
+               shift += 16;
+               emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
+       }
+}
+
 static inline void emit_a64_mov_i(const int is64, const int reg,
                                  const s32 val, struct jit_ctx *ctx)
 {
@@ -603,7 +617,10 @@ emit_cond_jmp:
                const u8 r0 = bpf2a64[BPF_REG_0];
                const u64 func = (u64)__bpf_call_base + imm;
 
-               emit_a64_mov_i64(tmp, func, ctx);
+               if (ctx->prog->is_func)
+                       emit_addr_mov_i64(tmp, func, ctx);
+               else
+                       emit_a64_mov_i64(tmp, func, ctx);
                emit(A64_BLR(tmp), ctx);
                emit(A64_MOV(1, r0, A64_R(0)), ctx);
                break;
@@ -835,16 +852,24 @@ static inline void bpf_flush_icache(void *start, void *end)
        flush_icache_range((unsigned long)start, (unsigned long)end);
 }
 
+struct arm64_jit_data {
+       struct bpf_binary_header *header;
+       u8 *image;
+       struct jit_ctx ctx;
+};
+
 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 {
        struct bpf_prog *tmp, *orig_prog = prog;
        struct bpf_binary_header *header;
+       struct arm64_jit_data *jit_data;
        bool tmp_blinded = false;
+       bool extra_pass = false;
        struct jit_ctx ctx;
        int image_size;
        u8 *image_ptr;
 
-       if (!bpf_jit_enable)
+       if (!prog->jit_requested)
                return orig_prog;
 
        tmp = bpf_jit_blind_constants(prog);
@@ -858,13 +883,29 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                prog = tmp;
        }
 
+       jit_data = prog->aux->jit_data;
+       if (!jit_data) {
+               jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
+               if (!jit_data) {
+                       prog = orig_prog;
+                       goto out;
+               }
+               prog->aux->jit_data = jit_data;
+       }
+       if (jit_data->ctx.offset) {
+               ctx = jit_data->ctx;
+               image_ptr = jit_data->image;
+               header = jit_data->header;
+               extra_pass = true;
+               goto skip_init_ctx;
+       }
        memset(&ctx, 0, sizeof(ctx));
        ctx.prog = prog;
 
        ctx.offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
        if (ctx.offset == NULL) {
                prog = orig_prog;
-               goto out;
+               goto out_off;
        }
 
        /* 1. Initial fake pass to compute ctx->idx. */
@@ -895,6 +936,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
        /* 2. Now, the actual pass. */
 
        ctx.image = (__le32 *)image_ptr;
+skip_init_ctx:
        ctx.idx = 0;
 
        build_prologue(&ctx);
@@ -920,13 +962,31 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 
        bpf_flush_icache(header, ctx.image + ctx.idx);
 
-       bpf_jit_binary_lock_ro(header);
+       if (!prog->is_func || extra_pass) {
+               if (extra_pass && ctx.idx != jit_data->ctx.idx) {
+                       pr_err_once("multi-func JIT bug %d != %d\n",
+                                   ctx.idx, jit_data->ctx.idx);
+                       bpf_jit_binary_free(header);
+                       prog->bpf_func = NULL;
+                       prog->jited = 0;
+                       goto out_off;
+               }
+               bpf_jit_binary_lock_ro(header);
+       } else {
+               jit_data->ctx = ctx;
+               jit_data->image = image_ptr;
+               jit_data->header = header;
+       }
        prog->bpf_func = (void *)ctx.image;
        prog->jited = 1;
        prog->jited_len = image_size;
 
+       if (!prog->is_func || extra_pass) {
 out_off:
-       kfree(ctx.offset);
+               kfree(ctx.offset);
+               kfree(jit_data);
+               prog->aux->jit_data = NULL;
+       }
 out:
        if (tmp_blinded)
                bpf_jit_prog_release_other(prog, prog == orig_prog ?