f043936763f391878c76f7c7f0b6d51ac1694a81
[sfrench/cifs-2.6.git] / net / netfilter / nft_ct.c
1 /*
2  * Copyright (c) 2008-2009 Patrick McHardy <kaber@trash.net>
3  * Copyright (c) 2016 Pablo Neira Ayuso <pablo@netfilter.org>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  *
9  * Development of this code funded by Astaro AG (http://www.astaro.com/)
10  */
11
12 #include <linux/kernel.h>
13 #include <linux/init.h>
14 #include <linux/module.h>
15 #include <linux/netlink.h>
16 #include <linux/netfilter.h>
17 #include <linux/netfilter/nf_tables.h>
18 #include <net/netfilter/nf_tables.h>
19 #include <net/netfilter/nf_conntrack.h>
20 #include <net/netfilter/nf_conntrack_acct.h>
21 #include <net/netfilter/nf_conntrack_tuple.h>
22 #include <net/netfilter/nf_conntrack_helper.h>
23 #include <net/netfilter/nf_conntrack_ecache.h>
24 #include <net/netfilter/nf_conntrack_labels.h>
25 #include <net/netfilter/nf_conntrack_timeout.h>
26 #include <net/netfilter/nf_conntrack_l4proto.h>
27
28 struct nft_ct {
29         enum nft_ct_keys        key:8;
30         enum ip_conntrack_dir   dir:8;
31         union {
32                 enum nft_registers      dreg:8;
33                 enum nft_registers      sreg:8;
34         };
35 };
36
37 struct nft_ct_helper_obj  {
38         struct nf_conntrack_helper *helper4;
39         struct nf_conntrack_helper *helper6;
40         u8 l4proto;
41 };
42
43 #ifdef CONFIG_NF_CONNTRACK_ZONES
44 static DEFINE_PER_CPU(struct nf_conn *, nft_ct_pcpu_template);
45 static unsigned int nft_ct_pcpu_template_refcnt __read_mostly;
46 #endif
47
48 static u64 nft_ct_get_eval_counter(const struct nf_conn_counter *c,
49                                    enum nft_ct_keys k,
50                                    enum ip_conntrack_dir d)
51 {
52         if (d < IP_CT_DIR_MAX)
53                 return k == NFT_CT_BYTES ? atomic64_read(&c[d].bytes) :
54                                            atomic64_read(&c[d].packets);
55
56         return nft_ct_get_eval_counter(c, k, IP_CT_DIR_ORIGINAL) +
57                nft_ct_get_eval_counter(c, k, IP_CT_DIR_REPLY);
58 }
59
60 static void nft_ct_get_eval(const struct nft_expr *expr,
61                             struct nft_regs *regs,
62                             const struct nft_pktinfo *pkt)
63 {
64         const struct nft_ct *priv = nft_expr_priv(expr);
65         u32 *dest = &regs->data[priv->dreg];
66         enum ip_conntrack_info ctinfo;
67         const struct nf_conn *ct;
68         const struct nf_conn_help *help;
69         const struct nf_conntrack_tuple *tuple;
70         const struct nf_conntrack_helper *helper;
71         unsigned int state;
72
73         ct = nf_ct_get(pkt->skb, &ctinfo);
74
75         switch (priv->key) {
76         case NFT_CT_STATE:
77                 if (ct)
78                         state = NF_CT_STATE_BIT(ctinfo);
79                 else if (ctinfo == IP_CT_UNTRACKED)
80                         state = NF_CT_STATE_UNTRACKED_BIT;
81                 else
82                         state = NF_CT_STATE_INVALID_BIT;
83                 *dest = state;
84                 return;
85         default:
86                 break;
87         }
88
89         if (ct == NULL)
90                 goto err;
91
92         switch (priv->key) {
93         case NFT_CT_DIRECTION:
94                 nft_reg_store8(dest, CTINFO2DIR(ctinfo));
95                 return;
96         case NFT_CT_STATUS:
97                 *dest = ct->status;
98                 return;
99 #ifdef CONFIG_NF_CONNTRACK_MARK
100         case NFT_CT_MARK:
101                 *dest = ct->mark;
102                 return;
103 #endif
104 #ifdef CONFIG_NF_CONNTRACK_SECMARK
105         case NFT_CT_SECMARK:
106                 *dest = ct->secmark;
107                 return;
108 #endif
109         case NFT_CT_EXPIRATION:
110                 *dest = jiffies_to_msecs(nf_ct_expires(ct));
111                 return;
112         case NFT_CT_HELPER:
113                 if (ct->master == NULL)
114                         goto err;
115                 help = nfct_help(ct->master);
116                 if (help == NULL)
117                         goto err;
118                 helper = rcu_dereference(help->helper);
119                 if (helper == NULL)
120                         goto err;
121                 strncpy((char *)dest, helper->name, NF_CT_HELPER_NAME_LEN);
122                 return;
123 #ifdef CONFIG_NF_CONNTRACK_LABELS
124         case NFT_CT_LABELS: {
125                 struct nf_conn_labels *labels = nf_ct_labels_find(ct);
126
127                 if (labels)
128                         memcpy(dest, labels->bits, NF_CT_LABELS_MAX_SIZE);
129                 else
130                         memset(dest, 0, NF_CT_LABELS_MAX_SIZE);
131                 return;
132         }
133 #endif
134         case NFT_CT_BYTES: /* fallthrough */
135         case NFT_CT_PKTS: {
136                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
137                 u64 count = 0;
138
139                 if (acct)
140                         count = nft_ct_get_eval_counter(acct->counter,
141                                                         priv->key, priv->dir);
142                 memcpy(dest, &count, sizeof(count));
143                 return;
144         }
145         case NFT_CT_AVGPKT: {
146                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
147                 u64 avgcnt = 0, bcnt = 0, pcnt = 0;
148
149                 if (acct) {
150                         pcnt = nft_ct_get_eval_counter(acct->counter,
151                                                        NFT_CT_PKTS, priv->dir);
152                         bcnt = nft_ct_get_eval_counter(acct->counter,
153                                                        NFT_CT_BYTES, priv->dir);
154                         if (pcnt != 0)
155                                 avgcnt = div64_u64(bcnt, pcnt);
156                 }
157
158                 memcpy(dest, &avgcnt, sizeof(avgcnt));
159                 return;
160         }
161         case NFT_CT_L3PROTOCOL:
162                 nft_reg_store8(dest, nf_ct_l3num(ct));
163                 return;
164         case NFT_CT_PROTOCOL:
165                 nft_reg_store8(dest, nf_ct_protonum(ct));
166                 return;
167 #ifdef CONFIG_NF_CONNTRACK_ZONES
168         case NFT_CT_ZONE: {
169                 const struct nf_conntrack_zone *zone = nf_ct_zone(ct);
170                 u16 zoneid;
171
172                 if (priv->dir < IP_CT_DIR_MAX)
173                         zoneid = nf_ct_zone_id(zone, priv->dir);
174                 else
175                         zoneid = zone->id;
176
177                 nft_reg_store16(dest, zoneid);
178                 return;
179         }
180 #endif
181         case NFT_CT_ID:
182                 if (!nf_ct_is_confirmed(ct))
183                         goto err;
184                 *dest = nf_ct_get_id(ct);
185                 return;
186         default:
187                 break;
188         }
189
190         tuple = &ct->tuplehash[priv->dir].tuple;
191         switch (priv->key) {
192         case NFT_CT_SRC:
193                 memcpy(dest, tuple->src.u3.all,
194                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
195                 return;
196         case NFT_CT_DST:
197                 memcpy(dest, tuple->dst.u3.all,
198                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
199                 return;
200         case NFT_CT_PROTO_SRC:
201                 nft_reg_store16(dest, (__force u16)tuple->src.u.all);
202                 return;
203         case NFT_CT_PROTO_DST:
204                 nft_reg_store16(dest, (__force u16)tuple->dst.u.all);
205                 return;
206         case NFT_CT_SRC_IP:
207                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
208                         goto err;
209                 *dest = tuple->src.u3.ip;
210                 return;
211         case NFT_CT_DST_IP:
212                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
213                         goto err;
214                 *dest = tuple->dst.u3.ip;
215                 return;
216         case NFT_CT_SRC_IP6:
217                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
218                         goto err;
219                 memcpy(dest, tuple->src.u3.ip6, sizeof(struct in6_addr));
220                 return;
221         case NFT_CT_DST_IP6:
222                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
223                         goto err;
224                 memcpy(dest, tuple->dst.u3.ip6, sizeof(struct in6_addr));
225                 return;
226         default:
227                 break;
228         }
229         return;
230 err:
231         regs->verdict.code = NFT_BREAK;
232 }
233
234 #ifdef CONFIG_NF_CONNTRACK_ZONES
235 static void nft_ct_set_zone_eval(const struct nft_expr *expr,
236                                  struct nft_regs *regs,
237                                  const struct nft_pktinfo *pkt)
238 {
239         struct nf_conntrack_zone zone = { .dir = NF_CT_DEFAULT_ZONE_DIR };
240         const struct nft_ct *priv = nft_expr_priv(expr);
241         struct sk_buff *skb = pkt->skb;
242         enum ip_conntrack_info ctinfo;
243         u16 value = nft_reg_load16(&regs->data[priv->sreg]);
244         struct nf_conn *ct;
245
246         ct = nf_ct_get(skb, &ctinfo);
247         if (ct) /* already tracked */
248                 return;
249
250         zone.id = value;
251
252         switch (priv->dir) {
253         case IP_CT_DIR_ORIGINAL:
254                 zone.dir = NF_CT_ZONE_DIR_ORIG;
255                 break;
256         case IP_CT_DIR_REPLY:
257                 zone.dir = NF_CT_ZONE_DIR_REPL;
258                 break;
259         default:
260                 break;
261         }
262
263         ct = this_cpu_read(nft_ct_pcpu_template);
264
265         if (likely(atomic_read(&ct->ct_general.use) == 1)) {
266                 nf_ct_zone_add(ct, &zone);
267         } else {
268                 /* previous skb got queued to userspace */
269                 ct = nf_ct_tmpl_alloc(nft_net(pkt), &zone, GFP_ATOMIC);
270                 if (!ct) {
271                         regs->verdict.code = NF_DROP;
272                         return;
273                 }
274         }
275
276         atomic_inc(&ct->ct_general.use);
277         nf_ct_set(skb, ct, IP_CT_NEW);
278 }
279 #endif
280
281 static void nft_ct_set_eval(const struct nft_expr *expr,
282                             struct nft_regs *regs,
283                             const struct nft_pktinfo *pkt)
284 {
285         const struct nft_ct *priv = nft_expr_priv(expr);
286         struct sk_buff *skb = pkt->skb;
287 #if defined(CONFIG_NF_CONNTRACK_MARK) || defined(CONFIG_NF_CONNTRACK_SECMARK)
288         u32 value = regs->data[priv->sreg];
289 #endif
290         enum ip_conntrack_info ctinfo;
291         struct nf_conn *ct;
292
293         ct = nf_ct_get(skb, &ctinfo);
294         if (ct == NULL || nf_ct_is_template(ct))
295                 return;
296
297         switch (priv->key) {
298 #ifdef CONFIG_NF_CONNTRACK_MARK
299         case NFT_CT_MARK:
300                 if (ct->mark != value) {
301                         ct->mark = value;
302                         nf_conntrack_event_cache(IPCT_MARK, ct);
303                 }
304                 break;
305 #endif
306 #ifdef CONFIG_NF_CONNTRACK_SECMARK
307         case NFT_CT_SECMARK:
308                 if (ct->secmark != value) {
309                         ct->secmark = value;
310                         nf_conntrack_event_cache(IPCT_SECMARK, ct);
311                 }
312                 break;
313 #endif
314 #ifdef CONFIG_NF_CONNTRACK_LABELS
315         case NFT_CT_LABELS:
316                 nf_connlabels_replace(ct,
317                                       &regs->data[priv->sreg],
318                                       &regs->data[priv->sreg],
319                                       NF_CT_LABELS_MAX_SIZE / sizeof(u32));
320                 break;
321 #endif
322 #ifdef CONFIG_NF_CONNTRACK_EVENTS
323         case NFT_CT_EVENTMASK: {
324                 struct nf_conntrack_ecache *e = nf_ct_ecache_find(ct);
325                 u32 ctmask = regs->data[priv->sreg];
326
327                 if (e) {
328                         if (e->ctmask != ctmask)
329                                 e->ctmask = ctmask;
330                         break;
331                 }
332
333                 if (ctmask && !nf_ct_is_confirmed(ct))
334                         nf_ct_ecache_ext_add(ct, ctmask, 0, GFP_ATOMIC);
335                 break;
336         }
337 #endif
338         default:
339                 break;
340         }
341 }
342
343 static const struct nla_policy nft_ct_policy[NFTA_CT_MAX + 1] = {
344         [NFTA_CT_DREG]          = { .type = NLA_U32 },
345         [NFTA_CT_KEY]           = { .type = NLA_U32 },
346         [NFTA_CT_DIRECTION]     = { .type = NLA_U8 },
347         [NFTA_CT_SREG]          = { .type = NLA_U32 },
348 };
349
350 #ifdef CONFIG_NF_CONNTRACK_ZONES
351 static void nft_ct_tmpl_put_pcpu(void)
352 {
353         struct nf_conn *ct;
354         int cpu;
355
356         for_each_possible_cpu(cpu) {
357                 ct = per_cpu(nft_ct_pcpu_template, cpu);
358                 if (!ct)
359                         break;
360                 nf_ct_put(ct);
361                 per_cpu(nft_ct_pcpu_template, cpu) = NULL;
362         }
363 }
364
365 static bool nft_ct_tmpl_alloc_pcpu(void)
366 {
367         struct nf_conntrack_zone zone = { .id = 0 };
368         struct nf_conn *tmp;
369         int cpu;
370
371         if (nft_ct_pcpu_template_refcnt)
372                 return true;
373
374         for_each_possible_cpu(cpu) {
375                 tmp = nf_ct_tmpl_alloc(&init_net, &zone, GFP_KERNEL);
376                 if (!tmp) {
377                         nft_ct_tmpl_put_pcpu();
378                         return false;
379                 }
380
381                 atomic_set(&tmp->ct_general.use, 1);
382                 per_cpu(nft_ct_pcpu_template, cpu) = tmp;
383         }
384
385         return true;
386 }
387 #endif
388
389 static int nft_ct_get_init(const struct nft_ctx *ctx,
390                            const struct nft_expr *expr,
391                            const struct nlattr * const tb[])
392 {
393         struct nft_ct *priv = nft_expr_priv(expr);
394         unsigned int len;
395         int err;
396
397         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
398         priv->dir = IP_CT_DIR_MAX;
399         switch (priv->key) {
400         case NFT_CT_DIRECTION:
401                 if (tb[NFTA_CT_DIRECTION] != NULL)
402                         return -EINVAL;
403                 len = sizeof(u8);
404                 break;
405         case NFT_CT_STATE:
406         case NFT_CT_STATUS:
407 #ifdef CONFIG_NF_CONNTRACK_MARK
408         case NFT_CT_MARK:
409 #endif
410 #ifdef CONFIG_NF_CONNTRACK_SECMARK
411         case NFT_CT_SECMARK:
412 #endif
413         case NFT_CT_EXPIRATION:
414                 if (tb[NFTA_CT_DIRECTION] != NULL)
415                         return -EINVAL;
416                 len = sizeof(u32);
417                 break;
418 #ifdef CONFIG_NF_CONNTRACK_LABELS
419         case NFT_CT_LABELS:
420                 if (tb[NFTA_CT_DIRECTION] != NULL)
421                         return -EINVAL;
422                 len = NF_CT_LABELS_MAX_SIZE;
423                 break;
424 #endif
425         case NFT_CT_HELPER:
426                 if (tb[NFTA_CT_DIRECTION] != NULL)
427                         return -EINVAL;
428                 len = NF_CT_HELPER_NAME_LEN;
429                 break;
430
431         case NFT_CT_L3PROTOCOL:
432         case NFT_CT_PROTOCOL:
433                 /* For compatibility, do not report error if NFTA_CT_DIRECTION
434                  * attribute is specified.
435                  */
436                 len = sizeof(u8);
437                 break;
438         case NFT_CT_SRC:
439         case NFT_CT_DST:
440                 if (tb[NFTA_CT_DIRECTION] == NULL)
441                         return -EINVAL;
442
443                 switch (ctx->family) {
444                 case NFPROTO_IPV4:
445                         len = FIELD_SIZEOF(struct nf_conntrack_tuple,
446                                            src.u3.ip);
447                         break;
448                 case NFPROTO_IPV6:
449                 case NFPROTO_INET:
450                         len = FIELD_SIZEOF(struct nf_conntrack_tuple,
451                                            src.u3.ip6);
452                         break;
453                 default:
454                         return -EAFNOSUPPORT;
455                 }
456                 break;
457         case NFT_CT_SRC_IP:
458         case NFT_CT_DST_IP:
459                 if (tb[NFTA_CT_DIRECTION] == NULL)
460                         return -EINVAL;
461
462                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u3.ip);
463                 break;
464         case NFT_CT_SRC_IP6:
465         case NFT_CT_DST_IP6:
466                 if (tb[NFTA_CT_DIRECTION] == NULL)
467                         return -EINVAL;
468
469                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u3.ip6);
470                 break;
471         case NFT_CT_PROTO_SRC:
472         case NFT_CT_PROTO_DST:
473                 if (tb[NFTA_CT_DIRECTION] == NULL)
474                         return -EINVAL;
475                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u.all);
476                 break;
477         case NFT_CT_BYTES:
478         case NFT_CT_PKTS:
479         case NFT_CT_AVGPKT:
480                 len = sizeof(u64);
481                 break;
482 #ifdef CONFIG_NF_CONNTRACK_ZONES
483         case NFT_CT_ZONE:
484                 len = sizeof(u16);
485                 break;
486 #endif
487         case NFT_CT_ID:
488                 len = sizeof(u32);
489                 break;
490         default:
491                 return -EOPNOTSUPP;
492         }
493
494         if (tb[NFTA_CT_DIRECTION] != NULL) {
495                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
496                 switch (priv->dir) {
497                 case IP_CT_DIR_ORIGINAL:
498                 case IP_CT_DIR_REPLY:
499                         break;
500                 default:
501                         return -EINVAL;
502                 }
503         }
504
505         priv->dreg = nft_parse_register(tb[NFTA_CT_DREG]);
506         err = nft_validate_register_store(ctx, priv->dreg, NULL,
507                                           NFT_DATA_VALUE, len);
508         if (err < 0)
509                 return err;
510
511         err = nf_ct_netns_get(ctx->net, ctx->family);
512         if (err < 0)
513                 return err;
514
515         if (priv->key == NFT_CT_BYTES ||
516             priv->key == NFT_CT_PKTS  ||
517             priv->key == NFT_CT_AVGPKT)
518                 nf_ct_set_acct(ctx->net, true);
519
520         return 0;
521 }
522
523 static void __nft_ct_set_destroy(const struct nft_ctx *ctx, struct nft_ct *priv)
524 {
525         switch (priv->key) {
526 #ifdef CONFIG_NF_CONNTRACK_LABELS
527         case NFT_CT_LABELS:
528                 nf_connlabels_put(ctx->net);
529                 break;
530 #endif
531 #ifdef CONFIG_NF_CONNTRACK_ZONES
532         case NFT_CT_ZONE:
533                 if (--nft_ct_pcpu_template_refcnt == 0)
534                         nft_ct_tmpl_put_pcpu();
535 #endif
536         default:
537                 break;
538         }
539 }
540
541 static int nft_ct_set_init(const struct nft_ctx *ctx,
542                            const struct nft_expr *expr,
543                            const struct nlattr * const tb[])
544 {
545         struct nft_ct *priv = nft_expr_priv(expr);
546         unsigned int len;
547         int err;
548
549         priv->dir = IP_CT_DIR_MAX;
550         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
551         switch (priv->key) {
552 #ifdef CONFIG_NF_CONNTRACK_MARK
553         case NFT_CT_MARK:
554                 if (tb[NFTA_CT_DIRECTION])
555                         return -EINVAL;
556                 len = FIELD_SIZEOF(struct nf_conn, mark);
557                 break;
558 #endif
559 #ifdef CONFIG_NF_CONNTRACK_LABELS
560         case NFT_CT_LABELS:
561                 if (tb[NFTA_CT_DIRECTION])
562                         return -EINVAL;
563                 len = NF_CT_LABELS_MAX_SIZE;
564                 err = nf_connlabels_get(ctx->net, (len * BITS_PER_BYTE) - 1);
565                 if (err)
566                         return err;
567                 break;
568 #endif
569 #ifdef CONFIG_NF_CONNTRACK_ZONES
570         case NFT_CT_ZONE:
571                 if (!nft_ct_tmpl_alloc_pcpu())
572                         return -ENOMEM;
573                 nft_ct_pcpu_template_refcnt++;
574                 len = sizeof(u16);
575                 break;
576 #endif
577 #ifdef CONFIG_NF_CONNTRACK_EVENTS
578         case NFT_CT_EVENTMASK:
579                 if (tb[NFTA_CT_DIRECTION])
580                         return -EINVAL;
581                 len = sizeof(u32);
582                 break;
583 #endif
584 #ifdef CONFIG_NF_CONNTRACK_SECMARK
585         case NFT_CT_SECMARK:
586                 if (tb[NFTA_CT_DIRECTION])
587                         return -EINVAL;
588                 len = sizeof(u32);
589                 break;
590 #endif
591         default:
592                 return -EOPNOTSUPP;
593         }
594
595         if (tb[NFTA_CT_DIRECTION]) {
596                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
597                 switch (priv->dir) {
598                 case IP_CT_DIR_ORIGINAL:
599                 case IP_CT_DIR_REPLY:
600                         break;
601                 default:
602                         err = -EINVAL;
603                         goto err1;
604                 }
605         }
606
607         priv->sreg = nft_parse_register(tb[NFTA_CT_SREG]);
608         err = nft_validate_register_load(priv->sreg, len);
609         if (err < 0)
610                 goto err1;
611
612         err = nf_ct_netns_get(ctx->net, ctx->family);
613         if (err < 0)
614                 goto err1;
615
616         return 0;
617
618 err1:
619         __nft_ct_set_destroy(ctx, priv);
620         return err;
621 }
622
623 static void nft_ct_get_destroy(const struct nft_ctx *ctx,
624                                const struct nft_expr *expr)
625 {
626         nf_ct_netns_put(ctx->net, ctx->family);
627 }
628
629 static void nft_ct_set_destroy(const struct nft_ctx *ctx,
630                                const struct nft_expr *expr)
631 {
632         struct nft_ct *priv = nft_expr_priv(expr);
633
634         __nft_ct_set_destroy(ctx, priv);
635         nf_ct_netns_put(ctx->net, ctx->family);
636 }
637
638 static int nft_ct_get_dump(struct sk_buff *skb, const struct nft_expr *expr)
639 {
640         const struct nft_ct *priv = nft_expr_priv(expr);
641
642         if (nft_dump_register(skb, NFTA_CT_DREG, priv->dreg))
643                 goto nla_put_failure;
644         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
645                 goto nla_put_failure;
646
647         switch (priv->key) {
648         case NFT_CT_SRC:
649         case NFT_CT_DST:
650         case NFT_CT_SRC_IP:
651         case NFT_CT_DST_IP:
652         case NFT_CT_SRC_IP6:
653         case NFT_CT_DST_IP6:
654         case NFT_CT_PROTO_SRC:
655         case NFT_CT_PROTO_DST:
656                 if (nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
657                         goto nla_put_failure;
658                 break;
659         case NFT_CT_BYTES:
660         case NFT_CT_PKTS:
661         case NFT_CT_AVGPKT:
662         case NFT_CT_ZONE:
663                 if (priv->dir < IP_CT_DIR_MAX &&
664                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
665                         goto nla_put_failure;
666                 break;
667         default:
668                 break;
669         }
670
671         return 0;
672
673 nla_put_failure:
674         return -1;
675 }
676
677 static int nft_ct_set_dump(struct sk_buff *skb, const struct nft_expr *expr)
678 {
679         const struct nft_ct *priv = nft_expr_priv(expr);
680
681         if (nft_dump_register(skb, NFTA_CT_SREG, priv->sreg))
682                 goto nla_put_failure;
683         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
684                 goto nla_put_failure;
685
686         switch (priv->key) {
687         case NFT_CT_ZONE:
688                 if (priv->dir < IP_CT_DIR_MAX &&
689                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
690                         goto nla_put_failure;
691                 break;
692         default:
693                 break;
694         }
695
696         return 0;
697
698 nla_put_failure:
699         return -1;
700 }
701
702 static struct nft_expr_type nft_ct_type;
703 static const struct nft_expr_ops nft_ct_get_ops = {
704         .type           = &nft_ct_type,
705         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
706         .eval           = nft_ct_get_eval,
707         .init           = nft_ct_get_init,
708         .destroy        = nft_ct_get_destroy,
709         .dump           = nft_ct_get_dump,
710 };
711
712 static const struct nft_expr_ops nft_ct_set_ops = {
713         .type           = &nft_ct_type,
714         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
715         .eval           = nft_ct_set_eval,
716         .init           = nft_ct_set_init,
717         .destroy        = nft_ct_set_destroy,
718         .dump           = nft_ct_set_dump,
719 };
720
721 #ifdef CONFIG_NF_CONNTRACK_ZONES
722 static const struct nft_expr_ops nft_ct_set_zone_ops = {
723         .type           = &nft_ct_type,
724         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
725         .eval           = nft_ct_set_zone_eval,
726         .init           = nft_ct_set_init,
727         .destroy        = nft_ct_set_destroy,
728         .dump           = nft_ct_set_dump,
729 };
730 #endif
731
732 static const struct nft_expr_ops *
733 nft_ct_select_ops(const struct nft_ctx *ctx,
734                     const struct nlattr * const tb[])
735 {
736         if (tb[NFTA_CT_KEY] == NULL)
737                 return ERR_PTR(-EINVAL);
738
739         if (tb[NFTA_CT_DREG] && tb[NFTA_CT_SREG])
740                 return ERR_PTR(-EINVAL);
741
742         if (tb[NFTA_CT_DREG])
743                 return &nft_ct_get_ops;
744
745         if (tb[NFTA_CT_SREG]) {
746 #ifdef CONFIG_NF_CONNTRACK_ZONES
747                 if (nla_get_be32(tb[NFTA_CT_KEY]) == htonl(NFT_CT_ZONE))
748                         return &nft_ct_set_zone_ops;
749 #endif
750                 return &nft_ct_set_ops;
751         }
752
753         return ERR_PTR(-EINVAL);
754 }
755
756 static struct nft_expr_type nft_ct_type __read_mostly = {
757         .name           = "ct",
758         .select_ops     = nft_ct_select_ops,
759         .policy         = nft_ct_policy,
760         .maxattr        = NFTA_CT_MAX,
761         .owner          = THIS_MODULE,
762 };
763
764 static void nft_notrack_eval(const struct nft_expr *expr,
765                              struct nft_regs *regs,
766                              const struct nft_pktinfo *pkt)
767 {
768         struct sk_buff *skb = pkt->skb;
769         enum ip_conntrack_info ctinfo;
770         struct nf_conn *ct;
771
772         ct = nf_ct_get(pkt->skb, &ctinfo);
773         /* Previously seen (loopback or untracked)?  Ignore. */
774         if (ct || ctinfo == IP_CT_UNTRACKED)
775                 return;
776
777         nf_ct_set(skb, ct, IP_CT_UNTRACKED);
778 }
779
780 static struct nft_expr_type nft_notrack_type;
781 static const struct nft_expr_ops nft_notrack_ops = {
782         .type           = &nft_notrack_type,
783         .size           = NFT_EXPR_SIZE(0),
784         .eval           = nft_notrack_eval,
785 };
786
787 static struct nft_expr_type nft_notrack_type __read_mostly = {
788         .name           = "notrack",
789         .ops            = &nft_notrack_ops,
790         .owner          = THIS_MODULE,
791 };
792
793 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
794 static int
795 nft_ct_timeout_parse_policy(void *timeouts,
796                             const struct nf_conntrack_l4proto *l4proto,
797                             struct net *net, const struct nlattr *attr)
798 {
799         struct nlattr **tb;
800         int ret = 0;
801
802         tb = kcalloc(l4proto->ctnl_timeout.nlattr_max + 1, sizeof(*tb),
803                      GFP_KERNEL);
804
805         if (!tb)
806                 return -ENOMEM;
807
808         ret = nla_parse_nested_deprecated(tb,
809                                           l4proto->ctnl_timeout.nlattr_max,
810                                           attr,
811                                           l4proto->ctnl_timeout.nla_policy,
812                                           NULL);
813         if (ret < 0)
814                 goto err;
815
816         ret = l4proto->ctnl_timeout.nlattr_to_obj(tb, net, timeouts);
817
818 err:
819         kfree(tb);
820         return ret;
821 }
822
823 struct nft_ct_timeout_obj {
824         struct nf_ct_timeout    *timeout;
825         u8                      l4proto;
826 };
827
828 static void nft_ct_timeout_obj_eval(struct nft_object *obj,
829                                     struct nft_regs *regs,
830                                     const struct nft_pktinfo *pkt)
831 {
832         const struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
833         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
834         struct nf_conn_timeout *timeout;
835         const unsigned int *values;
836
837         if (priv->l4proto != pkt->tprot)
838                 return;
839
840         if (!ct || nf_ct_is_template(ct) || nf_ct_is_confirmed(ct))
841                 return;
842
843         timeout = nf_ct_timeout_find(ct);
844         if (!timeout) {
845                 timeout = nf_ct_timeout_ext_add(ct, priv->timeout, GFP_ATOMIC);
846                 if (!timeout) {
847                         regs->verdict.code = NF_DROP;
848                         return;
849                 }
850         }
851
852         rcu_assign_pointer(timeout->timeout, priv->timeout);
853
854         /* adjust the timeout as per 'new' state. ct is unconfirmed,
855          * so the current timestamp must not be added.
856          */
857         values = nf_ct_timeout_data(timeout);
858         if (values)
859                 nf_ct_refresh(ct, pkt->skb, values[0]);
860 }
861
862 static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx,
863                                    const struct nlattr * const tb[],
864                                    struct nft_object *obj)
865 {
866         struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
867         const struct nf_conntrack_l4proto *l4proto;
868         struct nf_ct_timeout *timeout;
869         int l3num = ctx->family;
870         __u8 l4num;
871         int ret;
872
873         if (!tb[NFTA_CT_TIMEOUT_L4PROTO] ||
874             !tb[NFTA_CT_TIMEOUT_DATA])
875                 return -EINVAL;
876
877         if (tb[NFTA_CT_TIMEOUT_L3PROTO])
878                 l3num = ntohs(nla_get_be16(tb[NFTA_CT_TIMEOUT_L3PROTO]));
879
880         l4num = nla_get_u8(tb[NFTA_CT_TIMEOUT_L4PROTO]);
881         priv->l4proto = l4num;
882
883         l4proto = nf_ct_l4proto_find(l4num);
884
885         if (l4proto->l4proto != l4num) {
886                 ret = -EOPNOTSUPP;
887                 goto err_proto_put;
888         }
889
890         timeout = kzalloc(sizeof(struct nf_ct_timeout) +
891                           l4proto->ctnl_timeout.obj_size, GFP_KERNEL);
892         if (timeout == NULL) {
893                 ret = -ENOMEM;
894                 goto err_proto_put;
895         }
896
897         ret = nft_ct_timeout_parse_policy(&timeout->data, l4proto, ctx->net,
898                                           tb[NFTA_CT_TIMEOUT_DATA]);
899         if (ret < 0)
900                 goto err_free_timeout;
901
902         timeout->l3num = l3num;
903         timeout->l4proto = l4proto;
904
905         ret = nf_ct_netns_get(ctx->net, ctx->family);
906         if (ret < 0)
907                 goto err_free_timeout;
908
909         priv->timeout = timeout;
910         return 0;
911
912 err_free_timeout:
913         kfree(timeout);
914 err_proto_put:
915         return ret;
916 }
917
918 static void nft_ct_timeout_obj_destroy(const struct nft_ctx *ctx,
919                                        struct nft_object *obj)
920 {
921         struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
922         struct nf_ct_timeout *timeout = priv->timeout;
923
924         nf_ct_untimeout(ctx->net, timeout);
925         nf_ct_netns_put(ctx->net, ctx->family);
926         kfree(priv->timeout);
927 }
928
929 static int nft_ct_timeout_obj_dump(struct sk_buff *skb,
930                                    struct nft_object *obj, bool reset)
931 {
932         const struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
933         const struct nf_ct_timeout *timeout = priv->timeout;
934         struct nlattr *nest_params;
935         int ret;
936
937         if (nla_put_u8(skb, NFTA_CT_TIMEOUT_L4PROTO, timeout->l4proto->l4proto) ||
938             nla_put_be16(skb, NFTA_CT_TIMEOUT_L3PROTO, htons(timeout->l3num)))
939                 return -1;
940
941         nest_params = nla_nest_start(skb, NFTA_CT_TIMEOUT_DATA);
942         if (!nest_params)
943                 return -1;
944
945         ret = timeout->l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->data);
946         if (ret < 0)
947                 return -1;
948         nla_nest_end(skb, nest_params);
949         return 0;
950 }
951
952 static const struct nla_policy nft_ct_timeout_policy[NFTA_CT_TIMEOUT_MAX + 1] = {
953         [NFTA_CT_TIMEOUT_L3PROTO] = {.type = NLA_U16 },
954         [NFTA_CT_TIMEOUT_L4PROTO] = {.type = NLA_U8 },
955         [NFTA_CT_TIMEOUT_DATA]    = {.type = NLA_NESTED },
956 };
957
958 static struct nft_object_type nft_ct_timeout_obj_type;
959
960 static const struct nft_object_ops nft_ct_timeout_obj_ops = {
961         .type           = &nft_ct_timeout_obj_type,
962         .size           = sizeof(struct nft_ct_timeout_obj),
963         .eval           = nft_ct_timeout_obj_eval,
964         .init           = nft_ct_timeout_obj_init,
965         .destroy        = nft_ct_timeout_obj_destroy,
966         .dump           = nft_ct_timeout_obj_dump,
967 };
968
969 static struct nft_object_type nft_ct_timeout_obj_type __read_mostly = {
970         .type           = NFT_OBJECT_CT_TIMEOUT,
971         .ops            = &nft_ct_timeout_obj_ops,
972         .maxattr        = NFTA_CT_TIMEOUT_MAX,
973         .policy         = nft_ct_timeout_policy,
974         .owner          = THIS_MODULE,
975 };
976 #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */
977
978 static int nft_ct_helper_obj_init(const struct nft_ctx *ctx,
979                                   const struct nlattr * const tb[],
980                                   struct nft_object *obj)
981 {
982         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
983         struct nf_conntrack_helper *help4, *help6;
984         char name[NF_CT_HELPER_NAME_LEN];
985         int family = ctx->family;
986         int err;
987
988         if (!tb[NFTA_CT_HELPER_NAME] || !tb[NFTA_CT_HELPER_L4PROTO])
989                 return -EINVAL;
990
991         priv->l4proto = nla_get_u8(tb[NFTA_CT_HELPER_L4PROTO]);
992         if (!priv->l4proto)
993                 return -ENOENT;
994
995         nla_strlcpy(name, tb[NFTA_CT_HELPER_NAME], sizeof(name));
996
997         if (tb[NFTA_CT_HELPER_L3PROTO])
998                 family = ntohs(nla_get_be16(tb[NFTA_CT_HELPER_L3PROTO]));
999
1000         help4 = NULL;
1001         help6 = NULL;
1002
1003         switch (family) {
1004         case NFPROTO_IPV4:
1005                 if (ctx->family == NFPROTO_IPV6)
1006                         return -EINVAL;
1007
1008                 help4 = nf_conntrack_helper_try_module_get(name, family,
1009                                                            priv->l4proto);
1010                 break;
1011         case NFPROTO_IPV6:
1012                 if (ctx->family == NFPROTO_IPV4)
1013                         return -EINVAL;
1014
1015                 help6 = nf_conntrack_helper_try_module_get(name, family,
1016                                                            priv->l4proto);
1017                 break;
1018         case NFPROTO_NETDEV: /* fallthrough */
1019         case NFPROTO_BRIDGE: /* same */
1020         case NFPROTO_INET:
1021                 help4 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV4,
1022                                                            priv->l4proto);
1023                 help6 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV6,
1024                                                            priv->l4proto);
1025                 break;
1026         default:
1027                 return -EAFNOSUPPORT;
1028         }
1029
1030         /* && is intentional; only error if INET found neither ipv4 or ipv6 */
1031         if (!help4 && !help6)
1032                 return -ENOENT;
1033
1034         priv->helper4 = help4;
1035         priv->helper6 = help6;
1036
1037         err = nf_ct_netns_get(ctx->net, ctx->family);
1038         if (err < 0)
1039                 goto err_put_helper;
1040
1041         return 0;
1042
1043 err_put_helper:
1044         if (priv->helper4)
1045                 nf_conntrack_helper_put(priv->helper4);
1046         if (priv->helper6)
1047                 nf_conntrack_helper_put(priv->helper6);
1048         return err;
1049 }
1050
1051 static void nft_ct_helper_obj_destroy(const struct nft_ctx *ctx,
1052                                       struct nft_object *obj)
1053 {
1054         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1055
1056         if (priv->helper4)
1057                 nf_conntrack_helper_put(priv->helper4);
1058         if (priv->helper6)
1059                 nf_conntrack_helper_put(priv->helper6);
1060
1061         nf_ct_netns_put(ctx->net, ctx->family);
1062 }
1063
1064 static void nft_ct_helper_obj_eval(struct nft_object *obj,
1065                                    struct nft_regs *regs,
1066                                    const struct nft_pktinfo *pkt)
1067 {
1068         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1069         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
1070         struct nf_conntrack_helper *to_assign = NULL;
1071         struct nf_conn_help *help;
1072
1073         if (!ct ||
1074             nf_ct_is_confirmed(ct) ||
1075             nf_ct_is_template(ct) ||
1076             priv->l4proto != nf_ct_protonum(ct))
1077                 return;
1078
1079         switch (nf_ct_l3num(ct)) {
1080         case NFPROTO_IPV4:
1081                 to_assign = priv->helper4;
1082                 break;
1083         case NFPROTO_IPV6:
1084                 to_assign = priv->helper6;
1085                 break;
1086         default:
1087                 WARN_ON_ONCE(1);
1088                 return;
1089         }
1090
1091         if (!to_assign)
1092                 return;
1093
1094         if (test_bit(IPS_HELPER_BIT, &ct->status))
1095                 return;
1096
1097         help = nf_ct_helper_ext_add(ct, GFP_ATOMIC);
1098         if (help) {
1099                 rcu_assign_pointer(help->helper, to_assign);
1100                 set_bit(IPS_HELPER_BIT, &ct->status);
1101         }
1102 }
1103
1104 static int nft_ct_helper_obj_dump(struct sk_buff *skb,
1105                                   struct nft_object *obj, bool reset)
1106 {
1107         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1108         const struct nf_conntrack_helper *helper;
1109         u16 family;
1110
1111         if (priv->helper4 && priv->helper6) {
1112                 family = NFPROTO_INET;
1113                 helper = priv->helper4;
1114         } else if (priv->helper6) {
1115                 family = NFPROTO_IPV6;
1116                 helper = priv->helper6;
1117         } else {
1118                 family = NFPROTO_IPV4;
1119                 helper = priv->helper4;
1120         }
1121
1122         if (nla_put_string(skb, NFTA_CT_HELPER_NAME, helper->name))
1123                 return -1;
1124
1125         if (nla_put_u8(skb, NFTA_CT_HELPER_L4PROTO, priv->l4proto))
1126                 return -1;
1127
1128         if (nla_put_be16(skb, NFTA_CT_HELPER_L3PROTO, htons(family)))
1129                 return -1;
1130
1131         return 0;
1132 }
1133
1134 static const struct nla_policy nft_ct_helper_policy[NFTA_CT_HELPER_MAX + 1] = {
1135         [NFTA_CT_HELPER_NAME] = { .type = NLA_STRING,
1136                                   .len = NF_CT_HELPER_NAME_LEN - 1 },
1137         [NFTA_CT_HELPER_L3PROTO] = { .type = NLA_U16 },
1138         [NFTA_CT_HELPER_L4PROTO] = { .type = NLA_U8 },
1139 };
1140
1141 static struct nft_object_type nft_ct_helper_obj_type;
1142 static const struct nft_object_ops nft_ct_helper_obj_ops = {
1143         .type           = &nft_ct_helper_obj_type,
1144         .size           = sizeof(struct nft_ct_helper_obj),
1145         .eval           = nft_ct_helper_obj_eval,
1146         .init           = nft_ct_helper_obj_init,
1147         .destroy        = nft_ct_helper_obj_destroy,
1148         .dump           = nft_ct_helper_obj_dump,
1149 };
1150
1151 static struct nft_object_type nft_ct_helper_obj_type __read_mostly = {
1152         .type           = NFT_OBJECT_CT_HELPER,
1153         .ops            = &nft_ct_helper_obj_ops,
1154         .maxattr        = NFTA_CT_HELPER_MAX,
1155         .policy         = nft_ct_helper_policy,
1156         .owner          = THIS_MODULE,
1157 };
1158
1159 static int __init nft_ct_module_init(void)
1160 {
1161         int err;
1162
1163         BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE > NFT_REG_SIZE);
1164
1165         err = nft_register_expr(&nft_ct_type);
1166         if (err < 0)
1167                 return err;
1168
1169         err = nft_register_expr(&nft_notrack_type);
1170         if (err < 0)
1171                 goto err1;
1172
1173         err = nft_register_obj(&nft_ct_helper_obj_type);
1174         if (err < 0)
1175                 goto err2;
1176 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1177         err = nft_register_obj(&nft_ct_timeout_obj_type);
1178         if (err < 0)
1179                 goto err3;
1180 #endif
1181         return 0;
1182
1183 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1184 err3:
1185         nft_unregister_obj(&nft_ct_helper_obj_type);
1186 #endif
1187 err2:
1188         nft_unregister_expr(&nft_notrack_type);
1189 err1:
1190         nft_unregister_expr(&nft_ct_type);
1191         return err;
1192 }
1193
1194 static void __exit nft_ct_module_exit(void)
1195 {
1196 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1197         nft_unregister_obj(&nft_ct_timeout_obj_type);
1198 #endif
1199         nft_unregister_obj(&nft_ct_helper_obj_type);
1200         nft_unregister_expr(&nft_notrack_type);
1201         nft_unregister_expr(&nft_ct_type);
1202 }
1203
1204 module_init(nft_ct_module_init);
1205 module_exit(nft_ct_module_exit);
1206
1207 MODULE_LICENSE("GPL");
1208 MODULE_AUTHOR("Patrick McHardy <kaber@trash.net>");
1209 MODULE_ALIAS_NFT_EXPR("ct");
1210 MODULE_ALIAS_NFT_EXPR("notrack");
1211 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_HELPER);
1212 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_TIMEOUT);