bpf: mark registers in all frames after pkt/null checks
authorPaul Chaignon <paul.chaignon@orange.com>
Wed, 24 Apr 2019 19:50:42 +0000 (21:50 +0200)
committerAlexei Starovoitov <ast@kernel.org>
Fri, 26 Apr 2019 00:20:06 +0000 (17:20 -0700)
In case of a null check on a pointer inside a subprog, we should mark all
registers with this pointer as either safe or unknown, in both the current
and previous frames.  Currently, only spilled registers and registers in
the current frame are marked.  Packet bound checks in subprogs have the
same issue.  This patch fixes it to mark registers in previous frames as
well.

A good reproducer for null checks looks as follow:

1: ptr = bpf_map_lookup_elem(map, &key);
2: ret = subprog(ptr) {
3:   return ptr != NULL;
4: }
5: if (ret)
6:   value = *ptr;

With the above, the verifier will complain on line 6 because it sees ptr
as map_value_or_null despite the null check in subprog 1.

Note that this patch fixes another resulting bug when using
bpf_sk_release():

1: sk = bpf_sk_lookup_tcp(...);
2: subprog(sk) {
3:   if (sk)
4:     bpf_sk_release(sk);
5: }
6: if (!sk)
7:   return 0;
8: return 1;

In the above, mark_ptr_or_null_regs will warn on line 6 because it will
try to free the reference state, even though it was already freed on
line 3.

Fixes: f4d7e40a5b71 ("bpf: introduce function calls (verification)")
Signed-off-by: Paul Chaignon <paul.chaignon@orange.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
kernel/bpf/verifier.c

index 6c5a41f7f33856d79f641c57767c7c093ec2a831..09d5d972c9ff20c9fe69ca4ffbbbb185a998b56d 100644 (file)
@@ -4138,15 +4138,35 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
        return 0;
 }
 
+static void __find_good_pkt_pointers(struct bpf_func_state *state,
+                                    struct bpf_reg_state *dst_reg,
+                                    enum bpf_reg_type type, u16 new_range)
+{
+       struct bpf_reg_state *reg;
+       int i;
+
+       for (i = 0; i < MAX_BPF_REG; i++) {
+               reg = &state->regs[i];
+               if (reg->type == type && reg->id == dst_reg->id)
+                       /* keep the maximum range already checked */
+                       reg->range = max(reg->range, new_range);
+       }
+
+       bpf_for_each_spilled_reg(i, state, reg) {
+               if (!reg)
+                       continue;
+               if (reg->type == type && reg->id == dst_reg->id)
+                       reg->range = max(reg->range, new_range);
+       }
+}
+
 static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
                                   struct bpf_reg_state *dst_reg,
                                   enum bpf_reg_type type,
                                   bool range_right_open)
 {
-       struct bpf_func_state *state = vstate->frame[vstate->curframe];
-       struct bpf_reg_state *regs = state->regs, *reg;
        u16 new_range;
-       int i, j;
+       int i;
 
        if (dst_reg->off < 0 ||
            (dst_reg->off == 0 && range_right_open))
@@ -4211,20 +4231,9 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
         * the range won't allow anything.
         * dst_reg->off is known < MAX_PACKET_OFF, therefore it fits in a u16.
         */
-       for (i = 0; i < MAX_BPF_REG; i++)
-               if (regs[i].type == type && regs[i].id == dst_reg->id)
-                       /* keep the maximum range already checked */
-                       regs[i].range = max(regs[i].range, new_range);
-
-       for (j = 0; j <= vstate->curframe; j++) {
-               state = vstate->frame[j];
-               bpf_for_each_spilled_reg(i, state, reg) {
-                       if (!reg)
-                               continue;
-                       if (reg->type == type && reg->id == dst_reg->id)
-                               reg->range = max(reg->range, new_range);
-               }
-       }
+       for (i = 0; i <= vstate->curframe; i++)
+               __find_good_pkt_pointers(vstate->frame[i], dst_reg, type,
+                                        new_range);
 }
 
 /* compute branch direction of the expression "if (reg opcode val) goto target;"
@@ -4698,6 +4707,22 @@ static void mark_ptr_or_null_reg(struct bpf_func_state *state,
        }
 }
 
+static void __mark_ptr_or_null_regs(struct bpf_func_state *state, u32 id,
+                                   bool is_null)
+{
+       struct bpf_reg_state *reg;
+       int i;
+
+       for (i = 0; i < MAX_BPF_REG; i++)
+               mark_ptr_or_null_reg(state, &state->regs[i], id, is_null);
+
+       bpf_for_each_spilled_reg(i, state, reg) {
+               if (!reg)
+                       continue;
+               mark_ptr_or_null_reg(state, reg, id, is_null);
+       }
+}
+
 /* The logic is similar to find_good_pkt_pointers(), both could eventually
  * be folded together at some point.
  */
@@ -4705,10 +4730,10 @@ static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
                                  bool is_null)
 {
        struct bpf_func_state *state = vstate->frame[vstate->curframe];
-       struct bpf_reg_state *reg, *regs = state->regs;
+       struct bpf_reg_state *regs = state->regs;
        u32 ref_obj_id = regs[regno].ref_obj_id;
        u32 id = regs[regno].id;
-       int i, j;
+       int i;
 
        if (ref_obj_id && ref_obj_id == id && is_null)
                /* regs[regno] is in the " == NULL" branch.
@@ -4717,17 +4742,8 @@ static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
                 */
                WARN_ON_ONCE(release_reference_state(state, id));
 
-       for (i = 0; i < MAX_BPF_REG; i++)
-               mark_ptr_or_null_reg(state, &regs[i], id, is_null);
-
-       for (j = 0; j <= vstate->curframe; j++) {
-               state = vstate->frame[j];
-               bpf_for_each_spilled_reg(i, state, reg) {
-                       if (!reg)
-                               continue;
-                       mark_ptr_or_null_reg(state, reg, id, is_null);
-               }
-       }
+       for (i = 0; i <= vstate->curframe; i++)
+               __mark_ptr_or_null_regs(vstate->frame[i], id, is_null);
 }
 
 static bool try_match_pkt_pointers(const struct bpf_insn *insn,