Merge tag 'rxrpc-next-20171111' of git://git.kernel.org/pub/scm/linux/kernel/git...
[sfrench/cifs-2.6.git] / net / ipv6 / addrlabel.c
index f664871feca666fb6256990c2fc36fc15940a541..00e1f8ee08f8aa1d70c10ca2941072f7ba0dc9da 100644 (file)
@@ -19,7 +19,6 @@
 #include <linux/if_addrlabel.h>
 #include <linux/netlink.h>
 #include <linux/rtnetlink.h>
-#include <linux/refcount.h>
 
 #if 0
 #define ADDRLABEL(x...) printk(x)
  * Policy Table
  */
 struct ip6addrlbl_entry {
-       possible_net_t lbl_net;
        struct in6_addr prefix;
        int prefixlen;
        int ifindex;
        int addrtype;
        u32 label;
        struct hlist_node list;
-       refcount_t refcnt;
        struct rcu_head rcu;
 };
 
-static struct ip6addrlbl_table
-{
-       struct hlist_head head;
-       spinlock_t lock;
-       u32 seq;
-} ip6addrlbl_table;
-
-static inline
-struct net *ip6addrlbl_net(const struct ip6addrlbl_entry *lbl)
-{
-       return read_pnet(&lbl->lbl_net);
-}
-
 /*
  * Default policy table (RFC6724 + extensions)
  *
@@ -126,36 +110,11 @@ static const __net_initconst struct ip6addrlbl_init_table
        }
 };
 
-/* Object management */
-static inline void ip6addrlbl_free(struct ip6addrlbl_entry *p)
-{
-       kfree(p);
-}
-
-static void ip6addrlbl_free_rcu(struct rcu_head *h)
-{
-       ip6addrlbl_free(container_of(h, struct ip6addrlbl_entry, rcu));
-}
-
-static bool ip6addrlbl_hold(struct ip6addrlbl_entry *p)
-{
-       return refcount_inc_not_zero(&p->refcnt);
-}
-
-static inline void ip6addrlbl_put(struct ip6addrlbl_entry *p)
-{
-       if (refcount_dec_and_test(&p->refcnt))
-               call_rcu(&p->rcu, ip6addrlbl_free_rcu);
-}
-
 /* Find label */
