bpf: tcp: Allow bpf prog to write and parse TCP header option
[sfrench/cifs-2.6.git] / net / ipv4 / tcp_input.c
index 184ea556f50e35141a4be5940c692db41e09f464..319cc7fd5117c44d507dad352cb3ac30519a76c0 100644 (file)
@@ -138,6 +138,69 @@ void clean_acked_data_flush(void)
 EXPORT_SYMBOL_GPL(clean_acked_data_flush);
 #endif
 
+#ifdef CONFIG_CGROUP_BPF
+static void bpf_skops_parse_hdr(struct sock *sk, struct sk_buff *skb)
+{
+       bool unknown_opt = tcp_sk(sk)->rx_opt.saw_unknown &&
+               BPF_SOCK_OPS_TEST_FLAG(tcp_sk(sk),
+                                      BPF_SOCK_OPS_PARSE_UNKNOWN_HDR_OPT_CB_FLAG);
+       bool parse_all_opt = BPF_SOCK_OPS_TEST_FLAG(tcp_sk(sk),
+                                                   BPF_SOCK_OPS_PARSE_ALL_HDR_OPT_CB_FLAG);
+       struct bpf_sock_ops_kern sock_ops;
+
+       if (likely(!unknown_opt && !parse_all_opt))
+               return;
+
+       /* The skb will be handled in the
+        * bpf_skops_established() or
+        * bpf_skops_write_hdr_opt().
+        */
+       switch (sk->sk_state) {
+       case TCP_SYN_RECV:
+       case TCP_SYN_SENT:
+       case TCP_LISTEN:
+               return;
+       }
+
+       sock_owned_by_me(sk);
+
+       memset(&sock_ops, 0, offsetof(struct bpf_sock_ops_kern, temp));
+       sock_ops.op = BPF_SOCK_OPS_PARSE_HDR_OPT_CB;
+       sock_ops.is_fullsock = 1;
+       sock_ops.sk = sk;
+       bpf_skops_init_skb(&sock_ops, skb, tcp_hdrlen(skb));
+
+       BPF_CGROUP_RUN_PROG_SOCK_OPS(&sock_ops);
+}
+
+static void bpf_skops_established(struct sock *sk, int bpf_op,
+                                 struct sk_buff *skb)
+{
+       struct bpf_sock_ops_kern sock_ops;
+
+       sock_owned_by_me(sk);
+
+       memset(&sock_ops, 0, offsetof(struct bpf_sock_ops_kern, temp));
+       sock_ops.op = bpf_op;
+       sock_ops.is_fullsock = 1;
+       sock_ops.sk = sk;
+       /* sk with TCP_REPAIR_ON does not have skb in tcp_finish_connect */
+       if (skb)
+               bpf_skops_init_skb(&sock_ops, skb, tcp_hdrlen(skb));
+
+       BPF_CGROUP_RUN_PROG_SOCK_OPS(&sock_ops);
+}
+#else
+static void bpf_skops_parse_hdr(struct sock *sk, struct sk_buff *skb)
+{
+}
+
+static void bpf_skops_established(struct sock *sk, int bpf_op,
+                                 struct sk_buff *skb)
+{
+}
+#endif
+
 static void tcp_gro_dev_warn(struct sock *sk, const struct sk_buff *skb,
                             unsigned int len)
 {
@@ -3801,7 +3864,7 @@ static void tcp_parse_fastopen_option(int len, const unsigned char *cookie,
        foc->exp = exp_opt;
 }
 
-static void smc_parse_options(const struct tcphdr *th,
+static bool smc_parse_options(const struct tcphdr *th,
                              struct tcp_options_received *opt_rx,
                              const unsigned char *ptr,
                              int opsize)
@@ -3810,10 +3873,13 @@ static void smc_parse_options(const struct tcphdr *th,
        if (static_branch_unlikely(&tcp_have_smc)) {
                if (th->syn && !(opsize & 1) &&
                    opsize >= TCPOLEN_EXP_SMC_BASE &&
-                   get_unaligned_be32(ptr) == TCPOPT_SMC_MAGIC)
+                   get_unaligned_be32(ptr) == TCPOPT_SMC_MAGIC) {
                        opt_rx->smc_ok = 1;
+                       return true;
+               }
        }
 #endif
+       return false;
 }
 
 /* Try to parse the MSS option from the TCP header. Return 0 on failure, clamped
@@ -3874,6 +3940,7 @@ void tcp_parse_options(const struct net *net,
 
        ptr = (const unsigned char *)(th + 1);
        opt_rx->saw_tstamp = 0;
+       opt_rx->saw_unknown = 0;
 
        while (length > 0) {
                int opcode = *ptr++;
@@ -3964,15 +4031,21 @@ void tcp_parse_options(const struct net *net,
                                 */
                                if (opsize >= TCPOLEN_EXP_FASTOPEN_BASE &&
                                    get_unaligned_be16(ptr) ==
-                                   TCPOPT_FASTOPEN_MAGIC)
+                                   TCPOPT_FASTOPEN_MAGIC) {
                                        tcp_parse_fastopen_option(opsize -
                                                TCPOLEN_EXP_FASTOPEN_BASE,
                                                ptr + 2, th->syn, foc, true);
-                               else
-                                       smc_parse_options(th, opt_rx, ptr,
-                                                         opsize);
+                                       break;
+                               }
+
+                               if (smc_parse_options(th, opt_rx, ptr, opsize))
+                                       break;
+
+                               opt_rx->saw_unknown = 1;
                                break;
 
