bpf, arm64: Support struct arguments in the BPF trampoline
[sfrench/cifs-2.6.git] / arch / arm64 / net / bpf_jit_comp.c
index b26da8efa616ec133b23b0d4cacd54192ea7c6af..145b540ec34ffd1dfcb98144f25ab52e2c420dc2 100644 (file)
@@ -1731,21 +1731,21 @@ static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
        }
 }
 
-static void save_args(struct jit_ctx *ctx, int args_off, int nargs)
+static void save_args(struct jit_ctx *ctx, int args_off, int nregs)
 {
        int i;
 
-       for (i = 0; i < nargs; i++) {
+       for (i = 0; i < nregs; i++) {
                emit(A64_STR64I(i, A64_SP, args_off), ctx);
                args_off += 8;
        }
 }
 
-static void restore_args(struct jit_ctx *ctx, int args_off, int nargs)
+static void restore_args(struct jit_ctx *ctx, int args_off, int nregs)
 {
        int i;
 
-       for (i = 0; i < nargs; i++) {
+       for (i = 0; i < nregs; i++) {
                emit(A64_LDR64I(i, A64_SP, args_off), ctx);
                args_off += 8;
        }
@@ -1764,7 +1764,7 @@ static void restore_args(struct jit_ctx *ctx, int args_off, int nargs)
  */
 static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
                              struct bpf_tramp_links *tlinks, void *orig_call,
-                             int nargs, u32 flags)
+                             int nregs, u32 flags)
 {
        int i;
        int stack_size;
@@ -1772,7 +1772,7 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
        int regs_off;
        int retval_off;
        int args_off;
-       int nargs_off;
+       int nregs_off;
        int ip_off;
        int run_ctx_off;
        struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
@@ -1795,11 +1795,11 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
         * SP + retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
         *                                        BPF_TRAMP_F_RET_FENTRY_RET
         *
-        *                  [ argN              ]
+        *                  [ arg reg N         ]
         *                  [ ...               ]
-        * SP + args_off    [ arg1              ]
+        * SP + args_off    [ arg reg 1         ]
         *
-        * SP + nargs_off   [ args count        ]
+        * SP + nregs_off   [ arg regs count    ]
         *
         * SP + ip_off      [ traced function   ] BPF_TRAMP_F_IP_ARG flag
         *
@@ -1816,13 +1816,13 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
        if (flags & BPF_TRAMP_F_IP_ARG)
                stack_size += 8;
 
-       nargs_off = stack_size;
+       nregs_off = stack_size;
        /* room for args count */
        stack_size += 8;
 
        args_off = stack_size;
        /* room for args */
-       stack_size += nargs * 8;
+       stack_size += nregs * 8;
 
        /* room for return value */
        retval_off = stack_size;
@@ -1865,12 +1865,12 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
                emit(A64_STR64I(A64_R(10), A64_SP, ip_off), ctx);
        }
 
-       /* save args count*/
-       emit(A64_MOVZ(1, A64_R(10), nargs, 0), ctx);
-       emit(A64_STR64I(A64_R(10), A64_SP, nargs_off), ctx);
+       /* save arg regs count*/
+       emit(A64_MOVZ(1, A64_R(10), nregs, 0), ctx);
+       emit(A64_STR64I(A64_R(10), A64_SP, nregs_off), ctx);
 
-       /* save args */
-       save_args(ctx, args_off, nargs);
+       /* save arg regs */
+       save_args(ctx, args_off, nregs);
 
        /* save callee saved registers */
        emit(A64_STR64I(A64_R(19), A64_SP, regs_off), ctx);
@@ -1897,7 +1897,7 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
        }
 
        if (flags & BPF_TRAMP_F_CALL_ORIG) {
-               restore_args(ctx, args_off, nargs);
+               restore_args(ctx, args_off, nregs);
                /* call original func */
                emit(A64_LDR64I(A64_R(10), A64_SP, retaddr_off), ctx);
                emit(A64_ADR(A64_LR, AARCH64_INSN_SIZE * 2), ctx);
@@ -1926,7 +1926,7 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
        }
 
        if (flags & BPF_TRAMP_F_RESTORE_REGS)
-               restore_args(ctx, args_off, nargs);
+               restore_args(ctx, args_off, nregs);
 
        /* restore callee saved register x19 and x20 */
        emit(A64_LDR64I(A64_R(19), A64_SP, regs_off), ctx);
@@ -1967,24 +1967,25 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
                                void *orig_call)
 {
        int i, ret;
-       int nargs = m->nr_args;
+       int nregs = m->nr_args;
        int max_insns = ((long)image_end - (long)image) / AARCH64_INSN_SIZE;
        struct jit_ctx ctx = {
                .image = NULL,
                .idx = 0,
        };
 
-       /* the first 8 arguments are passed by registers */
-       if (nargs > 8)
-               return -ENOTSUPP;
-
-       /* don't support struct argument */
+       /* extra registers needed for struct argument */
        for (i = 0; i < MAX_BPF_FUNC_ARGS; i++) {
+               /* The arg_size is at most 16 bytes, enforced by the verifier. */
                if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
-                       return -ENOTSUPP;
+                       nregs += (m->arg_size[i] + 7) / 8 - 1;
        }
 
-       ret = prepare_trampoline(&ctx, im, tlinks, orig_call, nargs, flags);
+       /* the first 8 registers are used for arguments */
+       if (nregs > 8)
+               return -ENOTSUPP;
+
+       ret = prepare_trampoline(&ctx, im, tlinks, orig_call, nregs, flags);
        if (ret < 0)
                return ret;
 
@@ -1995,7 +1996,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
        ctx.idx = 0;
 
        jit_fill_hole(image, (unsigned int)(image_end - image));
-       ret = prepare_trampoline(&ctx, im, tlinks, orig_call, nargs, flags);
+       ret = prepare_trampoline(&ctx, im, tlinks, orig_call, nregs, flags);
 
        if (ret > 0 && validate_code(&ctx) < 0)
                ret = -EINVAL;