c34e902855dbefdd37b86e2f9cb9ef067be87c3d
[sfrench/cifs-2.6.git] / net / psample / psample.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * net/psample/psample.c - Netlink channel for packet sampling
4  * Copyright (c) 2017 Yotam Gigi <yotamg@mellanox.com>
5  */
6
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <linux/skbuff.h>
10 #include <linux/module.h>
11 #include <linux/timekeeping.h>
12 #include <net/net_namespace.h>
13 #include <net/sock.h>
14 #include <net/netlink.h>
15 #include <net/genetlink.h>
16 #include <net/psample.h>
17 #include <linux/spinlock.h>
18 #include <net/ip_tunnels.h>
19 #include <net/dst_metadata.h>
20
21 #define PSAMPLE_MAX_PACKET_SIZE 0xffff
22
23 static LIST_HEAD(psample_groups_list);
24 static DEFINE_SPINLOCK(psample_groups_lock);
25
26 /* multicast groups */
27 enum psample_nl_multicast_groups {
28         PSAMPLE_NL_MCGRP_CONFIG,
29         PSAMPLE_NL_MCGRP_SAMPLE,
30 };
31
32 static const struct genl_multicast_group psample_nl_mcgrps[] = {
33         [PSAMPLE_NL_MCGRP_CONFIG] = { .name = PSAMPLE_NL_MCGRP_CONFIG_NAME },
34         [PSAMPLE_NL_MCGRP_SAMPLE] = { .name = PSAMPLE_NL_MCGRP_SAMPLE_NAME,
35                                       .flags = GENL_UNS_ADMIN_PERM },
36 };
37
38 static struct genl_family psample_nl_family __ro_after_init;
39
40 static int psample_group_nl_fill(struct sk_buff *msg,
41                                  struct psample_group *group,
42                                  enum psample_command cmd, u32 portid, u32 seq,
43                                  int flags)
44 {
45         void *hdr;
46         int ret;
47
48         hdr = genlmsg_put(msg, portid, seq, &psample_nl_family, flags, cmd);
49         if (!hdr)
50                 return -EMSGSIZE;
51
52         ret = nla_put_u32(msg, PSAMPLE_ATTR_SAMPLE_GROUP, group->group_num);
53         if (ret < 0)
54                 goto error;
55
56         ret = nla_put_u32(msg, PSAMPLE_ATTR_GROUP_REFCOUNT, group->refcount);
57         if (ret < 0)
58                 goto error;
59
60         ret = nla_put_u32(msg, PSAMPLE_ATTR_GROUP_SEQ, group->seq);
61         if (ret < 0)
62                 goto error;
63
64         genlmsg_end(msg, hdr);
65         return 0;
66
67 error:
68         genlmsg_cancel(msg, hdr);
69         return -EMSGSIZE;
70 }
71
72 static int psample_nl_cmd_get_group_dumpit(struct sk_buff *msg,
73                                            struct netlink_callback *cb)
74 {
75         struct psample_group *group;
76         int start = cb->args[0];
77         int idx = 0;
78         int err;
79
80         spin_lock_bh(&psample_groups_lock);
81         list_for_each_entry(group, &psample_groups_list, list) {
82                 if (!net_eq(group->net, sock_net(msg->sk)))
83                         continue;
84                 if (idx < start) {
85                         idx++;
86                         continue;
87                 }
88                 err = psample_group_nl_fill(msg, group, PSAMPLE_CMD_NEW_GROUP,
89                                             NETLINK_CB(cb->skb).portid,
90                                             cb->nlh->nlmsg_seq, NLM_F_MULTI);
91                 if (err)
92                         break;
93                 idx++;
94         }
95
96         spin_unlock_bh(&psample_groups_lock);
97         cb->args[0] = idx;
98         return msg->len;
99 }
100
101 static const struct genl_small_ops psample_nl_ops[] = {
102         {
103                 .cmd = PSAMPLE_CMD_GET_GROUP,
104                 .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
105                 .dumpit = psample_nl_cmd_get_group_dumpit,
106                 /* can be retrieved by unprivileged users */
107         }
108 };
109
110 static struct genl_family psample_nl_family __ro_after_init = {
111         .name           = PSAMPLE_GENL_NAME,
112         .version        = PSAMPLE_GENL_VERSION,
113         .maxattr        = PSAMPLE_ATTR_MAX,
114         .netnsok        = true,
115         .module         = THIS_MODULE,
116         .mcgrps         = psample_nl_mcgrps,
117         .small_ops      = psample_nl_ops,
118         .n_small_ops    = ARRAY_SIZE(psample_nl_ops),
119         .resv_start_op  = PSAMPLE_CMD_GET_GROUP + 1,
120         .n_mcgrps       = ARRAY_SIZE(psample_nl_mcgrps),
121 };
122
123 static void psample_group_notify(struct psample_group *group,
124                                  enum psample_command cmd)
125 {
126         struct sk_buff *msg;
127         int err;
128
129         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
130         if (!msg)
131                 return;
132
133         err = psample_group_nl_fill(msg, group, cmd, 0, 0, NLM_F_MULTI);
134         if (!err)
135                 genlmsg_multicast_netns(&psample_nl_family, group->net, msg, 0,
136                                         PSAMPLE_NL_MCGRP_CONFIG, GFP_ATOMIC);
137         else
138                 nlmsg_free(msg);
139 }
140
141 static struct psample_group *psample_group_create(struct net *net,
142                                                   u32 group_num)
143 {
144         struct psample_group *group;
145
146         group = kzalloc(sizeof(*group), GFP_ATOMIC);
147         if (!group)
148                 return NULL;
149
150         group->net = net;
151         group->group_num = group_num;
152         list_add_tail(&group->list, &psample_groups_list);
153
154         psample_group_notify(group, PSAMPLE_CMD_NEW_GROUP);
155         return group;
156 }
157
158 static void psample_group_destroy(struct psample_group *group)
159 {
160         psample_group_notify(group, PSAMPLE_CMD_DEL_GROUP);
161         list_del(&group->list);
162         kfree_rcu(group, rcu);
163 }
164
165 static struct psample_group *
166 psample_group_lookup(struct net *net, u32 group_num)
167 {
168         struct psample_group *group;
169
170         list_for_each_entry(group, &psample_groups_list, list)
171                 if ((group->group_num == group_num) && (group->net == net))
172                         return group;
173         return NULL;
174 }
175
176 struct psample_group *psample_group_get(struct net *net, u32 group_num)
177 {
178         struct psample_group *group;
179
180         spin_lock_bh(&psample_groups_lock);
181
182         group = psample_group_lookup(net, group_num);
183         if (!group) {
184                 group = psample_group_create(net, group_num);
185                 if (!group)
186                         goto out;
187         }
188         group->refcount++;
189
190 out:
191         spin_unlock_bh(&psample_groups_lock);
192         return group;
193 }
194 EXPORT_SYMBOL_GPL(psample_group_get);
195
196 void psample_group_take(struct psample_group *group)
197 {
198         spin_lock_bh(&psample_groups_lock);
199         group->refcount++;
200         spin_unlock_bh(&psample_groups_lock);
201 }
202 EXPORT_SYMBOL_GPL(psample_group_take);
203
204 void psample_group_put(struct psample_group *group)
205 {
206         spin_lock_bh(&psample_groups_lock);
207
208         if (--group->refcount == 0)
209                 psample_group_destroy(group);
210
211         spin_unlock_bh(&psample_groups_lock);
212 }
213 EXPORT_SYMBOL_GPL(psample_group_put);
214
215 #ifdef CONFIG_INET
216 static int __psample_ip_tun_to_nlattr(struct sk_buff *skb,
217                               struct ip_tunnel_info *tun_info)
218 {
219         unsigned short tun_proto = ip_tunnel_info_af(tun_info);
220         const void *tun_opts = ip_tunnel_info_opts(tun_info);
221         const struct ip_tunnel_key *tun_key = &tun_info->key;
222         int tun_opts_len = tun_info->options_len;
223
224         if (tun_key->tun_flags & TUNNEL_KEY &&
225             nla_put_be64(skb, PSAMPLE_TUNNEL_KEY_ATTR_ID, tun_key->tun_id,
226                          PSAMPLE_TUNNEL_KEY_ATTR_PAD))
227                 return -EMSGSIZE;
228
229         if (tun_info->mode & IP_TUNNEL_INFO_BRIDGE &&
230             nla_put_flag(skb, PSAMPLE_TUNNEL_KEY_ATTR_IPV4_INFO_BRIDGE))
231                 return -EMSGSIZE;
232
233         switch (tun_proto) {
234         case AF_INET:
235                 if (tun_key->u.ipv4.src &&
236                     nla_put_in_addr(skb, PSAMPLE_TUNNEL_KEY_ATTR_IPV4_SRC,
237                                     tun_key->u.ipv4.src))
238                         return -EMSGSIZE;
239                 if (tun_key->u.ipv4.dst &&
240                     nla_put_in_addr(skb, PSAMPLE_TUNNEL_KEY_ATTR_IPV4_DST,
241                                     tun_key->u.ipv4.dst))
242                         return -EMSGSIZE;
243                 break;
244         case AF_INET6:
245                 if (!ipv6_addr_any(&tun_key->u.ipv6.src) &&
246                     nla_put_in6_addr(skb, PSAMPLE_TUNNEL_KEY_ATTR_IPV6_SRC,
247                                      &tun_key->u.ipv6.src))
248                         return -EMSGSIZE;
249                 if (!ipv6_addr_any(&tun_key->u.ipv6.dst) &&
250                     nla_put_in6_addr(skb, PSAMPLE_TUNNEL_KEY_ATTR_IPV6_DST,
251                                      &tun_key->u.ipv6.dst))
252                         return -EMSGSIZE;
253                 break;
254         }
255         if (tun_key->tos &&
256             nla_put_u8(skb, PSAMPLE_TUNNEL_KEY_ATTR_TOS, tun_key->tos))
257                 return -EMSGSIZE;
258         if (nla_put_u8(skb, PSAMPLE_TUNNEL_KEY_ATTR_TTL, tun_key->ttl))
259                 return -EMSGSIZE;
260         if ((tun_key->tun_flags & TUNNEL_DONT_FRAGMENT) &&
261             nla_put_flag(skb, PSAMPLE_TUNNEL_KEY_ATTR_DONT_FRAGMENT))
262                 return -EMSGSIZE;
263         if ((tun_key->tun_flags & TUNNEL_CSUM) &&
264             nla_put_flag(skb, PSAMPLE_TUNNEL_KEY_ATTR_CSUM))
265                 return -EMSGSIZE;
266         if (tun_key->tp_src &&
267             nla_put_be16(skb, PSAMPLE_TUNNEL_KEY_ATTR_TP_SRC, tun_key->tp_src))
268                 return -EMSGSIZE;
269         if (tun_key->tp_dst &&
270             nla_put_be16(skb, PSAMPLE_TUNNEL_KEY_ATTR_TP_DST, tun_key->tp_dst))
271                 return -EMSGSIZE;
272         if ((tun_key->tun_flags & TUNNEL_OAM) &&
273             nla_put_flag(skb, PSAMPLE_TUNNEL_KEY_ATTR_OAM))
274                 return -EMSGSIZE;
275         if (tun_opts_len) {
276                 if (tun_key->tun_flags & TUNNEL_GENEVE_OPT &&
277                     nla_put(skb, PSAMPLE_TUNNEL_KEY_ATTR_GENEVE_OPTS,
278                             tun_opts_len, tun_opts))
279                         return -EMSGSIZE;
280                 else if (tun_key->tun_flags & TUNNEL_ERSPAN_OPT &&
281                          nla_put(skb, PSAMPLE_TUNNEL_KEY_ATTR_ERSPAN_OPTS,
282                                  tun_opts_len, tun_opts))
283                         return -EMSGSIZE;
284         }
285
286         return 0;
287 }
288
289 static int psample_ip_tun_to_nlattr(struct sk_buff *skb,
290                             struct ip_tunnel_info *tun_info)
291 {
292         struct nlattr *nla;
293         int err;
294
295         nla = nla_nest_start_noflag(skb, PSAMPLE_ATTR_TUNNEL);
296         if (!nla)
297                 return -EMSGSIZE;
298
299         err = __psample_ip_tun_to_nlattr(skb, tun_info);
300         if (err) {
301                 nla_nest_cancel(skb, nla);
302                 return err;
303         }
304
305         nla_nest_end(skb, nla);
306
307         return 0;
308 }
309
310 static int psample_tunnel_meta_len(struct ip_tunnel_info *tun_info)
311 {
312         unsigned short tun_proto = ip_tunnel_info_af(tun_info);
313         const struct ip_tunnel_key *tun_key = &tun_info->key;
314         int tun_opts_len = tun_info->options_len;
315         int sum = nla_total_size(0);    /* PSAMPLE_ATTR_TUNNEL */
316
317         if (tun_key->tun_flags & TUNNEL_KEY)
318                 sum += nla_total_size_64bit(sizeof(u64));
319
320         if (tun_info->mode & IP_TUNNEL_INFO_BRIDGE)
321                 sum += nla_total_size(0);
322
323         switch (tun_proto) {
324         case AF_INET:
325                 if (tun_key->u.ipv4.src)
326                         sum += nla_total_size(sizeof(u32));
327                 if (tun_key->u.ipv4.dst)
328                         sum += nla_total_size(sizeof(u32));
329                 break;
330         case AF_INET6:
331                 if (!ipv6_addr_any(&tun_key->u.ipv6.src))
332                         sum += nla_total_size(sizeof(struct in6_addr));
333                 if (!ipv6_addr_any(&tun_key->u.ipv6.dst))
334                         sum += nla_total_size(sizeof(struct in6_addr));
335                 break;
336         }
337         if (tun_key->tos)
338                 sum += nla_total_size(sizeof(u8));
339         sum += nla_total_size(sizeof(u8));      /* TTL */
340         if (tun_key->tun_flags & TUNNEL_DONT_FRAGMENT)
341                 sum += nla_total_size(0);
342         if (tun_key->tun_flags & TUNNEL_CSUM)
343                 sum += nla_total_size(0);
344         if (tun_key->tp_src)
345                 sum += nla_total_size(sizeof(u16));
346         if (tun_key->tp_dst)
347                 sum += nla_total_size(sizeof(u16));
348         if (tun_key->tun_flags & TUNNEL_OAM)
349                 sum += nla_total_size(0);
350         if (tun_opts_len) {
351                 if (tun_key->tun_flags & TUNNEL_GENEVE_OPT)
352                         sum += nla_total_size(tun_opts_len);
353                 else if (tun_key->tun_flags & TUNNEL_ERSPAN_OPT)
354                         sum += nla_total_size(tun_opts_len);
355         }
356
357         return sum;
358 }
359 #endif
360
361 void psample_sample_packet(struct psample_group *group, struct sk_buff *skb,
362                            u32 sample_rate, const struct psample_metadata *md)
363 {
364         ktime_t tstamp = ktime_get_real();
365         int out_ifindex = md->out_ifindex;
366         int in_ifindex = md->in_ifindex;
367         u32 trunc_size = md->trunc_size;
368 #ifdef CONFIG_INET
369         struct ip_tunnel_info *tun_info;
370 #endif
371         struct sk_buff *nl_skb;
372         int data_len;
373         int meta_len;
374         void *data;
375         int ret;
376
377         meta_len = (in_ifindex ? nla_total_size(sizeof(u16)) : 0) +
378                    (out_ifindex ? nla_total_size(sizeof(u16)) : 0) +
379                    (md->out_tc_valid ? nla_total_size(sizeof(u16)) : 0) +
380                    (md->out_tc_occ_valid ? nla_total_size_64bit(sizeof(u64)) : 0) +
381                    (md->latency_valid ? nla_total_size_64bit(sizeof(u64)) : 0) +
382                    nla_total_size(sizeof(u32)) +        /* sample_rate */
383                    nla_total_size(sizeof(u32)) +        /* orig_size */
384                    nla_total_size(sizeof(u32)) +        /* group_num */
385                    nla_total_size(sizeof(u32)) +        /* seq */
386                    nla_total_size_64bit(sizeof(u64)) +  /* timestamp */
387                    nla_total_size(sizeof(u16));         /* protocol */
388
389 #ifdef CONFIG_INET
390         tun_info = skb_tunnel_info(skb);
391         if (tun_info)
392                 meta_len += psample_tunnel_meta_len(tun_info);
393 #endif
394
395         data_len = min(skb->len, trunc_size);
396         if (meta_len + nla_total_size(data_len) > PSAMPLE_MAX_PACKET_SIZE)
397                 data_len = PSAMPLE_MAX_PACKET_SIZE - meta_len - NLA_HDRLEN
398                             - NLA_ALIGNTO;
399
400         nl_skb = genlmsg_new(meta_len + nla_total_size(data_len), GFP_ATOMIC);
401         if (unlikely(!nl_skb))
402                 return;
403
404         data = genlmsg_put(nl_skb, 0, 0, &psample_nl_family, 0,
405                            PSAMPLE_CMD_SAMPLE);
406         if (unlikely(!data))
407                 goto error;
408
409         if (in_ifindex) {
410                 ret = nla_put_u16(nl_skb, PSAMPLE_ATTR_IIFINDEX, in_ifindex);
411                 if (unlikely(ret < 0))
412                         goto error;
413         }
414
415         if (out_ifindex) {
416                 ret = nla_put_u16(nl_skb, PSAMPLE_ATTR_OIFINDEX, out_ifindex);
417                 if (unlikely(ret < 0))
418                         goto error;
419         }
420
421         ret = nla_put_u32(nl_skb, PSAMPLE_ATTR_SAMPLE_RATE, sample_rate);
422         if (unlikely(ret < 0))
423                 goto error;
424
425         ret = nla_put_u32(nl_skb, PSAMPLE_ATTR_ORIGSIZE, skb->len);
426         if (unlikely(ret < 0))
427                 goto error;
428
429         ret = nla_put_u32(nl_skb, PSAMPLE_ATTR_SAMPLE_GROUP, group->group_num);
430         if (unlikely(ret < 0))
431                 goto error;
432
433         ret = nla_put_u32(nl_skb, PSAMPLE_ATTR_GROUP_SEQ, group->seq++);
434         if (unlikely(ret < 0))
435                 goto error;
436
437         if (md->out_tc_valid) {
438                 ret = nla_put_u16(nl_skb, PSAMPLE_ATTR_OUT_TC, md->out_tc);
439                 if (unlikely(ret < 0))
440                         goto error;
441         }
442
443         if (md->out_tc_occ_valid) {
444                 ret = nla_put_u64_64bit(nl_skb, PSAMPLE_ATTR_OUT_TC_OCC,
445                                         md->out_tc_occ, PSAMPLE_ATTR_PAD);
446                 if (unlikely(ret < 0))
447                         goto error;
448         }
449
450         if (md->latency_valid) {
451                 ret = nla_put_u64_64bit(nl_skb, PSAMPLE_ATTR_LATENCY,
452                                         md->latency, PSAMPLE_ATTR_PAD);
453                 if (unlikely(ret < 0))
454                         goto error;
455         }
456
457         ret = nla_put_u64_64bit(nl_skb, PSAMPLE_ATTR_TIMESTAMP,
458                                 ktime_to_ns(tstamp), PSAMPLE_ATTR_PAD);
459         if (unlikely(ret < 0))
460                 goto error;
461
462         ret = nla_put_u16(nl_skb, PSAMPLE_ATTR_PROTO,
463                           be16_to_cpu(skb->protocol));
464         if (unlikely(ret < 0))
465                 goto error;
466
467         if (data_len) {
468                 int nla_len = nla_total_size(data_len);
469                 struct nlattr *nla;
470
471                 nla = skb_put(nl_skb, nla_len);
472                 nla->nla_type = PSAMPLE_ATTR_DATA;
473                 nla->nla_len = nla_attr_size(data_len);
474
475                 if (skb_copy_bits(skb, 0, nla_data(nla), data_len))
476                         goto error;
477         }
478
479 #ifdef CONFIG_INET
480         if (tun_info) {
481                 ret = psample_ip_tun_to_nlattr(nl_skb, tun_info);
482                 if (unlikely(ret < 0))
483                         goto error;
484         }
485 #endif
486
487         genlmsg_end(nl_skb, data);
488         genlmsg_multicast_netns(&psample_nl_family, group->net, nl_skb, 0,
489                                 PSAMPLE_NL_MCGRP_SAMPLE, GFP_ATOMIC);
490
491         return;
492 error:
493         pr_err_ratelimited("Could not create psample log message\n");
494         nlmsg_free(nl_skb);
495 }
496 EXPORT_SYMBOL_GPL(psample_sample_packet);
497
498 static int __init psample_module_init(void)
499 {
500         return genl_register_family(&psample_nl_family);
501 }
502
503 static void __exit psample_module_exit(void)
504 {
505         genl_unregister_family(&psample_nl_family);
506 }
507
508 module_init(psample_module_init);
509 module_exit(psample_module_exit);
510
511 MODULE_AUTHOR("Yotam Gigi <yotam.gi@gmail.com>");
512 MODULE_DESCRIPTION("netlink channel for packet sampling");
513 MODULE_LICENSE("GPL v2");