Merge tag 'net-5.13-rc1' of git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[sfrench/cifs-2.6.git] / net / ipv6 / seg6_local.c
index bd7140885e60ec06c269fd7b1b014dc832aedf16..4ff38cb08f4bb227117482a93a597d72f9e19bf4 100644 (file)
@@ -93,6 +93,35 @@ struct seg6_end_dt_info {
        int hdrlen;
 };
 
+struct pcpu_seg6_local_counters {
+       u64_stats_t packets;
+       u64_stats_t bytes;
+       u64_stats_t errors;
+
+       struct u64_stats_sync syncp;
+};
+
+/* This struct groups all the SRv6 Behavior counters supported so far.
+ *
+ * put_nla_counters() makes use of this data structure to collect all counter
+ * values after the per-CPU counter evaluation has been performed.
+ * Finally, each counter value (in seg6_local_counters) is stored in the
+ * corresponding netlink attribute and sent to user space.
+ *
+ * NB: we don't want to expose this structure to user space!
+ */
+struct seg6_local_counters {
+       __u64 packets;
+       __u64 bytes;
+       __u64 errors;
+};
+
+#define seg6_local_alloc_pcpu_counters(__gfp)                          \
+       __netdev_alloc_pcpu_stats(struct pcpu_seg6_local_counters,      \
+                                 ((__gfp) | __GFP_ZERO))
+
+#define SEG6_F_LOCAL_COUNTERS  SEG6_F_ATTR(SEG6_LOCAL_COUNTERS)
+
 struct seg6_local_lwt {
        int action;
        struct ipv6_sr_hdr *srh;
@@ -105,6 +134,7 @@ struct seg6_local_lwt {
 #ifdef CONFIG_NET_L3_MASTER_DEV
        struct seg6_end_dt_info dt_info;
 #endif
+       struct pcpu_seg6_local_counters __percpu *pcpu_counters;
 
        int headroom;
        struct seg6_action_desc *desc;
@@ -878,36 +908,43 @@ static struct seg6_action_desc seg6_action_table[] = {
        {
                .action         = SEG6_LOCAL_ACTION_END,
                .attrs          = 0,
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
                .input          = input_action_end,
        },
        {
                .action         = SEG6_LOCAL_ACTION_END_X,
                .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH6),
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
                .input          = input_action_end_x,
        },
        {
                .action         = SEG6_LOCAL_ACTION_END_T,
                .attrs          = SEG6_F_ATTR(SEG6_LOCAL_TABLE),
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
                .input          = input_action_end_t,
        },
        {
                .action         = SEG6_LOCAL_ACTION_END_DX2,
                .attrs          = SEG6_F_ATTR(SEG6_LOCAL_OIF),
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
                .input          = input_action_end_dx2,
        },
        {
                .action         = SEG6_LOCAL_ACTION_END_DX6,
                .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH6),
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
                .input          = input_action_end_dx6,
        },
        {
                .action         = SEG6_LOCAL_ACTION_END_DX4,
                .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH4),
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
                .input          = input_action_end_dx4,
        },
        {
                .action         = SEG6_LOCAL_ACTION_END_DT4,
                .attrs          = SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
 #ifdef CONFIG_NET_L3_MASTER_DEV
                .input          = input_action_end_dt4,
                .slwt_ops       = {
@@ -919,30 +956,35 @@ static struct seg6_action_desc seg6_action_table[] = {
                .action         = SEG6_LOCAL_ACTION_END_DT6,
 #ifdef CONFIG_NET_L3_MASTER_DEV
                .attrs          = 0,
-               .optattrs       = SEG6_F_ATTR(SEG6_LOCAL_TABLE) |
+               .optattrs       = SEG6_F_LOCAL_COUNTERS         |
+                                 SEG6_F_ATTR(SEG6_LOCAL_TABLE) |
                                  SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
                .slwt_ops       = {
                                        .build_state = seg6_end_dt6_build,
                                  },
 #else
                .attrs          = SEG6_F_ATTR(SEG6_LOCAL_TABLE),
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
 #endif
                .input          = input_action_end_dt6,
        },
        {
                .action         = SEG6_LOCAL_ACTION_END_B6,
                .attrs          = SEG6_F_ATTR(SEG6_LOCAL_SRH),
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
                .input          = input_action_end_b6,
        },
        {
                .action         = SEG6_LOCAL_ACTION_END_B6_ENCAP,
                .attrs          = SEG6_F_ATTR(SEG6_LOCAL_SRH),
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
                .input          = input_action_end_b6_encap,
                .static_headroom        = sizeof(struct ipv6hdr),
        },
        {
                .action         = SEG6_LOCAL_ACTION_END_BPF,
                .attrs          = SEG6_F_ATTR(SEG6_LOCAL_BPF),
+               .optattrs       = SEG6_F_LOCAL_COUNTERS,
                .input          = input_action_end_bpf,
        },
 
@@ -963,11 +1005,36 @@ static struct seg6_action_desc *__get_action_desc(int action)
        return NULL;
 }
 
+static bool seg6_lwtunnel_counters_enabled(struct seg6_local_lwt *slwt)
+{
+       return slwt->parsed_optattrs & SEG6_F_LOCAL_COUNTERS;
+}
+
+static void seg6_local_update_counters(struct seg6_local_lwt *slwt,
+                                      unsigned int len, int err)
+{
+       struct pcpu_seg6_local_counters *pcounters;
+
+       pcounters = this_cpu_ptr(slwt->pcpu_counters);
+       u64_stats_update_begin(&pcounters->syncp);
+
+       if (likely(!err)) {
+               u64_stats_inc(&pcounters->packets);
+               u64_stats_add(&pcounters->bytes, len);
+       } else {
+               u64_stats_inc(&pcounters->errors);
+       }
+
+       u64_stats_update_end(&pcounters->syncp);
+}
+
 static int seg6_local_input(struct sk_buff *skb)
 {
        struct dst_entry *orig_dst = skb_dst(skb);
        struct seg6_action_desc *desc;
        struct seg6_local_lwt *slwt;
+       unsigned int len = skb->len;
+       int rc;
 
        if (skb->protocol != htons(ETH_P_IPV6)) {
                kfree_skb(skb);
@@ -977,7 +1044,14 @@ static int seg6_local_input(struct sk_buff *skb)
        slwt = seg6_local_lwtunnel(orig_dst->lwtstate);
        desc = slwt->desc;
 
-       return desc->input(skb, slwt);
+       rc = desc->input(skb, slwt);
+
+       if (!seg6_lwtunnel_counters_enabled(slwt))
+               return rc;
+
+       seg6_local_update_counters(slwt, len, rc);
+
+       return rc;
 }
 
 static const struct nla_policy seg6_local_policy[SEG6_LOCAL_MAX + 1] = {
@@ -992,6 +1066,7 @@ static const struct nla_policy seg6_local_policy[SEG6_LOCAL_MAX + 1] = {
        [SEG6_LOCAL_IIF]        = { .type = NLA_U32 },
        [SEG6_LOCAL_OIF]        = { .type = NLA_U32 },
        [SEG6_LOCAL_BPF]        = { .type = NLA_NESTED },
+       [SEG6_LOCAL_COUNTERS]   = { .type = NLA_NESTED },
 };
 
 static int parse_nla_srh(struct nlattr **attrs, struct seg6_local_lwt *slwt)
@@ -1296,6 +1371,112 @@ static void destroy_attr_bpf(struct seg6_local_lwt *slwt)
                bpf_prog_put(slwt->bpf.prog);
 }
 
+static const struct
+nla_policy seg6_local_counters_policy[SEG6_LOCAL_CNT_MAX + 1] = {
+       [SEG6_LOCAL_CNT_PACKETS]        = { .type = NLA_U64 },
+       [SEG6_LOCAL_CNT_BYTES]          = { .type = NLA_U64 },
+       [SEG6_LOCAL_CNT_ERRORS]         = { .type = NLA_U64 },
+};
+
+static int parse_nla_counters(struct nlattr **attrs,
+                             struct seg6_local_lwt *slwt)
+{
+       struct pcpu_seg6_local_counters __percpu *pcounters;
+       struct nlattr *tb[SEG6_LOCAL_CNT_MAX + 1];
+       int ret;
+
+       ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_CNT_MAX,
+                                         attrs[SEG6_LOCAL_COUNTERS],
+                                         seg6_local_counters_policy, NULL);
+       if (ret < 0)
+               return ret;
+
+       /* basic support for SRv6 Behavior counters requires at least:
+        * packets, bytes and errors.
+        */
+       if (!tb[SEG6_LOCAL_CNT_PACKETS] || !tb[SEG6_LOCAL_CNT_BYTES] ||
+           !tb[SEG6_LOCAL_CNT_ERRORS])
+               return -EINVAL;
+
+       /* counters are always zero initialized */
+       pcounters = seg6_local_alloc_pcpu_counters(GFP_KERNEL);
+       if (!pcounters)
+               return -ENOMEM;
+
+       slwt->pcpu_counters = pcounters;
+
+       return 0;
+}
+
+static int seg6_local_fill_nla_counters(struct sk_buff *skb,
+                                       struct seg6_local_counters *counters)
+{
+       if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_PACKETS, counters->packets,
+                             SEG6_LOCAL_CNT_PAD))
+               return -EMSGSIZE;
+
+       if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_BYTES, counters->bytes,
+                             SEG6_LOCAL_CNT_PAD))
+               return -EMSGSIZE;
+
+       if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_ERRORS, counters->errors,
+                             SEG6_LOCAL_CNT_PAD))
+               return -EMSGSIZE;
+
+       return 0;
+}
+
+static int put_nla_counters(struct sk_buff *skb, struct seg6_local_lwt *slwt)
+{
+       struct seg6_local_counters counters = { 0, 0, 0 };
+       struct nlattr *nest;
+       int rc, i;
+
+       nest = nla_nest_start(skb, SEG6_LOCAL_COUNTERS);
+       if (!nest)
+               return -EMSGSIZE;
+
+       for_each_possible_cpu(i) {
+               struct pcpu_seg6_local_counters *pcounters;
+               u64 packets, bytes, errors;
+               unsigned int start;
+
+               pcounters = per_cpu_ptr(slwt->pcpu_counters, i);
+               do {
+                       start = u64_stats_fetch_begin_irq(&pcounters->syncp);
+
+                       packets = u64_stats_read(&pcounters->packets);
+                       bytes = u64_stats_read(&pcounters->bytes);
+                       errors = u64_stats_read(&pcounters->errors);
+
+               } while (u64_stats_fetch_retry_irq(&pcounters->syncp, start));
+
+               counters.packets += packets;
+               counters.bytes += bytes;
+               counters.errors += errors;
+       }
+
+       rc = seg6_local_fill_nla_counters(skb, &counters);
+       if (rc < 0) {
+               nla_nest_cancel(skb, nest);
+               return rc;
+       }
+
+       return nla_nest_end(skb, nest);
+}
+
+static int cmp_nla_counters(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
+{
+       /* a and b are equal if both have pcpu_counters set or not */
+       return (!!((unsigned long)a->pcpu_counters)) ^
+               (!!((unsigned long)b->pcpu_counters));
+}
+
+static void destroy_attr_counters(struct seg6_local_lwt *slwt)
+{
+       free_percpu(slwt->pcpu_counters);
+}
+
 struct seg6_action_param {
        int (*parse)(struct nlattr **attrs, struct seg6_local_lwt *slwt);
        int (*put)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
@@ -1343,6 +1524,10 @@ static struct seg6_action_param seg6_action_params[SEG6_LOCAL_MAX + 1] = {
                                    .put = put_nla_vrftable,
                                    .cmp = cmp_nla_vrftable },
 
+       [SEG6_LOCAL_COUNTERS]   = { .parse = parse_nla_counters,
+                                   .put = put_nla_counters,
+                                   .cmp = cmp_nla_counters,
+                                   .destroy = destroy_attr_counters },
 };
 
 /* call the destroy() callback (if available) for each set attribute in
@@ -1645,6 +1830,15 @@ static int seg6_local_get_encap_size(struct lwtunnel_state *lwt)
        if (attrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE))
                nlsize += nla_total_size(4);
 
+       if (attrs & SEG6_F_LOCAL_COUNTERS)
+               nlsize += nla_total_size(0) + /* nest SEG6_LOCAL_COUNTERS */
+                         /* SEG6_LOCAL_CNT_PACKETS */
+                         nla_total_size_64bit(sizeof(__u64)) +
+                         /* SEG6_LOCAL_CNT_BYTES */
+                         nla_total_size_64bit(sizeof(__u64)) +
+                         /* SEG6_LOCAL_CNT_ERRORS */
+                         nla_total_size_64bit(sizeof(__u64));
+
        return nlsize;
 }