-static bool __ip6addrlbl_match(struct net *net,
-                              const struct ip6addrlbl_entry *p,
+static bool __ip6addrlbl_match(const struct ip6addrlbl_entry *p,
                               const struct in6_addr *addr,
                               int addrtype, int ifindex)
 {
-       if (!net_eq(ip6addrlbl_net(p), net))
-               return false;
        if (p->ifindex && p->ifindex != ifindex)
                return false;
        if (p->addrtype && p->addrtype != addrtype)
@@ -170,8 +129,9 @@ static struct ip6addrlbl_entry *__ipv6_addr_label(struct net *net,
                                                  int type, int ifindex)
 {
        struct ip6addrlbl_entry *p;
-       hlist_for_each_entry_rcu(p, &ip6addrlbl_table.head, list) {
-               if (__ip6addrlbl_match(net, p, addr, type, ifindex))
+
+       hlist_for_each_entry_rcu(p, &net->ipv6.ip6addrlbl_table.head, list) {
+               if (__ip6addrlbl_match(p, addr, type, ifindex))
                        return p;
        }
        return NULL;
@@ -197,8 +157,7 @@ u32 ipv6_addr_label(struct net *net,
 }
 
 /* allocate one entry */
-static struct ip6addrlbl_entry *ip6addrlbl_alloc(struct net *net,
-                                                const struct in6_addr *prefix,
+static struct ip6addrlbl_entry *ip6addrlbl_alloc(const struct in6_addr *prefix,
                                                 int prefixlen, int ifindex,
                                                 u32 label)
 {
@@ -237,24 +196,22 @@ static struct ip6addrlbl_entry *ip6addrlbl_alloc(struct net *net,
        newp->addrtype = addrtype;
        newp->label = label;
        INIT_HLIST_NODE(&newp->list);
-       write_pnet(&newp->lbl_net, net);
-       refcount_set(&newp->refcnt, 1);
        return newp;
 }
 
 /* add a label */
-static int __ip6addrlbl_add(struct ip6addrlbl_entry *newp, int replace)
+static int __ip6addrlbl_add(struct net *net, struct ip6addrlbl_entry *newp,
+                           int replace)
 {
-       struct hlist_node *n;
        struct ip6addrlbl_entry *last = NULL, *p = NULL;
+       struct hlist_node *n;
        int ret = 0;
 
        ADDRLABEL(KERN_DEBUG "%s(newp=%p, replace=%d)\n", __func__, newp,
                  replace);
 
-       hlist_for_each_entry_safe(p, n, &ip6addrlbl_table.head, list) {
+       hlist_for_each_entry_safe(p, n, &net->ipv6.ip6addrlbl_table.head, list) {
                if (p->prefixlen == newp->prefixlen &&
-                   net_eq(ip6addrlbl_net(p), ip6addrlbl_net(newp)) &&
                    p->ifindex == newp->ifindex &&
                    ipv6_addr_equal(&p->prefix, &newp->prefix)) {
                        if (!replace) {
@@ -262,7 +219,7 @@ static int __ip6addrlbl_add(struct ip6addrlbl_entry *newp, int replace)
                                goto out;
                        }
                        hlist_replace_rcu(&p->list, &newp->list);
-                       ip6addrlbl_put(p);
+                       kfree_rcu(p, rcu);
                        goto out;
                } else if ((p->prefixlen == newp->prefixlen && !p->ifindex) ||
                           (p->prefixlen < newp->prefixlen)) {
@@ -274,10 +231,10 @@ static int __ip6addrlbl_add(struct ip6addrlbl_entry *newp, int replace)
        if (last)
                hlist_add_behind_rcu(&newp->list, &last->list);
        else
-               hlist_add_head_rcu(&newp->list, &ip6addrlbl_table.head);
+               hlist_add_head_rcu(&newp->list, &net->ipv6.ip6addrlbl_table.head);
 out:
        if (!ret)
-               ip6addrlbl_table.seq++;
+               net->ipv6.ip6addrlbl_table.seq++;
        return ret;
 }
 
@@ -293,14 +250,14 @@ static int ip6addrlbl_add(struct net *net,
                  __func__, prefix, prefixlen, ifindex, (unsigned int)label,
                  replace);
 
-       newp = ip6addrlbl_alloc(net, prefix, prefixlen, ifindex, label);
+       newp = ip6addrlbl_alloc(prefix, prefixlen, ifindex, label);
        if (IS_ERR(newp))
                return PTR_ERR(newp);
-       spin_lock(&ip6addrlbl_table.lock);
-       ret = __ip6addrlbl_add(newp, replace);
-       spin_unlock(&ip6addrlbl_table.lock);
+       spin_lock(&net->ipv6.ip6addrlbl_table.lock);
+       ret = __ip6addrlbl_add(net, newp, replace);
+       spin_unlock(&net->ipv6.ip6addrlbl_table.lock);
        if (ret)
-               ip6addrlbl_free(newp);
+               kfree(newp);
        return ret;
 }
 
@@ -316,13 +273,12 @@ static int __ip6addrlbl_del(struct net *net,
        ADDRLABEL(KERN_DEBUG "%s(prefix=%pI6, prefixlen=%d, ifindex=%d)\n",
                  __func__, prefix, prefixlen, ifindex);
 
-       hlist_for_each_entry_safe(p, n, &ip6addrlbl_table.head, list) {
+       hlist_for_each_entry_safe(p, n, &net->ipv6.ip6addrlbl_table.head, list) {
                if (p->prefixlen == prefixlen &&
-                   net_eq(ip6addrlbl_net(p), net) &&
                    p->ifindex == ifindex &&
                    ipv6_addr_equal(&p->prefix, prefix)) {
                        hlist_del_rcu(&p->list);
-                       ip6addrlbl_put(p);
+                       kfree_rcu(p, rcu);
                        ret = 0;
                        break;
                }
@@ -341,9 +297,9 @@ static int ip6addrlbl_del(struct net *net,
                  __func__, prefix, prefixlen, ifindex);
 
        ipv6_addr_prefix(&prefix_buf, prefix, prefixlen);
-       spin_lock(&ip6addrlbl_table.lock);
+       spin_lock(&net->ipv6.ip6addrlbl_table.lock);
        ret = __ip6addrlbl_del(net, &prefix_buf, prefixlen, ifindex);
-       spin_unlock(&ip6addrlbl_table.lock);
+       spin_unlock(&net->ipv6.ip6addrlbl_table.lock);
        return ret;
 }
 
@@ -355,6 +311,9 @@ static int __net_init ip6addrlbl_net_init(struct net *net)
 
        ADDRLABEL(KERN_DEBUG "%s\n", __func__);
 
+       spin_lock_init(&net->ipv6.ip6addrlbl_table.lock);
+       INIT_HLIST_HEAD(&net->ipv6.ip6addrlbl_table.head);
+
        for (i = 0; i < ARRAY_SIZE(ip6addrlbl_init_table); i++) {
                int ret = ip6addrlbl_add(net,
                                         ip6addrlbl_init_table[i].prefix,
@@ -374,14 +333,12 @@ static void __net_exit ip6addrlbl_net_exit(struct net *net)
        struct hlist_node *n;
 
        /* Remove all labels belonging to the exiting net */
-       spin_lock(&ip6addrlbl_table.lock);
-       hlist_for_each_entry_safe(p, n, &ip6addrlbl_table.head, list) {
-               if (net_eq(ip6addrlbl_net(p), net)) {
-                       hlist_del_rcu(&p->list);
-                       ip6addrlbl_put(p);
-               }
+       spin_lock(&net->ipv6.ip6addrlbl_table.lock);
+       hlist_for_each_entry_safe(p, n, &net->ipv6.ip6addrlbl_table.head, list) {
+               hlist_del_rcu(&p->list);
+               kfree_rcu(p, rcu);
        }
-       spin_unlock(&ip6addrlbl_table.lock);
+       spin_unlock(&net->ipv6.ip6addrlbl_table.lock);
 }
 
 static struct pernet_operations ipv6_addr_label_ops = {
@@ -391,8 +348,6 @@ static struct pernet_operations ipv6_addr_label_ops = {
 
 int __init ipv6_addr_label_init(void)
 {
-       spin_lock_init(&ip6addrlbl_table.lock);
-
        return register_pernet_subsys(&ipv6_addr_label_ops);
 }
 
@@ -511,11 +466,10 @@ static int ip6addrlbl_dump(struct sk_buff *skb, struct netlink_callback *cb)
        int err;
 
        rcu_read_lock();
-       hlist_for_each_entry_rcu(p, &ip6addrlbl_table.head, list) {
-               if (idx >= s_idx &&
-                   net_eq(ip6addrlbl_net(p), net)) {
+       hlist_for_each_entry_rcu(p, &net->ipv6.ip6addrlbl_table.head, list) {
+               if (idx >= s_idx) {
                        err = ip6addrlbl_fill(skb, p,
-                                             ip6addrlbl_table.seq,
+                                             net->ipv6.ip6addrlbl_table.seq,
                                              NETLINK_CB(cb->skb).portid,
                                              cb->nlh->nlmsg_seq,
                                              RTM_NEWADDRLABEL,
@@ -568,38 +522,28 @@ static int ip6addrlbl_get(struct sk_buff *in_skb, struct nlmsghdr *nlh,
                return -EINVAL;
        addr = nla_data(tb[IFAL_ADDRESS]);
 
-       rcu_read_lock();
-       p = __ipv6_addr_label(net, addr, ipv6_addr_type(addr), ifal->ifal_index);
-       if (p && !ip6addrlbl_hold(p))
-               p = NULL;
-       lseq = ip6addrlbl_table.seq;
-       rcu_read_unlock();
-
-       if (!p) {
-               err = -ESRCH;
-               goto out;
-       }
-
        skb = nlmsg_new(ip6addrlbl_msgsize(), GFP_KERNEL);
-       if (!skb) {
-               ip6addrlbl_put(p);
+       if (!skb)
                return -ENOBUFS;
-       }
 
-       err = ip6addrlbl_fill(skb, p, lseq,
-                             NETLINK_CB(in_skb).portid, nlh->nlmsg_seq,
-                             RTM_NEWADDRLABEL, 0);
+       err = -ESRCH;
 
-       ip6addrlbl_put(p);
+       rcu_read_lock();
+       p = __ipv6_addr_label(net, addr, ipv6_addr_type(addr), ifal->ifal_index);
+       lseq = net->ipv6.ip6addrlbl_table.seq;
+       if (p)
+               err = ip6addrlbl_fill(skb, p, lseq,
+                                     NETLINK_CB(in_skb).portid,
+                                     nlh->nlmsg_seq,
+                                     RTM_NEWADDRLABEL, 0);
+       rcu_read_unlock();
 
        if (err < 0) {
                WARN_ON(err == -EMSGSIZE);
                kfree_skb(skb);
-               goto out;
+       } else {
+               err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid);
        }
-
-       err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid);
-out:
        return err;
 }