Merge branches 'acpi-tables', 'acpi-osl', 'acpi-misc' and 'acpi-tools'
[sfrench/cifs-2.6.git] / net / netfilter / ipset / ip_set_core.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
3  *                         Patrick Schaaf <bof@bof.de>
4  * Copyright (C) 2003-2013 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
5  */
6
7 /* Kernel module for IP set management */
8
9 #include <linux/init.h>
10 #include <linux/module.h>
11 #include <linux/moduleparam.h>
12 #include <linux/ip.h>
13 #include <linux/skbuff.h>
14 #include <linux/spinlock.h>
15 #include <linux/rculist.h>
16 #include <net/netlink.h>
17 #include <net/net_namespace.h>
18 #include <net/netns/generic.h>
19
20 #include <linux/netfilter.h>
21 #include <linux/netfilter/x_tables.h>
22 #include <linux/netfilter/nfnetlink.h>
23 #include <linux/netfilter/ipset/ip_set.h>
24
25 static LIST_HEAD(ip_set_type_list);             /* all registered set types */
26 static DEFINE_MUTEX(ip_set_type_mutex);         /* protects ip_set_type_list */
27 static DEFINE_RWLOCK(ip_set_ref_lock);          /* protects the set refs */
28
29 struct ip_set_net {
30         struct ip_set * __rcu *ip_set_list;     /* all individual sets */
31         ip_set_id_t     ip_set_max;     /* max number of sets */
32         bool            is_deleted;     /* deleted by ip_set_net_exit */
33         bool            is_destroyed;   /* all sets are destroyed */
34 };
35
36 static unsigned int ip_set_net_id __read_mostly;
37
38 static inline struct ip_set_net *ip_set_pernet(struct net *net)
39 {
40         return net_generic(net, ip_set_net_id);
41 }
42
43 #define IP_SET_INC      64
44 #define STRNCMP(a, b)   (strncmp(a, b, IPSET_MAXNAMELEN) == 0)
45
46 static unsigned int max_sets;
47
48 module_param(max_sets, int, 0600);
49 MODULE_PARM_DESC(max_sets, "maximal number of sets");
50 MODULE_LICENSE("GPL");
51 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
52 MODULE_DESCRIPTION("core IP set support");
53 MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
54
55 /* When the nfnl mutex or ip_set_ref_lock is held: */
56 #define ip_set_dereference(p)           \
57         rcu_dereference_protected(p,    \
58                 lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET) || \
59                 lockdep_is_held(&ip_set_ref_lock))
60 #define ip_set(inst, id)                \
61         ip_set_dereference((inst)->ip_set_list)[id]
62 #define ip_set_ref_netlink(inst,id)     \
63         rcu_dereference_raw((inst)->ip_set_list)[id]
64
65 /* The set types are implemented in modules and registered set types
66  * can be found in ip_set_type_list. Adding/deleting types is
67  * serialized by ip_set_type_mutex.
68  */
69
70 static inline void
71 ip_set_type_lock(void)
72 {
73         mutex_lock(&ip_set_type_mutex);
74 }
75
76 static inline void
77 ip_set_type_unlock(void)
78 {
79         mutex_unlock(&ip_set_type_mutex);
80 }
81
82 /* Register and deregister settype */
83
84 static struct ip_set_type *
85 find_set_type(const char *name, u8 family, u8 revision)
86 {
87         struct ip_set_type *type;
88
89         list_for_each_entry_rcu(type, &ip_set_type_list, list)
90                 if (STRNCMP(type->name, name) &&
91                     (type->family == family ||
92                      type->family == NFPROTO_UNSPEC) &&
93                     revision >= type->revision_min &&
94                     revision <= type->revision_max)
95                         return type;
96         return NULL;
97 }
98
99 /* Unlock, try to load a set type module and lock again */
100 static bool
101 load_settype(const char *name)
102 {
103         nfnl_unlock(NFNL_SUBSYS_IPSET);
104         pr_debug("try to load ip_set_%s\n", name);
105         if (request_module("ip_set_%s", name) < 0) {
106                 pr_warn("Can't find ip_set type %s\n", name);
107                 nfnl_lock(NFNL_SUBSYS_IPSET);
108                 return false;
109         }
110         nfnl_lock(NFNL_SUBSYS_IPSET);
111         return true;
112 }
113
114 /* Find a set type and reference it */
115 #define find_set_type_get(name, family, revision, found)        \
116         __find_set_type_get(name, family, revision, found, false)
117
118 static int
119 __find_set_type_get(const char *name, u8 family, u8 revision,
120                     struct ip_set_type **found, bool retry)
121 {
122         struct ip_set_type *type;
123         int err;
124
125         if (retry && !load_settype(name))
126                 return -IPSET_ERR_FIND_TYPE;
127
128         rcu_read_lock();
129         *found = find_set_type(name, family, revision);
130         if (*found) {
131                 err = !try_module_get((*found)->me) ? -EFAULT : 0;
132                 goto unlock;
133         }
134         /* Make sure the type is already loaded
135          * but we don't support the revision
136          */
137         list_for_each_entry_rcu(type, &ip_set_type_list, list)
138                 if (STRNCMP(type->name, name)) {
139                         err = -IPSET_ERR_FIND_TYPE;
140                         goto unlock;
141                 }
142         rcu_read_unlock();
143
144         return retry ? -IPSET_ERR_FIND_TYPE :
145                 __find_set_type_get(name, family, revision, found, true);
146
147 unlock:
148         rcu_read_unlock();
149         return err;
150 }
151
152 /* Find a given set type by name and family.
153  * If we succeeded, the supported minimal and maximum revisions are
154  * filled out.
155  */
156 #define find_set_type_minmax(name, family, min, max) \
157         __find_set_type_minmax(name, family, min, max, false)
158
159 static int
160 __find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max,
161                        bool retry)
162 {
163         struct ip_set_type *type;
164         bool found = false;
165
166         if (retry && !load_settype(name))
167                 return -IPSET_ERR_FIND_TYPE;
168
169         *min = 255; *max = 0;
170         rcu_read_lock();
171         list_for_each_entry_rcu(type, &ip_set_type_list, list)
172                 if (STRNCMP(type->name, name) &&
173                     (type->family == family ||
174                      type->family == NFPROTO_UNSPEC)) {
175                         found = true;
176                         if (type->revision_min < *min)
177                                 *min = type->revision_min;
178                         if (type->revision_max > *max)
179                                 *max = type->revision_max;
180                 }
181         rcu_read_unlock();
182         if (found)
183                 return 0;
184
185         return retry ? -IPSET_ERR_FIND_TYPE :
186                 __find_set_type_minmax(name, family, min, max, true);
187 }
188
189 #define family_name(f)  ((f) == NFPROTO_IPV4 ? "inet" : \
190                          (f) == NFPROTO_IPV6 ? "inet6" : "any")
191
192 /* Register a set type structure. The type is identified by
193  * the unique triple of name, family and revision.
194  */
195 int
196 ip_set_type_register(struct ip_set_type *type)
197 {
198         int ret = 0;
199
200         if (type->protocol != IPSET_PROTOCOL) {
201                 pr_warn("ip_set type %s, family %s, revision %u:%u uses wrong protocol version %u (want %u)\n",
202                         type->name, family_name(type->family),
203                         type->revision_min, type->revision_max,
204                         type->protocol, IPSET_PROTOCOL);
205                 return -EINVAL;
206         }
207
208         ip_set_type_lock();
209         if (find_set_type(type->name, type->family, type->revision_min)) {
210                 /* Duplicate! */
211                 pr_warn("ip_set type %s, family %s with revision min %u already registered!\n",
212                         type->name, family_name(type->family),
213                         type->revision_min);
214                 ip_set_type_unlock();
215                 return -EINVAL;
216         }
217         list_add_rcu(&type->list, &ip_set_type_list);
218         pr_debug("type %s, family %s, revision %u:%u registered.\n",
219                  type->name, family_name(type->family),
220                  type->revision_min, type->revision_max);
221         ip_set_type_unlock();
222
223         return ret;
224 }
225 EXPORT_SYMBOL_GPL(ip_set_type_register);
226
227 /* Unregister a set type. There's a small race with ip_set_create */
228 void
229 ip_set_type_unregister(struct ip_set_type *type)
230 {
231         ip_set_type_lock();
232         if (!find_set_type(type->name, type->family, type->revision_min)) {
233                 pr_warn("ip_set type %s, family %s with revision min %u not registered\n",
234                         type->name, family_name(type->family),
235                         type->revision_min);
236                 ip_set_type_unlock();
237                 return;
238         }
239         list_del_rcu(&type->list);
240         pr_debug("type %s, family %s with revision min %u unregistered.\n",
241                  type->name, family_name(type->family), type->revision_min);
242         ip_set_type_unlock();
243
244         synchronize_rcu();
245 }
246 EXPORT_SYMBOL_GPL(ip_set_type_unregister);
247
248 /* Utility functions */
249 void *
250 ip_set_alloc(size_t size)
251 {
252         void *members = NULL;
253
254         if (size < KMALLOC_MAX_SIZE)
255                 members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
256
257         if (members) {
258                 pr_debug("%p: allocated with kmalloc\n", members);
259                 return members;
260         }
261
262         members = vzalloc(size);
263         if (!members)
264                 return NULL;
265         pr_debug("%p: allocated with vmalloc\n", members);
266
267         return members;
268 }
269 EXPORT_SYMBOL_GPL(ip_set_alloc);
270
271 void
272 ip_set_free(void *members)
273 {
274         pr_debug("%p: free with %s\n", members,
275                  is_vmalloc_addr(members) ? "vfree" : "kfree");
276         kvfree(members);
277 }
278 EXPORT_SYMBOL_GPL(ip_set_free);
279
280 static inline bool
281 flag_nested(const struct nlattr *nla)
282 {
283         return nla->nla_type & NLA_F_NESTED;
284 }
285
286 static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
287         [IPSET_ATTR_IPADDR_IPV4]        = { .type = NLA_U32 },
288         [IPSET_ATTR_IPADDR_IPV6]        = { .type = NLA_BINARY,
289                                             .len = sizeof(struct in6_addr) },
290 };
291
292 int
293 ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
294 {
295         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
296
297         if (unlikely(!flag_nested(nla)))
298                 return -IPSET_ERR_PROTOCOL;
299         if (nla_parse_nested_deprecated(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy, NULL))
300                 return -IPSET_ERR_PROTOCOL;
301         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
302                 return -IPSET_ERR_PROTOCOL;
303
304         *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
305         return 0;
306 }
307 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
308
309 int
310 ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
311 {
312         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
313
314         if (unlikely(!flag_nested(nla)))
315                 return -IPSET_ERR_PROTOCOL;
316
317         if (nla_parse_nested_deprecated(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy, NULL))
318                 return -IPSET_ERR_PROTOCOL;
319         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
320                 return -IPSET_ERR_PROTOCOL;
321
322         memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
323                sizeof(struct in6_addr));
324         return 0;
325 }
326 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
327
328 typedef void (*destroyer)(struct ip_set *, void *);
329 /* ipset data extension types, in size order */
330
331 const struct ip_set_ext_type ip_set_extensions[] = {
332         [IPSET_EXT_ID_COUNTER] = {
333                 .type   = IPSET_EXT_COUNTER,
334                 .flag   = IPSET_FLAG_WITH_COUNTERS,
335                 .len    = sizeof(struct ip_set_counter),
336                 .align  = __alignof__(struct ip_set_counter),
337         },
338         [IPSET_EXT_ID_TIMEOUT] = {
339                 .type   = IPSET_EXT_TIMEOUT,
340                 .len    = sizeof(unsigned long),
341                 .align  = __alignof__(unsigned long),
342         },
343         [IPSET_EXT_ID_SKBINFO] = {
344                 .type   = IPSET_EXT_SKBINFO,
345                 .flag   = IPSET_FLAG_WITH_SKBINFO,
346                 .len    = sizeof(struct ip_set_skbinfo),
347                 .align  = __alignof__(struct ip_set_skbinfo),
348         },
349         [IPSET_EXT_ID_COMMENT] = {
350                 .type    = IPSET_EXT_COMMENT | IPSET_EXT_DESTROY,
351                 .flag    = IPSET_FLAG_WITH_COMMENT,
352                 .len     = sizeof(struct ip_set_comment),
353                 .align   = __alignof__(struct ip_set_comment),
354                 .destroy = (destroyer) ip_set_comment_free,
355         },
356 };
357 EXPORT_SYMBOL_GPL(ip_set_extensions);
358
359 static inline bool
360 add_extension(enum ip_set_ext_id id, u32 flags, struct nlattr *tb[])
361 {
362         return ip_set_extensions[id].flag ?
363                 (flags & ip_set_extensions[id].flag) :
364                 !!tb[IPSET_ATTR_TIMEOUT];
365 }
366
367 size_t
368 ip_set_elem_len(struct ip_set *set, struct nlattr *tb[], size_t len,
369                 size_t align)
370 {
371         enum ip_set_ext_id id;
372         u32 cadt_flags = 0;
373
374         if (tb[IPSET_ATTR_CADT_FLAGS])
375                 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
376         if (cadt_flags & IPSET_FLAG_WITH_FORCEADD)
377                 set->flags |= IPSET_CREATE_FLAG_FORCEADD;
378         if (!align)
379                 align = 1;
380         for (id = 0; id < IPSET_EXT_ID_MAX; id++) {
381                 if (!add_extension(id, cadt_flags, tb))
382                         continue;
383                 len = ALIGN(len, ip_set_extensions[id].align);
384                 set->offset[id] = len;
385                 set->extensions |= ip_set_extensions[id].type;
386                 len += ip_set_extensions[id].len;
387         }
388         return ALIGN(len, align);
389 }
390 EXPORT_SYMBOL_GPL(ip_set_elem_len);
391
392 int
393 ip_set_get_extensions(struct ip_set *set, struct nlattr *tb[],
394                       struct ip_set_ext *ext)
395 {
396         u64 fullmark;
397
398         if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
399                      !ip_set_optattr_netorder(tb, IPSET_ATTR_PACKETS) ||
400                      !ip_set_optattr_netorder(tb, IPSET_ATTR_BYTES) ||
401                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBMARK) ||
402                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBPRIO) ||
403                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBQUEUE)))
404                 return -IPSET_ERR_PROTOCOL;
405
406         if (tb[IPSET_ATTR_TIMEOUT]) {
407                 if (!SET_WITH_TIMEOUT(set))
408                         return -IPSET_ERR_TIMEOUT;
409                 ext->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
410         }
411         if (tb[IPSET_ATTR_BYTES] || tb[IPSET_ATTR_PACKETS]) {
412                 if (!SET_WITH_COUNTER(set))
413                         return -IPSET_ERR_COUNTER;
414                 if (tb[IPSET_ATTR_BYTES])
415                         ext->bytes = be64_to_cpu(nla_get_be64(
416                                                  tb[IPSET_ATTR_BYTES]));
417                 if (tb[IPSET_ATTR_PACKETS])
418                         ext->packets = be64_to_cpu(nla_get_be64(
419                                                    tb[IPSET_ATTR_PACKETS]));
420         }
421         if (tb[IPSET_ATTR_COMMENT]) {
422                 if (!SET_WITH_COMMENT(set))
423                         return -IPSET_ERR_COMMENT;
424                 ext->comment = ip_set_comment_uget(tb[IPSET_ATTR_COMMENT]);
425         }
426         if (tb[IPSET_ATTR_SKBMARK]) {
427                 if (!SET_WITH_SKBINFO(set))
428                         return -IPSET_ERR_SKBINFO;
429                 fullmark = be64_to_cpu(nla_get_be64(tb[IPSET_ATTR_SKBMARK]));
430                 ext->skbinfo.skbmark = fullmark >> 32;
431                 ext->skbinfo.skbmarkmask = fullmark & 0xffffffff;
432         }
433         if (tb[IPSET_ATTR_SKBPRIO]) {
434                 if (!SET_WITH_SKBINFO(set))
435                         return -IPSET_ERR_SKBINFO;
436                 ext->skbinfo.skbprio =
437                         be32_to_cpu(nla_get_be32(tb[IPSET_ATTR_SKBPRIO]));
438         }
439         if (tb[IPSET_ATTR_SKBQUEUE]) {
440                 if (!SET_WITH_SKBINFO(set))
441                         return -IPSET_ERR_SKBINFO;
442                 ext->skbinfo.skbqueue =
443                         be16_to_cpu(nla_get_be16(tb[IPSET_ATTR_SKBQUEUE]));
444         }
445         return 0;
446 }
447 EXPORT_SYMBOL_GPL(ip_set_get_extensions);
448
449 int
450 ip_set_put_extensions(struct sk_buff *skb, const struct ip_set *set,
451                       const void *e, bool active)
452 {
453         if (SET_WITH_TIMEOUT(set)) {
454                 unsigned long *timeout = ext_timeout(e, set);
455
456                 if (nla_put_net32(skb, IPSET_ATTR_TIMEOUT,
457                         htonl(active ? ip_set_timeout_get(timeout)
458                                 : *timeout)))
459                         return -EMSGSIZE;
460         }
461         if (SET_WITH_COUNTER(set) &&
462             ip_set_put_counter(skb, ext_counter(e, set)))
463                 return -EMSGSIZE;
464         if (SET_WITH_COMMENT(set) &&
465             ip_set_put_comment(skb, ext_comment(e, set)))
466                 return -EMSGSIZE;
467         if (SET_WITH_SKBINFO(set) &&
468             ip_set_put_skbinfo(skb, ext_skbinfo(e, set)))
469                 return -EMSGSIZE;
470         return 0;
471 }
472 EXPORT_SYMBOL_GPL(ip_set_put_extensions);
473
474 bool
475 ip_set_match_extensions(struct ip_set *set, const struct ip_set_ext *ext,
476                         struct ip_set_ext *mext, u32 flags, void *data)
477 {
478         if (SET_WITH_TIMEOUT(set) &&
479             ip_set_timeout_expired(ext_timeout(data, set)))
480                 return false;
481         if (SET_WITH_COUNTER(set)) {
482                 struct ip_set_counter *counter = ext_counter(data, set);
483
484                 if (flags & IPSET_FLAG_MATCH_COUNTERS &&
485                     !(ip_set_match_counter(ip_set_get_packets(counter),
486                                 mext->packets, mext->packets_op) &&
487                       ip_set_match_counter(ip_set_get_bytes(counter),
488                                 mext->bytes, mext->bytes_op)))
489                         return false;
490                 ip_set_update_counter(counter, ext, flags);
491         }
492         if (SET_WITH_SKBINFO(set))
493                 ip_set_get_skbinfo(ext_skbinfo(data, set),
494                                    ext, mext, flags);
495         return true;
496 }
497 EXPORT_SYMBOL_GPL(ip_set_match_extensions);
498
499 /* Creating/destroying/renaming/swapping affect the existence and
500  * the properties of a set. All of these can be executed from userspace
501  * only and serialized by the nfnl mutex indirectly from nfnetlink.
502  *
503  * Sets are identified by their index in ip_set_list and the index
504  * is used by the external references (set/SET netfilter modules).
505  *
506  * The set behind an index may change by swapping only, from userspace.
507  */
508
509 static inline void
510 __ip_set_get(struct ip_set *set)
511 {
512         write_lock_bh(&ip_set_ref_lock);
513         set->ref++;
514         write_unlock_bh(&ip_set_ref_lock);
515 }
516
517 static inline void
518 __ip_set_put(struct ip_set *set)
519 {
520         write_lock_bh(&ip_set_ref_lock);
521         BUG_ON(set->ref == 0);
522         set->ref--;
523         write_unlock_bh(&ip_set_ref_lock);
524 }
525
526 /* set->ref can be swapped out by ip_set_swap, netlink events (like dump) need
527  * a separate reference counter
528  */
529 static inline void
530 __ip_set_put_netlink(struct ip_set *set)
531 {
532         write_lock_bh(&ip_set_ref_lock);
533         BUG_ON(set->ref_netlink == 0);
534         set->ref_netlink--;
535         write_unlock_bh(&ip_set_ref_lock);
536 }
537
538 /* Add, del and test set entries from kernel.
539  *
540  * The set behind the index must exist and must be referenced
541  * so it can't be destroyed (or changed) under our foot.
542  */
543
544 static inline struct ip_set *
545 ip_set_rcu_get(struct net *net, ip_set_id_t index)
546 {
547         struct ip_set *set;
548         struct ip_set_net *inst = ip_set_pernet(net);
549
550         rcu_read_lock();
551         /* ip_set_list itself needs to be protected */
552         set = rcu_dereference(inst->ip_set_list)[index];
553         rcu_read_unlock();
554
555         return set;
556 }
557
558 int
559 ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
560             const struct xt_action_param *par, struct ip_set_adt_opt *opt)
561 {
562         struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
563         int ret = 0;
564
565         BUG_ON(!set);
566         pr_debug("set %s, index %u\n", set->name, index);
567
568         if (opt->dim < set->type->dimension ||
569             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
570                 return 0;
571
572         rcu_read_lock_bh();
573         ret = set->variant->kadt(set, skb, par, IPSET_TEST, opt);
574         rcu_read_unlock_bh();
575
576         if (ret == -EAGAIN) {
577                 /* Type requests element to be completed */
578                 pr_debug("element must be completed, ADD is triggered\n");
579                 spin_lock_bh(&set->lock);
580                 set->variant->kadt(set, skb, par, IPSET_ADD, opt);
581                 spin_unlock_bh(&set->lock);
582                 ret = 1;
583         } else {
584                 /* --return-nomatch: invert matched element */
585                 if ((opt->cmdflags & IPSET_FLAG_RETURN_NOMATCH) &&
586                     (set->type->features & IPSET_TYPE_NOMATCH) &&
587                     (ret > 0 || ret == -ENOTEMPTY))
588                         ret = -ret;
589         }
590
591         /* Convert error codes to nomatch */
592         return (ret < 0 ? 0 : ret);
593 }
594 EXPORT_SYMBOL_GPL(ip_set_test);
595
596 int
597 ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
598            const struct xt_action_param *par, struct ip_set_adt_opt *opt)
599 {
600         struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
601         int ret;
602
603         BUG_ON(!set);
604         pr_debug("set %s, index %u\n", set->name, index);
605
606         if (opt->dim < set->type->dimension ||
607             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
608                 return -IPSET_ERR_TYPE_MISMATCH;
609
610         spin_lock_bh(&set->lock);
611         ret = set->variant->kadt(set, skb, par, IPSET_ADD, opt);
612         spin_unlock_bh(&set->lock);
613
614         return ret;
615 }
616 EXPORT_SYMBOL_GPL(ip_set_add);
617
618 int
619 ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
620            const struct xt_action_param *par, struct ip_set_adt_opt *opt)
621 {
622         struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
623         int ret = 0;
624
625         BUG_ON(!set);
626         pr_debug("set %s, index %u\n", set->name, index);
627
628         if (opt->dim < set->type->dimension ||
629             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
630                 return -IPSET_ERR_TYPE_MISMATCH;
631
632         spin_lock_bh(&set->lock);
633         ret = set->variant->kadt(set, skb, par, IPSET_DEL, opt);
634         spin_unlock_bh(&set->lock);
635
636         return ret;
637 }
638 EXPORT_SYMBOL_GPL(ip_set_del);
639
640 /* Find set by name, reference it once. The reference makes sure the
641  * thing pointed to, does not go away under our feet.
642  *
643  */
644 ip_set_id_t
645 ip_set_get_byname(struct net *net, const char *name, struct ip_set **set)
646 {
647         ip_set_id_t i, index = IPSET_INVALID_ID;
648         struct ip_set *s;
649         struct ip_set_net *inst = ip_set_pernet(net);
650
651         rcu_read_lock();
652         for (i = 0; i < inst->ip_set_max; i++) {
653                 s = rcu_dereference(inst->ip_set_list)[i];
654                 if (s && STRNCMP(s->name, name)) {
655                         __ip_set_get(s);
656                         index = i;
657                         *set = s;
658                         break;
659                 }
660         }
661         rcu_read_unlock();
662
663         return index;
664 }
665 EXPORT_SYMBOL_GPL(ip_set_get_byname);
666
667 /* If the given set pointer points to a valid set, decrement
668  * reference count by 1. The caller shall not assume the index
669  * to be valid, after calling this function.
670  *
671  */
672
673 static inline void
674 __ip_set_put_byindex(struct ip_set_net *inst, ip_set_id_t index)
675 {
676         struct ip_set *set;
677
678         rcu_read_lock();
679         set = rcu_dereference(inst->ip_set_list)[index];
680         if (set)
681                 __ip_set_put(set);
682         rcu_read_unlock();
683 }
684
685 void
686 ip_set_put_byindex(struct net *net, ip_set_id_t index)
687 {
688         struct ip_set_net *inst = ip_set_pernet(net);
689
690         __ip_set_put_byindex(inst, index);
691 }
692 EXPORT_SYMBOL_GPL(ip_set_put_byindex);
693
694 /* Get the name of a set behind a set index.
695  * Set itself is protected by RCU, but its name isn't: to protect against
696  * renaming, grab ip_set_ref_lock as reader (see ip_set_rename()) and copy the
697  * name.
698  */
699 void
700 ip_set_name_byindex(struct net *net, ip_set_id_t index, char *name)
701 {
702         struct ip_set *set = ip_set_rcu_get(net, index);
703
704         BUG_ON(!set);
705
706         read_lock_bh(&ip_set_ref_lock);
707         strncpy(name, set->name, IPSET_MAXNAMELEN);
708         read_unlock_bh(&ip_set_ref_lock);
709 }
710 EXPORT_SYMBOL_GPL(ip_set_name_byindex);
711
712 /* Routines to call by external subsystems, which do not
713  * call nfnl_lock for us.
714  */
715
716 /* Find set by index, reference it once. The reference makes sure the
717  * thing pointed to, does not go away under our feet.
718  *
719  * The nfnl mutex is used in the function.
720  */
721 ip_set_id_t
722 ip_set_nfnl_get_byindex(struct net *net, ip_set_id_t index)
723 {
724         struct ip_set *set;
725         struct ip_set_net *inst = ip_set_pernet(net);
726
727         if (index >= inst->ip_set_max)
728                 return IPSET_INVALID_ID;
729
730         nfnl_lock(NFNL_SUBSYS_IPSET);
731         set = ip_set(inst, index);
732         if (set)
733                 __ip_set_get(set);
734         else
735                 index = IPSET_INVALID_ID;
736         nfnl_unlock(NFNL_SUBSYS_IPSET);
737
738         return index;
739 }
740 EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
741
742 /* If the given set pointer points to a valid set, decrement
743  * reference count by 1. The caller shall not assume the index
744  * to be valid, after calling this function.
745  *
746  * The nfnl mutex is used in the function.
747  */
748 void
749 ip_set_nfnl_put(struct net *net, ip_set_id_t index)
750 {
751         struct ip_set *set;
752         struct ip_set_net *inst = ip_set_pernet(net);
753
754         nfnl_lock(NFNL_SUBSYS_IPSET);
755         if (!inst->is_deleted) { /* already deleted from ip_set_net_exit() */
756                 set = ip_set(inst, index);
757                 if (set)
758                         __ip_set_put(set);
759         }
760         nfnl_unlock(NFNL_SUBSYS_IPSET);
761 }
762 EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
763
764 /* Communication protocol with userspace over netlink.
765  *
766  * The commands are serialized by the nfnl mutex.
767  */
768
769 static inline u8 protocol(const struct nlattr * const tb[])
770 {
771         return nla_get_u8(tb[IPSET_ATTR_PROTOCOL]);
772 }
773
774 static inline bool
775 protocol_failed(const struct nlattr * const tb[])
776 {
777         return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) != IPSET_PROTOCOL;
778 }
779
780 static inline bool
781 protocol_min_failed(const struct nlattr * const tb[])
782 {
783         return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) < IPSET_PROTOCOL_MIN;
784 }
785
786 static inline u32
787 flag_exist(const struct nlmsghdr *nlh)
788 {
789         return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
790 }
791
792 static struct nlmsghdr *
793 start_msg(struct sk_buff *skb, u32 portid, u32 seq, unsigned int flags,
794           enum ipset_cmd cmd)
795 {
796         struct nlmsghdr *nlh;
797         struct nfgenmsg *nfmsg;
798
799         nlh = nlmsg_put(skb, portid, seq, nfnl_msg_type(NFNL_SUBSYS_IPSET, cmd),
800                         sizeof(*nfmsg), flags);
801         if (!nlh)
802                 return NULL;
803
804         nfmsg = nlmsg_data(nlh);
805         nfmsg->nfgen_family = NFPROTO_IPV4;
806         nfmsg->version = NFNETLINK_V0;
807         nfmsg->res_id = 0;
808
809         return nlh;
810 }
811
812 /* Create a set */
813
814 static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
815         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
816         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
817                                     .len = IPSET_MAXNAMELEN - 1 },
818         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
819                                     .len = IPSET_MAXNAMELEN - 1},
820         [IPSET_ATTR_REVISION]   = { .type = NLA_U8 },
821         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
822         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
823 };
824
825 static struct ip_set *
826 find_set_and_id(struct ip_set_net *inst, const char *name, ip_set_id_t *id)
827 {
828         struct ip_set *set = NULL;
829         ip_set_id_t i;
830
831         *id = IPSET_INVALID_ID;
832         for (i = 0; i < inst->ip_set_max; i++) {
833                 set = ip_set(inst, i);
834                 if (set && STRNCMP(set->name, name)) {
835                         *id = i;
836                         break;
837                 }
838         }
839         return (*id == IPSET_INVALID_ID ? NULL : set);
840 }
841
842 static inline struct ip_set *
843 find_set(struct ip_set_net *inst, const char *name)
844 {
845         ip_set_id_t id;
846
847         return find_set_and_id(inst, name, &id);
848 }
849
850 static int
851 find_free_id(struct ip_set_net *inst, const char *name, ip_set_id_t *index,
852              struct ip_set **set)
853 {
854         struct ip_set *s;
855         ip_set_id_t i;
856
857         *index = IPSET_INVALID_ID;
858         for (i = 0;  i < inst->ip_set_max; i++) {
859                 s = ip_set(inst, i);
860                 if (!s) {
861                         if (*index == IPSET_INVALID_ID)
862                                 *index = i;
863                 } else if (STRNCMP(name, s->name)) {
864                         /* Name clash */
865                         *set = s;
866                         return -EEXIST;
867                 }
868         }
869         if (*index == IPSET_INVALID_ID)
870                 /* No free slot remained */
871                 return -IPSET_ERR_MAX_SETS;
872         return 0;
873 }
874
875 static int ip_set_none(struct net *net, struct sock *ctnl, struct sk_buff *skb,
876                        const struct nlmsghdr *nlh,
877                        const struct nlattr * const attr[],
878                        struct netlink_ext_ack *extack)
879 {
880         return -EOPNOTSUPP;
881 }
882
883 static int ip_set_create(struct net *net, struct sock *ctnl,
884                          struct sk_buff *skb, const struct nlmsghdr *nlh,
885                          const struct nlattr * const attr[],
886                          struct netlink_ext_ack *extack)
887 {
888         struct ip_set_net *inst = ip_set_pernet(net);
889         struct ip_set *set, *clash = NULL;
890         ip_set_id_t index = IPSET_INVALID_ID;
891         struct nlattr *tb[IPSET_ATTR_CREATE_MAX + 1] = {};
892         const char *name, *typename;
893         u8 family, revision;
894         u32 flags = flag_exist(nlh);
895         int ret = 0;
896
897         if (unlikely(protocol_min_failed(attr) ||
898                      !attr[IPSET_ATTR_SETNAME] ||
899                      !attr[IPSET_ATTR_TYPENAME] ||
900                      !attr[IPSET_ATTR_REVISION] ||
901                      !attr[IPSET_ATTR_FAMILY] ||
902                      (attr[IPSET_ATTR_DATA] &&
903                       !flag_nested(attr[IPSET_ATTR_DATA]))))
904                 return -IPSET_ERR_PROTOCOL;
905
906         name = nla_data(attr[IPSET_ATTR_SETNAME]);
907         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
908         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
909         revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
910         pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
911                  name, typename, family_name(family), revision);
912
913         /* First, and without any locks, allocate and initialize
914          * a normal base set structure.
915          */
916         set = kzalloc(sizeof(*set), GFP_KERNEL);
917         if (!set)
918                 return -ENOMEM;
919         spin_lock_init(&set->lock);
920         strlcpy(set->name, name, IPSET_MAXNAMELEN);
921         set->family = family;
922         set->revision = revision;
923
924         /* Next, check that we know the type, and take
925          * a reference on the type, to make sure it stays available
926          * while constructing our new set.
927          *
928          * After referencing the type, we try to create the type
929          * specific part of the set without holding any locks.
930          */
931         ret = find_set_type_get(typename, family, revision, &set->type);
932         if (ret)
933                 goto out;
934
935         /* Without holding any locks, create private part. */
936         if (attr[IPSET_ATTR_DATA] &&
937             nla_parse_nested_deprecated(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA], set->type->create_policy, NULL)) {
938                 ret = -IPSET_ERR_PROTOCOL;
939                 goto put_out;
940         }
941
942         ret = set->type->create(net, set, tb, flags);
943         if (ret != 0)
944                 goto put_out;
945
946         /* BTW, ret==0 here. */
947
948         /* Here, we have a valid, constructed set and we are protected
949          * by the nfnl mutex. Find the first free index in ip_set_list
950          * and check clashing.
951          */
952         ret = find_free_id(inst, set->name, &index, &clash);
953         if (ret == -EEXIST) {
954                 /* If this is the same set and requested, ignore error */
955                 if ((flags & IPSET_FLAG_EXIST) &&
956                     STRNCMP(set->type->name, clash->type->name) &&
957                     set->type->family == clash->type->family &&
958                     set->type->revision_min == clash->type->revision_min &&
959                     set->type->revision_max == clash->type->revision_max &&
960                     set->variant->same_set(set, clash))
961                         ret = 0;
962                 goto cleanup;
963         } else if (ret == -IPSET_ERR_MAX_SETS) {
964                 struct ip_set **list, **tmp;
965                 ip_set_id_t i = inst->ip_set_max + IP_SET_INC;
966
967                 if (i < inst->ip_set_max || i == IPSET_INVALID_ID)
968                         /* Wraparound */
969                         goto cleanup;
970
971                 list = kvcalloc(i, sizeof(struct ip_set *), GFP_KERNEL);
972                 if (!list)
973                         goto cleanup;
974                 /* nfnl mutex is held, both lists are valid */
975                 tmp = ip_set_dereference(inst->ip_set_list);
976                 memcpy(list, tmp, sizeof(struct ip_set *) * inst->ip_set_max);
977                 rcu_assign_pointer(inst->ip_set_list, list);
978                 /* Make sure all current packets have passed through */
979                 synchronize_net();
980                 /* Use new list */
981                 index = inst->ip_set_max;
982                 inst->ip_set_max = i;
983                 kvfree(tmp);
984                 ret = 0;
985         } else if (ret) {
986                 goto cleanup;
987         }
988
989         /* Finally! Add our shiny new set to the list, and be done. */
990         pr_debug("create: '%s' created with index %u!\n", set->name, index);
991         ip_set(inst, index) = set;
992
993         return ret;
994
995 cleanup:
996         set->variant->destroy(set);
997 put_out:
998         module_put(set->type->me);
999 out:
1000         kfree(set);
1001         return ret;
1002 }
1003
1004 /* Destroy sets */
1005
1006 static const struct nla_policy
1007 ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
1008         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1009         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1010                                     .len = IPSET_MAXNAMELEN - 1 },
1011 };
1012
1013 static void
1014 ip_set_destroy_set(struct ip_set *set)
1015 {
1016         pr_debug("set: %s\n",  set->name);
1017
1018         /* Must call it without holding any lock */
1019         set->variant->destroy(set);
1020         module_put(set->type->me);
1021         kfree(set);
1022 }
1023
1024 static int ip_set_destroy(struct net *net, struct sock *ctnl,
1025                           struct sk_buff *skb, const struct nlmsghdr *nlh,
1026                           const struct nlattr * const attr[],
1027                           struct netlink_ext_ack *extack)
1028 {
1029         struct ip_set_net *inst = ip_set_pernet(net);
1030         struct ip_set *s;
1031         ip_set_id_t i;
1032         int ret = 0;
1033
1034         if (unlikely(protocol_min_failed(attr)))
1035                 return -IPSET_ERR_PROTOCOL;
1036
1037         /* Must wait for flush to be really finished in list:set */
1038         rcu_barrier();
1039
1040         /* Commands are serialized and references are
1041          * protected by the ip_set_ref_lock.
1042          * External systems (i.e. xt_set) must call
1043          * ip_set_put|get_nfnl_* functions, that way we
1044          * can safely check references here.
1045          *
1046          * list:set timer can only decrement the reference
1047          * counter, so if it's already zero, we can proceed
1048          * without holding the lock.
1049          */
1050         read_lock_bh(&ip_set_ref_lock);
1051         if (!attr[IPSET_ATTR_SETNAME]) {
1052                 for (i = 0; i < inst->ip_set_max; i++) {
1053                         s = ip_set(inst, i);
1054                         if (s && (s->ref || s->ref_netlink)) {
1055                                 ret = -IPSET_ERR_BUSY;
1056                                 goto out;
1057                         }
1058                 }
1059                 inst->is_destroyed = true;
1060                 read_unlock_bh(&ip_set_ref_lock);
1061                 for (i = 0; i < inst->ip_set_max; i++) {
1062                         s = ip_set(inst, i);
1063                         if (s) {
1064                                 ip_set(inst, i) = NULL;
1065                                 ip_set_destroy_set(s);
1066                         }
1067                 }
1068                 /* Modified by ip_set_destroy() only, which is serialized */
1069                 inst->is_destroyed = false;
1070         } else {
1071                 s = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1072                                     &i);
1073                 if (!s) {
1074                         ret = -ENOENT;
1075                         goto out;
1076                 } else if (s->ref || s->ref_netlink) {
1077                         ret = -IPSET_ERR_BUSY;
1078                         goto out;
1079                 }
1080                 ip_set(inst, i) = NULL;
1081                 read_unlock_bh(&ip_set_ref_lock);
1082
1083                 ip_set_destroy_set(s);
1084         }
1085         return 0;
1086 out:
1087         read_unlock_bh(&ip_set_ref_lock);
1088         return ret;
1089 }
1090
1091 /* Flush sets */
1092
1093 static void
1094 ip_set_flush_set(struct ip_set *set)
1095 {
1096         pr_debug("set: %s\n",  set->name);
1097
1098         spin_lock_bh(&set->lock);
1099         set->variant->flush(set);
1100         spin_unlock_bh(&set->lock);
1101 }
1102
1103 static int ip_set_flush(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1104                         const struct nlmsghdr *nlh,
1105                         const struct nlattr * const attr[],
1106                         struct netlink_ext_ack *extack)
1107 {
1108         struct ip_set_net *inst = ip_set_pernet(net);
1109         struct ip_set *s;
1110         ip_set_id_t i;
1111
1112         if (unlikely(protocol_min_failed(attr)))
1113                 return -IPSET_ERR_PROTOCOL;
1114
1115         if (!attr[IPSET_ATTR_SETNAME]) {
1116                 for (i = 0; i < inst->ip_set_max; i++) {
1117                         s = ip_set(inst, i);
1118                         if (s)
1119                                 ip_set_flush_set(s);
1120                 }
1121         } else {
1122                 s = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1123                 if (!s)
1124                         return -ENOENT;
1125
1126                 ip_set_flush_set(s);
1127         }
1128
1129         return 0;
1130 }
1131
1132 /* Rename a set */
1133
1134 static const struct nla_policy
1135 ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
1136         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1137         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1138                                     .len = IPSET_MAXNAMELEN - 1 },
1139         [IPSET_ATTR_SETNAME2]   = { .type = NLA_NUL_STRING,
1140                                     .len = IPSET_MAXNAMELEN - 1 },
1141 };
1142
1143 static int ip_set_rename(struct net *net, struct sock *ctnl,
1144                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1145                          const struct nlattr * const attr[],
1146                          struct netlink_ext_ack *extack)
1147 {
1148         struct ip_set_net *inst = ip_set_pernet(net);
1149         struct ip_set *set, *s;
1150         const char *name2;
1151         ip_set_id_t i;
1152         int ret = 0;
1153
1154         if (unlikely(protocol_min_failed(attr) ||
1155                      !attr[IPSET_ATTR_SETNAME] ||
1156                      !attr[IPSET_ATTR_SETNAME2]))
1157                 return -IPSET_ERR_PROTOCOL;
1158
1159         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1160         if (!set)
1161                 return -ENOENT;
1162
1163         write_lock_bh(&ip_set_ref_lock);
1164         if (set->ref != 0) {
1165                 ret = -IPSET_ERR_REFERENCED;
1166                 goto out;
1167         }
1168
1169         name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
1170         for (i = 0; i < inst->ip_set_max; i++) {
1171                 s = ip_set(inst, i);
1172                 if (s && STRNCMP(s->name, name2)) {
1173                         ret = -IPSET_ERR_EXIST_SETNAME2;
1174                         goto out;
1175                 }
1176         }
1177         strncpy(set->name, name2, IPSET_MAXNAMELEN);
1178
1179 out:
1180         write_unlock_bh(&ip_set_ref_lock);
1181         return ret;
1182 }
1183
1184 /* Swap two sets so that name/index points to the other.
1185  * References and set names are also swapped.
1186  *
1187  * The commands are serialized by the nfnl mutex and references are
1188  * protected by the ip_set_ref_lock. The kernel interfaces
1189  * do not hold the mutex but the pointer settings are atomic
1190  * so the ip_set_list always contains valid pointers to the sets.
1191  */
1192
1193 static int ip_set_swap(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1194                        const struct nlmsghdr *nlh,
1195                        const struct nlattr * const attr[],
1196                        struct netlink_ext_ack *extack)
1197 {
1198         struct ip_set_net *inst = ip_set_pernet(net);
1199         struct ip_set *from, *to;
1200         ip_set_id_t from_id, to_id;
1201         char from_name[IPSET_MAXNAMELEN];
1202
1203         if (unlikely(protocol_min_failed(attr) ||
1204                      !attr[IPSET_ATTR_SETNAME] ||
1205                      !attr[IPSET_ATTR_SETNAME2]))
1206                 return -IPSET_ERR_PROTOCOL;
1207
1208         from = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1209                                &from_id);
1210         if (!from)
1211                 return -ENOENT;
1212
1213         to = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME2]),
1214                              &to_id);
1215         if (!to)
1216                 return -IPSET_ERR_EXIST_SETNAME2;
1217
1218         /* Features must not change.
1219          * Not an artifical restriction anymore, as we must prevent
1220          * possible loops created by swapping in setlist type of sets.
1221          */
1222         if (!(from->type->features == to->type->features &&
1223               from->family == to->family))
1224                 return -IPSET_ERR_TYPE_MISMATCH;
1225
1226         write_lock_bh(&ip_set_ref_lock);
1227
1228         if (from->ref_netlink || to->ref_netlink) {
1229                 write_unlock_bh(&ip_set_ref_lock);
1230                 return -EBUSY;
1231         }
1232
1233         strncpy(from_name, from->name, IPSET_MAXNAMELEN);
1234         strncpy(from->name, to->name, IPSET_MAXNAMELEN);
1235         strncpy(to->name, from_name, IPSET_MAXNAMELEN);
1236
1237         swap(from->ref, to->ref);
1238         ip_set(inst, from_id) = to;
1239         ip_set(inst, to_id) = from;
1240         write_unlock_bh(&ip_set_ref_lock);
1241
1242         return 0;
1243 }
1244
1245 /* List/save set data */
1246
1247 #define DUMP_INIT       0
1248 #define DUMP_ALL        1
1249 #define DUMP_ONE        2
1250 #define DUMP_LAST       3
1251
1252 #define DUMP_TYPE(arg)          (((u32)(arg)) & 0x0000FFFF)
1253 #define DUMP_FLAGS(arg)         (((u32)(arg)) >> 16)
1254
1255 static int
1256 ip_set_dump_done(struct netlink_callback *cb)
1257 {
1258         if (cb->args[IPSET_CB_ARG0]) {
1259                 struct ip_set_net *inst =
1260                         (struct ip_set_net *)cb->args[IPSET_CB_NET];
1261                 ip_set_id_t index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1262                 struct ip_set *set = ip_set_ref_netlink(inst, index);
1263
1264                 if (set->variant->uref)
1265                         set->variant->uref(set, cb, false);
1266                 pr_debug("release set %s\n", set->name);
1267                 __ip_set_put_netlink(set);
1268         }
1269         return 0;
1270 }
1271
1272 static inline void
1273 dump_attrs(struct nlmsghdr *nlh)
1274 {
1275         const struct nlattr *attr;
1276         int rem;
1277
1278         pr_debug("dump nlmsg\n");
1279         nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
1280                 pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
1281         }
1282 }
1283
1284 static int
1285 dump_init(struct netlink_callback *cb, struct ip_set_net *inst)
1286 {
1287         struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
1288         int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1289         struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1290         struct nlattr *attr = (void *)nlh + min_len;
1291         u32 dump_type;
1292         ip_set_id_t index;
1293
1294         /* Second pass, so parser can't fail */
1295         nla_parse_deprecated(cda, IPSET_ATTR_CMD_MAX, attr,
1296                              nlh->nlmsg_len - min_len, ip_set_setname_policy,
1297                              NULL);
1298
1299         cb->args[IPSET_CB_PROTO] = nla_get_u8(cda[IPSET_ATTR_PROTOCOL]);
1300         if (cda[IPSET_ATTR_SETNAME]) {
1301                 struct ip_set *set;
1302
1303                 set = find_set_and_id(inst, nla_data(cda[IPSET_ATTR_SETNAME]),
1304                                       &index);
1305                 if (!set)
1306                         return -ENOENT;
1307
1308                 dump_type = DUMP_ONE;
1309                 cb->args[IPSET_CB_INDEX] = index;
1310         } else {
1311                 dump_type = DUMP_ALL;
1312         }
1313
1314         if (cda[IPSET_ATTR_FLAGS]) {
1315                 u32 f = ip_set_get_h32(cda[IPSET_ATTR_FLAGS]);
1316
1317                 dump_type |= (f << 16);
1318         }
1319         cb->args[IPSET_CB_NET] = (unsigned long)inst;
1320         cb->args[IPSET_CB_DUMP] = dump_type;
1321
1322         return 0;
1323 }
1324
1325 static int
1326 ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
1327 {
1328         ip_set_id_t index = IPSET_INVALID_ID, max;
1329         struct ip_set *set = NULL;
1330         struct nlmsghdr *nlh = NULL;
1331         unsigned int flags = NETLINK_CB(cb->skb).portid ? NLM_F_MULTI : 0;
1332         struct ip_set_net *inst = ip_set_pernet(sock_net(skb->sk));
1333         u32 dump_type, dump_flags;
1334         bool is_destroyed;
1335         int ret = 0;
1336
1337         if (!cb->args[IPSET_CB_DUMP]) {
1338                 ret = dump_init(cb, inst);
1339                 if (ret < 0) {
1340                         nlh = nlmsg_hdr(cb->skb);
1341                         /* We have to create and send the error message
1342                          * manually :-(
1343                          */
1344                         if (nlh->nlmsg_flags & NLM_F_ACK)
1345                                 netlink_ack(cb->skb, nlh, ret, NULL);
1346                         return ret;
1347                 }
1348         }
1349
1350         if (cb->args[IPSET_CB_INDEX] >= inst->ip_set_max)
1351                 goto out;
1352
1353         dump_type = DUMP_TYPE(cb->args[IPSET_CB_DUMP]);
1354         dump_flags = DUMP_FLAGS(cb->args[IPSET_CB_DUMP]);
1355         max = dump_type == DUMP_ONE ? cb->args[IPSET_CB_INDEX] + 1
1356                                     : inst->ip_set_max;
1357 dump_last:
1358         pr_debug("dump type, flag: %u %u index: %ld\n",
1359                  dump_type, dump_flags, cb->args[IPSET_CB_INDEX]);
1360         for (; cb->args[IPSET_CB_INDEX] < max; cb->args[IPSET_CB_INDEX]++) {
1361                 index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1362                 write_lock_bh(&ip_set_ref_lock);
1363                 set = ip_set(inst, index);
1364                 is_destroyed = inst->is_destroyed;
1365                 if (!set || is_destroyed) {
1366                         write_unlock_bh(&ip_set_ref_lock);
1367                         if (dump_type == DUMP_ONE) {
1368                                 ret = -ENOENT;
1369                                 goto out;
1370                         }
1371                         if (is_destroyed) {
1372                                 /* All sets are just being destroyed */
1373                                 ret = 0;
1374                                 goto out;
1375                         }
1376                         continue;
1377                 }
1378                 /* When dumping all sets, we must dump "sorted"
1379                  * so that lists (unions of sets) are dumped last.
1380                  */
1381                 if (dump_type != DUMP_ONE &&
1382                     ((dump_type == DUMP_ALL) ==
1383                      !!(set->type->features & IPSET_DUMP_LAST))) {
1384                         write_unlock_bh(&ip_set_ref_lock);
1385                         continue;
1386                 }
1387                 pr_debug("List set: %s\n", set->name);
1388                 if (!cb->args[IPSET_CB_ARG0]) {
1389                         /* Start listing: make sure set won't be destroyed */
1390                         pr_debug("reference set\n");
1391                         set->ref_netlink++;
1392                 }
1393                 write_unlock_bh(&ip_set_ref_lock);
1394                 nlh = start_msg(skb, NETLINK_CB(cb->skb).portid,
1395                                 cb->nlh->nlmsg_seq, flags,
1396                                 IPSET_CMD_LIST);
1397                 if (!nlh) {
1398                         ret = -EMSGSIZE;
1399                         goto release_refcount;
1400                 }
1401                 if (nla_put_u8(skb, IPSET_ATTR_PROTOCOL,
1402                                cb->args[IPSET_CB_PROTO]) ||
1403                     nla_put_string(skb, IPSET_ATTR_SETNAME, set->name))
1404                         goto nla_put_failure;
1405                 if (dump_flags & IPSET_FLAG_LIST_SETNAME)
1406                         goto next_set;
1407                 switch (cb->args[IPSET_CB_ARG0]) {
1408                 case 0:
1409                         /* Core header data */
1410                         if (nla_put_string(skb, IPSET_ATTR_TYPENAME,
1411                                            set->type->name) ||
1412                             nla_put_u8(skb, IPSET_ATTR_FAMILY,
1413                                        set->family) ||
1414                             nla_put_u8(skb, IPSET_ATTR_REVISION,
1415                                        set->revision))
1416                                 goto nla_put_failure;
1417                         if (cb->args[IPSET_CB_PROTO] > IPSET_PROTOCOL_MIN &&
1418                             nla_put_net16(skb, IPSET_ATTR_INDEX, htons(index)))
1419                                 goto nla_put_failure;
1420                         ret = set->variant->head(set, skb);
1421                         if (ret < 0)
1422                                 goto release_refcount;
1423                         if (dump_flags & IPSET_FLAG_LIST_HEADER)
1424                                 goto next_set;
1425                         if (set->variant->uref)
1426                                 set->variant->uref(set, cb, true);
1427                         /* fall through */
1428                 default:
1429                         ret = set->variant->list(set, skb, cb);
1430                         if (!cb->args[IPSET_CB_ARG0])
1431                                 /* Set is done, proceed with next one */
1432                                 goto next_set;
1433                         goto release_refcount;
1434                 }
1435         }
1436         /* If we dump all sets, continue with dumping last ones */
1437         if (dump_type == DUMP_ALL) {
1438                 dump_type = DUMP_LAST;
1439                 cb->args[IPSET_CB_DUMP] = dump_type | (dump_flags << 16);
1440                 cb->args[IPSET_CB_INDEX] = 0;
1441                 if (set && set->variant->uref)
1442                         set->variant->uref(set, cb, false);
1443                 goto dump_last;
1444         }
1445         goto out;
1446
1447 nla_put_failure:
1448         ret = -EFAULT;
1449 next_set:
1450         if (dump_type == DUMP_ONE)
1451                 cb->args[IPSET_CB_INDEX] = IPSET_INVALID_ID;
1452         else
1453                 cb->args[IPSET_CB_INDEX]++;
1454 release_refcount:
1455         /* If there was an error or set is done, release set */
1456         if (ret || !cb->args[IPSET_CB_ARG0]) {
1457                 set = ip_set_ref_netlink(inst, index);
1458                 if (set->variant->uref)
1459                         set->variant->uref(set, cb, false);
1460                 pr_debug("release set %s\n", set->name);
1461                 __ip_set_put_netlink(set);
1462                 cb->args[IPSET_CB_ARG0] = 0;
1463         }
1464 out:
1465         if (nlh) {
1466                 nlmsg_end(skb, nlh);
1467                 pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1468                 dump_attrs(nlh);
1469         }
1470
1471         return ret < 0 ? ret : skb->len;
1472 }
1473
1474 static int ip_set_dump(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1475                        const struct nlmsghdr *nlh,
1476                        const struct nlattr * const attr[],
1477                        struct netlink_ext_ack *extack)
1478 {
1479         if (unlikely(protocol_min_failed(attr)))
1480                 return -IPSET_ERR_PROTOCOL;
1481
1482         {
1483                 struct netlink_dump_control c = {
1484                         .dump = ip_set_dump_start,
1485                         .done = ip_set_dump_done,
1486                 };
1487                 return netlink_dump_start(ctnl, skb, nlh, &c);
1488         }
1489 }
1490
1491 /* Add, del and test */
1492
1493 static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1494         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1495         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1496                                     .len = IPSET_MAXNAMELEN - 1 },
1497         [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
1498         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
1499         [IPSET_ATTR_ADT]        = { .type = NLA_NESTED },
1500 };
1501
1502 static int
1503 call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1504         struct nlattr *tb[], enum ipset_adt adt,
1505         u32 flags, bool use_lineno)
1506 {
1507         int ret;
1508         u32 lineno = 0;
1509         bool eexist = flags & IPSET_FLAG_EXIST, retried = false;
1510
1511         do {
1512                 spin_lock_bh(&set->lock);
1513                 ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried);
1514                 spin_unlock_bh(&set->lock);
1515                 retried = true;
1516         } while (ret == -EAGAIN &&
1517                  set->variant->resize &&
1518                  (ret = set->variant->resize(set, retried)) == 0);
1519
1520         if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1521                 return 0;
1522         if (lineno && use_lineno) {
1523                 /* Error in restore/batch mode: send back lineno */
1524                 struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1525                 struct sk_buff *skb2;
1526                 struct nlmsgerr *errmsg;
1527                 size_t payload = min(SIZE_MAX,
1528                                      sizeof(*errmsg) + nlmsg_len(nlh));
1529                 int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1530                 struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1531                 struct nlattr *cmdattr;
1532                 u32 *errline;
1533
1534                 skb2 = nlmsg_new(payload, GFP_KERNEL);
1535                 if (!skb2)
1536                         return -ENOMEM;
1537                 rep = __nlmsg_put(skb2, NETLINK_CB(skb).portid,
1538                                   nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1539                 errmsg = nlmsg_data(rep);
1540                 errmsg->error = ret;
1541                 memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1542                 cmdattr = (void *)&errmsg->msg + min_len;
1543
1544                 nla_parse_deprecated(cda, IPSET_ATTR_CMD_MAX, cmdattr,
1545                                      nlh->nlmsg_len - min_len,
1546                                      ip_set_adt_policy, NULL);
1547
1548                 errline = nla_data(cda[IPSET_ATTR_LINENO]);
1549
1550                 *errline = lineno;
1551
1552                 netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid,
1553                                 MSG_DONTWAIT);
1554                 /* Signal netlink not to send its ACK/errmsg.  */
1555                 return -EINTR;
1556         }
1557
1558         return ret;
1559 }
1560
1561 static int ip_set_uadd(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1562                        const struct nlmsghdr *nlh,
1563                        const struct nlattr * const attr[],
1564                        struct netlink_ext_ack *extack)
1565 {
1566         struct ip_set_net *inst = ip_set_pernet(net);
1567         struct ip_set *set;
1568         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1569         const struct nlattr *nla;
1570         u32 flags = flag_exist(nlh);
1571         bool use_lineno;
1572         int ret = 0;
1573
1574         if (unlikely(protocol_min_failed(attr) ||
1575                      !attr[IPSET_ATTR_SETNAME] ||
1576                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1577                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1578                      (attr[IPSET_ATTR_DATA] &&
1579                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1580                      (attr[IPSET_ATTR_ADT] &&
1581                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1582                        !attr[IPSET_ATTR_LINENO]))))
1583                 return -IPSET_ERR_PROTOCOL;
1584
1585         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1586         if (!set)
1587                 return -ENOENT;
1588
1589         use_lineno = !!attr[IPSET_ATTR_LINENO];
1590         if (attr[IPSET_ATTR_DATA]) {
1591                 if (nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA], set->type->adt_policy, NULL))
1592                         return -IPSET_ERR_PROTOCOL;
1593                 ret = call_ad(ctnl, skb, set, tb, IPSET_ADD, flags,
1594                               use_lineno);
1595         } else {
1596                 int nla_rem;
1597
1598                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1599                         memset(tb, 0, sizeof(tb));
1600                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1601                             !flag_nested(nla) ||
1602                             nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, nla, set->type->adt_policy, NULL))
1603                                 return -IPSET_ERR_PROTOCOL;
1604                         ret = call_ad(ctnl, skb, set, tb, IPSET_ADD,
1605                                       flags, use_lineno);
1606                         if (ret < 0)
1607                                 return ret;
1608                 }
1609         }
1610         return ret;
1611 }
1612
1613 static int ip_set_udel(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1614                        const struct nlmsghdr *nlh,
1615                        const struct nlattr * const attr[],
1616                        struct netlink_ext_ack *extack)
1617 {
1618         struct ip_set_net *inst = ip_set_pernet(net);
1619         struct ip_set *set;
1620         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1621         const struct nlattr *nla;
1622         u32 flags = flag_exist(nlh);
1623         bool use_lineno;
1624         int ret = 0;
1625
1626         if (unlikely(protocol_min_failed(attr) ||
1627                      !attr[IPSET_ATTR_SETNAME] ||
1628                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1629                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1630                      (attr[IPSET_ATTR_DATA] &&
1631                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1632                      (attr[IPSET_ATTR_ADT] &&
1633                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1634                        !attr[IPSET_ATTR_LINENO]))))
1635                 return -IPSET_ERR_PROTOCOL;
1636
1637         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1638         if (!set)
1639                 return -ENOENT;
1640
1641         use_lineno = !!attr[IPSET_ATTR_LINENO];
1642         if (attr[IPSET_ATTR_DATA]) {
1643                 if (nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA], set->type->adt_policy, NULL))
1644                         return -IPSET_ERR_PROTOCOL;
1645                 ret = call_ad(ctnl, skb, set, tb, IPSET_DEL, flags,
1646                               use_lineno);
1647         } else {
1648                 int nla_rem;
1649
1650                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1651                         memset(tb, 0, sizeof(*tb));
1652                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1653                             !flag_nested(nla) ||
1654                             nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, nla, set->type->adt_policy, NULL))
1655                                 return -IPSET_ERR_PROTOCOL;
1656                         ret = call_ad(ctnl, skb, set, tb, IPSET_DEL,
1657                                       flags, use_lineno);
1658                         if (ret < 0)
1659                                 return ret;
1660                 }
1661         }
1662         return ret;
1663 }
1664
1665 static int ip_set_utest(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1666                         const struct nlmsghdr *nlh,
1667                         const struct nlattr * const attr[],
1668                         struct netlink_ext_ack *extack)
1669 {
1670         struct ip_set_net *inst = ip_set_pernet(net);
1671         struct ip_set *set;
1672         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1673         int ret = 0;
1674
1675         if (unlikely(protocol_min_failed(attr) ||
1676                      !attr[IPSET_ATTR_SETNAME] ||
1677                      !attr[IPSET_ATTR_DATA] ||
1678                      !flag_nested(attr[IPSET_ATTR_DATA])))
1679                 return -IPSET_ERR_PROTOCOL;
1680
1681         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1682         if (!set)
1683                 return -ENOENT;
1684
1685         if (nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA], set->type->adt_policy, NULL))
1686                 return -IPSET_ERR_PROTOCOL;
1687
1688         rcu_read_lock_bh();
1689         ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0, 0);
1690         rcu_read_unlock_bh();
1691         /* Userspace can't trigger element to be re-added */
1692         if (ret == -EAGAIN)
1693                 ret = 1;
1694
1695         return ret > 0 ? 0 : -IPSET_ERR_EXIST;
1696 }
1697
1698 /* Get headed data of a set */
1699
1700 static int ip_set_header(struct net *net, struct sock *ctnl,
1701                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1702                          const struct nlattr * const attr[],
1703                          struct netlink_ext_ack *extack)
1704 {
1705         struct ip_set_net *inst = ip_set_pernet(net);
1706         const struct ip_set *set;
1707         struct sk_buff *skb2;
1708         struct nlmsghdr *nlh2;
1709         int ret = 0;
1710
1711         if (unlikely(protocol_min_failed(attr) ||
1712                      !attr[IPSET_ATTR_SETNAME]))
1713                 return -IPSET_ERR_PROTOCOL;
1714
1715         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1716         if (!set)
1717                 return -ENOENT;
1718
1719         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1720         if (!skb2)
1721                 return -ENOMEM;
1722
1723         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1724                          IPSET_CMD_HEADER);
1725         if (!nlh2)
1726                 goto nlmsg_failure;
1727         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1728             nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name) ||
1729             nla_put_string(skb2, IPSET_ATTR_TYPENAME, set->type->name) ||
1730             nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1731             nla_put_u8(skb2, IPSET_ATTR_REVISION, set->revision))
1732                 goto nla_put_failure;
1733         nlmsg_end(skb2, nlh2);
1734
1735         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1736         if (ret < 0)
1737                 return ret;
1738
1739         return 0;
1740
1741 nla_put_failure:
1742         nlmsg_cancel(skb2, nlh2);
1743 nlmsg_failure:
1744         kfree_skb(skb2);
1745         return -EMSGSIZE;
1746 }
1747
1748 /* Get type data */
1749
1750 static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1751         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1752         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
1753                                     .len = IPSET_MAXNAMELEN - 1 },
1754         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
1755 };
1756
1757 static int ip_set_type(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1758                        const struct nlmsghdr *nlh,
1759                        const struct nlattr * const attr[],
1760                        struct netlink_ext_ack *extack)
1761 {
1762         struct sk_buff *skb2;
1763         struct nlmsghdr *nlh2;
1764         u8 family, min, max;
1765         const char *typename;
1766         int ret = 0;
1767
1768         if (unlikely(protocol_min_failed(attr) ||
1769                      !attr[IPSET_ATTR_TYPENAME] ||
1770                      !attr[IPSET_ATTR_FAMILY]))
1771                 return -IPSET_ERR_PROTOCOL;
1772
1773         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1774         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1775         ret = find_set_type_minmax(typename, family, &min, &max);
1776         if (ret)
1777                 return ret;
1778
1779         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1780         if (!skb2)
1781                 return -ENOMEM;
1782
1783         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1784                          IPSET_CMD_TYPE);
1785         if (!nlh2)
1786                 goto nlmsg_failure;
1787         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1788             nla_put_string(skb2, IPSET_ATTR_TYPENAME, typename) ||
1789             nla_put_u8(skb2, IPSET_ATTR_FAMILY, family) ||
1790             nla_put_u8(skb2, IPSET_ATTR_REVISION, max) ||
1791             nla_put_u8(skb2, IPSET_ATTR_REVISION_MIN, min))
1792                 goto nla_put_failure;
1793         nlmsg_end(skb2, nlh2);
1794
1795         pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1796         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1797         if (ret < 0)
1798                 return ret;
1799
1800         return 0;
1801
1802 nla_put_failure:
1803         nlmsg_cancel(skb2, nlh2);
1804 nlmsg_failure:
1805         kfree_skb(skb2);
1806         return -EMSGSIZE;
1807 }
1808
1809 /* Get protocol version */
1810
1811 static const struct nla_policy
1812 ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1813         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1814 };
1815
1816 static int ip_set_protocol(struct net *net, struct sock *ctnl,
1817                            struct sk_buff *skb, const struct nlmsghdr *nlh,
1818                            const struct nlattr * const attr[],
1819                            struct netlink_ext_ack *extack)
1820 {
1821         struct sk_buff *skb2;
1822         struct nlmsghdr *nlh2;
1823         int ret = 0;
1824
1825         if (unlikely(!attr[IPSET_ATTR_PROTOCOL]))
1826                 return -IPSET_ERR_PROTOCOL;
1827
1828         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1829         if (!skb2)
1830                 return -ENOMEM;
1831
1832         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1833                          IPSET_CMD_PROTOCOL);
1834         if (!nlh2)
1835                 goto nlmsg_failure;
1836         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL))
1837                 goto nla_put_failure;
1838         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL_MIN, IPSET_PROTOCOL_MIN))
1839                 goto nla_put_failure;
1840         nlmsg_end(skb2, nlh2);
1841
1842         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1843         if (ret < 0)
1844                 return ret;
1845
1846         return 0;
1847
1848 nla_put_failure:
1849         nlmsg_cancel(skb2, nlh2);
1850 nlmsg_failure:
1851         kfree_skb(skb2);
1852         return -EMSGSIZE;
1853 }
1854
1855 /* Get set by name or index, from userspace */
1856
1857 static int ip_set_byname(struct net *net, struct sock *ctnl,
1858                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1859                          const struct nlattr * const attr[],
1860                          struct netlink_ext_ack *extack)
1861 {
1862         struct ip_set_net *inst = ip_set_pernet(net);
1863         struct sk_buff *skb2;
1864         struct nlmsghdr *nlh2;
1865         ip_set_id_t id = IPSET_INVALID_ID;
1866         const struct ip_set *set;
1867         int ret = 0;
1868
1869         if (unlikely(protocol_failed(attr) ||
1870                      !attr[IPSET_ATTR_SETNAME]))
1871                 return -IPSET_ERR_PROTOCOL;
1872
1873         set = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]), &id);
1874         if (id == IPSET_INVALID_ID)
1875                 return -ENOENT;
1876
1877         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1878         if (!skb2)
1879                 return -ENOMEM;
1880
1881         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1882                          IPSET_CMD_GET_BYNAME);
1883         if (!nlh2)
1884                 goto nlmsg_failure;
1885         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1886             nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1887             nla_put_net16(skb2, IPSET_ATTR_INDEX, htons(id)))
1888                 goto nla_put_failure;
1889         nlmsg_end(skb2, nlh2);
1890
1891         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1892         if (ret < 0)
1893                 return ret;
1894
1895         return 0;
1896
1897 nla_put_failure:
1898         nlmsg_cancel(skb2, nlh2);
1899 nlmsg_failure:
1900         kfree_skb(skb2);
1901         return -EMSGSIZE;
1902 }
1903
1904 static const struct nla_policy ip_set_index_policy[IPSET_ATTR_CMD_MAX + 1] = {
1905         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1906         [IPSET_ATTR_INDEX]      = { .type = NLA_U16 },
1907 };
1908
1909 static int ip_set_byindex(struct net *net, struct sock *ctnl,
1910                           struct sk_buff *skb, const struct nlmsghdr *nlh,
1911                           const struct nlattr * const attr[],
1912                           struct netlink_ext_ack *extack)
1913 {
1914         struct ip_set_net *inst = ip_set_pernet(net);
1915         struct sk_buff *skb2;
1916         struct nlmsghdr *nlh2;
1917         ip_set_id_t id = IPSET_INVALID_ID;
1918         const struct ip_set *set;
1919         int ret = 0;
1920
1921         if (unlikely(protocol_failed(attr) ||
1922                      !attr[IPSET_ATTR_INDEX]))
1923                 return -IPSET_ERR_PROTOCOL;
1924
1925         id = ip_set_get_h16(attr[IPSET_ATTR_INDEX]);
1926         if (id >= inst->ip_set_max)
1927                 return -ENOENT;
1928         set = ip_set(inst, id);
1929         if (set == NULL)
1930                 return -ENOENT;
1931
1932         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1933         if (!skb2)
1934                 return -ENOMEM;
1935
1936         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1937                          IPSET_CMD_GET_BYINDEX);
1938         if (!nlh2)
1939                 goto nlmsg_failure;
1940         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1941             nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name))
1942                 goto nla_put_failure;
1943         nlmsg_end(skb2, nlh2);
1944
1945         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1946         if (ret < 0)
1947                 return ret;
1948
1949         return 0;
1950
1951 nla_put_failure:
1952         nlmsg_cancel(skb2, nlh2);
1953 nlmsg_failure:
1954         kfree_skb(skb2);
1955         return -EMSGSIZE;
1956 }
1957
1958 static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1959         [IPSET_CMD_NONE]        = {
1960                 .call           = ip_set_none,
1961                 .attr_count     = IPSET_ATTR_CMD_MAX,
1962         },
1963         [IPSET_CMD_CREATE]      = {
1964                 .call           = ip_set_create,
1965                 .attr_count     = IPSET_ATTR_CMD_MAX,
1966                 .policy         = ip_set_create_policy,
1967         },
1968         [IPSET_CMD_DESTROY]     = {
1969                 .call           = ip_set_destroy,
1970                 .attr_count     = IPSET_ATTR_CMD_MAX,
1971                 .policy         = ip_set_setname_policy,
1972         },
1973         [IPSET_CMD_FLUSH]       = {
1974                 .call           = ip_set_flush,
1975                 .attr_count     = IPSET_ATTR_CMD_MAX,
1976                 .policy         = ip_set_setname_policy,
1977         },
1978         [IPSET_CMD_RENAME]      = {
1979                 .call           = ip_set_rename,
1980                 .attr_count     = IPSET_ATTR_CMD_MAX,
1981                 .policy         = ip_set_setname2_policy,
1982         },
1983         [IPSET_CMD_SWAP]        = {
1984                 .call           = ip_set_swap,
1985                 .attr_count     = IPSET_ATTR_CMD_MAX,
1986                 .policy         = ip_set_setname2_policy,
1987         },
1988         [IPSET_CMD_LIST]        = {
1989                 .call           = ip_set_dump,
1990                 .attr_count     = IPSET_ATTR_CMD_MAX,
1991                 .policy         = ip_set_setname_policy,
1992         },
1993         [IPSET_CMD_SAVE]        = {
1994                 .call           = ip_set_dump,
1995                 .attr_count     = IPSET_ATTR_CMD_MAX,
1996                 .policy         = ip_set_setname_policy,
1997         },
1998         [IPSET_CMD_ADD] = {
1999                 .call           = ip_set_uadd,
2000                 .attr_count     = IPSET_ATTR_CMD_MAX,
2001                 .policy         = ip_set_adt_policy,
2002         },
2003         [IPSET_CMD_DEL] = {
2004                 .call           = ip_set_udel,
2005                 .attr_count     = IPSET_ATTR_CMD_MAX,
2006                 .policy         = ip_set_adt_policy,
2007         },
2008         [IPSET_CMD_TEST]        = {
2009                 .call           = ip_set_utest,
2010                 .attr_count     = IPSET_ATTR_CMD_MAX,
2011                 .policy         = ip_set_adt_policy,
2012         },
2013         [IPSET_CMD_HEADER]      = {
2014                 .call           = ip_set_header,
2015                 .attr_count     = IPSET_ATTR_CMD_MAX,
2016                 .policy         = ip_set_setname_policy,
2017         },
2018         [IPSET_CMD_TYPE]        = {
2019                 .call           = ip_set_type,
2020                 .attr_count     = IPSET_ATTR_CMD_MAX,
2021                 .policy         = ip_set_type_policy,
2022         },
2023         [IPSET_CMD_PROTOCOL]    = {
2024                 .call           = ip_set_protocol,
2025                 .attr_count     = IPSET_ATTR_CMD_MAX,
2026                 .policy         = ip_set_protocol_policy,
2027         },
2028         [IPSET_CMD_GET_BYNAME]  = {
2029                 .call           = ip_set_byname,
2030                 .attr_count     = IPSET_ATTR_CMD_MAX,
2031                 .policy         = ip_set_setname_policy,
2032         },
2033         [IPSET_CMD_GET_BYINDEX] = {
2034                 .call           = ip_set_byindex,
2035                 .attr_count     = IPSET_ATTR_CMD_MAX,
2036                 .policy         = ip_set_index_policy,
2037         },
2038 };
2039
2040 static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
2041         .name           = "ip_set",
2042         .subsys_id      = NFNL_SUBSYS_IPSET,
2043         .cb_count       = IPSET_MSG_MAX,
2044         .cb             = ip_set_netlink_subsys_cb,
2045 };
2046
2047 /* Interface to iptables/ip6tables */
2048
2049 static int
2050 ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
2051 {
2052         unsigned int *op;
2053         void *data;
2054         int copylen = *len, ret = 0;
2055         struct net *net = sock_net(sk);
2056         struct ip_set_net *inst = ip_set_pernet(net);
2057
2058         if (!ns_capable(net->user_ns, CAP_NET_ADMIN))
2059                 return -EPERM;
2060         if (optval != SO_IP_SET)
2061                 return -EBADF;
2062         if (*len < sizeof(unsigned int))
2063                 return -EINVAL;
2064
2065         data = vmalloc(*len);
2066         if (!data)
2067                 return -ENOMEM;
2068         if (copy_from_user(data, user, *len) != 0) {
2069                 ret = -EFAULT;
2070                 goto done;
2071         }
2072         op = data;
2073
2074         if (*op < IP_SET_OP_VERSION) {
2075                 /* Check the version at the beginning of operations */
2076                 struct ip_set_req_version *req_version = data;
2077
2078                 if (*len < sizeof(struct ip_set_req_version)) {
2079                         ret = -EINVAL;
2080                         goto done;
2081                 }
2082
2083                 if (req_version->version < IPSET_PROTOCOL_MIN) {
2084                         ret = -EPROTO;
2085                         goto done;
2086                 }
2087         }
2088
2089         switch (*op) {
2090         case IP_SET_OP_VERSION: {
2091                 struct ip_set_req_version *req_version = data;
2092
2093                 if (*len != sizeof(struct ip_set_req_version)) {
2094                         ret = -EINVAL;
2095                         goto done;
2096                 }
2097
2098                 req_version->version = IPSET_PROTOCOL;
2099                 ret = copy_to_user(user, req_version,
2100                                    sizeof(struct ip_set_req_version));
2101                 goto done;
2102         }
2103         case IP_SET_OP_GET_BYNAME: {
2104                 struct ip_set_req_get_set *req_get = data;
2105                 ip_set_id_t id;
2106
2107                 if (*len != sizeof(struct ip_set_req_get_set)) {
2108                         ret = -EINVAL;
2109                         goto done;
2110                 }
2111                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2112                 nfnl_lock(NFNL_SUBSYS_IPSET);
2113                 find_set_and_id(inst, req_get->set.name, &id);
2114                 req_get->set.index = id;
2115                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2116                 goto copy;
2117         }
2118         case IP_SET_OP_GET_FNAME: {
2119                 struct ip_set_req_get_set_family *req_get = data;
2120                 ip_set_id_t id;
2121
2122                 if (*len != sizeof(struct ip_set_req_get_set_family)) {
2123                         ret = -EINVAL;
2124                         goto done;
2125                 }
2126                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2127                 nfnl_lock(NFNL_SUBSYS_IPSET);
2128                 find_set_and_id(inst, req_get->set.name, &id);
2129                 req_get->set.index = id;
2130                 if (id != IPSET_INVALID_ID)
2131                         req_get->family = ip_set(inst, id)->family;
2132                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2133                 goto copy;
2134         }
2135         case IP_SET_OP_GET_BYINDEX: {
2136                 struct ip_set_req_get_set *req_get = data;
2137                 struct ip_set *set;
2138
2139                 if (*len != sizeof(struct ip_set_req_get_set) ||
2140                     req_get->set.index >= inst->ip_set_max) {
2141                         ret = -EINVAL;
2142                         goto done;
2143                 }
2144                 nfnl_lock(NFNL_SUBSYS_IPSET);
2145                 set = ip_set(inst, req_get->set.index);
2146                 ret = strscpy(req_get->set.name, set ? set->name : "",
2147                               IPSET_MAXNAMELEN);
2148                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2149                 if (ret < 0)
2150                         goto done;
2151                 goto copy;
2152         }
2153         default:
2154                 ret = -EBADMSG;
2155                 goto done;
2156         }       /* end of switch(op) */
2157
2158 copy:
2159         ret = copy_to_user(user, data, copylen);
2160
2161 done:
2162         vfree(data);
2163         if (ret > 0)
2164                 ret = 0;
2165         return ret;
2166 }
2167
2168 static struct nf_sockopt_ops so_set __read_mostly = {
2169         .pf             = PF_INET,
2170         .get_optmin     = SO_IP_SET,
2171         .get_optmax     = SO_IP_SET + 1,
2172         .get            = ip_set_sockfn_get,
2173         .owner          = THIS_MODULE,
2174 };
2175
2176 static int __net_init
2177 ip_set_net_init(struct net *net)
2178 {
2179         struct ip_set_net *inst = ip_set_pernet(net);
2180         struct ip_set **list;
2181
2182         inst->ip_set_max = max_sets ? max_sets : CONFIG_IP_SET_MAX;
2183         if (inst->ip_set_max >= IPSET_INVALID_ID)
2184                 inst->ip_set_max = IPSET_INVALID_ID - 1;
2185
2186         list = kvcalloc(inst->ip_set_max, sizeof(struct ip_set *), GFP_KERNEL);
2187         if (!list)
2188                 return -ENOMEM;
2189         inst->is_deleted = false;
2190         inst->is_destroyed = false;
2191         rcu_assign_pointer(inst->ip_set_list, list);
2192         return 0;
2193 }
2194
2195 static void __net_exit
2196 ip_set_net_exit(struct net *net)
2197 {
2198         struct ip_set_net *inst = ip_set_pernet(net);
2199
2200         struct ip_set *set = NULL;
2201         ip_set_id_t i;
2202
2203         inst->is_deleted = true; /* flag for ip_set_nfnl_put */
2204
2205         nfnl_lock(NFNL_SUBSYS_IPSET);
2206         for (i = 0; i < inst->ip_set_max; i++) {
2207                 set = ip_set(inst, i);
2208                 if (set) {
2209                         ip_set(inst, i) = NULL;
2210                         ip_set_destroy_set(set);
2211                 }
2212         }
2213         nfnl_unlock(NFNL_SUBSYS_IPSET);
2214         kvfree(rcu_dereference_protected(inst->ip_set_list, 1));
2215 }
2216
2217 static struct pernet_operations ip_set_net_ops = {
2218         .init   = ip_set_net_init,
2219         .exit   = ip_set_net_exit,
2220         .id     = &ip_set_net_id,
2221         .size   = sizeof(struct ip_set_net),
2222 };
2223
2224 static int __init
2225 ip_set_init(void)
2226 {
2227         int ret = register_pernet_subsys(&ip_set_net_ops);
2228
2229         if (ret) {
2230                 pr_err("ip_set: cannot register pernet_subsys.\n");
2231                 return ret;
2232         }
2233
2234         ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
2235         if (ret != 0) {
2236                 pr_err("ip_set: cannot register with nfnetlink.\n");
2237                 unregister_pernet_subsys(&ip_set_net_ops);
2238                 return ret;
2239         }
2240
2241         ret = nf_register_sockopt(&so_set);
2242         if (ret != 0) {
2243                 pr_err("SO_SET registry failed: %d\n", ret);
2244                 nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2245                 unregister_pernet_subsys(&ip_set_net_ops);
2246                 return ret;
2247         }
2248
2249         return 0;
2250 }
2251
2252 static void __exit
2253 ip_set_fini(void)
2254 {
2255         nf_unregister_sockopt(&so_set);
2256         nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2257
2258         unregister_pernet_subsys(&ip_set_net_ops);
2259         pr_debug("these are the famous last words\n");
2260 }
2261
2262 module_init(ip_set_init);
2263 module_exit(ip_set_fini);
2264
2265 MODULE_DESCRIPTION("ip_set: protocol " __stringify(IPSET_PROTOCOL));