powerpc/tm: Fix restoring FP/VMX facility incorrectly on interrupts
[sfrench/cifs-2.6.git] / net / netfilter / ipset / ip_set_core.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
3  *                         Patrick Schaaf <bof@bof.de>
4  * Copyright (C) 2003-2013 Jozsef Kadlecsik <kadlec@netfilter.org>
5  */
6
7 /* Kernel module for IP set management */
8
9 #include <linux/init.h>
10 #include <linux/module.h>
11 #include <linux/moduleparam.h>
12 #include <linux/ip.h>
13 #include <linux/skbuff.h>
14 #include <linux/spinlock.h>
15 #include <linux/rculist.h>
16 #include <net/netlink.h>
17 #include <net/net_namespace.h>
18 #include <net/netns/generic.h>
19
20 #include <linux/netfilter.h>
21 #include <linux/netfilter/x_tables.h>
22 #include <linux/netfilter/nfnetlink.h>
23 #include <linux/netfilter/ipset/ip_set.h>
24
25 static LIST_HEAD(ip_set_type_list);             /* all registered set types */
26 static DEFINE_MUTEX(ip_set_type_mutex);         /* protects ip_set_type_list */
27 static DEFINE_RWLOCK(ip_set_ref_lock);          /* protects the set refs */
28
29 struct ip_set_net {
30         struct ip_set * __rcu *ip_set_list;     /* all individual sets */
31         ip_set_id_t     ip_set_max;     /* max number of sets */
32         bool            is_deleted;     /* deleted by ip_set_net_exit */
33         bool            is_destroyed;   /* all sets are destroyed */
34 };
35
36 static unsigned int ip_set_net_id __read_mostly;
37
38 static inline struct ip_set_net *ip_set_pernet(struct net *net)
39 {
40         return net_generic(net, ip_set_net_id);
41 }
42
43 #define IP_SET_INC      64
44 #define STRNCMP(a, b)   (strncmp(a, b, IPSET_MAXNAMELEN) == 0)
45
46 static unsigned int max_sets;
47
48 module_param(max_sets, int, 0600);
49 MODULE_PARM_DESC(max_sets, "maximal number of sets");
50 MODULE_LICENSE("GPL");
51 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@netfilter.org>");
52 MODULE_DESCRIPTION("core IP set support");
53 MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
54
55 /* When the nfnl mutex or ip_set_ref_lock is held: */
56 #define ip_set_dereference(p)           \
57         rcu_dereference_protected(p,    \
58                 lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET) || \
59                 lockdep_is_held(&ip_set_ref_lock))
60 #define ip_set(inst, id)                \
61         ip_set_dereference((inst)->ip_set_list)[id]
62 #define ip_set_ref_netlink(inst,id)     \
63         rcu_dereference_raw((inst)->ip_set_list)[id]
64
65 /* The set types are implemented in modules and registered set types
66  * can be found in ip_set_type_list. Adding/deleting types is
67  * serialized by ip_set_type_mutex.
68  */
69
70 static inline void
71 ip_set_type_lock(void)
72 {
73         mutex_lock(&ip_set_type_mutex);
74 }
75
76 static inline void
77 ip_set_type_unlock(void)
78 {
79         mutex_unlock(&ip_set_type_mutex);
80 }
81
82 /* Register and deregister settype */
83
84 static struct ip_set_type *
85 find_set_type(const char *name, u8 family, u8 revision)
86 {
87         struct ip_set_type *type;
88
89         list_for_each_entry_rcu(type, &ip_set_type_list, list)
90                 if (STRNCMP(type->name, name) &&
91                     (type->family == family ||
92                      type->family == NFPROTO_UNSPEC) &&
93                     revision >= type->revision_min &&
94                     revision <= type->revision_max)
95                         return type;
96         return NULL;
97 }
98
99 /* Unlock, try to load a set type module and lock again */
100 static bool
101 load_settype(const char *name)
102 {
103         nfnl_unlock(NFNL_SUBSYS_IPSET);
104         pr_debug("try to load ip_set_%s\n", name);
105         if (request_module("ip_set_%s", name) < 0) {
106                 pr_warn("Can't find ip_set type %s\n", name);
107                 nfnl_lock(NFNL_SUBSYS_IPSET);
108                 return false;
109         }
110         nfnl_lock(NFNL_SUBSYS_IPSET);
111         return true;
112 }
113
114 /* Find a set type and reference it */
115 #define find_set_type_get(name, family, revision, found)        \
116         __find_set_type_get(name, family, revision, found, false)
117
118 static int
119 __find_set_type_get(const char *name, u8 family, u8 revision,
120                     struct ip_set_type **found, bool retry)
121 {
122         struct ip_set_type *type;
123         int err;
124
125         if (retry && !load_settype(name))
126                 return -IPSET_ERR_FIND_TYPE;
127
128         rcu_read_lock();
129         *found = find_set_type(name, family, revision);
130         if (*found) {
131                 err = !try_module_get((*found)->me) ? -EFAULT : 0;
132                 goto unlock;
133         }
134         /* Make sure the type is already loaded
135          * but we don't support the revision
136          */
137         list_for_each_entry_rcu(type, &ip_set_type_list, list)
138                 if (STRNCMP(type->name, name)) {
139                         err = -IPSET_ERR_FIND_TYPE;
140                         goto unlock;
141                 }
142         rcu_read_unlock();
143
144         return retry ? -IPSET_ERR_FIND_TYPE :
145                 __find_set_type_get(name, family, revision, found, true);
146
147 unlock:
148         rcu_read_unlock();
149         return err;
150 }
151
152 /* Find a given set type by name and family.
153  * If we succeeded, the supported minimal and maximum revisions are
154  * filled out.
155  */
156 #define find_set_type_minmax(name, family, min, max) \
157         __find_set_type_minmax(name, family, min, max, false)
158
159 static int
160 __find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max,
161                        bool retry)
162 {
163         struct ip_set_type *type;
164         bool found = false;
165
166         if (retry && !load_settype(name))
167                 return -IPSET_ERR_FIND_TYPE;
168
169         *min = 255; *max = 0;
170         rcu_read_lock();
171         list_for_each_entry_rcu(type, &ip_set_type_list, list)
172                 if (STRNCMP(type->name, name) &&
173                     (type->family == family ||
174                      type->family == NFPROTO_UNSPEC)) {
175                         found = true;
176                         if (type->revision_min < *min)
177                                 *min = type->revision_min;
178                         if (type->revision_max > *max)
179                                 *max = type->revision_max;
180                 }
181         rcu_read_unlock();
182         if (found)
183                 return 0;
184
185         return retry ? -IPSET_ERR_FIND_TYPE :
186                 __find_set_type_minmax(name, family, min, max, true);
187 }
188
189 #define family_name(f)  ((f) == NFPROTO_IPV4 ? "inet" : \
190                          (f) == NFPROTO_IPV6 ? "inet6" : "any")
191
192 /* Register a set type structure. The type is identified by
193  * the unique triple of name, family and revision.
194  */
195 int
196 ip_set_type_register(struct ip_set_type *type)
197 {
198         int ret = 0;
199
200         if (type->protocol != IPSET_PROTOCOL) {
201                 pr_warn("ip_set type %s, family %s, revision %u:%u uses wrong protocol version %u (want %u)\n",
202                         type->name, family_name(type->family),
203                         type->revision_min, type->revision_max,
204                         type->protocol, IPSET_PROTOCOL);
205                 return -EINVAL;
206         }
207
208         ip_set_type_lock();
209         if (find_set_type(type->name, type->family, type->revision_min)) {
210                 /* Duplicate! */
211                 pr_warn("ip_set type %s, family %s with revision min %u already registered!\n",
212                         type->name, family_name(type->family),
213                         type->revision_min);
214                 ip_set_type_unlock();
215                 return -EINVAL;
216         }
217         list_add_rcu(&type->list, &ip_set_type_list);
218         pr_debug("type %s, family %s, revision %u:%u registered.\n",
219                  type->name, family_name(type->family),
220                  type->revision_min, type->revision_max);
221         ip_set_type_unlock();
222
223         return ret;
224 }
225 EXPORT_SYMBOL_GPL(ip_set_type_register);
226
227 /* Unregister a set type. There's a small race with ip_set_create */
228 void
229 ip_set_type_unregister(struct ip_set_type *type)
230 {
231         ip_set_type_lock();
232         if (!find_set_type(type->name, type->family, type->revision_min)) {
233                 pr_warn("ip_set type %s, family %s with revision min %u not registered\n",
234                         type->name, family_name(type->family),
235                         type->revision_min);
236                 ip_set_type_unlock();
237                 return;
238         }
239         list_del_rcu(&type->list);
240         pr_debug("type %s, family %s with revision min %u unregistered.\n",
241                  type->name, family_name(type->family), type->revision_min);
242         ip_set_type_unlock();
243
244         synchronize_rcu();
245 }
246 EXPORT_SYMBOL_GPL(ip_set_type_unregister);
247
248 /* Utility functions */
249 void *
250 ip_set_alloc(size_t size)
251 {
252         void *members = NULL;
253
254         if (size < KMALLOC_MAX_SIZE)
255                 members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
256
257         if (members) {
258                 pr_debug("%p: allocated with kmalloc\n", members);
259                 return members;
260         }
261
262         members = vzalloc(size);
263         if (!members)
264                 return NULL;
265         pr_debug("%p: allocated with vmalloc\n", members);
266
267         return members;
268 }
269 EXPORT_SYMBOL_GPL(ip_set_alloc);
270
271 void
272 ip_set_free(void *members)
273 {
274         pr_debug("%p: free with %s\n", members,
275                  is_vmalloc_addr(members) ? "vfree" : "kfree");
276         kvfree(members);
277 }
278 EXPORT_SYMBOL_GPL(ip_set_free);
279
280 static inline bool
281 flag_nested(const struct nlattr *nla)
282 {
283         return nla->nla_type & NLA_F_NESTED;
284 }
285
286 static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
287         [IPSET_ATTR_IPADDR_IPV4]        = { .type = NLA_U32 },
288         [IPSET_ATTR_IPADDR_IPV6]        = { .type = NLA_BINARY,
289                                             .len = sizeof(struct in6_addr) },
290 };
291
292 int
293 ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
294 {
295         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
296
297         if (unlikely(!flag_nested(nla)))
298                 return -IPSET_ERR_PROTOCOL;
299         if (nla_parse_nested_deprecated(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy, NULL))
300                 return -IPSET_ERR_PROTOCOL;
301         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
302                 return -IPSET_ERR_PROTOCOL;
303
304         *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
305         return 0;
306 }
307 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
308
309 int
310 ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
311 {
312         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
313
314         if (unlikely(!flag_nested(nla)))
315                 return -IPSET_ERR_PROTOCOL;
316
317         if (nla_parse_nested_deprecated(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy, NULL))
318                 return -IPSET_ERR_PROTOCOL;
319         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
320                 return -IPSET_ERR_PROTOCOL;
321
322         memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
323                sizeof(struct in6_addr));
324         return 0;
325 }
326 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
327
328 typedef void (*destroyer)(struct ip_set *, void *);
329 /* ipset data extension types, in size order */
330
331 const struct ip_set_ext_type ip_set_extensions[] = {
332         [IPSET_EXT_ID_COUNTER] = {
333                 .type   = IPSET_EXT_COUNTER,
334                 .flag   = IPSET_FLAG_WITH_COUNTERS,
335                 .len    = sizeof(struct ip_set_counter),
336                 .align  = __alignof__(struct ip_set_counter),
337         },
338         [IPSET_EXT_ID_TIMEOUT] = {
339                 .type   = IPSET_EXT_TIMEOUT,
340                 .len    = sizeof(unsigned long),
341                 .align  = __alignof__(unsigned long),
342         },
343         [IPSET_EXT_ID_SKBINFO] = {
344                 .type   = IPSET_EXT_SKBINFO,
345                 .flag   = IPSET_FLAG_WITH_SKBINFO,
346                 .len    = sizeof(struct ip_set_skbinfo),
347                 .align  = __alignof__(struct ip_set_skbinfo),
348         },
349         [IPSET_EXT_ID_COMMENT] = {
350                 .type    = IPSET_EXT_COMMENT | IPSET_EXT_DESTROY,
351                 .flag    = IPSET_FLAG_WITH_COMMENT,
352                 .len     = sizeof(struct ip_set_comment),
353                 .align   = __alignof__(struct ip_set_comment),
354                 .destroy = (destroyer) ip_set_comment_free,
355         },
356 };
357 EXPORT_SYMBOL_GPL(ip_set_extensions);
358
359 static inline bool
360 add_extension(enum ip_set_ext_id id, u32 flags, struct nlattr *tb[])
361 {
362         return ip_set_extensions[id].flag ?
363                 (flags & ip_set_extensions[id].flag) :
364                 !!tb[IPSET_ATTR_TIMEOUT];
365 }
366
367 size_t
368 ip_set_elem_len(struct ip_set *set, struct nlattr *tb[], size_t len,
369                 size_t align)
370 {
371         enum ip_set_ext_id id;
372         u32 cadt_flags = 0;
373
374         if (tb[IPSET_ATTR_CADT_FLAGS])
375                 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
376         if (cadt_flags & IPSET_FLAG_WITH_FORCEADD)
377                 set->flags |= IPSET_CREATE_FLAG_FORCEADD;
378         if (!align)
379                 align = 1;
380         for (id = 0; id < IPSET_EXT_ID_MAX; id++) {
381                 if (!add_extension(id, cadt_flags, tb))
382                         continue;
383                 len = ALIGN(len, ip_set_extensions[id].align);
384                 set->offset[id] = len;
385                 set->extensions |= ip_set_extensions[id].type;
386                 len += ip_set_extensions[id].len;
387         }
388         return ALIGN(len, align);
389 }
390 EXPORT_SYMBOL_GPL(ip_set_elem_len);
391
392 int
393 ip_set_get_extensions(struct ip_set *set, struct nlattr *tb[],
394                       struct ip_set_ext *ext)
395 {
396         u64 fullmark;
397
398         if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
399                      !ip_set_optattr_netorder(tb, IPSET_ATTR_PACKETS) ||
400                      !ip_set_optattr_netorder(tb, IPSET_ATTR_BYTES) ||
401                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBMARK) ||
402                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBPRIO) ||
403                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBQUEUE)))
404                 return -IPSET_ERR_PROTOCOL;
405
406         if (tb[IPSET_ATTR_TIMEOUT]) {
407                 if (!SET_WITH_TIMEOUT(set))
408                         return -IPSET_ERR_TIMEOUT;
409                 ext->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
410         }
411         if (tb[IPSET_ATTR_BYTES] || tb[IPSET_ATTR_PACKETS]) {
412                 if (!SET_WITH_COUNTER(set))
413                         return -IPSET_ERR_COUNTER;
414                 if (tb[IPSET_ATTR_BYTES])
415                         ext->bytes = be64_to_cpu(nla_get_be64(
416                                                  tb[IPSET_ATTR_BYTES]));
417                 if (tb[IPSET_ATTR_PACKETS])
418                         ext->packets = be64_to_cpu(nla_get_be64(
419                                                    tb[IPSET_ATTR_PACKETS]));
420         }
421         if (tb[IPSET_ATTR_COMMENT]) {
422                 if (!SET_WITH_COMMENT(set))
423                         return -IPSET_ERR_COMMENT;
424                 ext->comment = ip_set_comment_uget(tb[IPSET_ATTR_COMMENT]);
425         }
426         if (tb[IPSET_ATTR_SKBMARK]) {
427                 if (!SET_WITH_SKBINFO(set))
428                         return -IPSET_ERR_SKBINFO;
429                 fullmark = be64_to_cpu(nla_get_be64(tb[IPSET_ATTR_SKBMARK]));
430                 ext->skbinfo.skbmark = fullmark >> 32;
431                 ext->skbinfo.skbmarkmask = fullmark & 0xffffffff;
432         }
433         if (tb[IPSET_ATTR_SKBPRIO]) {
434                 if (!SET_WITH_SKBINFO(set))
435                         return -IPSET_ERR_SKBINFO;
436                 ext->skbinfo.skbprio =
437                         be32_to_cpu(nla_get_be32(tb[IPSET_ATTR_SKBPRIO]));
438         }
439         if (tb[IPSET_ATTR_SKBQUEUE]) {
440                 if (!SET_WITH_SKBINFO(set))
441                         return -IPSET_ERR_SKBINFO;
442                 ext->skbinfo.skbqueue =
443                         be16_to_cpu(nla_get_be16(tb[IPSET_ATTR_SKBQUEUE]));
444         }
445         return 0;
446 }
447 EXPORT_SYMBOL_GPL(ip_set_get_extensions);
448
449 int
450 ip_set_put_extensions(struct sk_buff *skb, const struct ip_set *set,
451                       const void *e, bool active)
452 {
453         if (SET_WITH_TIMEOUT(set)) {
454                 unsigned long *timeout = ext_timeout(e, set);
455
456                 if (nla_put_net32(skb, IPSET_ATTR_TIMEOUT,
457                         htonl(active ? ip_set_timeout_get(timeout)
458                                 : *timeout)))
459                         return -EMSGSIZE;
460         }
461         if (SET_WITH_COUNTER(set) &&
462             ip_set_put_counter(skb, ext_counter(e, set)))
463                 return -EMSGSIZE;
464         if (SET_WITH_COMMENT(set) &&
465             ip_set_put_comment(skb, ext_comment(e, set)))
466                 return -EMSGSIZE;
467         if (SET_WITH_SKBINFO(set) &&
468             ip_set_put_skbinfo(skb, ext_skbinfo(e, set)))
469                 return -EMSGSIZE;
470         return 0;
471 }
472 EXPORT_SYMBOL_GPL(ip_set_put_extensions);
473
474 bool
475 ip_set_match_extensions(struct ip_set *set, const struct ip_set_ext *ext,
476                         struct ip_set_ext *mext, u32 flags, void *data)
477 {
478         if (SET_WITH_TIMEOUT(set) &&
479             ip_set_timeout_expired(ext_timeout(data, set)))
480                 return false;
481         if (SET_WITH_COUNTER(set)) {
482                 struct ip_set_counter *counter = ext_counter(data, set);
483
484                 if (flags & IPSET_FLAG_MATCH_COUNTERS &&
485                     !(ip_set_match_counter(ip_set_get_packets(counter),
486                                 mext->packets, mext->packets_op) &&
487                       ip_set_match_counter(ip_set_get_bytes(counter),
488                                 mext->bytes, mext->bytes_op)))
489                         return false;
490                 ip_set_update_counter(counter, ext, flags);
491         }
492         if (SET_WITH_SKBINFO(set))
493                 ip_set_get_skbinfo(ext_skbinfo(data, set),
494                                    ext, mext, flags);
495         return true;
496 }
497 EXPORT_SYMBOL_GPL(ip_set_match_extensions);
498
499 /* Creating/destroying/renaming/swapping affect the existence and
500  * the properties of a set. All of these can be executed from userspace
501  * only and serialized by the nfnl mutex indirectly from nfnetlink.
502  *
503  * Sets are identified by their index in ip_set_list and the index
504  * is used by the external references (set/SET netfilter modules).
505  *
506  * The set behind an index may change by swapping only, from userspace.
507  */
508
509 static inline void
510 __ip_set_get(struct ip_set *set)
511 {
512         write_lock_bh(&ip_set_ref_lock);
513         set->ref++;
514         write_unlock_bh(&ip_set_ref_lock);
515 }
516
517 static inline void
518 __ip_set_put(struct ip_set *set)
519 {
520         write_lock_bh(&ip_set_ref_lock);
521         BUG_ON(set->ref == 0);
522         set->ref--;
523         write_unlock_bh(&ip_set_ref_lock);
524 }
525
526 /* set->ref can be swapped out by ip_set_swap, netlink events (like dump) need
527  * a separate reference counter
528  */
529 static inline void
530 __ip_set_put_netlink(struct ip_set *set)
531 {
532         write_lock_bh(&ip_set_ref_lock);
533         BUG_ON(set->ref_netlink == 0);
534         set->ref_netlink--;
535         write_unlock_bh(&ip_set_ref_lock);
536 }
537
538 /* Add, del and test set entries from kernel.
539  *
540  * The set behind the index must exist and must be referenced
541  * so it can't be destroyed (or changed) under our foot.
542  */
543
544 static inline struct ip_set *
545 ip_set_rcu_get(struct net *net, ip_set_id_t index)
546 {
547         struct ip_set *set;
548         struct ip_set_net *inst = ip_set_pernet(net);
549
550         rcu_read_lock();
551         /* ip_set_list itself needs to be protected */
552         set = rcu_dereference(inst->ip_set_list)[index];
553         rcu_read_unlock();
554
555         return set;
556 }
557
558 int
559 ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
560             const struct xt_action_param *par, struct ip_set_adt_opt *opt)
561 {
562         struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
563         int ret = 0;
564
565         BUG_ON(!set);
566         pr_debug("set %s, index %u\n", set->name, index);
567
568         if (opt->dim < set->type->dimension ||
569             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
570                 return 0;
571
572         rcu_read_lock_bh();
573         ret = set->variant->kadt(set, skb, par, IPSET_TEST, opt);
574         rcu_read_unlock_bh();
575
576         if (ret == -EAGAIN) {
577                 /* Type requests element to be completed */
578                 pr_debug("element must be completed, ADD is triggered\n");
579                 spin_lock_bh(&set->lock);
580                 set->variant->kadt(set, skb, par, IPSET_ADD, opt);
581                 spin_unlock_bh(&set->lock);
582                 ret = 1;
583         } else {
584                 /* --return-nomatch: invert matched element */
585                 if ((opt->cmdflags & IPSET_FLAG_RETURN_NOMATCH) &&
586                     (set->type->features & IPSET_TYPE_NOMATCH) &&
587                     (ret > 0 || ret == -ENOTEMPTY))
588                         ret = -ret;
589         }
590
591         /* Convert error codes to nomatch */
592         return (ret < 0 ? 0 : ret);
593 }
594 EXPORT_SYMBOL_GPL(ip_set_test);
595
596 int
597 ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
598            const struct xt_action_param *par, struct ip_set_adt_opt *opt)
599 {
600         struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
601         int ret;
602
603         BUG_ON(!set);
604         pr_debug("set %s, index %u\n", set->name, index);
605
606         if (opt->dim < set->type->dimension ||
607             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
608                 return -IPSET_ERR_TYPE_MISMATCH;
609
610         spin_lock_bh(&set->lock);
611         ret = set->variant->kadt(set, skb, par, IPSET_ADD, opt);
612         spin_unlock_bh(&set->lock);
613
614         return ret;
615 }
616 EXPORT_SYMBOL_GPL(ip_set_add);
617
618 int
619 ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
620            const struct xt_action_param *par, struct ip_set_adt_opt *opt)
621 {
622         struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
623         int ret = 0;
624
625         BUG_ON(!set);
626         pr_debug("set %s, index %u\n", set->name, index);
627
628         if (opt->dim < set->type->dimension ||
629             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
630                 return -IPSET_ERR_TYPE_MISMATCH;
631
632         spin_lock_bh(&set->lock);
633         ret = set->variant->kadt(set, skb, par, IPSET_DEL, opt);
634         spin_unlock_bh(&set->lock);
635
636         return ret;
637 }
638 EXPORT_SYMBOL_GPL(ip_set_del);
639
640 /* Find set by name, reference it once. The reference makes sure the
641  * thing pointed to, does not go away under our feet.
642  *
643  */
644 ip_set_id_t
645 ip_set_get_byname(struct net *net, const char *name, struct ip_set **set)
646 {
647         ip_set_id_t i, index = IPSET_INVALID_ID;
648         struct ip_set *s;
649         struct ip_set_net *inst = ip_set_pernet(net);
650
651         rcu_read_lock();
652         for (i = 0; i < inst->ip_set_max; i++) {
653                 s = rcu_dereference(inst->ip_set_list)[i];
654                 if (s && STRNCMP(s->name, name)) {
655                         __ip_set_get(s);
656                         index = i;
657                         *set = s;
658                         break;
659                 }
660         }
661         rcu_read_unlock();
662
663         return index;
664 }
665 EXPORT_SYMBOL_GPL(ip_set_get_byname);
666
667 /* If the given set pointer points to a valid set, decrement
668  * reference count by 1. The caller shall not assume the index
669  * to be valid, after calling this function.
670  *
671  */
672
673 static inline void
674 __ip_set_put_byindex(struct ip_set_net *inst, ip_set_id_t index)
675 {
676         struct ip_set *set;
677
678         rcu_read_lock();
679         set = rcu_dereference(inst->ip_set_list)[index];
680         if (set)
681                 __ip_set_put(set);
682         rcu_read_unlock();
683 }
684
685 void
686 ip_set_put_byindex(struct net *net, ip_set_id_t index)
687 {
688         struct ip_set_net *inst = ip_set_pernet(net);
689
690         __ip_set_put_byindex(inst, index);
691 }
692 EXPORT_SYMBOL_GPL(ip_set_put_byindex);
693
694 /* Get the name of a set behind a set index.
695  * Set itself is protected by RCU, but its name isn't: to protect against
696  * renaming, grab ip_set_ref_lock as reader (see ip_set_rename()) and copy the
697  * name.
698  */
699 void
700 ip_set_name_byindex(struct net *net, ip_set_id_t index, char *name)
701 {
702         struct ip_set *set = ip_set_rcu_get(net, index);
703
704         BUG_ON(!set);
705
706         read_lock_bh(&ip_set_ref_lock);
707         strncpy(name, set->name, IPSET_MAXNAMELEN);
708         read_unlock_bh(&ip_set_ref_lock);
709 }
710 EXPORT_SYMBOL_GPL(ip_set_name_byindex);
711
712 /* Routines to call by external subsystems, which do not
713  * call nfnl_lock for us.
714  */
715
716 /* Find set by index, reference it once. The reference makes sure the
717  * thing pointed to, does not go away under our feet.
718  *
719  * The nfnl mutex is used in the function.
720  */
721 ip_set_id_t
722 ip_set_nfnl_get_byindex(struct net *net, ip_set_id_t index)
723 {
724         struct ip_set *set;
725         struct ip_set_net *inst = ip_set_pernet(net);
726
727         if (index >= inst->ip_set_max)
728                 return IPSET_INVALID_ID;
729
730         nfnl_lock(NFNL_SUBSYS_IPSET);
731         set = ip_set(inst, index);
732         if (set)
733                 __ip_set_get(set);
734         else
735                 index = IPSET_INVALID_ID;
736         nfnl_unlock(NFNL_SUBSYS_IPSET);
737
738         return index;
739 }
740 EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
741
742 /* If the given set pointer points to a valid set, decrement
743  * reference count by 1. The caller shall not assume the index
744  * to be valid, after calling this function.
745  *
746  * The nfnl mutex is used in the function.
747  */
748 void
749 ip_set_nfnl_put(struct net *net, ip_set_id_t index)
750 {
751         struct ip_set *set;
752         struct ip_set_net *inst = ip_set_pernet(net);
753
754         nfnl_lock(NFNL_SUBSYS_IPSET);
755         if (!inst->is_deleted) { /* already deleted from ip_set_net_exit() */
756                 set = ip_set(inst, index);
757                 if (set)
758                         __ip_set_put(set);
759         }
760         nfnl_unlock(NFNL_SUBSYS_IPSET);
761 }
762 EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
763
764 /* Communication protocol with userspace over netlink.
765  *
766  * The commands are serialized by the nfnl mutex.
767  */
768
769 static inline u8 protocol(const struct nlattr * const tb[])
770 {
771         return nla_get_u8(tb[IPSET_ATTR_PROTOCOL]);
772 }
773
774 static inline bool
775 protocol_failed(const struct nlattr * const tb[])
776 {
777         return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) != IPSET_PROTOCOL;
778 }
779
780 static inline bool
781 protocol_min_failed(const struct nlattr * const tb[])
782 {
783         return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) < IPSET_PROTOCOL_MIN;
784 }
785
786 static inline u32
787 flag_exist(const struct nlmsghdr *nlh)
788 {
789         return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
790 }
791
792 static struct nlmsghdr *
793 start_msg(struct sk_buff *skb, u32 portid, u32 seq, unsigned int flags,
794           enum ipset_cmd cmd)
795 {
796         struct nlmsghdr *nlh;
797         struct nfgenmsg *nfmsg;
798
799         nlh = nlmsg_put(skb, portid, seq, nfnl_msg_type(NFNL_SUBSYS_IPSET, cmd),
800                         sizeof(*nfmsg), flags);
801         if (!nlh)
802                 return NULL;
803
804         nfmsg = nlmsg_data(nlh);
805         nfmsg->nfgen_family = NFPROTO_IPV4;
806         nfmsg->version = NFNETLINK_V0;
807         nfmsg->res_id = 0;
808
809         return nlh;
810 }
811
812 /* Create a set */
813
814 static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
815         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
816         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
817                                     .len = IPSET_MAXNAMELEN - 1 },
818         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
819                                     .len = IPSET_MAXNAMELEN - 1},
820         [IPSET_ATTR_REVISION]   = { .type = NLA_U8 },
821         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
822         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
823 };
824
825 static struct ip_set *
826 find_set_and_id(struct ip_set_net *inst, const char *name, ip_set_id_t *id)
827 {
828         struct ip_set *set = NULL;
829         ip_set_id_t i;
830
831         *id = IPSET_INVALID_ID;
832         for (i = 0; i < inst->ip_set_max; i++) {
833                 set = ip_set(inst, i);
834                 if (set && STRNCMP(set->name, name)) {
835                         *id = i;
836                         break;
837                 }
838         }
839         return (*id == IPSET_INVALID_ID ? NULL : set);
840 }
841
842 static inline struct ip_set *
843 find_set(struct ip_set_net *inst, const char *name)
844 {
845         ip_set_id_t id;
846
847         return find_set_and_id(inst, name, &id);
848 }
849
850 static int
851 find_free_id(struct ip_set_net *inst, const char *name, ip_set_id_t *index,
852              struct ip_set **set)
853 {
854         struct ip_set *s;
855         ip_set_id_t i;
856
857         *index = IPSET_INVALID_ID;
858         for (i = 0;  i < inst->ip_set_max; i++) {
859                 s = ip_set(inst, i);
860                 if (!s) {
861                         if (*index == IPSET_INVALID_ID)
862                                 *index = i;
863                 } else if (STRNCMP(name, s->name)) {
864                         /* Name clash */
865                         *set = s;
866                         return -EEXIST;
867                 }
868         }
869         if (*index == IPSET_INVALID_ID)
870                 /* No free slot remained */
871                 return -IPSET_ERR_MAX_SETS;
872         return 0;
873 }
874
875 static int ip_set_none(struct net *net, struct sock *ctnl, struct sk_buff *skb,
876                        const struct nlmsghdr *nlh,
877                        const struct nlattr * const attr[],
878                        struct netlink_ext_ack *extack)
879 {
880         return -EOPNOTSUPP;
881 }
882
883 static int ip_set_create(struct net *net, struct sock *ctnl,
884                          struct sk_buff *skb, const struct nlmsghdr *nlh,
885                          const struct nlattr * const attr[],
886                          struct netlink_ext_ack *extack)
887 {
888         struct ip_set_net *inst = ip_set_pernet(net);
889         struct ip_set *set, *clash = NULL;
890         ip_set_id_t index = IPSET_INVALID_ID;
891         struct nlattr *tb[IPSET_ATTR_CREATE_MAX + 1] = {};
892         const char *name, *typename;
893         u8 family, revision;
894         u32 flags = flag_exist(nlh);
895         int ret = 0;
896
897         if (unlikely(protocol_min_failed(attr) ||
898                      !attr[IPSET_ATTR_SETNAME] ||
899                      !attr[IPSET_ATTR_TYPENAME] ||
900                      !attr[IPSET_ATTR_REVISION] ||
901                      !attr[IPSET_ATTR_FAMILY] ||
902                      (attr[IPSET_ATTR_DATA] &&
903                       !flag_nested(attr[IPSET_ATTR_DATA]))))
904                 return -IPSET_ERR_PROTOCOL;
905
906         name = nla_data(attr[IPSET_ATTR_SETNAME]);
907         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
908         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
909         revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
910         pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
911                  name, typename, family_name(family), revision);
912
913         /* First, and without any locks, allocate and initialize
914          * a normal base set structure.
915          */
916         set = kzalloc(sizeof(*set), GFP_KERNEL);
917         if (!set)
918                 return -ENOMEM;
919         spin_lock_init(&set->lock);
920         strlcpy(set->name, name, IPSET_MAXNAMELEN);
921         set->family = family;
922         set->revision = revision;
923
924         /* Next, check that we know the type, and take
925          * a reference on the type, to make sure it stays available
926          * while constructing our new set.
927          *
928          * After referencing the type, we try to create the type
929          * specific part of the set without holding any locks.
930          */
931         ret = find_set_type_get(typename, family, revision, &set->type);
932         if (ret)
933                 goto out;
934
935         /* Without holding any locks, create private part. */
936         if (attr[IPSET_ATTR_DATA] &&
937             nla_parse_nested_deprecated(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA], set->type->create_policy, NULL)) {
938                 ret = -IPSET_ERR_PROTOCOL;
939                 goto put_out;
940         }
941
942         ret = set->type->create(net, set, tb, flags);
943         if (ret != 0)
944                 goto put_out;
945
946         /* BTW, ret==0 here. */
947
948         /* Here, we have a valid, constructed set and we are protected
949          * by the nfnl mutex. Find the first free index in ip_set_list
950          * and check clashing.
951          */
952         ret = find_free_id(inst, set->name, &index, &clash);
953         if (ret == -EEXIST) {
954                 /* If this is the same set and requested, ignore error */
955                 if ((flags & IPSET_FLAG_EXIST) &&
956                     STRNCMP(set->type->name, clash->type->name) &&
957                     set->type->family == clash->type->family &&
958                     set->type->revision_min == clash->type->revision_min &&
959                     set->type->revision_max == clash->type->revision_max &&
960                     set->variant->same_set(set, clash))
961                         ret = 0;
962                 goto cleanup;
963         } else if (ret == -IPSET_ERR_MAX_SETS) {
964                 struct ip_set **list, **tmp;
965                 ip_set_id_t i = inst->ip_set_max + IP_SET_INC;
966
967                 if (i < inst->ip_set_max || i == IPSET_INVALID_ID)
968                         /* Wraparound */
969                         goto cleanup;
970
971                 list = kvcalloc(i, sizeof(struct ip_set *), GFP_KERNEL);
972                 if (!list)
973                         goto cleanup;
974                 /* nfnl mutex is held, both lists are valid */
975                 tmp = ip_set_dereference(inst->ip_set_list);
976                 memcpy(list, tmp, sizeof(struct ip_set *) * inst->ip_set_max);
977                 rcu_assign_pointer(inst->ip_set_list, list);
978                 /* Make sure all current packets have passed through */
979                 synchronize_net();
980                 /* Use new list */
981                 index = inst->ip_set_max;
982                 inst->ip_set_max = i;
983                 kvfree(tmp);
984                 ret = 0;
985         } else if (ret) {
986                 goto cleanup;
987         }
988
989         /* Finally! Add our shiny new set to the list, and be done. */
990         pr_debug("create: '%s' created with index %u!\n", set->name, index);
991         ip_set(inst, index) = set;
992
993         return ret;
994
995 cleanup:
996         set->variant->destroy(set);
997 put_out:
998         module_put(set->type->me);
999 out:
1000         kfree(set);
1001         return ret;
1002 }
1003
1004 /* Destroy sets */
1005
1006 static const struct nla_policy
1007 ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
1008         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1009         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1010                                     .len = IPSET_MAXNAMELEN - 1 },
1011 };
1012
1013 static void
1014 ip_set_destroy_set(struct ip_set *set)
1015 {
1016         pr_debug("set: %s\n",  set->name);
1017
1018         /* Must call it without holding any lock */
1019         set->variant->destroy(set);
1020         module_put(set->type->me);
1021         kfree(set);
1022 }
1023
1024 static int ip_set_destroy(struct net *net, struct sock *ctnl,
1025                           struct sk_buff *skb, const struct nlmsghdr *nlh,
1026                           const struct nlattr * const attr[],
1027                           struct netlink_ext_ack *extack)
1028 {
1029         struct ip_set_net *inst = ip_set_pernet(net);
1030         struct ip_set *s;
1031         ip_set_id_t i;
1032         int ret = 0;
1033
1034         if (unlikely(protocol_min_failed(attr)))
1035                 return -IPSET_ERR_PROTOCOL;
1036
1037         /* Must wait for flush to be really finished in list:set */
1038         rcu_barrier();
1039
1040         /* Commands are serialized and references are
1041          * protected by the ip_set_ref_lock.
1042          * External systems (i.e. xt_set) must call
1043          * ip_set_put|get_nfnl_* functions, that way we
1044          * can safely check references here.
1045          *
1046          * list:set timer can only decrement the reference
1047          * counter, so if it's already zero, we can proceed
1048          * without holding the lock.
1049          */
1050         read_lock_bh(&ip_set_ref_lock);
1051         if (!attr[IPSET_ATTR_SETNAME]) {
1052                 for (i = 0; i < inst->ip_set_max; i++) {
1053                         s = ip_set(inst, i);
1054                         if (s && (s->ref || s->ref_netlink)) {
1055                                 ret = -IPSET_ERR_BUSY;
1056                                 goto out;
1057                         }
1058                 }
1059                 inst->is_destroyed = true;
1060                 read_unlock_bh(&ip_set_ref_lock);
1061                 for (i = 0; i < inst->ip_set_max; i++) {
1062                         s = ip_set(inst, i);
1063                         if (s) {
1064                                 ip_set(inst, i) = NULL;
1065                                 ip_set_destroy_set(s);
1066                         }
1067                 }
1068                 /* Modified by ip_set_destroy() only, which is serialized */
1069                 inst->is_destroyed = false;
1070         } else {
1071                 s = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1072                                     &i);
1073                 if (!s) {
1074                         ret = -ENOENT;
1075                         goto out;
1076                 } else if (s->ref || s->ref_netlink) {
1077                         ret = -IPSET_ERR_BUSY;
1078                         goto out;
1079                 }
1080                 ip_set(inst, i) = NULL;
1081                 read_unlock_bh(&ip_set_ref_lock);
1082
1083                 ip_set_destroy_set(s);
1084         }
1085         return 0;
1086 out:
1087         read_unlock_bh(&ip_set_ref_lock);
1088         return ret;
1089 }
1090
1091 /* Flush sets */
1092
1093 static void
1094 ip_set_flush_set(struct ip_set *set)
1095 {
1096         pr_debug("set: %s\n",  set->name);
1097
1098         spin_lock_bh(&set->lock);
1099         set->variant->flush(set);
1100         spin_unlock_bh(&set->lock);
1101 }
1102
1103 static int ip_set_flush(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1104                         const struct nlmsghdr *nlh,
1105                         const struct nlattr * const attr[],
1106                         struct netlink_ext_ack *extack)
1107 {
1108         struct ip_set_net *inst = ip_set_pernet(net);
1109         struct ip_set *s;
1110         ip_set_id_t i;
1111
1112         if (unlikely(protocol_min_failed(attr)))
1113                 return -IPSET_ERR_PROTOCOL;
1114
1115         if (!attr[IPSET_ATTR_SETNAME]) {
1116                 for (i = 0; i < inst->ip_set_max; i++) {
1117                         s = ip_set(inst, i);
1118                         if (s)
1119                                 ip_set_flush_set(s);
1120                 }
1121         } else {
1122                 s = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1123                 if (!s)
1124                         return -ENOENT;
1125
1126                 ip_set_flush_set(s);
1127         }
1128
1129         return 0;
1130 }
1131
1132 /* Rename a set */
1133
1134 static const struct nla_policy
1135 ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
1136         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1137         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1138                                     .len = IPSET_MAXNAMELEN - 1 },
1139         [IPSET_ATTR_SETNAME2]   = { .type = NLA_NUL_STRING,
1140                                     .len = IPSET_MAXNAMELEN - 1 },
1141 };
1142
1143 static int ip_set_rename(struct net *net, struct sock *ctnl,
1144                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1145                          const struct nlattr * const attr[],
1146                          struct netlink_ext_ack *extack)
1147 {
1148         struct ip_set_net *inst = ip_set_pernet(net);
1149         struct ip_set *set, *s;
1150         const char *name2;
1151         ip_set_id_t i;
1152         int ret = 0;
1153
1154         if (unlikely(protocol_min_failed(attr) ||
1155                      !attr[IPSET_ATTR_SETNAME] ||
1156                      !attr[IPSET_ATTR_SETNAME2]))
1157                 return -IPSET_ERR_PROTOCOL;
1158
1159         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1160         if (!set)
1161                 return -ENOENT;
1162
1163         write_lock_bh(&ip_set_ref_lock);
1164         if (set->ref != 0) {
1165                 ret = -IPSET_ERR_REFERENCED;
1166                 goto out;
1167         }
1168
1169         name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
1170         for (i = 0; i < inst->ip_set_max; i++) {
1171                 s = ip_set(inst, i);
1172                 if (s && STRNCMP(s->name, name2)) {
1173                         ret = -IPSET_ERR_EXIST_SETNAME2;
1174                         goto out;
1175                 }
1176         }
1177         strncpy(set->name, name2, IPSET_MAXNAMELEN);
1178
1179 out:
1180         write_unlock_bh(&ip_set_ref_lock);
1181         return ret;
1182 }
1183
1184 /* Swap two sets so that name/index points to the other.
1185  * References and set names are also swapped.
1186  *
1187  * The commands are serialized by the nfnl mutex and references are
1188  * protected by the ip_set_ref_lock. The kernel interfaces
1189  * do not hold the mutex but the pointer settings are atomic
1190  * so the ip_set_list always contains valid pointers to the sets.
1191  */
1192
1193 static int ip_set_swap(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1194                        const struct nlmsghdr *nlh,
1195                        const struct nlattr * const attr[],
1196                        struct netlink_ext_ack *extack)
1197 {
1198         struct ip_set_net *inst = ip_set_pernet(net);
1199         struct ip_set *from, *to;
1200         ip_set_id_t from_id, to_id;
1201         char from_name[IPSET_MAXNAMELEN];
1202
1203         if (unlikely(protocol_min_failed(attr) ||
1204                      !attr[IPSET_ATTR_SETNAME] ||
1205                      !attr[IPSET_ATTR_SETNAME2]))
1206                 return -IPSET_ERR_PROTOCOL;
1207
1208         from = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1209                                &from_id);
1210         if (!from)
1211                 return -ENOENT;
1212
1213         to = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME2]),
1214                              &to_id);
1215         if (!to)
1216                 return -IPSET_ERR_EXIST_SETNAME2;
1217
1218         /* Features must not change.
1219          * Not an artifical restriction anymore, as we must prevent
1220          * possible loops created by swapping in setlist type of sets.
1221          */
1222         if (!(from->type->features == to->type->features &&
1223               from->family == to->family))
1224                 return -IPSET_ERR_TYPE_MISMATCH;
1225
1226         write_lock_bh(&ip_set_ref_lock);
1227
1228         if (from->ref_netlink || to->ref_netlink) {
1229                 write_unlock_bh(&ip_set_ref_lock);
1230                 return -EBUSY;
1231         }
1232
1233         strncpy(from_name, from->name, IPSET_MAXNAMELEN);
1234         strncpy(from->name, to->name, IPSET_MAXNAMELEN);
1235         strncpy(to->name, from_name, IPSET_MAXNAMELEN);
1236
1237         swap(from->ref, to->ref);
1238         ip_set(inst, from_id) = to;
1239         ip_set(inst, to_id) = from;
1240         write_unlock_bh(&ip_set_ref_lock);
1241
1242         return 0;
1243 }
1244
1245 /* List/save set data */
1246
1247 #define DUMP_INIT       0
1248 #define DUMP_ALL        1
1249 #define DUMP_ONE        2
1250 #define DUMP_LAST       3
1251
1252 #define DUMP_TYPE(arg)          (((u32)(arg)) & 0x0000FFFF)
1253 #define DUMP_FLAGS(arg)         (((u32)(arg)) >> 16)
1254
1255 static int
1256 ip_set_dump_done(struct netlink_callback *cb)
1257 {
1258         if (cb->args[IPSET_CB_ARG0]) {
1259                 struct ip_set_net *inst =
1260                         (struct ip_set_net *)cb->args[IPSET_CB_NET];
1261                 ip_set_id_t index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1262                 struct ip_set *set = ip_set_ref_netlink(inst, index);
1263
1264                 if (set->variant->uref)
1265                         set->variant->uref(set, cb, false);
1266                 pr_debug("release set %s\n", set->name);
1267                 __ip_set_put_netlink(set);
1268         }
1269         return 0;
1270 }
1271
1272 static inline void
1273 dump_attrs(struct nlmsghdr *nlh)
1274 {
1275         const struct nlattr *attr;
1276         int rem;
1277
1278         pr_debug("dump nlmsg\n");
1279         nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
1280                 pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
1281         }
1282 }
1283
1284 static int
1285 dump_init(struct netlink_callback *cb, struct ip_set_net *inst)
1286 {
1287         struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
1288         int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1289         struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1290         struct nlattr *attr = (void *)nlh + min_len;
1291         u32 dump_type;
1292         ip_set_id_t index;
1293         int ret;
1294
1295         ret = nla_parse_deprecated(cda, IPSET_ATTR_CMD_MAX, attr,
1296                                    nlh->nlmsg_len - min_len,
1297                                    ip_set_setname_policy, NULL);
1298         if (ret)
1299                 return ret;
1300
1301         cb->args[IPSET_CB_PROTO] = nla_get_u8(cda[IPSET_ATTR_PROTOCOL]);
1302         if (cda[IPSET_ATTR_SETNAME]) {
1303                 struct ip_set *set;
1304
1305                 set = find_set_and_id(inst, nla_data(cda[IPSET_ATTR_SETNAME]),
1306                                       &index);
1307                 if (!set)
1308                         return -ENOENT;
1309
1310                 dump_type = DUMP_ONE;
1311                 cb->args[IPSET_CB_INDEX] = index;
1312         } else {
1313                 dump_type = DUMP_ALL;
1314         }
1315
1316         if (cda[IPSET_ATTR_FLAGS]) {
1317                 u32 f = ip_set_get_h32(cda[IPSET_ATTR_FLAGS]);
1318
1319                 dump_type |= (f << 16);
1320         }
1321         cb->args[IPSET_CB_NET] = (unsigned long)inst;
1322         cb->args[IPSET_CB_DUMP] = dump_type;
1323
1324         return 0;
1325 }
1326
1327 static int
1328 ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
1329 {
1330         ip_set_id_t index = IPSET_INVALID_ID, max;
1331         struct ip_set *set = NULL;
1332         struct nlmsghdr *nlh = NULL;
1333         unsigned int flags = NETLINK_CB(cb->skb).portid ? NLM_F_MULTI : 0;
1334         struct ip_set_net *inst = ip_set_pernet(sock_net(skb->sk));
1335         u32 dump_type, dump_flags;
1336         bool is_destroyed;
1337         int ret = 0;
1338
1339         if (!cb->args[IPSET_CB_DUMP]) {
1340                 ret = dump_init(cb, inst);
1341                 if (ret < 0) {
1342                         nlh = nlmsg_hdr(cb->skb);
1343                         /* We have to create and send the error message
1344                          * manually :-(
1345                          */
1346                         if (nlh->nlmsg_flags & NLM_F_ACK)
1347                                 netlink_ack(cb->skb, nlh, ret, NULL);
1348                         return ret;
1349                 }
1350         }
1351
1352         if (cb->args[IPSET_CB_INDEX] >= inst->ip_set_max)
1353                 goto out;
1354
1355         dump_type = DUMP_TYPE(cb->args[IPSET_CB_DUMP]);
1356         dump_flags = DUMP_FLAGS(cb->args[IPSET_CB_DUMP]);
1357         max = dump_type == DUMP_ONE ? cb->args[IPSET_CB_INDEX] + 1
1358                                     : inst->ip_set_max;
1359 dump_last:
1360         pr_debug("dump type, flag: %u %u index: %ld\n",
1361                  dump_type, dump_flags, cb->args[IPSET_CB_INDEX]);
1362         for (; cb->args[IPSET_CB_INDEX] < max; cb->args[IPSET_CB_INDEX]++) {
1363                 index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1364                 write_lock_bh(&ip_set_ref_lock);
1365                 set = ip_set(inst, index);
1366                 is_destroyed = inst->is_destroyed;
1367                 if (!set || is_destroyed) {
1368                         write_unlock_bh(&ip_set_ref_lock);
1369                         if (dump_type == DUMP_ONE) {
1370                                 ret = -ENOENT;
1371                                 goto out;
1372                         }
1373                         if (is_destroyed) {
1374                                 /* All sets are just being destroyed */
1375                                 ret = 0;
1376                                 goto out;
1377                         }
1378                         continue;
1379                 }
1380                 /* When dumping all sets, we must dump "sorted"
1381                  * so that lists (unions of sets) are dumped last.
1382                  */
1383                 if (dump_type != DUMP_ONE &&
1384                     ((dump_type == DUMP_ALL) ==
1385                      !!(set->type->features & IPSET_DUMP_LAST))) {
1386                         write_unlock_bh(&ip_set_ref_lock);
1387                         continue;
1388                 }
1389                 pr_debug("List set: %s\n", set->name);
1390                 if (!cb->args[IPSET_CB_ARG0]) {
1391                         /* Start listing: make sure set won't be destroyed */
1392                         pr_debug("reference set\n");
1393                         set->ref_netlink++;
1394                 }
1395                 write_unlock_bh(&ip_set_ref_lock);
1396                 nlh = start_msg(skb, NETLINK_CB(cb->skb).portid,
1397                                 cb->nlh->nlmsg_seq, flags,
1398                                 IPSET_CMD_LIST);
1399                 if (!nlh) {
1400                         ret = -EMSGSIZE;
1401                         goto release_refcount;
1402                 }
1403                 if (nla_put_u8(skb, IPSET_ATTR_PROTOCOL,
1404                                cb->args[IPSET_CB_PROTO]) ||
1405                     nla_put_string(skb, IPSET_ATTR_SETNAME, set->name))
1406                         goto nla_put_failure;
1407                 if (dump_flags & IPSET_FLAG_LIST_SETNAME)
1408                         goto next_set;
1409                 switch (cb->args[IPSET_CB_ARG0]) {
1410                 case 0:
1411                         /* Core header data */
1412                         if (nla_put_string(skb, IPSET_ATTR_TYPENAME,
1413                                            set->type->name) ||
1414                             nla_put_u8(skb, IPSET_ATTR_FAMILY,
1415                                        set->family) ||
1416                             nla_put_u8(skb, IPSET_ATTR_REVISION,
1417                                        set->revision))
1418                                 goto nla_put_failure;
1419                         if (cb->args[IPSET_CB_PROTO] > IPSET_PROTOCOL_MIN &&
1420                             nla_put_net16(skb, IPSET_ATTR_INDEX, htons(index)))
1421                                 goto nla_put_failure;
1422                         ret = set->variant->head(set, skb);
1423                         if (ret < 0)
1424                                 goto release_refcount;
1425                         if (dump_flags & IPSET_FLAG_LIST_HEADER)
1426                                 goto next_set;
1427                         if (set->variant->uref)
1428                                 set->variant->uref(set, cb, true);
1429                         /* fall through */
1430                 default:
1431                         ret = set->variant->list(set, skb, cb);
1432                         if (!cb->args[IPSET_CB_ARG0])
1433                                 /* Set is done, proceed with next one */
1434                                 goto next_set;
1435                         goto release_refcount;
1436                 }
1437         }
1438         /* If we dump all sets, continue with dumping last ones */
1439         if (dump_type == DUMP_ALL) {
1440                 dump_type = DUMP_LAST;
1441                 cb->args[IPSET_CB_DUMP] = dump_type | (dump_flags << 16);
1442                 cb->args[IPSET_CB_INDEX] = 0;
1443                 if (set && set->variant->uref)
1444                         set->variant->uref(set, cb, false);
1445                 goto dump_last;
1446         }
1447         goto out;
1448
1449 nla_put_failure:
1450         ret = -EFAULT;
1451 next_set:
1452         if (dump_type == DUMP_ONE)
1453                 cb->args[IPSET_CB_INDEX] = IPSET_INVALID_ID;
1454         else
1455                 cb->args[IPSET_CB_INDEX]++;
1456 release_refcount:
1457         /* If there was an error or set is done, release set */
1458         if (ret || !cb->args[IPSET_CB_ARG0]) {
1459                 set = ip_set_ref_netlink(inst, index);
1460                 if (set->variant->uref)
1461                         set->variant->uref(set, cb, false);
1462                 pr_debug("release set %s\n", set->name);
1463                 __ip_set_put_netlink(set);
1464                 cb->args[IPSET_CB_ARG0] = 0;
1465         }
1466 out:
1467         if (nlh) {
1468                 nlmsg_end(skb, nlh);
1469                 pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1470                 dump_attrs(nlh);
1471         }
1472
1473         return ret < 0 ? ret : skb->len;
1474 }
1475
1476 static int ip_set_dump(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1477                        const struct nlmsghdr *nlh,
1478                        const struct nlattr * const attr[],
1479                        struct netlink_ext_ack *extack)
1480 {
1481         if (unlikely(protocol_min_failed(attr)))
1482                 return -IPSET_ERR_PROTOCOL;
1483
1484         {
1485                 struct netlink_dump_control c = {
1486                         .dump = ip_set_dump_start,
1487                         .done = ip_set_dump_done,
1488                 };
1489                 return netlink_dump_start(ctnl, skb, nlh, &c);
1490         }
1491 }
1492
1493 /* Add, del and test */
1494
1495 static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1496         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1497         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1498                                     .len = IPSET_MAXNAMELEN - 1 },
1499         [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
1500         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
1501         [IPSET_ATTR_ADT]        = { .type = NLA_NESTED },
1502 };
1503
1504 static int
1505 call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1506         struct nlattr *tb[], enum ipset_adt adt,
1507         u32 flags, bool use_lineno)
1508 {
1509         int ret;
1510         u32 lineno = 0;
1511         bool eexist = flags & IPSET_FLAG_EXIST, retried = false;
1512
1513         do {
1514                 spin_lock_bh(&set->lock);
1515                 ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried);
1516                 spin_unlock_bh(&set->lock);
1517                 retried = true;
1518         } while (ret == -EAGAIN &&
1519                  set->variant->resize &&
1520                  (ret = set->variant->resize(set, retried)) == 0);
1521
1522         if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1523                 return 0;
1524         if (lineno && use_lineno) {
1525                 /* Error in restore/batch mode: send back lineno */
1526                 struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1527                 struct sk_buff *skb2;
1528                 struct nlmsgerr *errmsg;
1529                 size_t payload = min(SIZE_MAX,
1530                                      sizeof(*errmsg) + nlmsg_len(nlh));
1531                 int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1532                 struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1533                 struct nlattr *cmdattr;
1534                 u32 *errline;
1535
1536                 skb2 = nlmsg_new(payload, GFP_KERNEL);
1537                 if (!skb2)
1538                         return -ENOMEM;
1539                 rep = __nlmsg_put(skb2, NETLINK_CB(skb).portid,
1540                                   nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1541                 errmsg = nlmsg_data(rep);
1542                 errmsg->error = ret;
1543                 memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1544                 cmdattr = (void *)&errmsg->msg + min_len;
1545
1546                 ret = nla_parse_deprecated(cda, IPSET_ATTR_CMD_MAX, cmdattr,
1547                                            nlh->nlmsg_len - min_len,
1548                                            ip_set_adt_policy, NULL);
1549
1550                 if (ret) {
1551                         nlmsg_free(skb2);
1552                         return ret;
1553                 }
1554                 errline = nla_data(cda[IPSET_ATTR_LINENO]);
1555
1556                 *errline = lineno;
1557
1558                 netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid,
1559                                 MSG_DONTWAIT);
1560                 /* Signal netlink not to send its ACK/errmsg.  */
1561                 return -EINTR;
1562         }
1563
1564         return ret;
1565 }
1566
1567 static int ip_set_ad(struct net *net, struct sock *ctnl,
1568                      struct sk_buff *skb,
1569                      enum ipset_adt adt,
1570                      const struct nlmsghdr *nlh,
1571                      const struct nlattr * const attr[],
1572                      struct netlink_ext_ack *extack)
1573 {
1574         struct ip_set_net *inst = ip_set_pernet(net);
1575         struct ip_set *set;
1576         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1577         const struct nlattr *nla;
1578         u32 flags = flag_exist(nlh);
1579         bool use_lineno;
1580         int ret = 0;
1581
1582         if (unlikely(protocol_min_failed(attr) ||
1583                      !attr[IPSET_ATTR_SETNAME] ||
1584                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1585                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1586                      (attr[IPSET_ATTR_DATA] &&
1587                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1588                      (attr[IPSET_ATTR_ADT] &&
1589                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1590                        !attr[IPSET_ATTR_LINENO]))))
1591                 return -IPSET_ERR_PROTOCOL;
1592
1593         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1594         if (!set)
1595                 return -ENOENT;
1596
1597         use_lineno = !!attr[IPSET_ATTR_LINENO];
1598         if (attr[IPSET_ATTR_DATA]) {
1599                 if (nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA], set->type->adt_policy, NULL))
1600                         return -IPSET_ERR_PROTOCOL;
1601                 ret = call_ad(ctnl, skb, set, tb, adt, flags,
1602                               use_lineno);
1603         } else {
1604                 int nla_rem;
1605
1606                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1607                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1608                             !flag_nested(nla) ||
1609                             nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, nla, set->type->adt_policy, NULL))
1610                                 return -IPSET_ERR_PROTOCOL;
1611                         ret = call_ad(ctnl, skb, set, tb, adt,
1612                                       flags, use_lineno);
1613                         if (ret < 0)
1614                                 return ret;
1615                 }
1616         }
1617         return ret;
1618 }
1619
1620 static int ip_set_uadd(struct net *net, struct sock *ctnl,
1621                        struct sk_buff *skb, const struct nlmsghdr *nlh,
1622                        const struct nlattr * const attr[],
1623                        struct netlink_ext_ack *extack)
1624 {
1625         return ip_set_ad(net, ctnl, skb,
1626                          IPSET_ADD, nlh, attr, extack);
1627 }
1628
1629 static int ip_set_udel(struct net *net, struct sock *ctnl,
1630                        struct sk_buff *skb, const struct nlmsghdr *nlh,
1631                        const struct nlattr * const attr[],
1632                        struct netlink_ext_ack *extack)
1633 {
1634         return ip_set_ad(net, ctnl, skb,
1635                          IPSET_DEL, nlh, attr, extack);
1636 }
1637
1638 static int ip_set_utest(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1639                         const struct nlmsghdr *nlh,
1640                         const struct nlattr * const attr[],
1641                         struct netlink_ext_ack *extack)
1642 {
1643         struct ip_set_net *inst = ip_set_pernet(net);
1644         struct ip_set *set;
1645         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1646         int ret = 0;
1647
1648         if (unlikely(protocol_min_failed(attr) ||
1649                      !attr[IPSET_ATTR_SETNAME] ||
1650                      !attr[IPSET_ATTR_DATA] ||
1651                      !flag_nested(attr[IPSET_ATTR_DATA])))
1652                 return -IPSET_ERR_PROTOCOL;
1653
1654         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1655         if (!set)
1656                 return -ENOENT;
1657
1658         if (nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA], set->type->adt_policy, NULL))
1659                 return -IPSET_ERR_PROTOCOL;
1660
1661         rcu_read_lock_bh();
1662         ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0, 0);
1663         rcu_read_unlock_bh();
1664         /* Userspace can't trigger element to be re-added */
1665         if (ret == -EAGAIN)
1666                 ret = 1;
1667
1668         return ret > 0 ? 0 : -IPSET_ERR_EXIST;
1669 }
1670
1671 /* Get headed data of a set */
1672
1673 static int ip_set_header(struct net *net, struct sock *ctnl,
1674                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1675                          const struct nlattr * const attr[],
1676                          struct netlink_ext_ack *extack)
1677 {
1678         struct ip_set_net *inst = ip_set_pernet(net);
1679         const struct ip_set *set;
1680         struct sk_buff *skb2;
1681         struct nlmsghdr *nlh2;
1682         int ret = 0;
1683
1684         if (unlikely(protocol_min_failed(attr) ||
1685                      !attr[IPSET_ATTR_SETNAME]))
1686                 return -IPSET_ERR_PROTOCOL;
1687
1688         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1689         if (!set)
1690                 return -ENOENT;
1691
1692         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1693         if (!skb2)
1694                 return -ENOMEM;
1695
1696         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1697                          IPSET_CMD_HEADER);
1698         if (!nlh2)
1699                 goto nlmsg_failure;
1700         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1701             nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name) ||
1702             nla_put_string(skb2, IPSET_ATTR_TYPENAME, set->type->name) ||
1703             nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1704             nla_put_u8(skb2, IPSET_ATTR_REVISION, set->revision))
1705                 goto nla_put_failure;
1706         nlmsg_end(skb2, nlh2);
1707
1708         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1709         if (ret < 0)
1710                 return ret;
1711
1712         return 0;
1713
1714 nla_put_failure:
1715         nlmsg_cancel(skb2, nlh2);
1716 nlmsg_failure:
1717         kfree_skb(skb2);
1718         return -EMSGSIZE;
1719 }
1720
1721 /* Get type data */
1722
1723 static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1724         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1725         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
1726                                     .len = IPSET_MAXNAMELEN - 1 },
1727         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
1728 };
1729
1730 static int ip_set_type(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1731                        const struct nlmsghdr *nlh,
1732                        const struct nlattr * const attr[],
1733                        struct netlink_ext_ack *extack)
1734 {
1735         struct sk_buff *skb2;
1736         struct nlmsghdr *nlh2;
1737         u8 family, min, max;
1738         const char *typename;
1739         int ret = 0;
1740
1741         if (unlikely(protocol_min_failed(attr) ||
1742                      !attr[IPSET_ATTR_TYPENAME] ||
1743                      !attr[IPSET_ATTR_FAMILY]))
1744                 return -IPSET_ERR_PROTOCOL;
1745
1746         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1747         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1748         ret = find_set_type_minmax(typename, family, &min, &max);
1749         if (ret)
1750                 return ret;
1751
1752         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1753         if (!skb2)
1754                 return -ENOMEM;
1755
1756         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1757                          IPSET_CMD_TYPE);
1758         if (!nlh2)
1759                 goto nlmsg_failure;
1760         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1761             nla_put_string(skb2, IPSET_ATTR_TYPENAME, typename) ||
1762             nla_put_u8(skb2, IPSET_ATTR_FAMILY, family) ||
1763             nla_put_u8(skb2, IPSET_ATTR_REVISION, max) ||
1764             nla_put_u8(skb2, IPSET_ATTR_REVISION_MIN, min))
1765                 goto nla_put_failure;
1766         nlmsg_end(skb2, nlh2);
1767
1768         pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1769         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1770         if (ret < 0)
1771                 return ret;
1772
1773         return 0;
1774
1775 nla_put_failure:
1776         nlmsg_cancel(skb2, nlh2);
1777 nlmsg_failure:
1778         kfree_skb(skb2);
1779         return -EMSGSIZE;
1780 }
1781
1782 /* Get protocol version */
1783
1784 static const struct nla_policy
1785 ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1786         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1787 };
1788
1789 static int ip_set_protocol(struct net *net, struct sock *ctnl,
1790                            struct sk_buff *skb, const struct nlmsghdr *nlh,
1791                            const struct nlattr * const attr[],
1792                            struct netlink_ext_ack *extack)
1793 {
1794         struct sk_buff *skb2;
1795         struct nlmsghdr *nlh2;
1796         int ret = 0;
1797
1798         if (unlikely(!attr[IPSET_ATTR_PROTOCOL]))
1799                 return -IPSET_ERR_PROTOCOL;
1800
1801         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1802         if (!skb2)
1803                 return -ENOMEM;
1804
1805         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1806                          IPSET_CMD_PROTOCOL);
1807         if (!nlh2)
1808                 goto nlmsg_failure;
1809         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL))
1810                 goto nla_put_failure;
1811         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL_MIN, IPSET_PROTOCOL_MIN))
1812                 goto nla_put_failure;
1813         nlmsg_end(skb2, nlh2);
1814
1815         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1816         if (ret < 0)
1817                 return ret;
1818
1819         return 0;
1820
1821 nla_put_failure:
1822         nlmsg_cancel(skb2, nlh2);
1823 nlmsg_failure:
1824         kfree_skb(skb2);
1825         return -EMSGSIZE;
1826 }
1827
1828 /* Get set by name or index, from userspace */
1829
1830 static int ip_set_byname(struct net *net, struct sock *ctnl,
1831                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1832                          const struct nlattr * const attr[],
1833                          struct netlink_ext_ack *extack)
1834 {
1835         struct ip_set_net *inst = ip_set_pernet(net);
1836         struct sk_buff *skb2;
1837         struct nlmsghdr *nlh2;
1838         ip_set_id_t id = IPSET_INVALID_ID;
1839         const struct ip_set *set;
1840         int ret = 0;
1841
1842         if (unlikely(protocol_failed(attr) ||
1843                      !attr[IPSET_ATTR_SETNAME]))
1844                 return -IPSET_ERR_PROTOCOL;
1845
1846         set = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]), &id);
1847         if (id == IPSET_INVALID_ID)
1848                 return -ENOENT;
1849
1850         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1851         if (!skb2)
1852                 return -ENOMEM;
1853
1854         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1855                          IPSET_CMD_GET_BYNAME);
1856         if (!nlh2)
1857                 goto nlmsg_failure;
1858         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1859             nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1860             nla_put_net16(skb2, IPSET_ATTR_INDEX, htons(id)))
1861                 goto nla_put_failure;
1862         nlmsg_end(skb2, nlh2);
1863
1864         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1865         if (ret < 0)
1866                 return ret;
1867
1868         return 0;
1869
1870 nla_put_failure:
1871         nlmsg_cancel(skb2, nlh2);
1872 nlmsg_failure:
1873         kfree_skb(skb2);
1874         return -EMSGSIZE;
1875 }
1876
1877 static const struct nla_policy ip_set_index_policy[IPSET_ATTR_CMD_MAX + 1] = {
1878         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1879         [IPSET_ATTR_INDEX]      = { .type = NLA_U16 },
1880 };
1881
1882 static int ip_set_byindex(struct net *net, struct sock *ctnl,
1883                           struct sk_buff *skb, const struct nlmsghdr *nlh,
1884                           const struct nlattr * const attr[],
1885                           struct netlink_ext_ack *extack)
1886 {
1887         struct ip_set_net *inst = ip_set_pernet(net);
1888         struct sk_buff *skb2;
1889         struct nlmsghdr *nlh2;
1890         ip_set_id_t id = IPSET_INVALID_ID;
1891         const struct ip_set *set;
1892         int ret = 0;
1893
1894         if (unlikely(protocol_failed(attr) ||
1895                      !attr[IPSET_ATTR_INDEX]))
1896                 return -IPSET_ERR_PROTOCOL;
1897
1898         id = ip_set_get_h16(attr[IPSET_ATTR_INDEX]);
1899         if (id >= inst->ip_set_max)
1900                 return -ENOENT;
1901         set = ip_set(inst, id);
1902         if (set == NULL)
1903                 return -ENOENT;
1904
1905         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1906         if (!skb2)
1907                 return -ENOMEM;
1908
1909         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1910                          IPSET_CMD_GET_BYINDEX);
1911         if (!nlh2)
1912                 goto nlmsg_failure;
1913         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1914             nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name))
1915                 goto nla_put_failure;
1916         nlmsg_end(skb2, nlh2);
1917
1918         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1919         if (ret < 0)
1920                 return ret;
1921
1922         return 0;
1923
1924 nla_put_failure:
1925         nlmsg_cancel(skb2, nlh2);
1926 nlmsg_failure:
1927         kfree_skb(skb2);
1928         return -EMSGSIZE;
1929 }
1930
1931 static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1932         [IPSET_CMD_NONE]        = {
1933                 .call           = ip_set_none,
1934                 .attr_count     = IPSET_ATTR_CMD_MAX,
1935         },
1936         [IPSET_CMD_CREATE]      = {
1937                 .call           = ip_set_create,
1938                 .attr_count     = IPSET_ATTR_CMD_MAX,
1939                 .policy         = ip_set_create_policy,
1940         },
1941         [IPSET_CMD_DESTROY]     = {
1942                 .call           = ip_set_destroy,
1943                 .attr_count     = IPSET_ATTR_CMD_MAX,
1944                 .policy         = ip_set_setname_policy,
1945         },
1946         [IPSET_CMD_FLUSH]       = {
1947                 .call           = ip_set_flush,
1948                 .attr_count     = IPSET_ATTR_CMD_MAX,
1949                 .policy         = ip_set_setname_policy,
1950         },
1951         [IPSET_CMD_RENAME]      = {
1952                 .call           = ip_set_rename,
1953                 .attr_count     = IPSET_ATTR_CMD_MAX,
1954                 .policy         = ip_set_setname2_policy,
1955         },
1956         [IPSET_CMD_SWAP]        = {
1957                 .call           = ip_set_swap,
1958                 .attr_count     = IPSET_ATTR_CMD_MAX,
1959                 .policy         = ip_set_setname2_policy,
1960         },
1961         [IPSET_CMD_LIST]        = {
1962                 .call           = ip_set_dump,
1963                 .attr_count     = IPSET_ATTR_CMD_MAX,
1964                 .policy         = ip_set_setname_policy,
1965         },
1966         [IPSET_CMD_SAVE]        = {
1967                 .call           = ip_set_dump,
1968                 .attr_count     = IPSET_ATTR_CMD_MAX,
1969                 .policy         = ip_set_setname_policy,
1970         },
1971         [IPSET_CMD_ADD] = {
1972                 .call           = ip_set_uadd,
1973                 .attr_count     = IPSET_ATTR_CMD_MAX,
1974                 .policy         = ip_set_adt_policy,
1975         },
1976         [IPSET_CMD_DEL] = {
1977                 .call           = ip_set_udel,
1978                 .attr_count     = IPSET_ATTR_CMD_MAX,
1979                 .policy         = ip_set_adt_policy,
1980         },
1981         [IPSET_CMD_TEST]        = {
1982                 .call           = ip_set_utest,
1983                 .attr_count     = IPSET_ATTR_CMD_MAX,
1984                 .policy         = ip_set_adt_policy,
1985         },
1986         [IPSET_CMD_HEADER]      = {
1987                 .call           = ip_set_header,
1988                 .attr_count     = IPSET_ATTR_CMD_MAX,
1989                 .policy         = ip_set_setname_policy,
1990         },
1991         [IPSET_CMD_TYPE]        = {
1992                 .call           = ip_set_type,
1993                 .attr_count     = IPSET_ATTR_CMD_MAX,
1994                 .policy         = ip_set_type_policy,
1995         },
1996         [IPSET_CMD_PROTOCOL]    = {
1997                 .call           = ip_set_protocol,
1998                 .attr_count     = IPSET_ATTR_CMD_MAX,
1999                 .policy         = ip_set_protocol_policy,
2000         },
2001         [IPSET_CMD_GET_BYNAME]  = {
2002                 .call           = ip_set_byname,
2003                 .attr_count     = IPSET_ATTR_CMD_MAX,
2004                 .policy         = ip_set_setname_policy,
2005         },
2006         [IPSET_CMD_GET_BYINDEX] = {
2007                 .call           = ip_set_byindex,
2008                 .attr_count     = IPSET_ATTR_CMD_MAX,
2009                 .policy         = ip_set_index_policy,
2010         },
2011 };
2012
2013 static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
2014         .name           = "ip_set",
2015         .subsys_id      = NFNL_SUBSYS_IPSET,
2016         .cb_count       = IPSET_MSG_MAX,
2017         .cb             = ip_set_netlink_subsys_cb,
2018 };
2019
2020 /* Interface to iptables/ip6tables */
2021
2022 static int
2023 ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
2024 {
2025         unsigned int *op;
2026         void *data;
2027         int copylen = *len, ret = 0;
2028         struct net *net = sock_net(sk);
2029         struct ip_set_net *inst = ip_set_pernet(net);
2030
2031         if (!ns_capable(net->user_ns, CAP_NET_ADMIN))
2032                 return -EPERM;
2033         if (optval != SO_IP_SET)
2034                 return -EBADF;
2035         if (*len < sizeof(unsigned int))
2036                 return -EINVAL;
2037
2038         data = vmalloc(*len);
2039         if (!data)
2040                 return -ENOMEM;
2041         if (copy_from_user(data, user, *len) != 0) {
2042                 ret = -EFAULT;
2043                 goto done;
2044         }
2045         op = data;
2046
2047         if (*op < IP_SET_OP_VERSION) {
2048                 /* Check the version at the beginning of operations */
2049                 struct ip_set_req_version *req_version = data;
2050
2051                 if (*len < sizeof(struct ip_set_req_version)) {
2052                         ret = -EINVAL;
2053                         goto done;
2054                 }
2055
2056                 if (req_version->version < IPSET_PROTOCOL_MIN) {
2057                         ret = -EPROTO;
2058                         goto done;
2059                 }
2060         }
2061
2062         switch (*op) {
2063         case IP_SET_OP_VERSION: {
2064                 struct ip_set_req_version *req_version = data;
2065
2066                 if (*len != sizeof(struct ip_set_req_version)) {
2067                         ret = -EINVAL;
2068                         goto done;
2069                 }
2070
2071                 req_version->version = IPSET_PROTOCOL;
2072                 ret = copy_to_user(user, req_version,
2073                                    sizeof(struct ip_set_req_version));
2074                 goto done;
2075         }
2076         case IP_SET_OP_GET_BYNAME: {
2077                 struct ip_set_req_get_set *req_get = data;
2078                 ip_set_id_t id;
2079
2080                 if (*len != sizeof(struct ip_set_req_get_set)) {
2081                         ret = -EINVAL;
2082                         goto done;
2083                 }
2084                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2085                 nfnl_lock(NFNL_SUBSYS_IPSET);
2086                 find_set_and_id(inst, req_get->set.name, &id);
2087                 req_get->set.index = id;
2088                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2089                 goto copy;
2090         }
2091         case IP_SET_OP_GET_FNAME: {
2092                 struct ip_set_req_get_set_family *req_get = data;
2093                 ip_set_id_t id;
2094
2095                 if (*len != sizeof(struct ip_set_req_get_set_family)) {
2096                         ret = -EINVAL;
2097                         goto done;
2098                 }
2099                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2100                 nfnl_lock(NFNL_SUBSYS_IPSET);
2101                 find_set_and_id(inst, req_get->set.name, &id);
2102                 req_get->set.index = id;
2103                 if (id != IPSET_INVALID_ID)
2104                         req_get->family = ip_set(inst, id)->family;
2105                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2106                 goto copy;
2107         }
2108         case IP_SET_OP_GET_BYINDEX: {
2109                 struct ip_set_req_get_set *req_get = data;
2110                 struct ip_set *set;
2111
2112                 if (*len != sizeof(struct ip_set_req_get_set) ||
2113                     req_get->set.index >= inst->ip_set_max) {
2114                         ret = -EINVAL;
2115                         goto done;
2116                 }
2117                 nfnl_lock(NFNL_SUBSYS_IPSET);
2118                 set = ip_set(inst, req_get->set.index);
2119                 ret = strscpy(req_get->set.name, set ? set->name : "",
2120                               IPSET_MAXNAMELEN);
2121                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2122                 if (ret < 0)
2123                         goto done;
2124                 goto copy;
2125         }
2126         default:
2127                 ret = -EBADMSG;
2128                 goto done;
2129         }       /* end of switch(op) */
2130
2131 copy:
2132         ret = copy_to_user(user, data, copylen);
2133
2134 done:
2135         vfree(data);
2136         if (ret > 0)
2137                 ret = 0;
2138         return ret;
2139 }
2140
2141 static struct nf_sockopt_ops so_set __read_mostly = {
2142         .pf             = PF_INET,
2143         .get_optmin     = SO_IP_SET,
2144         .get_optmax     = SO_IP_SET + 1,
2145         .get            = ip_set_sockfn_get,
2146         .owner          = THIS_MODULE,
2147 };
2148
2149 static int __net_init
2150 ip_set_net_init(struct net *net)
2151 {
2152         struct ip_set_net *inst = ip_set_pernet(net);
2153         struct ip_set **list;
2154
2155         inst->ip_set_max = max_sets ? max_sets : CONFIG_IP_SET_MAX;
2156         if (inst->ip_set_max >= IPSET_INVALID_ID)
2157                 inst->ip_set_max = IPSET_INVALID_ID - 1;
2158
2159         list = kvcalloc(inst->ip_set_max, sizeof(struct ip_set *), GFP_KERNEL);
2160         if (!list)
2161                 return -ENOMEM;
2162         inst->is_deleted = false;
2163         inst->is_destroyed = false;
2164         rcu_assign_pointer(inst->ip_set_list, list);
2165         return 0;
2166 }
2167
2168 static void __net_exit
2169 ip_set_net_exit(struct net *net)
2170 {
2171         struct ip_set_net *inst = ip_set_pernet(net);
2172
2173         struct ip_set *set = NULL;
2174         ip_set_id_t i;
2175
2176         inst->is_deleted = true; /* flag for ip_set_nfnl_put */
2177
2178         nfnl_lock(NFNL_SUBSYS_IPSET);
2179         for (i = 0; i < inst->ip_set_max; i++) {
2180                 set = ip_set(inst, i);
2181                 if (set) {
2182                         ip_set(inst, i) = NULL;
2183                         ip_set_destroy_set(set);
2184                 }
2185         }
2186         nfnl_unlock(NFNL_SUBSYS_IPSET);
2187         kvfree(rcu_dereference_protected(inst->ip_set_list, 1));
2188 }
2189
2190 static struct pernet_operations ip_set_net_ops = {
2191         .init   = ip_set_net_init,
2192         .exit   = ip_set_net_exit,
2193         .id     = &ip_set_net_id,
2194         .size   = sizeof(struct ip_set_net),
2195 };
2196
2197 static int __init
2198 ip_set_init(void)
2199 {
2200         int ret = register_pernet_subsys(&ip_set_net_ops);
2201
2202         if (ret) {
2203                 pr_err("ip_set: cannot register pernet_subsys.\n");
2204                 return ret;
2205         }
2206
2207         ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
2208         if (ret != 0) {
2209                 pr_err("ip_set: cannot register with nfnetlink.\n");
2210                 unregister_pernet_subsys(&ip_set_net_ops);
2211                 return ret;
2212         }
2213
2214         ret = nf_register_sockopt(&so_set);
2215         if (ret != 0) {
2216                 pr_err("SO_SET registry failed: %d\n", ret);
2217                 nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2218                 unregister_pernet_subsys(&ip_set_net_ops);
2219                 return ret;
2220         }
2221
2222         return 0;
2223 }
2224
2225 static void __exit
2226 ip_set_fini(void)
2227 {
2228         nf_unregister_sockopt(&so_set);
2229         nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2230
2231         unregister_pernet_subsys(&ip_set_net_ops);
2232         pr_debug("these are the famous last words\n");
2233 }
2234
2235 module_init(ip_set_init);
2236 module_exit(ip_set_fini);
2237
2238 MODULE_DESCRIPTION("ip_set: protocol " __stringify(IPSET_PROTOCOL));