+                       default:
+                               opt_rx->saw_unknown = 1;
                        }
                        ptr += opsize-2;
                        length -= opsize;
@@ -5590,6 +5663,8 @@ syn_challenge:
                goto discard;
        }
 
+       bpf_skops_parse_hdr(sk, skb);
+
        return true;
 
 discard:
@@ -5798,7 +5873,7 @@ discard:
 }
 EXPORT_SYMBOL(tcp_rcv_established);
 
-void tcp_init_transfer(struct sock *sk, int bpf_op)
+void tcp_init_transfer(struct sock *sk, int bpf_op, struct sk_buff *skb)
 {
        struct inet_connection_sock *icsk = inet_csk(sk);
        struct tcp_sock *tp = tcp_sk(sk);
@@ -5819,7 +5894,7 @@ void tcp_init_transfer(struct sock *sk, int bpf_op)
                tp->snd_cwnd = tcp_init_cwnd(tp, __sk_dst_get(sk));
        tp->snd_cwnd_stamp = tcp_jiffies32;
 
-       tcp_call_bpf(sk, bpf_op, 0, NULL);
+       bpf_skops_established(sk, bpf_op, skb);
        tcp_init_congestion_control(sk);
        tcp_init_buffer_space(sk);
 }
@@ -5838,7 +5913,7 @@ void tcp_finish_connect(struct sock *sk, struct sk_buff *skb)
                sk_mark_napi_id(sk, skb);
        }
 
-       tcp_init_transfer(sk, BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB);
+       tcp_init_transfer(sk, BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB, skb);
 
        /* Prevent spurious tcp_cwnd_restart() on first data
         * packet.
@@ -6310,7 +6385,8 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb)
                } else {
                        tcp_try_undo_spurious_syn(sk);
                        tp->retrans_stamp = 0;
-                       tcp_init_transfer(sk, BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB);
+                       tcp_init_transfer(sk, BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB,
+                                         skb);
                        WRITE_ONCE(tp->copied_seq, tp->rcv_nxt);
                }
                smp_mb();
@@ -6599,13 +6675,15 @@ static void tcp_reqsk_record_syn(const struct sock *sk,
 {
        if (tcp_sk(sk)->save_syn) {
                u32 len = skb_network_header_len(skb) + tcp_hdrlen(skb);
-               u32 *copy;
-
-               copy = kmalloc(len + sizeof(u32), GFP_ATOMIC);
-               if (copy) {
-                       copy[0] = len;
-                       memcpy(&copy[1], skb_network_header(skb), len);
-                       req->saved_syn = copy;
+               struct saved_syn *saved_syn;
+
+               saved_syn = kmalloc(struct_size(saved_syn, data, len),
+                                   GFP_ATOMIC);
+               if (saved_syn) {
+                       saved_syn->network_hdrlen = skb_network_header_len(skb);
+                       saved_syn->tcp_hdrlen = tcp_hdrlen(skb);
+                       memcpy(saved_syn->data, skb_network_header(skb), len);
+                       req->saved_syn = saved_syn;
                }
        }
 }
@@ -6752,7 +6830,7 @@ int tcp_conn_request(struct request_sock_ops *rsk_ops,
        }
        if (fastopen_sk) {
                af_ops->send_synack(fastopen_sk, dst, &fl, req,
-                                   &foc, TCP_SYNACK_FASTOPEN);
+                                   &foc, TCP_SYNACK_FASTOPEN, skb);
                /* Add the child socket directly into the accept queue */
                if (!inet_csk_reqsk_queue_add(sk, req, fastopen_sk)) {
                        reqsk_fastopen_remove(fastopen_sk, req, false);
@@ -6770,7 +6848,8 @@ int tcp_conn_request(struct request_sock_ops *rsk_ops,
                                tcp_timeout_init((struct sock *)req));
                af_ops->send_synack(sk, dst, &fl, req, &foc,
                                    !want_cookie ? TCP_SYNACK_NORMAL :
-                                                  TCP_SYNACK_COOKIE);
+                                                  TCP_SYNACK_COOKIE,
+                                   skb);
                if (want_cookie) {
                        reqsk_free(req);
                        return 0;